You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.6 KiB
75 lines
2.6 KiB
import pytest
|
|
import requests
|
|
from typing import Dict, Any
|
|
from .conftest import BASE_URL
|
|
|
|
|
|
class TestOutputStructure:
|
|
"""Response structure tests"""
|
|
|
|
def test_general_output_structure(self, api_client):
|
|
"""Verify the output structure /general"""
|
|
payload = {
|
|
"text": "Apple Inc. in California",
|
|
"entities": ["organization", "location"],
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/general", json=payload)
|
|
data = response.json()
|
|
|
|
if data["outputs"]:
|
|
entity = data["outputs"][0]
|
|
assert "entity" in entity
|
|
assert "span" in entity
|
|
assert "start" in entity
|
|
assert "end" in entity
|
|
assert "score" in entity
|
|
assert isinstance(entity["start"], int)
|
|
assert isinstance(entity["end"], int)
|
|
assert isinstance(entity["score"], float)
|
|
|
|
def test_relation_output_structure(self, api_client):
|
|
"""Verify the output structure /relation-extraction"""
|
|
payload = {
|
|
"text": "Steve Jobs founded Apple",
|
|
"relations": ["founded"],
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/relation-extraction", json=payload)
|
|
data = response.json()
|
|
|
|
if data["outputs"]:
|
|
relation = data["outputs"][0]
|
|
assert "source" in relation
|
|
assert "relation" in relation
|
|
assert "target" in relation
|
|
assert "score" in relation
|
|
|
|
# Verify structure source and target
|
|
for entity_key in ["source", "target"]:
|
|
entity = relation[entity_key]
|
|
assert "entity" in entity
|
|
assert "span" in entity
|
|
assert "start" in entity
|
|
assert "end" in entity
|
|
assert "score" in entity
|
|
|
|
def test_summarization_output_structure(self, api_client):
|
|
"""Verify the output structure /summarization"""
|
|
payload = {
|
|
"text": "AI is changing the world with machine learning.",
|
|
"threshold": 0.3
|
|
}
|
|
response = api_client.post(f"{BASE_URL}/summarization", json=payload)
|
|
data = response.json()
|
|
|
|
if data["outputs"]:
|
|
summary = data["outputs"][0]
|
|
assert "text" in summary
|
|
assert "start" in summary
|
|
assert "end" in summary
|
|
assert "score" in summary
|
|
assert isinstance(summary["start"], int)
|
|
assert isinstance(summary["end"], int)
|
|
assert isinstance(summary["score"], float)
|