Add the prompt support for endpoints.

main
KKlochko 1 month ago
parent 6588ed2767
commit fc286b34ed
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -97,3 +97,8 @@ def handle_error(e: Exception):
} }
} }
) )
def add_prompt_to_text(text: str, prompt: Optional[str]):
if prompt in [None, ""]:
return text
return prompt + text

@ -111,7 +111,16 @@ async def health_check():
@app.post("/general") @app.post("/general")
async def general_extraction(request: GeneralRequest): async def general_extraction(request: GeneralRequest):
"""Named Entity Recognition endpoint""" """General Named Entity Recognition endpoint with prompt support
Example usage:
{
"text": "Apple Inc. is located in Cupertino, California.",
"entities": ["organization", "location"],
"prompt": "Extract business entities:\n",
"threshold": 0.5
}
"""
try: try:
input_data = validate_single_or_batch(request.text, request.texts) input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list) is_batch = isinstance(input_data, list)
@ -121,23 +130,35 @@ async def general_extraction(request: GeneralRequest):
results = [] results = []
for text in texts_to_process: for text in texts_to_process:
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# Extract entities with prompt
entities = model.predict_entities( entities = model.predict_entities(
text, text_with_prompt,
request.entities, request.entities,
threshold=request.threshold, threshold=request.threshold,
flat_ner=True flat_ner=True
) )
output = [ output = []
{ for ent in entities:
# Skip entities found in the prompt itself
if ent["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text (without prompt)
output.append({
"entity": ent["label"], "entity": ent["label"],
"span": ent["text"], "span": ent["text"],
"start": ent["start"], "start": ent["start"] - prompt_offset,
"end": ent["end"], "end": ent["end"] - prompt_offset,
"score": float(ent["score"]) "score": float(ent["score"])
} })
for ent in entities
]
results.append(output) results.append(output)
# Return based on input mode # Return based on input mode
@ -160,6 +181,7 @@ async def relation_extraction(request: RelationExtractionRequest):
"text": "Microsoft was founded by Bill Gates", "text": "Microsoft was founded by Bill Gates",
"relations": ["founder", "inception date"], "relations": ["founder", "inception date"],
"entities": ["organization", "person", "date"], "entities": ["organization", "person", "date"],
"prompt": "Extract business relationships:\n",
"threshold": 0.5 "threshold": 0.5
} }
""" """
@ -173,19 +195,36 @@ async def relation_extraction(request: RelationExtractionRequest):
for text in texts_to_process: for text in texts_to_process:
output = [] output = []
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# Extract entities if specified # Extract entities if specified
entities_dict = {} entities_dict = {}
if request.entities: if request.entities:
entities = model.predict_entities( entities = model.predict_entities(
text, text_with_prompt,
request.entities, request.entities,
threshold=request.threshold, threshold=request.threshold,
flat_ner=True flat_ner=True
) )
# Index entities by position for quick lookup # Index entities by position for quick lookup
# Adjust positions to be relative to original text (without prompt)
for entity in entities: for entity in entities:
key = (entity["start"], entity["end"]) # Only include entities that are in the actual text, not in the prompt
entities_dict[key] = entity if entity["start"] >= prompt_offset:
adjusted_start = entity["start"] - prompt_offset
adjusted_end = entity["end"] - prompt_offset
key = (adjusted_start, adjusted_end)
entities_dict[key] = {
"label": entity["label"],
"text": entity["text"],
"start": adjusted_start,
"end": adjusted_end,
"score": entity.get("score", 0.0)
}
# Form relation labels according to GLiNER documentation # Form relation labels according to GLiNER documentation
# Format: "entity_type <> relation_type" # Format: "entity_type <> relation_type"
@ -216,9 +255,9 @@ async def relation_extraction(request: RelationExtractionRequest):
# Remove duplicates # Remove duplicates
relation_labels = list(set(relation_labels)) relation_labels = list(set(relation_labels))
# Extract relations using GLiNER # Extract relations using GLiNER with prompt
relations_output = model.predict_entities( relations_output = model.predict_entities(
text, text_with_prompt,
relation_labels, relation_labels,
threshold=request.threshold, threshold=request.threshold,
flat_ner=True flat_ner=True
@ -229,6 +268,15 @@ async def relation_extraction(request: RelationExtractionRequest):
# GLiNER returns a text span that represents the relation # GLiNER returns a text span that represents the relation
# E.g.: "Bill Gates" for label "Microsoft <> founder" # E.g.: "Bill Gates" for label "Microsoft <> founder"
# Skip relations found in the prompt itself
if rel["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text
target_start = rel["start"] - prompt_offset
target_end = rel["end"] - prompt_offset
target_span = rel["text"]
# Parse the label to get entity type and relation # Parse the label to get entity type and relation
label_parts = rel["label"].split("<>") label_parts = rel["label"].split("<>")
@ -239,11 +287,6 @@ async def relation_extraction(request: RelationExtractionRequest):
source_entity_type = "" source_entity_type = ""
relation_type = rel["label"] relation_type = rel["label"]
# The extracted span is typically the target entity
target_span = rel["text"]
target_start = rel["start"]
target_end = rel["end"]
# Attempt to find the source entity in context # Attempt to find the source entity in context
# Find the closest entity of the specified type before the target # Find the closest entity of the specified type before the target
source_entity = None source_entity = None
@ -326,7 +369,15 @@ async def relation_extraction(request: RelationExtractionRequest):
@app.post("/summarization") @app.post("/summarization")
async def summarization(request: SummarizationRequest): async def summarization(request: SummarizationRequest):
"""Summarization endpoint""" """Summarization endpoint with prompt support
Example usage:
{
"text": "Long article text here...",
"prompt": "Extract the most important points:\n",
"threshold": 0.5
}
"""
try: try:
input_data = validate_single_or_batch(request.text, request.texts) input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list) is_batch = isinstance(input_data, list)
@ -335,23 +386,34 @@ async def summarization(request: SummarizationRequest):
results = [] results = []
for text in texts_to_process: for text in texts_to_process:
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# For summarization, extract key phrases/sentences # For summarization, extract key phrases/sentences
summaries = model.predict_entities( summaries = model.predict_entities(
text, text_with_prompt,
["summary", "key point", "important information"], ["summary", "key point", "important information"],
threshold=request.threshold, threshold=request.threshold,
flat_ner=True flat_ner=True
) )
output = [ output = []
{ for summ in summaries:
# Skip summaries found in the prompt itself
if summ["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text (without prompt)
output.append({
"text": summ["text"], "text": summ["text"],
"start": summ["start"], "start": summ["start"] - prompt_offset,
"end": summ["end"], "end": summ["end"] - prompt_offset,
"score": float(summ["score"]) "score": float(summ["score"])
} })
for summ in summaries
]
results.append(output) results.append(output)
if is_batch: if is_batch:

Loading…
Cancel
Save