src/backend/entrypoints/llm_backend/tasks.py (753 lines of code) (raw):

import json import time from collections import ChainMap from typing import Any, Dict, List, Optional import dill as pickle from celery import Task, current_app from celery.exceptions import SoftTimeLimitExceeded from celery.utils.log import get_task_logger from optuna import Study from optuna.study import StudyDirection from optuna.trial import FrozenTrial from redis import RedisError from market_alerts.containers import ( alerts_backend_proxy_singleton, optimization_samplers, optimization_target_funcs, ) from market_alerts.domain.constants import ADDITIONAL_COLUMNS, PUBSUB_END_OF_DATA from market_alerts.domain.exceptions import DataNotFoundError, LLMBadResponseError from market_alerts.domain.services import ( define_empty_indicators_step, define_useful_strings, get_actual_currency_fx_rates, get_combined_trading_statistics, get_sparse_dividends_for_each_tradable_symbol, indicator_chat, indicator_step, optimize, symbol_step, trading_step, ) from market_alerts.domain.services.charts import get_lines_per_symbol_mapping from market_alerts.domain.services.steps import ( delete_optimization_study, get_optimization_results, ) from market_alerts.entrypoints.llm_backend import middleware from market_alerts.entrypoints.llm_backend.api.models.common import ( FlowStatusResponseModel, ) from market_alerts.entrypoints.llm_backend.api.models.llm import LLMResponseModel from market_alerts.entrypoints.llm_backend.api.models.optimization import ( OptimizationResponseModel, OptimizationResult, ) from market_alerts.entrypoints.llm_backend.api.models.tickers import ( TickersFetchInfoResponseModel, ) from market_alerts.entrypoints.llm_backend.api.tasks import get_task_error_msg from market_alerts.entrypoints.llm_backend.containers import ( get_optimization_storage, session_manager_singleton, settings, sync_redis_client, ) from market_alerts.entrypoints.llm_backend.infrastructure.access_management.context_vars import ( user, ) from market_alerts.entrypoints.llm_backend.infrastructure.services.ws import ( OptimizationTaskWSMetaInfo, get_task_ws_meta_info, ) from market_alerts.entrypoints.llm_backend.infrastructure.session import Session, Steps from market_alerts.infrastructure.services.code import get_code_sections from market_alerts.infrastructure.services.proxy.alerts_backend.exceptions import ( LimitsDisabled, ) from market_alerts.utils import convert_date, progress_and_time_generator, time_profile logger = get_task_logger(__name__) @current_app.task(name="fetch_tickers", bind=True) def fetch_tickers( self, data_provider: str, datasets: List[str], periodicity: int, tradable_symbols_prompt: str, supplementary_symbols_prompt: str, economic_indicators: List[str], dividend_fields: List[str], time_range: int, is_chained: bool = False, ) -> Dict[str, Any]: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) if is_chained: session.flow_status.pipeline_start_step(Steps.FETCH_TICKERS) session_manager_singleton.save(current_user.email, session) if ( len(session["datasets_keys"]) != len(datasets) or data_provider != session["data_provider"] or len(dividend_fields) != len(session["dividend_fields"]) or any([x != y for x, y in zip(session["datasets_keys"], datasets)]) ): session["data_by_symbol"] = {} session["meta"] = {} session["data_provider"] = data_provider session["datasets_keys"] = datasets session["interval"] = periodicity session["tradable_symbols_prompt"] = tradable_symbols_prompt session["supplementary_symbols_prompt"] = supplementary_symbols_prompt session["economic_indicators"] = economic_indicators session["dividend_fields"] = dividend_fields session["time_period"] = time_range task_id = self.request.id try: progress_callback = lambda _: _safe_progress_publish( pubsub=sync_redis_client, channel=f"task-{task_id}-progress", message=json.dumps({}), error_flag=False, ) (request_timestamp, fetched_symbols_meta, synth_formulas, error_message), execution_time = time_profile( symbol_step, session, progress_callback ) payload = {"progress": PUBSUB_END_OF_DATA} _safe_progress_publish( pubsub=sync_redis_client, channel=f"task-{task_id}-progress", message=json.dumps(payload), error_flag=False, ) define_useful_strings(session) define_empty_indicators_step(session) plots_meta = _build_prices_plots_meta(session) session.flow_status.promote_fetch_step( TickersFetchInfoResponseModel( data_provider=data_provider, datasets=datasets, periodicity=periodicity, time_range=time_range, fetched_symbols_meta=fetched_symbols_meta, plots_meta=plots_meta, synth_formulas=synth_formulas, request_timestamp=request_timestamp, execution_time=execution_time, ).model_dump(), error_message, ) if is_chained: session.flow_status.pipeline_finish_step(Steps.FETCH_TICKERS) session_manager_singleton.save(current_user.email, session) return FlowStatusResponseModel( flow_status=session.flow_status.to_dict(), ).model_dump() # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, DataNotFoundError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.FETCH_TICKERS, exc_to_raise=e, ) def _build_prices_plots_meta(session) -> dict[str, dict[str, dict[str, list[list[dict[str, str]]]]]]: symbols_meta = {} plots_meta = { "symbols": symbols_meta, "start_date": convert_date(session["start_date"]), "end_date": convert_date(session["end_date"]), } lines_per_symbol = get_lines_per_symbol_mapping( { **session["data_by_symbol"], **session.get("data_by_synth", {}), } ) for symbol, lines in lines_per_symbol.items(): symbols_meta[symbol] = {"charts": []} for line in lines: symbols_meta[symbol]["charts"].append([{"name": line, "type": _determine_prices_plot_line_type(line)}]) symbols_meta[symbol]["type"] = _determine_plot_type(symbol, session) return plots_meta def _determine_prices_plot_line_type(line: str) -> str: if line in ADDITIONAL_COLUMNS: return "dividend" return "price" def _determine_plot_type(symbol: str, session) -> str: if symbol in session["tradable_symbols"] or symbol in session["synth_formulas_to_trade"]: return "tradable" if symbol in session["supplementary_symbols"] or symbol in session["synth_formulas_not_to_trade"]: return "supplementary" elif symbol in session["economic_indicator_symbols"]: return "economic_indicator" raise RuntimeError(f"unexpected symbol: {symbol}") @current_app.task(name="submit_llm_chat", bind=True) def submit_llm_chat( self, llm_query: str, user_prompt_ids: List[int], engine: str, ) -> Dict[str, Any]: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) session.setdefault("indicators_dialogue", []).append(llm_query) try: (llm_response, token_usage, request_timestamp), execution_time = time_profile( indicator_chat, session, user_prompt_ids, engine ) try: alerts_backend_proxy_singleton.send_used_resources_info( token_usage["prompt_tokens"], token_usage["completion_tokens"], 0 ) except LimitsDisabled: logger.warning("Tried sending used resources info, but limits were disabled") indicators_code, trading_code = get_code_sections(llm_response) session.flow_status.promote_submit_llm_chat_step( indicators_code, trading_code, LLMResponseModel( flow_status=session.flow_status.to_dict(), llm_response=llm_response, engine=engine, token_usage=token_usage, request_timestamp=request_timestamp, execution_time=execution_time, ).model_dump(exclude={"flow_status"}), ) session_manager_singleton.save(current_user.email, session) return FlowStatusResponseModel( flow_status=session.flow_status.to_dict(), ).model_dump() # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.SUBMIT_LLM_CHAT, exc_to_raise=e, ) @current_app.task(name="calculate_indicators", bind=True) def calculate_indicators( self, is_chained: bool = False, ) -> Dict[str, Any]: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) if is_chained: session.flow_status.pipeline_start_step(Steps.CALCULATE_INDICATORS) session_manager_singleton.save(current_user.email, session) try: _, execution_time = time_profile(indicator_step, session) plots_meta = _build_indicators_plots_meta(session) session.flow_status.promote_indicators_step( { "execution_time": execution_time, "logs_present": True if session["indicators_code_log"] else False, "plots_meta": plots_meta, } ) try: alerts_backend_proxy_singleton.send_used_resources_info(0, 0, execution_time) except LimitsDisabled: logger.warning("Tried sending used resources info, but limits were disabled") if is_chained: session.flow_status.pipeline_finish_step(Steps.CALCULATE_INDICATORS) session_manager_singleton.save(current_user.email, session) return FlowStatusResponseModel( flow_status=session.flow_status.to_dict(), ).model_dump() # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, LLMBadResponseError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.CALCULATE_INDICATORS, exc_to_raise=e, indicators_code_log=session.get("indicators_code_log", []), ) def _build_indicators_plots_meta(session): plots_meta = {} for symbol in session["main_roots"]: charts = [] composite_chart = [{"name": "close", "type": "price"}] for ind in session["roots"][symbol]: composite_chart.append({"name": ind, "type": "indicator"}) charts.append(composite_chart) for ind in session["main_roots"][symbol]: charts.append([{"name": ind, "type": "indicator"}]) plots_meta[symbol] = {"charts": charts, "type": _determine_plot_type(symbol, session)} return plots_meta @current_app.task(name="calculate_backtesting", bind=True) def calculate_backtesting( self, actual_currency: str, bet_size: float, per_instrument_gross_limit: float, total_gross_limit: float, nop_limit: float, account_for_dividends: bool, trade_fill_price: str, execution_cost_bps: float, ) -> Dict[str, Any]: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) session["actual_currency"] = actual_currency session["bet_size"] = bet_size session["per_instrument_gross_limit"] = per_instrument_gross_limit session["total_gross_limit"] = total_gross_limit session["nop_limit"] = nop_limit session["use_dividends_trading"] = account_for_dividends session["fill_trade_price"] = trade_fill_price session["execution_cost_bps"] = execution_cost_bps lclsglbls_before = session["lclsglbls"] try: elapsed_time = _run_backtesting_with_progress( session=session, account_for_dividends=account_for_dividends, progress_channel=f"task-{self.request.id}-progress", ) try: alerts_backend_proxy_singleton.send_used_resources_info(0, 0, elapsed_time) except LimitsDisabled: logger.warning("Tried sending used resources info, but limits were disabled") plots_meta = _build_backtesting_plots_meta(session) session.flow_status.promote_backtesting_step( { "execution_time": elapsed_time, "logs_present": True if session["trading_code_log"] else False, "plots_meta": plots_meta, } ) session["lclsglbls"] = lclsglbls_before session_manager_singleton.save(current_user.email, session) payload = {"progress": PUBSUB_END_OF_DATA} _safe_progress_publish( pubsub=sync_redis_client, channel=f"task-{self.request.id}-progress", message=json.dumps(payload), error_flag=False, ) return FlowStatusResponseModel( flow_status=session.flow_status.to_dict(), ).model_dump() # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, LLMBadResponseError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.PERFORM_BACKTESTING, exc_to_raise=e, trading_code_log=session.get("trading_code_log", []), ) def _run_backtesting_with_progress( session: Session, account_for_dividends: bool, progress_channel: str, ) -> float: publish_error_encountered = False elapsed_time = 0.0 for progress, elapsed_time, remaining_time in progress_and_time_generator()( trading_step, session, apply_dividends=account_for_dividends, ): payload = {"progress": progress, "elapsed_time": elapsed_time, "remaining_time": remaining_time} publish_error_encountered = _safe_progress_publish( pubsub=sync_redis_client, channel=progress_channel, message=json.dumps(payload), error_flag=publish_error_encountered, ) return elapsed_time def _build_backtesting_plots_meta(session): return { "symbols": list( filter( lambda s: _determine_plot_type(s, session) == "tradable", ChainMap(session["data_by_symbol"], session.get("data_by_synth", {})), ) ), "start_date": convert_date(session["start_date"]), "end_date": convert_date(session["end_date"]), } @current_app.task(name="run_optimization_prepare", bind=True) def run_optimization_prepare( self, actual_currency: str, bet_size: float, per_instrument_gross_limit: float, total_gross_limit: float, nop_limit: float, account_for_dividends: bool, trade_fill_price: str, execution_cost_bps: float, n_trials: int, train_size: float, params: list[dict[str, Any]], minimize: bool, maximize: bool, sampler: str, target_func: str, studies_names: list[str], ) -> None: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) session["actual_currency"] = actual_currency session["bet_size"] = bet_size session["per_instrument_gross_limit"] = per_instrument_gross_limit session["total_gross_limit"] = total_gross_limit session["nop_limit"] = nop_limit session["use_dividends_trading"] = account_for_dividends session["fill_trade_price"] = trade_fill_price session["execution_cost_bps"] = execution_cost_bps session["optimization_trials"] = n_trials session["optimization_train_size"] = train_size session["optimization_params"] = params session["optimization_minimize"] = minimize session["optimization_maximize"] = maximize session["optimization_sampler"] = sampler session["optimization_target_func"] = target_func if old_studies_names := session.get("optimization_studies_names"): for study_name in old_studies_names: try: delete_optimization_study(get_optimization_storage(pool_size=3, max_overflow=3), study_name) except KeyError: # Study doesn't exist, perhaps it was revoked pass session["optimization_studies_names"] = studies_names session["range_by_param"] = {p["name"]: p["values"] for p in params} try: session["fx_rates"] = get_actual_currency_fx_rates(session, actual_currency) if account_for_dividends: session["dividends_by_symbol"] = get_sparse_dividends_for_each_tradable_symbol(session) session_manager_singleton.save(current_user.email, session) # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.OPTIMIZE, exc_to_raise=e, ) # TODO: need to handle study removal if pipeline dropped with exception # try: # if minimization_study: # delete_study(study_name=minimization_study.study_name, storage=...) # if maximization_study: # delete_study(study_name=maximization_study.study_name, storage=...) # except Exception as e: # logger.warning(f"Couldn't remove optuna studies from DB due to error: %s", e) @current_app.task(name="run_optimization", bind=True) def run_optimization( self, study_name: str, target_func: str, sampler: str, account_for_dividends: bool, n_trials: int, train_size: float, progress_channel: str, pipeline_id: str, ) -> None: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) try: _run_optimization_with_progress( session=session, study_name=study_name, target_func=target_func, sampler=sampler, account_for_dividends=account_for_dividends, n_trials=n_trials, train_size=train_size, progress_channel=progress_channel, pipeline_id=pipeline_id, ) # session_manager_singleton.save(current_user.email, session) # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, LLMBadResponseError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.OPTIMIZE, exc_to_raise=e, ) def _run_optimization_with_progress( session: Session, study_name: str, target_func: str, sampler: str, account_for_dividends: bool, n_trials: int, train_size: float, progress_channel: str, pipeline_id: str, ) -> None: optimize( session, target_function=optimization_target_funcs[target_func]["value"], storage=get_optimization_storage(pool_size=3, max_overflow=3), sampler=optimization_samplers[sampler]["value"](), study_name=study_name, study_direction=StudyDirection.MAXIMIZE, study_load_if_exists=True, trial_callbacks=[OptimizationTrialCallback(sync_redis_client, progress_channel, pipeline_id)], n_trials=n_trials, train_size=train_size, apply_dividends=account_for_dividends, is_trades_stats_needed=optimization_target_funcs[target_func]["is_trades_stats_needed"], ) class OptimizationTrialCallback: def __init__(self, pubsub, channel: str, pipeline_id: str) -> None: self._pubsub = pubsub self._channel = channel self._pipeline_id = pipeline_id self._publish_error_encountered = False def __call__(self, study: Study, trial: FrozenTrial) -> None: message = { "trial_in_sample": trial.value, "trial_out_of_sample": trial.user_attrs["test_value"], "trial_params": trial.params, "direction": "minimization" if study.direction == StudyDirection.MINIMIZE else "maximization", "duration": trial.duration.total_seconds(), } try: alerts_backend_proxy_singleton.send_used_resources_info(0, 0, message["duration"]) except LimitsDisabled: logger.debug("Tried sending used resources info, but limits were disabled") self._publish_error_encountered = _safe_progress_publish( self._pubsub, self._channel, json.dumps(message), self._publish_error_encountered ) task_info = get_task_ws_meta_info(self._pipeline_id) if isinstance(task_info, OptimizationTaskWSMetaInfo): if task_info.stop_flag: study.stop() logger.info("Stopped '%s' study", study.study_name) else: raise RuntimeError( f"optimization trial callback expected {OptimizationTaskWSMetaInfo.__name__}, but got {type(task_info).__name__}" ) @current_app.task(name="run_optimization_chord", bind=True) def run_optimization_chord(self, pipeline_start_time: int, studies_names: list[str]): current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) try: optimization_results = [ get_optimization_results(settings.optimization.storage_url, study_name) for study_name in studies_names ] minimization_result = _get_optimization_result(StudyDirection.MINIMIZE, optimization_results) maximization_result = _get_optimization_result(StudyDirection.MAXIMIZE, optimization_results) session.flow_status.promote_optimization_step( { # TODO: safe only if running processes are within the same machine/pod, fine for one celery pod "execution_time": time.time() - pipeline_start_time, **OptimizationResponseModel( minimization=minimization_result, maximization=maximization_result, sampler=session.get("optimization_sampler"), target_func=session.get("optimization_target_func"), ).model_dump(), } ) session_manager_singleton.save(current_user.email, session) payload = {"progress": PUBSUB_END_OF_DATA} _safe_progress_publish( pubsub=sync_redis_client, channel=f"task-{self.request.id}-progress", message=json.dumps(payload), error_flag=False, ) return FlowStatusResponseModel( flow_status=session.flow_status.to_dict(), ).model_dump() # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, LLMBadResponseError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.OPTIMIZE, exc_to_raise=e, ) def _get_optimization_result(study_direction: StudyDirection, all_results) -> Optional[OptimizationResult]: try: best_params, _, trials = next(filter(lambda r: r[1].direction == study_direction, all_results)) except StopIteration: return None return OptimizationResult( best_params=best_params, trials=trials, ) @current_app.task(name="run_after_optimization_prepare", bind=True) def run_after_optimization_prepare( self, actual_currency: str, bet_size: float, per_instrument_gross_limit: float, total_gross_limit: float, nop_limit: float, account_for_dividends: bool, trade_fill_price: str, execution_cost_bps: float, ): current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) session["actual_currency"] = actual_currency session["bet_size"] = bet_size session["per_instrument_gross_limit"] = per_instrument_gross_limit session["total_gross_limit"] = total_gross_limit session["nop_limit"] = nop_limit session["use_dividends_trading"] = account_for_dividends session["fill_trade_price"] = trade_fill_price session["execution_cost_bps"] = execution_cost_bps try: session_manager_singleton.save(current_user.email, session) # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.OPTIMIZE, exc_to_raise=e, ) @current_app.task(name="run_after_optimization", bind=True) def run_after_optimization( self, params: list[dict[str, Any]], account_for_dividends: bool, progress_channel: str ) -> tuple[dict[str, Any], dict[str, Any]]: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) constructed_llm_response = """ ```python %s ``` ```python %s ``` """ % ( session.flow_status.get_interpolated_indicators_code_template(ChainMap(*params)), session.flow_status.trading_code, ) session["indicators_dialogue"][-1] = constructed_llm_response try: indicator_step(session) _run_backtesting_with_progress( session=session, account_for_dividends=account_for_dividends, progress_channel=progress_channel, ) return pickle.dumps((session["trading_stats_by_symbol"], session["strategy_stats"])) # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, LLMBadResponseError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.OPTIMIZE, exc_to_raise=e, ) @current_app.task(name="run_after_optimization_chord", bind=True) def run_after_optimization_chord(self, results, pipeline_start_time: int, account_for_dividends: bool): current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) try: list_trading_stats_by_symbol, list_strategy_stats = [], [] for result in results: trading_stats_by_symbol, strategy_stats = pickle.loads(result) list_trading_stats_by_symbol.append(trading_stats_by_symbol) list_strategy_stats.append(strategy_stats) trading_statistics = get_combined_trading_statistics( session, list_trading_stats_by_symbol, list_strategy_stats, apply_dividends=account_for_dividends ) session.update(trading_statistics) session.flow_status.promote_backtesting_step( { # TODO: safe only if running processes are within the same machine/pod, fine for one celery pod "execution_time": time.time() - pipeline_start_time, } ) session_manager_singleton.save(current_user.email, session) payload = {"progress": PUBSUB_END_OF_DATA} _safe_progress_publish( pubsub=sync_redis_client, channel=f"task-{self.request.id}-progress", message=json.dumps(payload), error_flag=False, ) return FlowStatusResponseModel( flow_status=session.flow_status.to_dict(), ).model_dump() # TODO: handle only specific exceptions in prod except (SoftTimeLimitExceeded, LLMBadResponseError, Exception) as e: _handle_error( session_id=session.session_id, email=current_user.email, step=Steps.OPTIMIZE, exc_to_raise=e, ) @current_app.task(name="chained_load_model_into_session", bind=True) def chained_load_model_into_session( self, model_id: Optional[int], is_public: bool, indicators_dialogue: List[str], data_provider: str, datasets: List[str], periodicity: int, tradable_symbols_prompt: str, supplementary_symbols_prompt: str, economic_indicators: List[str], dividend_fields: List[str], time_range: int, strategy_title: str, strategy_description: str, actual_currency: str, bet_size: float, per_instrument_gross_limit: float, total_gross_limit: float, nop_limit: float, account_for_dividends: bool, trade_fill_price: str, execution_cost_bps: float, optimization_trials: int, optimization_train_size: float, optimization_params: List[dict[str, Any]], optimization_minimize: bool, optimization_maximize: bool, optimization_sampler: str, optimization_target_func: str, ) -> None: current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) session.reset_flow_status() if not is_public and model_id is not None: session.flow_status.set_model_opened(model_id) session["indicators_dialogue"] = indicators_dialogue session["data_provider"] = data_provider session["datasets_keys"] = datasets session["interval"] = periodicity session["tradable_symbols_prompt"] = tradable_symbols_prompt session["supplementary_symbols_prompt"] = supplementary_symbols_prompt session["economic_indicators"] = economic_indicators session["dividend_fields"] = dividend_fields session["time_period"] = time_range session["strategy_title"] = strategy_title session["strategy_description"] = strategy_description session["actual_currency"] = actual_currency session["bet_size"] = bet_size session["per_instrument_gross_limit"] = per_instrument_gross_limit session["total_gross_limit"] = total_gross_limit session["nop_limit"] = nop_limit session["use_dividends_trading"] = account_for_dividends session["fill_trade_price"] = trade_fill_price session["execution_cost_bps"] = execution_cost_bps session["optimization_trials"] = optimization_trials session["optimization_train_size"] = optimization_train_size session["optimization_params"] = optimization_params session["optimization_minimize"] = optimization_minimize session["optimization_maximize"] = optimization_maximize session["optimization_sampler"] = optimization_sampler session["optimization_target_func"] = optimization_target_func session_manager_singleton.save(current_user.email, session) @current_app.task(name="chained_submit_llm_chat", bind=True) def chained_submit_llm_chat( self, indicators_code: str, trading_code: str, ) -> None: middleware.check_limits_middleware() current_user = user.get() session = session_manager_singleton.get(current_user.email, self.session_id) session.flow_status.promote_submit_llm_chat_step(indicators_code, trading_code) session_manager_singleton.save(current_user.email, session) @current_app.task(name="clear_expired_sessions", base=Task) def clear_expired_sessions() -> None: session_manager_singleton.delete_expired_sessions(settings.optimization.storage_url) def _handle_error(session_id: str, email: str, step: Steps, exc_to_raise: Exception, **additional_session_values): session = session_manager_singleton.get(email, session_id) for key, value in additional_session_values.items(): session[key] = value session.flow_status.add_error_for_step(step, get_task_error_msg(exc_to_raise)) session_manager_singleton.save(email, session) raise exc_to_raise def _safe_progress_publish(pubsub, channel: str, message: str, error_flag: bool) -> bool: try: pubsub.publish(channel, message) if error_flag: logger.info(f"Redis is back up. Resuming progress updates to '{channel}'") return False except RedisError as e: if not error_flag: logger.warning(f"Some problem occurred while pushing progress updates to '{channel}' Redis channel: {e}") return True return error_flag