dial-docker-compose/ollama/ollama_setup/app.py (83 lines of code) (raw):

import asyncio from contextlib import asynccontextmanager import os import asyncio from fastapi import FastAPI from ollama import AsyncClient from tqdm import tqdm from tenacity import retry, stop_after_attempt from utils import Writer, print_info, timer OLLAMA_URL = os.getenv("OLLAMA_URL") if OLLAMA_URL is None: raise RuntimeError("OLLAMA_URL env var isn't set") OLLAMA_CHAT_MODEL = os.getenv("OLLAMA_CHAT_MODEL") OLLAMA_VISION_MODEL = os.getenv("OLLAMA_VISION_MODEL") OLLAMA_EMBEDDING_MODEL = os.getenv("OLLAMA_EMBEDDING_MODEL") async def wait_for_startup(): attempts = 0 while True: attempts += 1 try: await AsyncClient(host=OLLAMA_URL, timeout=5).ps() except Exception: print_info(f"[{attempts:>3}] Waiting for Ollama to start...") await asyncio.sleep(5) else: break @retry(stop=stop_after_attempt(5)) async def pull_model(client: AsyncClient, model: str): response = await client.pull(model, stream=True) progress_bar = None prev_status = None async for chunk in response: status = chunk["status"] total = chunk.get("total") completed = chunk.get("completed") if status != prev_status and total: prev_status = status progress_bar = tqdm( total=total, unit="B", unit_scale=True, desc=f"[{status}]", mininterval=1, file=Writer, ) if completed and total and progress_bar: progress_bar.n = completed progress_bar.update(n=0) if total and total == completed and progress_bar: progress_bar.close() progress_bar = None if not completed and not total: print_info(f"[{status}]") async def startup(): print_info(f"OLLAMA_URL = {OLLAMA_URL}") print_info(f"OLLAMA_CHAT_MODEL = {OLLAMA_CHAT_MODEL}") print_info(f"OLLAMA_VISION_MODEL = {OLLAMA_VISION_MODEL}") print_info(f"OLLAMA_EMBEDDING_MODEL = {OLLAMA_EMBEDDING_MODEL}") client = AsyncClient(host=OLLAMA_URL, timeout=300) async with timer("Waiting for Ollama to start"): await wait_for_startup() for model, alias in [ (OLLAMA_CHAT_MODEL, "chat-model"), (OLLAMA_VISION_MODEL, "vision-model"), (OLLAMA_EMBEDDING_MODEL, "embedding-model"), ]: if model: async with timer(f"Pulling model {model}"): await pull_model(client, model) async with timer(f"Creating alias for {model}: {alias}"): await client.copy(model, alias) if model_to_load := (OLLAMA_CHAT_MODEL or OLLAMA_VISION_MODEL): async with timer(f"Loading model {model_to_load} into memory"): await client.generate(model_to_load) print_info("The Ollama server is up and running.") @asynccontextmanager async def lifespan(app): await startup() yield app = FastAPI(lifespan=lifespan) @app.get("/health") def health_check(): return {"status": "ok"}