Add the prompt support for endpoints.

main
KKlochko 1 month ago
parent 6588ed2767
commit fc286b34ed
Signed by: KKlochko
GPG Key ID: 572ECCD219BBA91B

@ -97,3 +97,8 @@ def handle_error(e: Exception):
}
}
)
def add_prompt_to_text(text: str, prompt: Optional[str]):
if prompt in [None, ""]:
return text
return prompt + text

@ -111,7 +111,16 @@ async def health_check():
@app.post("/general")
async def general_extraction(request: GeneralRequest):
"""Named Entity Recognition endpoint"""
"""General Named Entity Recognition endpoint with prompt support
Example usage:
{
"text": "Apple Inc. is located in Cupertino, California.",
"entities": ["organization", "location"],
"prompt": "Extract business entities:\n",
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
@ -121,23 +130,35 @@ async def general_extraction(request: GeneralRequest):
results = []
for text in texts_to_process:
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# Extract entities with prompt
entities = model.predict_entities(
text,
text_with_prompt,
request.entities,
threshold=request.threshold,
flat_ner=True
)
output = [
{
output = []
for ent in entities:
# Skip entities found in the prompt itself
if ent["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text (without prompt)
output.append({
"entity": ent["label"],
"span": ent["text"],
"start": ent["start"],
"end": ent["end"],
"start": ent["start"] - prompt_offset,
"end": ent["end"] - prompt_offset,
"score": float(ent["score"])
}
for ent in entities
]
})
results.append(output)
# Return based on input mode
@ -160,6 +181,7 @@ async def relation_extraction(request: RelationExtractionRequest):
"text": "Microsoft was founded by Bill Gates",
"relations": ["founder", "inception date"],
"entities": ["organization", "person", "date"],
"prompt": "Extract business relationships:\n",
"threshold": 0.5
}
"""
@ -173,19 +195,36 @@ async def relation_extraction(request: RelationExtractionRequest):
for text in texts_to_process:
output = []
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# Extract entities if specified
entities_dict = {}
if request.entities:
entities = model.predict_entities(
text,
text_with_prompt,
request.entities,
threshold=request.threshold,
flat_ner=True
)
# Index entities by position for quick lookup
# Adjust positions to be relative to original text (without prompt)
for entity in entities:
key = (entity["start"], entity["end"])
entities_dict[key] = entity
# Only include entities that are in the actual text, not in the prompt
if entity["start"] >= prompt_offset:
adjusted_start = entity["start"] - prompt_offset
adjusted_end = entity["end"] - prompt_offset
key = (adjusted_start, adjusted_end)
entities_dict[key] = {
"label": entity["label"],
"text": entity["text"],
"start": adjusted_start,
"end": adjusted_end,
"score": entity.get("score", 0.0)
}
# Form relation labels according to GLiNER documentation
# Format: "entity_type <> relation_type"
@ -216,9 +255,9 @@ async def relation_extraction(request: RelationExtractionRequest):
# Remove duplicates
relation_labels = list(set(relation_labels))
# Extract relations using GLiNER
# Extract relations using GLiNER with prompt
relations_output = model.predict_entities(
text,
text_with_prompt,
relation_labels,
threshold=request.threshold,
flat_ner=True
@ -229,6 +268,15 @@ async def relation_extraction(request: RelationExtractionRequest):
# GLiNER returns a text span that represents the relation
# E.g.: "Bill Gates" for label "Microsoft <> founder"
# Skip relations found in the prompt itself
if rel["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text
target_start = rel["start"] - prompt_offset
target_end = rel["end"] - prompt_offset
target_span = rel["text"]
# Parse the label to get entity type and relation
label_parts = rel["label"].split("<>")
@ -239,11 +287,6 @@ async def relation_extraction(request: RelationExtractionRequest):
source_entity_type = ""
relation_type = rel["label"]
# The extracted span is typically the target entity
target_span = rel["text"]
target_start = rel["start"]
target_end = rel["end"]
# Attempt to find the source entity in context
# Find the closest entity of the specified type before the target
source_entity = None
@ -326,7 +369,15 @@ async def relation_extraction(request: RelationExtractionRequest):
@app.post("/summarization")
async def summarization(request: SummarizationRequest):
"""Summarization endpoint"""
"""Summarization endpoint with prompt support
Example usage:
{
"text": "Long article text here...",
"prompt": "Extract the most important points:\n",
"threshold": 0.5
}
"""
try:
input_data = validate_single_or_batch(request.text, request.texts)
is_batch = isinstance(input_data, list)
@ -335,23 +386,34 @@ async def summarization(request: SummarizationRequest):
results = []
for text in texts_to_process:
# Calculate prompt offset for position adjustment
prompt_offset = len(request.prompt) if request.prompt else 0
# Apply prompt to text
text_with_prompt = add_prompt_to_text(text, request.prompt)
# For summarization, extract key phrases/sentences
summaries = model.predict_entities(
text,
text_with_prompt,
["summary", "key point", "important information"],
threshold=request.threshold,
flat_ner=True
)
output = [
{
output = []
for summ in summaries:
# Skip summaries found in the prompt itself
if summ["start"] < prompt_offset:
continue
# Adjust positions to be relative to original text (without prompt)
output.append({
"text": summ["text"],
"start": summ["start"],
"end": summ["end"],
"start": summ["start"] - prompt_offset,
"end": summ["end"] - prompt_offset,
"score": float(summ["score"])
}
for summ in summaries
]
})
results.append(output)
if is_batch:

Loading…
Cancel
Save