|
|
|
|
@ -111,7 +111,16 @@ async def health_check():
|
|
|
|
|
|
|
|
|
|
@app.post("/general")
|
|
|
|
|
async def general_extraction(request: GeneralRequest):
|
|
|
|
|
"""Named Entity Recognition endpoint"""
|
|
|
|
|
"""General Named Entity Recognition endpoint with prompt support
|
|
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
{
|
|
|
|
|
"text": "Apple Inc. is located in Cupertino, California.",
|
|
|
|
|
"entities": ["organization", "location"],
|
|
|
|
|
"prompt": "Extract business entities:\n",
|
|
|
|
|
"threshold": 0.5
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
|
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
|
@ -121,23 +130,35 @@ async def general_extraction(request: GeneralRequest):
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
for text in texts_to_process:
|
|
|
|
|
# Calculate prompt offset for position adjustment
|
|
|
|
|
prompt_offset = len(request.prompt) if request.prompt else 0
|
|
|
|
|
|
|
|
|
|
# Apply prompt to text
|
|
|
|
|
text_with_prompt = add_prompt_to_text(text, request.prompt)
|
|
|
|
|
|
|
|
|
|
# Extract entities with prompt
|
|
|
|
|
entities = model.predict_entities(
|
|
|
|
|
text,
|
|
|
|
|
text_with_prompt,
|
|
|
|
|
request.entities,
|
|
|
|
|
threshold=request.threshold,
|
|
|
|
|
flat_ner=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output = [
|
|
|
|
|
{
|
|
|
|
|
output = []
|
|
|
|
|
for ent in entities:
|
|
|
|
|
# Skip entities found in the prompt itself
|
|
|
|
|
if ent["start"] < prompt_offset:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Adjust positions to be relative to original text (without prompt)
|
|
|
|
|
output.append({
|
|
|
|
|
"entity": ent["label"],
|
|
|
|
|
"span": ent["text"],
|
|
|
|
|
"start": ent["start"],
|
|
|
|
|
"end": ent["end"],
|
|
|
|
|
"start": ent["start"] - prompt_offset,
|
|
|
|
|
"end": ent["end"] - prompt_offset,
|
|
|
|
|
"score": float(ent["score"])
|
|
|
|
|
}
|
|
|
|
|
for ent in entities
|
|
|
|
|
]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
results.append(output)
|
|
|
|
|
|
|
|
|
|
# Return based on input mode
|
|
|
|
|
@ -160,6 +181,7 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
"text": "Microsoft was founded by Bill Gates",
|
|
|
|
|
"relations": ["founder", "inception date"],
|
|
|
|
|
"entities": ["organization", "person", "date"],
|
|
|
|
|
"prompt": "Extract business relationships:\n",
|
|
|
|
|
"threshold": 0.5
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
@ -173,19 +195,36 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
for text in texts_to_process:
|
|
|
|
|
output = []
|
|
|
|
|
|
|
|
|
|
# Calculate prompt offset for position adjustment
|
|
|
|
|
prompt_offset = len(request.prompt) if request.prompt else 0
|
|
|
|
|
|
|
|
|
|
# Apply prompt to text
|
|
|
|
|
text_with_prompt = add_prompt_to_text(text, request.prompt)
|
|
|
|
|
|
|
|
|
|
# Extract entities if specified
|
|
|
|
|
entities_dict = {}
|
|
|
|
|
if request.entities:
|
|
|
|
|
entities = model.predict_entities(
|
|
|
|
|
text,
|
|
|
|
|
text_with_prompt,
|
|
|
|
|
request.entities,
|
|
|
|
|
threshold=request.threshold,
|
|
|
|
|
flat_ner=True
|
|
|
|
|
)
|
|
|
|
|
# Index entities by position for quick lookup
|
|
|
|
|
# Adjust positions to be relative to original text (without prompt)
|
|
|
|
|
for entity in entities:
|
|
|
|
|
key = (entity["start"], entity["end"])
|
|
|
|
|
entities_dict[key] = entity
|
|
|
|
|
# Only include entities that are in the actual text, not in the prompt
|
|
|
|
|
if entity["start"] >= prompt_offset:
|
|
|
|
|
adjusted_start = entity["start"] - prompt_offset
|
|
|
|
|
adjusted_end = entity["end"] - prompt_offset
|
|
|
|
|
key = (adjusted_start, adjusted_end)
|
|
|
|
|
entities_dict[key] = {
|
|
|
|
|
"label": entity["label"],
|
|
|
|
|
"text": entity["text"],
|
|
|
|
|
"start": adjusted_start,
|
|
|
|
|
"end": adjusted_end,
|
|
|
|
|
"score": entity.get("score", 0.0)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# Form relation labels according to GLiNER documentation
|
|
|
|
|
# Format: "entity_type <> relation_type"
|
|
|
|
|
@ -216,9 +255,9 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
# Remove duplicates
|
|
|
|
|
relation_labels = list(set(relation_labels))
|
|
|
|
|
|
|
|
|
|
# Extract relations using GLiNER
|
|
|
|
|
# Extract relations using GLiNER with prompt
|
|
|
|
|
relations_output = model.predict_entities(
|
|
|
|
|
text,
|
|
|
|
|
text_with_prompt,
|
|
|
|
|
relation_labels,
|
|
|
|
|
threshold=request.threshold,
|
|
|
|
|
flat_ner=True
|
|
|
|
|
@ -229,6 +268,15 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
# GLiNER returns a text span that represents the relation
|
|
|
|
|
# E.g.: "Bill Gates" for label "Microsoft <> founder"
|
|
|
|
|
|
|
|
|
|
# Skip relations found in the prompt itself
|
|
|
|
|
if rel["start"] < prompt_offset:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Adjust positions to be relative to original text
|
|
|
|
|
target_start = rel["start"] - prompt_offset
|
|
|
|
|
target_end = rel["end"] - prompt_offset
|
|
|
|
|
target_span = rel["text"]
|
|
|
|
|
|
|
|
|
|
# Parse the label to get entity type and relation
|
|
|
|
|
label_parts = rel["label"].split("<>")
|
|
|
|
|
|
|
|
|
|
@ -239,11 +287,6 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
source_entity_type = ""
|
|
|
|
|
relation_type = rel["label"]
|
|
|
|
|
|
|
|
|
|
# The extracted span is typically the target entity
|
|
|
|
|
target_span = rel["text"]
|
|
|
|
|
target_start = rel["start"]
|
|
|
|
|
target_end = rel["end"]
|
|
|
|
|
|
|
|
|
|
# Attempt to find the source entity in context
|
|
|
|
|
# Find the closest entity of the specified type before the target
|
|
|
|
|
source_entity = None
|
|
|
|
|
@ -326,7 +369,15 @@ async def relation_extraction(request: RelationExtractionRequest):
|
|
|
|
|
|
|
|
|
|
@app.post("/summarization")
|
|
|
|
|
async def summarization(request: SummarizationRequest):
|
|
|
|
|
"""Summarization endpoint"""
|
|
|
|
|
"""Summarization endpoint with prompt support
|
|
|
|
|
|
|
|
|
|
Example usage:
|
|
|
|
|
{
|
|
|
|
|
"text": "Long article text here...",
|
|
|
|
|
"prompt": "Extract the most important points:\n",
|
|
|
|
|
"threshold": 0.5
|
|
|
|
|
}
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
input_data = validate_single_or_batch(request.text, request.texts)
|
|
|
|
|
is_batch = isinstance(input_data, list)
|
|
|
|
|
@ -335,23 +386,34 @@ async def summarization(request: SummarizationRequest):
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
for text in texts_to_process:
|
|
|
|
|
# Calculate prompt offset for position adjustment
|
|
|
|
|
prompt_offset = len(request.prompt) if request.prompt else 0
|
|
|
|
|
|
|
|
|
|
# Apply prompt to text
|
|
|
|
|
text_with_prompt = add_prompt_to_text(text, request.prompt)
|
|
|
|
|
|
|
|
|
|
# For summarization, extract key phrases/sentences
|
|
|
|
|
summaries = model.predict_entities(
|
|
|
|
|
text,
|
|
|
|
|
text_with_prompt,
|
|
|
|
|
["summary", "key point", "important information"],
|
|
|
|
|
threshold=request.threshold,
|
|
|
|
|
flat_ner=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
output = [
|
|
|
|
|
{
|
|
|
|
|
output = []
|
|
|
|
|
for summ in summaries:
|
|
|
|
|
# Skip summaries found in the prompt itself
|
|
|
|
|
if summ["start"] < prompt_offset:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# Adjust positions to be relative to original text (without prompt)
|
|
|
|
|
output.append({
|
|
|
|
|
"text": summ["text"],
|
|
|
|
|
"start": summ["start"],
|
|
|
|
|
"end": summ["end"],
|
|
|
|
|
"start": summ["start"] - prompt_offset,
|
|
|
|
|
"end": summ["end"] - prompt_offset,
|
|
|
|
|
"score": float(summ["score"])
|
|
|
|
|
}
|
|
|
|
|
for summ in summaries
|
|
|
|
|
]
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
results.append(output)
|
|
|
|
|
|
|
|
|
|
if is_batch:
|
|
|
|
|
|