sourcecode/scoring/topic_model.py (167 lines of code) (raw):

"""Assign notes to a set of predetermined topics. The topic assignment process is seeded with a small set of terms which are indicative of the topic. After preliminary topic assignment based on term matching, a logistic regression trained on bag-of-words features model expands the set of in-topic notes for each topic. Note that the logistic regression modeling excludes any tokens containing seed terms. This approach represents a prelimiary approach to topic assignment while Community Notes evaluates the efficacy of per-topic note scoring. """ import logging import re from typing import List, Optional, Tuple from . import constants as c from .enums import Topics import numpy as np import pandas as pd from scipy.special import expit as sigmoid, softmax from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import LogisticRegression from sklearn.metrics import balanced_accuracy_score from sklearn.pipeline import Pipeline logger = logging.getLogger("birdwatch.topic_model") logger.setLevel(logging.INFO) class TopicModel(object): def __init__(self, unassignedThreshold=0.85): """Initialize a list of seed terms for each topic.""" self._seedTerms = { Topics.UkraineConflict: { "ukrain", # intentionally shortened for expanded matching "russia", "kiev", "kyiv", "moscow", "zelensky", "putin", }, Topics.GazaConflict: { "israel", "palestin", # intentionally shortened for expanded matching "gaza", "jerusalem", }, Topics.MessiRonaldo: { "messi\s", # intentional whitespace to prevent prefix matches "ronaldo", }, } self._unassignedThreshold = unassignedThreshold self._compiled_regex = self._compile_regex() def _compile_regex(self): """Compile a single regex from all seed terms grouped by topic.""" regex_patterns = {} for topic, patterns in self._seedTerms.items(): group_name = f"{topic.name}" regex_patterns[group_name] = f"(?P<{group_name}>{'|'.join(patterns)})" combined_regex = "|".join(regex_patterns.values()) return re.compile(combined_regex, re.IGNORECASE) def _make_seed_labels(self, texts: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Produce a label vector based on seed terms. Args: texts: array containing strings for topic assignment Returns: Tuple[0]: array specifying topic labels for texts Tuple[1]: array specifying texts that are unassigned due to conflicting matches. """ labels = np.zeros(texts.shape[0], dtype=np.int64) conflictedTexts = np.zeros(texts.shape[0], dtype=bool) for i, text in enumerate(texts): matches = self._compiled_regex.finditer(text.lower()) found_topics = set() for match in matches: found_topics.update([Topics[grp].value for grp in match.groupdict() if match.group(grp)]) if len(found_topics) == 1: labels[i] = found_topics.pop() elif len(found_topics) > 1: labels[i] = Topics.Unassigned.value conflictedTexts[i] = True unassigned_count = np.sum(conflictedTexts) logger.info(f" Notes unassigned due to multiple matches: {unassigned_count}") return labels, conflictedTexts def _get_stop_words(self, texts: np.ndarray) -> List[str]: """Identify tokens in the extracted vocabulary that contain seed terms. Any token containing a seed term will be treated as a stop word (i.e. removed from the extracted features). This prevents the model from training on the same tokens used to label the data. Args: texts: array containing strings for topic assignment Returns: List specifying which tokens to exclude from the features. """ # Extract vocabulary cv = CountVectorizer(strip_accents="unicode") cv.fit(texts) rawVocabulary = cv.vocabulary_.keys() logger.info(f" Initial vocabulary length: {len(rawVocabulary)}") # Identify stop words blockedTokens = set() for terms in self._seedTerms.values(): # Remove whitespace and any escaped characters from terms blockedTokens |= {re.sub(r"\\.", "", t.strip()) for t in terms} logger.info(f" Total tokens to filter: {len(blockedTokens)}") stopWords = [v for v in rawVocabulary if any(t in v for t in blockedTokens)] logger.info(f" Total identified stopwords: {len(stopWords)}") return stopWords def _merge_predictions_and_labels(self, probs: np.ndarray, labels: np.ndarray) -> np.ndarray: """Update predictions based on defined labels when the label is not Unassigned. Args: probs: 2D matrix specifying the likelihood of each class Returns: Updated predictions based on keyword matches when available. """ predictions = np.argmax(probs, axis=1) for label in range(1, len(Topics)): # Update label if (1) note was assigned based on the labeling heuristic, and (2) # p(Unassigned) is below the required uncertainty threshold. predictions[(labels == label) & (probs[:, 0] <= self._unassignedThreshold)] = label return predictions def _prepare_post_text(self, notes: pd.DataFrame) -> pd.DataFrame: """Concatenate all notes within each post into a single row associated with the post. Args: notes: dataframe containing raw note text Returns: DataFrame with one post per row containing note text """ postNoteText = ( notes[[c.tweetIdKey, c.summaryKey]] .fillna({c.summaryKey: ""}) .groupby(c.tweetIdKey)[c.summaryKey] .apply(lambda postNotes: " ".join(postNotes)) .reset_index(drop=False) ) # Default tokenization for CountVectorizer will not split on underscore, which # results in very long tokens containing many words inside of URLs. Removing # underscores allows us to keep default splitting while fixing that problem. postNoteText[c.summaryKey] = [ text.replace("_", " ") for text in postNoteText[c.summaryKey].values ] return postNoteText def train_note_topic_classifier( self, notes: pd.DataFrame ) -> Tuple[Pipeline, np.ndarray, np.ndarray]: # Obtain aggregate post text, seed labels and stop words with c.time_block("Get Note Topics: Prepare Post Text"): postText = self._prepare_post_text(notes) with c.time_block("Get Note Topics: Make Seed Labels"): seedLabels, conflictedTexts = self._make_seed_labels(postText[c.summaryKey].values) with c.time_block("Get Note Topics: Get Stop Words"): stopWords = self._get_stop_words(postText[c.summaryKey].values) with c.time_block("Get Note Topics: Train Model"): # Define and fit model pipe = Pipeline( [ ( "UnigramEncoder", CountVectorizer( strip_accents="unicode", stop_words=stopWords, min_df=25, max_df=max(1000, int(0.25 * len(postText))), ), ), ("tfidf", TfidfTransformer()), ("Classifier", LogisticRegression(max_iter=1000, verbose=1)), ], verbose=True, ) pipe.fit( # Notice that we omit posts with an unclear label from training. postText[c.summaryKey].values[~conflictedTexts], seedLabels[~conflictedTexts], ) return pipe, seedLabels, conflictedTexts def get_note_topics( self, notes: pd.DataFrame, noteTopicClassifier: Optional[Pipeline] = None, seedLabels: Optional[np.ndarray] = None, conflictedTextsForAccuracyEval: Optional[np.ndarray] = None, ) -> pd.DataFrame: """Return a DataFrame specifying each {note, topic} pair. Notes that are not assigned to a topic do not appear in the dataframe. Args: notes: DF containing all notes to potentially assign to a topic """ logger.info("Assigning notes to topics:") if noteTopicClassifier is not None: pipe = noteTopicClassifier else: logger.info("Training note topic classifier") pipe, seedLabels, conflictedTextsForAccuracyEval = self.train_note_topic_classifier(notes) postText = self._prepare_post_text(notes) with c.time_block("Get Note Topics: Predict"): # Predict notes. Notice that in effect we are looking to see which notes in the # training data the model felt were mis-labeled after the training process # completed, and generating labels for any posts which were omitted from the # original training. logits = pipe.decision_function(postText[c.summaryKey].values) # Transform logits to probabilities, handling the case where logits are 1D because # of unit testing with only 2 topics. if len(logits.shape) == 1: probs = sigmoid(logits) probs = np.vstack([1 - probs, probs]).T else: probs = softmax(logits, axis=1) if seedLabels is None: with c.time_block("Get Note Topics: Make Seed Labels"): seedLabels, _ = self._make_seed_labels(postText[c.summaryKey].values) if conflictedTextsForAccuracyEval is not None: self.validate_note_topic_accuracy_on_seed_labels( np.argmax(probs, axis=1), seedLabels, conflictedTextsForAccuracyEval ) with c.time_block("Get Note Topics: Merge and assign predictions"): topicAssignments = self._merge_predictions_and_labels(probs, seedLabels) logger.info(f" Post Topic assignment results: {np.bincount(topicAssignments)}") # Assign topics to notes based on aggregated note text, and drop any # notes on posts that were unassigned. postText[c.noteTopicKey] = [Topics(t).name for t in topicAssignments] postText = postText[postText[c.noteTopicKey] != Topics.Unassigned.name] noteTopics = notes[[c.noteIdKey, c.tweetIdKey]].merge( postText[[c.tweetIdKey, c.noteTopicKey]] ) logger.info( f" Note Topic assignment results:\n{noteTopics[c.noteTopicKey].value_counts(dropna=False)}" ) return noteTopics.drop(columns=c.tweetIdKey) def validate_note_topic_accuracy_on_seed_labels(self, pred, seedLabels, conflictedTexts): balancedAccuracy = balanced_accuracy_score(seedLabels[~conflictedTexts], pred[~conflictedTexts]) logger.info(f" Balanced accuracy on raw predictions: {balancedAccuracy}") assert balancedAccuracy > 0.5, f"Balanced accuracy too low: {balancedAccuracy}" # Validate that any conflicted text is Unassigned in seedLabels assert all(seedLabels[conflictedTexts] == Topics.Unassigned.value)