diff --git a/gliner_inference_server/main.py b/gliner_inference_server/main.py index d14d5dd..f22a45e 100644 --- a/gliner_inference_server/main.py +++ b/gliner_inference_server/main.py @@ -1,17 +1,33 @@ from fastapi import FastAPI, HTTPException, status -from pydantic import BaseModel, Field, field_validator from typing import List, Optional, Union, Dict, Any, Tuple from contextlib import asynccontextmanager import os +import torch +from gliner import GLiNER +from .models import * # Global model instance model = None +@asynccontextmanager +async def lifespan(app: FastAPI): + """Load model on startup, cleanup on shutdown""" + global model + print("Loading GLiNER model...") + model_name = os.getenv("MODEL_NAME", "knowledgator/gliner-multitask-large-v0.5") + device = "cuda" if torch.cuda.is_available() else "cpu" + model = GLiNER.from_pretrained(model_name).to(device) + print(f"Model loaded on {device}") + yield + print("Shutting down...") + + app = FastAPI( title="GLiNER Inference Server", description="Named Entity Recognition, Relation Extraction, and Summarization API", version="1.0.0", + lifespan=lifespan )