You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
3.0 KiB
88 lines
3.0 KiB
import pytest
|
|
import requests
|
|
from typing import Dict, Any
|
|
from .conftest import BASE_URL
|
|
|
|
|
|
class TestRelationExtractionEndpoint:
|
|
"""Tests the /relation-extraction endpoint"""
|
|
|
|
def test_single_text_with_entities(self, api_client):
|
|
"""Test with a single text and entities"""
|
|
payload = {
|
|
"text": "Steve Jobs founded Apple Inc. which is located in Cupertino.",
|
|
"relations": ["founded", "located_in"],
|
|
"entities": ["person", "organization", "location"],
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "outputs" in data
|
|
assert isinstance(data["outputs"], list)
|
|
|
|
def test_single_text_without_entities(self, api_client):
|
|
"""Test without explicit entities"""
|
|
payload = {
|
|
"text": "Bill Gates founded Microsoft.",
|
|
"relations": ["founded"],
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
def test_batch_texts(self, api_client):
|
|
"""Test with a batch of texts"""
|
|
payload = {
|
|
"texts": [
|
|
"Bill Gates founded Microsoft.",
|
|
"Tesla was founded by Elon Musk.",
|
|
"Apple is based in California."
|
|
],
|
|
"relations": ["founded", "founded_by", "based_in"],
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert len(data["outputs"]) == 3
|
|
|
|
def test_with_optional_rules(self, api_client):
|
|
"""Test with optional rules"""
|
|
payload = {
|
|
"text": "Steve Jobs founded Apple Inc.",
|
|
"relations": ["founded"],
|
|
"rules": [
|
|
{
|
|
"relation": "founded",
|
|
"pairs_filter": [["person", "organization"]],
|
|
"distance": 10
|
|
}
|
|
],
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
assert response.status_code == 200
|
|
|
|
def test_error_empty_relations(self, api_client):
|
|
"""Test error with an empty relations list"""
|
|
payload = {
|
|
"text": "Some text",
|
|
"relations": []
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
assert response.status_code == 400
|
|
|
|
def test_error_both_text_and_texts(self, api_client):
|
|
"""Test error when both text and texts are present"""
|
|
payload = {
|
|
"text": "Some text",
|
|
"texts": ["Another text"],
|
|
"relations": ["founded"]
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
assert response.status_code == 400
|
|
|