diff --git a/tests/test_relation_extraction_endpoint.py b/tests/test_relation_extraction_endpoint.py new file mode 100644 index 0000000..2e7eb88 --- /dev/null +++ b/tests/test_relation_extraction_endpoint.py @@ -0,0 +1,87 @@ +import pytest +import requests +from typing import Dict, Any +from .conftest import BASE_URL + + +class TestRelationExtractionEndpoint: + """Tests the /relation-extraction endpoint""" + + def test_single_text_with_entities(self, api_client): + """Test with a single text and entities""" + payload = { + "text": "Steve Jobs founded Apple Inc. which is located in Cupertino.", + "relations": ["founded", "located_in"], + "entities": ["person", "organization", "location"], + "threshold": 0.3 + } + response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) + assert response.status_code == 200 + + data = response.json() + assert "outputs" in data + assert isinstance(data["outputs"], list) + + def test_single_text_without_entities(self, api_client): + """Test without explicit entities""" + payload = { + "text": "Bill Gates founded Microsoft.", + "relations": ["founded"], + "threshold": 0.3 + } + response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) + assert response.status_code == 200 + + def test_batch_texts(self, api_client): + """Test with a batch of texts""" + payload = { + "texts": [ + "Bill Gates founded Microsoft.", + "Tesla was founded by Elon Musk.", + "Apple is based in California." + ], + "relations": ["founded", "founded_by", "based_in"], + "threshold": 0.3 + } + response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) + assert response.status_code == 200 + + data = response.json() + assert len(data["outputs"]) == 3 + + def test_with_optional_rules(self, api_client): + """Test with optional rules""" + payload = { + "text": "Steve Jobs founded Apple Inc.", + "relations": ["founded"], + "rules": [ + { + "relation": "founded", + "pairs_filter": [["person", "organization"]], + "distance": 10 + } + ], + "threshold": 0.3 + } + response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) + assert response.status_code == 200 + + def test_error_empty_relations(self, api_client): + """Test error with an empty relations list""" + payload = { + "text": "Some text", + "relations": [] + } + response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) + assert response.status_code == 422 + + def test_error_both_text_and_texts(self, api_client): + """Test error when both text and texts are present""" + payload = { + "text": "Some text", + "texts": ["Another text"], + "relations": ["founded"] + } + response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload) + assert response.status_code == 400 +