aidial_analytics_realtime/topic_model.py (32 lines of code) (raw):
import os
from bertopic import BERTopic
from aidial_analytics_realtime.utils.concurrency import (
run_in_cpu_tasks_executor,
)
class TopicModel:
def __init__(
self,
topic_model_name: str | None = None,
topic_embeddings_model_name: str | None = None,
):
if not topic_model_name:
topic_model_name = os.environ.get("TOPIC_MODEL", "./topic_model")
topic_embeddings_model_name = os.environ.get(
"TOPIC_EMBEDDINGS_MODEL", None
)
assert topic_model_name is not None
self.model = BERTopic.load(
topic_model_name, topic_embeddings_model_name
)
# Make sure the model is loaded
self._get_topic_by_text("test")
async def get_topic_by_text(self, text: str) -> str | None:
return await run_in_cpu_tasks_executor(self._get_topic_by_text, text)
def _get_topic_by_text(self, text: str) -> str | None:
text = text.strip()
if not text:
return None
topics, _ = self.model.transform([text])
topic = self.model.get_topic_info(topics[0])
if "GeneratedName" in topic:
# "GeneratedName" is an expected name for the human readable topic representation
return topic["GeneratedName"][0][0][0] # type: ignore
return topic["Name"][0] # type: ignore