src/backend/entrypoints/llm_backend/api/session.py (128 lines of code) (raw):
from datetime import datetime
from typing import List
from fastapi import APIRouter
from market_alerts.entrypoints.llm_backend.api.models.common import (
FlowStatusResponseModel,
)
from market_alerts.entrypoints.llm_backend.api.models.session import (
DefaultSession,
SessionInfoResponseModel,
SessionResponseModel,
UpdateCodeRequestModel,
)
from market_alerts.entrypoints.llm_backend.containers import (
celery_app,
session_manager_singleton,
)
from market_alerts.entrypoints.llm_backend.domain.exceptions import (
BacktestingNotPerformedError,
IndicatorsNotGeneratedError,
LLMChatNotSubmittedError,
)
from market_alerts.entrypoints.llm_backend.infrastructure.access_management.context_vars import (
user,
)
from market_alerts.entrypoints.llm_backend.infrastructure.session import Steps
from market_alerts.infrastructure.services.code import update_code_sections
session_router = APIRouter(prefix="/session")
@session_router.post("", tags=["Session"], response_model=SessionResponseModel, status_code=201)
def create_session():
current_user = user.get()
session = session_manager_singleton.create(current_user.email, data=DefaultSession().model_dump())
return SessionResponseModel(sessionId=session.session_id, expires_in=session.expires_in)
@session_router.get("/{session_id}", response_model=SessionInfoResponseModel, tags=["Session"], status_code=200)
def get_data(session_id: str):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
return SessionInfoResponseModel(
sessionId=session_id,
flow_status=session.flow_status.to_dict(),
datasets=session.get("datasets_keys"),
time_range=session.get("time_period"),
periodicity=session.get("interval"),
account_for_dividends=session.get("use_dividends_trading"),
trade_fill_price=session.get("fill_trade_price"),
**session.data,
)
@session_router.put("/{session_id}", tags=["Session"], response_model=SessionResponseModel, status_code=200)
def prolong_session(session_id: str):
current_user = user.get()
session = session_manager_singleton.prolong(current_user.email, session_id)
return SessionResponseModel(sessionId=session.session_id, expires_in=session.expires_in)
@session_router.patch("/{session_id}/clear_dialogue", tags=["Session"], response_model=SessionResponseModel, status_code=200)
def clear_llm_dialogue(session_id: str):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
session.data["indicators_dialogue"] = []
session.flow_status.set_llm_chat_history_cleared()
session_manager_singleton.save(current_user.email, session)
return SessionResponseModel(sessionId=session.session_id, expires_in=session.expires_in)
@session_router.put("/{session_id}/update_code", tags=["Session"], response_model=FlowStatusResponseModel, status_code=200)
def update_code(session_id: str, request_model: UpdateCodeRequestModel):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
if not session.flow_status.is_step_done(Steps.SUBMIT_LLM_CHAT):
raise LLMChatNotSubmittedError
session["indicators_dialogue"][-1] = update_code_sections(
session["indicators_dialogue"][-1], request_model.indicators_code, request_model.trading_code
)
session.flow_status.promote_submit_llm_chat_step(request_model.indicators_code, request_model.trading_code)
session_manager_singleton.save(current_user.email, session)
return FlowStatusResponseModel(flow_status=session.flow_status.to_dict())
@session_router.get("/{session_id}/indicators_logs", response_model=List[str], tags=["Session"], status_code=200)
def get_indicators_logs(session_id: str):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
if not session.flow_status.is_step_done(Steps.CALCULATE_INDICATORS):
raise IndicatorsNotGeneratedError
return session["indicators_code_log"]
@session_router.get("/{session_id}/trading_logs", response_model=List[str], tags=["Session"], status_code=200)
def get_trading_logs(session_id: str):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
if not session.flow_status.is_step_done(Steps.PERFORM_BACKTESTING):
raise BacktestingNotPerformedError
return session["trading_code_log"]
@session_router.put("/{session_id}/reset", response_model=SessionInfoResponseModel, tags=["Session"], status_code=200)
def reset_session(session_id: str):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
session.reset_flow_status()
session.data = DefaultSession().model_dump()
session_manager_singleton.save(current_user.email, session)
return SessionInfoResponseModel(
sessionId=session_id,
flow_status=session.flow_status.to_dict(),
datasets=session.get("datasets_keys"),
time_range=session.get("time_period"),
periodicity=session.get("interval"),
account_for_dividends=session.get("use_dividends_trading"),
trade_fill_price=session.get("fill_trade_price"),
**session.data,
)
@session_router.delete("/{session_id}", tags=["Session"], status_code=204)
def delete_session(session_id: str):
current_user = user.get()
session_manager_singleton.delete(current_user.email, session_id)
@session_router.get("/action_history/{session_id}", tags=["Session"], status_code=200)
def get_action_history(session_id: str):
current_user = user.get()
session = session_manager_singleton.get(current_user.email, session_id)
actions = []
for task_id, timestamp in session.actions_history:
task = celery_app.AsyncResult(task_id)
args_str = ", ".join([str(arg) for arg in task.args])
kwargs_str = ", ".join([f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" for k, v in task.kwargs.items()])
all_args_str = ", ".join(filter(None, [args_str, kwargs_str]))
action = f"{task.name}({all_args_str})"
actions.append(
dict(
start_date=datetime.fromtimestamp(timestamp).strftime("%Y-%m-%dT%H:%M:%S.%f"),
end_date__=task.date_done,
action=action,
result=str(task.result) if isinstance(task.result, Exception) else task.result,
)
)
if task.traceback:
actions[-1].update(traceback=task.traceback)
return actions