You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.9 KiB
87 lines
2.9 KiB
from pydantic import BaseModel, Field, field_validator
|
|
from typing import List, Optional, Union, Dict, Any, Tuple
|
|
import os
|
|
|
|
|
|
# ==================== Request Models ====================
|
|
class GeneralRequest(BaseModel):
|
|
text: Optional[str] = None
|
|
texts: Optional[List[str]] = None
|
|
entities: List[str]
|
|
prompt: Optional[str] = None
|
|
threshold: Optional[float] = Field(default=0.5, ge=0.0, le=1.0)
|
|
|
|
@field_validator('entities')
|
|
@classmethod
|
|
def validate_entities(cls, v):
|
|
if not v or len(v) == 0:
|
|
raise ValueError("entities list cannot be empty")
|
|
return v
|
|
|
|
@field_validator('text', 'texts')
|
|
@classmethod
|
|
def validate_text_length(cls, v):
|
|
max_length = int(os.getenv("MAX_TEXT_LENGTH", "10000"))
|
|
if isinstance(v, str) and len(v) > max_length:
|
|
raise ValueError(f"text exceeds maximum length of {max_length}")
|
|
if isinstance(v, list):
|
|
for text in v:
|
|
if len(text) > max_length:
|
|
raise ValueError(f"text exceeds maximum length of {max_length}")
|
|
return v
|
|
|
|
|
|
class RelationRule(BaseModel):
|
|
relation: str
|
|
pairs_filter: List[Tuple[str, str]]
|
|
distance: int
|
|
|
|
|
|
class RelationExtractionRequest(BaseModel):
|
|
text: Optional[str] = None
|
|
texts: Optional[List[str]] = None
|
|
relations: List[str]
|
|
entities: Optional[List[str]] = None
|
|
rules: Optional[List[RelationRule]] = None
|
|
prompt: Optional[str] = None
|
|
threshold: Optional[float] = Field(default=0.5, ge=0.0, le=1.0)
|
|
|
|
@field_validator('relations')
|
|
@classmethod
|
|
def validate_relations(cls, v):
|
|
if not v or len(v) == 0:
|
|
raise ValueError("relations list cannot be empty")
|
|
return v
|
|
|
|
@field_validator('text', 'texts')
|
|
@classmethod
|
|
def validate_text_length(cls, v):
|
|
max_length = int(os.getenv("MAX_TEXT_LENGTH", "10000"))
|
|
if isinstance(v, str) and len(v) > max_length:
|
|
raise ValueError(f"text exceeds maximum length of {max_length}")
|
|
if isinstance(v, list):
|
|
for text in v:
|
|
if len(text) > max_length:
|
|
raise ValueError(f"text exceeds maximum length of {max_length}")
|
|
return v
|
|
|
|
|
|
class SummarizationRequest(BaseModel):
|
|
text: Optional[str] = None
|
|
texts: Optional[List[str]] = None
|
|
prompt: Optional[str] = None
|
|
threshold: Optional[float] = Field(default=0.5, ge=0.0, le=1.0)
|
|
|
|
@field_validator('text', 'texts')
|
|
@classmethod
|
|
def validate_text_length(cls, v):
|
|
max_length = int(os.getenv("MAX_TEXT_LENGTH", "10000"))
|
|
if isinstance(v, str) and len(v) > max_length:
|
|
raise ValueError(f"text exceeds maximum length of {max_length}")
|
|
if isinstance(v, list):
|
|
for text in v:
|
|
if len(text) > max_length:
|
|
raise ValueError(f"text exceeds maximum length of {max_length}")
|
|
return v
|
|
|