sourcecode/scoring/mf_expansion_plus_scorer.py (73 lines of code) (raw):

from typing import Dict, List, Optional from . import constants as c from .mf_base_scorer import MFBaseScorer class MFExpansionPlusScorer(MFBaseScorer): def __init__( self, seed: Optional[int] = None, useStableInitialization: bool = True, saveIntermediateState: bool = False, threads: int = c.defaultNumThreads, ) -> None: """Configure MFExpansionPlusScorer object. Args: seed: if not None, seed value to ensure deterministic execution threads: number of threads to use for intra-op parallelism in pytorch """ super().__init__( includedGroups=(c.coreGroups | c.expansionGroups | c.expansionPlusGroups), includeUnassigned=True, seed=seed, pseudoraters=False, useStableInitialization=useStableInitialization, saveIntermediateState=saveIntermediateState, threads=threads, ) def get_name(self): return "MFExpansionPlusScorer" def _get_note_col_mapping(self) -> Dict[str, str]: """Returns a dict mapping default note column names to custom names for a specific model.""" return { c.internalNoteInterceptKey: c.expansionPlusNoteInterceptKey, c.internalNoteFactor1Key: c.expansionPlusNoteFactor1Key, c.internalRatingStatusKey: c.expansionPlusRatingStatusKey, c.internalActiveRulesKey: c.expansionPlusInternalActiveRulesKey, c.numFinalRoundRatingsKey: c.expansionPlusNumFinalRoundRatingsKey, c.lowDiligenceNoteInterceptKey: c.lowDiligenceLegacyNoteInterceptKey, } def _get_user_col_mapping(self) -> Dict[str, str]: """Returns a dict mapping default user column names to custom names for a specific model.""" return { c.internalRaterInterceptKey: c.expansionPlusRaterInterceptKey, c.internalRaterFactor1Key: c.expansionPlusRaterFactor1Key, } def get_scored_notes_cols(self) -> List[str]: """Returns a list of columns which should be present in the scoredNotes output.""" return [ c.noteIdKey, c.expansionPlusNoteInterceptKey, c.expansionPlusNoteFactor1Key, c.expansionPlusRatingStatusKey, c.expansionPlusInternalActiveRulesKey, c.expansionPlusNumFinalRoundRatingsKey, ] def get_helpfulness_scores_cols(self) -> List[str]: """Returns a list of columns which should be present in the helpfulnessScores output.""" return [ c.raterParticipantIdKey, c.expansionPlusRaterInterceptKey, c.expansionPlusRaterFactor1Key, ] def get_auxiliary_note_info_cols(self) -> List[str]: """Returns a list of columns which should be present in the auxiliaryNoteInfo output.""" return [] def _get_dropped_note_cols(self) -> List[str]: """Returns a list of columns which should be excluded from scoredNotes and auxiliaryNoteInfo.""" return super()._get_dropped_note_cols() + ( [ c.activeFilterTagsKey, c.ratingWeightKey, c.noteInterceptMinKey, c.noteInterceptMaxKey, ] + c.notHelpfulTagsAdjustedColumns + c.notHelpfulTagsAdjustedRatioColumns + c.incorrectFilterColumns + c.noteParameterUncertaintyTSVAuxColumns ) def _get_dropped_user_cols(self) -> List[str]: """Returns a list of columns which should be excluded from helpfulnessScores output.""" return super()._get_dropped_user_cols() + [ c.crhCrnhRatioDifferenceKey, c.meanNoteScoreKey, c.raterAgreeRatioKey, c.aboveHelpfulnessThresholdKey, ]