aidial_analytics_realtime/topic_model.py (32 lines of code) (raw):

import os from bertopic import BERTopic from aidial_analytics_realtime.utils.concurrency import ( run_in_cpu_tasks_executor, ) class TopicModel: def __init__( self, topic_model_name: str | None = None, topic_embeddings_model_name: str | None = None, ): if not topic_model_name: topic_model_name = os.environ.get("TOPIC_MODEL", "./topic_model") topic_embeddings_model_name = os.environ.get( "TOPIC_EMBEDDINGS_MODEL", None ) assert topic_model_name is not None self.model = BERTopic.load( topic_model_name, topic_embeddings_model_name ) # Make sure the model is loaded self._get_topic_by_text("test") async def get_topic_by_text(self, text: str) -> str | None: return await run_in_cpu_tasks_executor(self._get_topic_by_text, text) def _get_topic_by_text(self, text: str) -> str | None: text = text.strip() if not text: return None topics, _ = self.model.transform([text]) topic = self.model.get_topic_info(topics[0]) if "GeneratedName" in topic: # "GeneratedName" is an expected name for the human readable topic representation return topic["GeneratedName"][0][0][0] # type: ignore return topic["Name"][0] # type: ignore