src/backend/entrypoints/llm_backend/api/ws.py (49 lines of code) (raw):

from fastapi import APIRouter, Body, Query, WebSocket, status from market_alerts.entrypoints.llm_backend.api.models.tasks import TaskIDRequestModel from market_alerts.entrypoints.llm_backend.api.models.ws import ( WSAuthTicketResponseModel, ) from market_alerts.entrypoints.llm_backend.infrastructure.services.ws import ( BaseTaskProgressHandler, OptimizationTaskProgressHandler, OptimizationTaskWSMetaInfo, PipelineWorkUnitBasedWSTaskMetaInfo, WorkUnitBasedProgressHandler, WorkUnitBasedTaskWSMetaInfo, consume_progress_updates, generate_ws_auth_ticket, get_task_ws_meta_info_async, ) ws_router = APIRouter(prefix="/websockets") @ws_router.post("/auth", tags=["Websockets"], response_model=WSAuthTicketResponseModel, status_code=201) async def create_ws_auth_ticket(request_model: TaskIDRequestModel = Body(...)): ticket_id, creation_timestamp = await generate_ws_auth_ticket( task_id=request_model.task_id, ttl=60, ) return WSAuthTicketResponseModel( ticket_id=ticket_id, task_id=request_model.task_id, creation_timestamp=creation_timestamp, ) @ws_router.websocket("/progress") async def websocket_task_progress_updates(websocket: WebSocket, task_id: str = Query(...)): await websocket.accept() task_ids = None task_info = await get_task_ws_meta_info_async(task_id) if isinstance(task_info, OptimizationTaskWSMetaInfo): task_handler = OptimizationTaskProgressHandler( task_info.studies_names, task_info.work_units, task_info.start_time, websocket ) task_ids = task_info.task_ids elif isinstance(task_info, PipelineWorkUnitBasedWSTaskMetaInfo): task_handler = WorkUnitBasedProgressHandler(task_info.work_units, task_info.start_time, websocket) task_ids = task_info.task_ids elif isinstance(task_info, WorkUnitBasedTaskWSMetaInfo): task_handler = WorkUnitBasedProgressHandler(task_info.work_units, task_info.start_time, websocket) else: task_handler = BaseTaskProgressHandler(task_info.start_time, websocket) try: await consume_progress_updates(websocket, task_id, task_handler, task_ids) except Exception: await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason="Internal error")