clns-CORONET_tool/CORONET_functions.py (490 lines of code) (raw):

import numpy as np import joblib import matplotlib.pyplot as plt import os import pickle import pandas as pd import seaborn as sns from matplotlib.colors import LinearSegmentedColormap """ List of functions: predict_and_explain - main function that generates prediction and explanation load_predictive_model load_predictive_model_using_pickle load_explainer transform_x_values get_prediction_for_x get_shap_values_for_x plot_local_explanation_shap generate_colorbar generate_plot_all_patients calculate_NEWS2 """ def predict_and_explain(x, model, explainer, plot_expl_barplot = True, path_to_save_plots=''): """ Predicts and explains the prediction using pre-trained model and explainer. It outputs two dictionaries: 1) prediction, 2) explanation Optionally, it can plot and save a barplot with shap values explaining the prediction. Before running the function, the model and the explainer should be loaded. Contains admission_threshold and severe_condition_threshold which are specified inside this function. Parameters ---------- x : dict A dictionary with keys = features, values = patient's parameters values. Dictionary format: {'NEWS2': value, 'CRP': value, 'Albumin': value, 'Age': value, 'Platelets': value, 'Neutrophil': value, 'Lymphocyte': value, 'Performance status: value', 'Total no. comorbidities': value } model : sklearn predictive model explainer : shap.TreeExplainer object plot_expl_barplot : bool - default True, if False the function does not generate the barplot with the explanation path_to_save_plots : str - directory to save png file with the figure Returns ------ prediction : dict a dictionary with keys: 'Predicted_score' and 'Recommendation' explanation : dict a dictionary with shap values for each feature sorted by absolute value """ admission_threshold = 1.0 severe_condition_threshold = 2.3 x_trans = transform_x_values(x) prediction = get_prediction_for_x(x_trans, model, admission_threshold, severe_condition_threshold) explanation = get_shap_values_for_x(x_trans, explainer, sort_explanation=True) if plot_expl_barplot: plot_local_explanation_shap(explanation, x_trans, path_to_save_plots) return prediction, explanation def load_predictive_model(file_path): """ Loads predictive model stored in a .pkl file from 'file_path' The model is a Random Forest model trained using sklearn library and saved to .pkl file using joblib library (using joblib.dump command) required libraries: joblib :param file_path: :return: ----------- model """ model = joblib.load(file_path) return model def load_predictive_model_using_pickle(file_path): """ Loads predictive model stored in a .pkl file from 'file_path' The model is a Random Forest model trained using sklearn library and saved to .pkl or to .sav file using pickle library Commands used for saving the model: pickle_file = ...\CORONET_model\RF_model_pickle.pkl' pickle_file = ...\CORONET_model\RF_model_pickle.sav' pickle.dump(coronet_RF_model, open(pickle_file, 'wb')) (using pickle.dumps command) required libraries: pickle :param file_path: :return: ----------- model """ model = pickle.load(open(file_path, 'rb')) return model def load_explainer(file_path): """ Loads explainer stored in a .pkl file from 'file_path' The explainer is a TreeExplainer from SHAP library (https://github.com/slundberg/shap) created using function shap.TreeExplainer(model), where 'model' is the predictive model used in CORONET and saved to .pkl file using joblib library (using joblib.dump command) required libraries: joblib :param file_path: :return: """ explainer = joblib.load(file_path) return explainer def load_explainer_using_pickle(file_path): """ Loads explainer stored in a .pkl file from 'file_path' The explainer is a TreeExplainer from SHAP library (https://github.com/slundberg/shap) created using function shap.TreeExplainer(model), where 'model' is the predictive model used in CORONET and saved to .pkl file using pickle library required libraries: pickle :param file_path: :return: """ explainer = pickle.load(open(file_path, 'rb')) return explainer def transform_x_values(x): """ Calculates NLR (Neutrophil:Lymphocyte Ratio) If Lymphocyte < 0.1, then Lymhpcyte is set to 0.1 before calculating the NLR Parameters: ---------- param x : dict x : dict A dictionary with keys = features, values = patient's parameters values. Dictionary format: {'NEWS2': value, 'CRP': value, 'Albumin': value, 'Age': value, 'Platelets': value, 'Neutrophil': value, 'Lymphocyte': value, 'Performance status: value', 'Total no. comorbidities': value } Return: ------ x_transformed : dict A dictionary with keys = features, values = patient's parameters values. Dictionary format: {'NEWS2': value, 'CRP': value, 'Albumin': value, 'Age': value, 'Platelets': value, 'Neutrophil': value, 'Lymphocyte': value, 'Performance status: value', 'Total no. comorbidities': value 'NLR': value} """ x_transformed = x.copy() if x_transformed['Lymphocyte'] < 0.1: x_transformed['Lymphocyte'] = 0.1 x_transformed['NLR'] = x_transformed['Neutrophil']/x_transformed['Lymphocyte'] # Reorder dict after adding NLR - this is desired_order_list = ['NEWS2', 'CRP', 'Albumin', 'Age', 'Platelets', 'Neutrophil', 'Performance status', 'Lymphocyte', 'Total no. comorbidities', 'NLR'] x_transformed = {k: x_transformed[k] for k in desired_order_list} return x_transformed def get_prediction_for_x(x, model, admission_threshold, severe_condition_threshold): """ Calculates the score and assigns recommendation based on given thresholds. Uses transformed x (with calculated NLR) and predictive model and calculates the score (range 0.0-3.0), It also outputs a string with a recommendation from the list of three: - 'consider discharge' - 'consider admission' - 'high risk of severe condition' Parameters: ----------- x : dict A dictionary with keys = features, values = patient's parameters values. Dictionary format: {'NEWS2': value, 'CRP': value, 'Albumin': value, 'Age': value, 'Platelets': value, 'Neutrophil': value, 'Lymphocyte': value, 'Performance status: value', 'Total no. comorbidities': value 'NLR': value} model : sklearn predictive model admission_threshold : float A threshold defined by the researcher. Above this value all recommendation will be 'consider admission' or 'high risk of severe condition'. Below this value all recommendation will be 'consider discharge'. severe_condition_threshold : float A threshold defined by the researcher. Above this value all recommendation will be 'high risk of severe condition'. Below this value all recommendation will be 'consider discharge' or 'consider admission'. Return ------ prediction : dict a dictionary with predicted score (str, the score rounded to 2 decimals) and textual recommendation (str). Dictionary format (example values): {'Predicted_score': '0.95', 'Recommendation': 'consider discharge'} """ x_to_model = np.array(list(x.values())).reshape(1, -1) predicted_score = np.round(model.predict(x_to_model)[0], 2) recommendations = ['consider discharge', 'consider admission', 'high risk of severe condition'] if predicted_score < admission_threshold: recommendation = recommendations[0] elif predicted_score > severe_condition_threshold: recommendation = recommendations[2] else: recommendation = recommendations[1] # convert to string with 2 decimal places (for consitency in showing the coronet score to the user) predicted_score = f'{predicted_score:.2f}' prediction = {'Predicted_score': predicted_score, 'Recommendation': recommendation} return prediction def get_shap_values_for_x(x, explainer, sort_explanation=True): """ Computes shapley values of local explanation for 'x'. Uses 'explainer' which is an explainer object from shap library. Generated 'explanation' can sorted (default) or in the same order as 'x'. Parameters: ---------- x : dict A dictionary with keys = features, values = patient's parameters values. CRP and NLR values should be transformed. Dictionary format: {'NEWS2': value, 'CRP': value, 'Albumin': value, 'Age': value, 'Platelets': value, 'Neutrophil': value, 'Lymphocyte': value, 'Performance status: value', 'Total no. comorbidities': value 'NLR': value} explainer : shap.Explainer object sort_explanation : bool default True, if False the keys of explanation dict will be in the same order as x. If True, the explanation dict will be sorted by absolute value of shap value (the highest - most important - are at the bottom) Return: ------- explanation : dict a dictionary with shap values for each feature sorted by absolute value (sorting is optional but default True) Dictionary format: {'NEWS2': shap_value, 'CRP': shap_value, 'Albumin': shap_value, 'Age': shap_value, 'Platelets': shap_value, 'Neutrophil': shap_value, 'Lymphocyte': shap_value, 'Performance status: shap_value', 'Total no. comorbidities': shap_value 'NLR': shap_value} """ x_to_model = np.array(list(x.values())) features = list(x.keys()) shap_values = np.round(explainer.shap_values(x_to_model), 4) explanation = {} for i, feature in enumerate(features): explanation[feature] = shap_values[i] if sort_explanation: explanation = {k: v for k, v in sorted(explanation.items(), key=lambda item: np.abs(item[1]), reverse=False)} return explanation def plot_local_explanation_shap(shap_dict, x, path_to_save, example_no=0): """ Plot a red-green barplot showing the contribution of each feature to the prediction. The contribution is equal to shap value for given feature. Negative shap values contribute to the 'consider discharge' recommendation and are represented as green bars on the left side of the plot. Positive shap values contribute to the 'consider admission' or 'high risk of severe condition' recommendation and are represented as red bars on the right side of the plot. Next to the bars a value of given parameter is shown in a textbox. Important: bar width corresponds to the shap value, not to the parameter value displayed in the textbox. Saves the figure as 'local_explanation_shap.png' Parameters: ----------- shap_dict : : dict a dictionary with shap values for each feature sorted by absolute value (sorting is optional but default True) Dictionary format: {'NEWS2': shap_value, 'CRP': shap_value, 'Albumin': shap_value, 'Age': shap_value, 'Platelets': shap_value, 'Neutrophil': shap_value, 'Lymphocyte': shap_value, 'Performance status: shap_value', 'Total no. comorbidities': shap_value 'NLR': shap_value} x : dict A dictionary with keys = features, values = patient's parameters values. Dictionary format: {'NEWS2': value, 'CRP': value, 'Albumin': value, 'Age': value, 'Platelets': value, 'Neutrophil': value, 'Lymphocyte': value, 'Performance status: value', 'Total no. comorbidities': value 'NLR': value} path_to_save : str Directory where the png file with figure should be saved. Returns: ------- """ # sort shap_values dictionary by absolute value shap_dict_sorted = shap_dict#{k: v for k, v in sorted(shap_dict.items(), key=lambda item: np.abs(item[1]), reverse=False)} fig, ax = plt.subplots(figsize=(13, 6)) features = list(shap_dict_sorted.keys()) values = list(shap_dict_sorted.values()) # plot barplot bars = ax.barh(width=values, y=features, linewidth=1, edgecolor='black') # assing bar colors (red for features voting for 'admission', green for features voting for 'discharge') for j, bar in enumerate(bars): if values[j] < 0: bar.set_color('green') else: bar.set_color('red') bar.set_edgecolor('k') ax.set_xticklabels([None]) # add arrows at the top props = dict(boxstyle="larrow,pad=0.3", facecolor='white', alpha=1) text = 'DISCHARGE' ax.text(0.45, 1.05, text, bbox=props, transform=ax.transAxes, va='bottom', ha='right', fontsize=15) props = dict(boxstyle="rarrow,pad=0.3", facecolor='white', alpha=1) text = 'ADMISSION' ax.text(0.55, 1.05, text, bbox=props, transform=ax.transAxes, va='bottom', ha='left', fontsize=15) ax.set_yticklabels([None]) # add bars description (feature name and its value, i.e. real value, not the shap value) for m in range(len(values)): parameter = features[m] shap_value = values[m] if parameter == 'NLR': text = 'Neutrophil:Lymphocyte Ratio' + ' = ' + str(np.round(x['NLR'], 1)) elif parameter == 'CRP': # RF model uses transformed CRP, but to show the CRP value on the plot, we need to refer to initial 'x' instead of 'x_transformed' unit = 'mg/L' text = 'C-reactive protein' + ' (' + unit + ') = ' + str(np.round(x['CRP'], 1)) elif parameter == 'Albumin': unit = 'g/L' text = parameter + ' (' + unit + ') = ' + str(np.round(x[parameter], 0)) elif parameter == 'Lymphocyte': unit = 'x10^9/L' text = parameter + ' (' + unit + ') = ' + str(np.round(x[parameter], 1)) elif parameter == 'Neutrophil': unit = 'x10^9/L' text = parameter + ' (' + unit + ') = ' + str(np.round(x[parameter], 1)) elif parameter == 'Platelets': unit = 'x10^9/L' text = parameter + ' (' + unit + ') = ' + str(np.round(x[parameter], 0)) else: text = parameter + ' = ' + str(np.int(np.round(x[parameter], 0))) if shap_value > 0: ha = 'left' else: ha = 'right' ax.text(shap_value + 0.008 * np.sign(shap_value), m, text, ha=ha, va='center', fontsize=18) #ax.set_xlim([-np.abs(values).max() - 0.1, np.abs(values).max() + .1]) ax.set_xlim([-.65, .65]) # remove axes lines ax.set_frame_on(False) ax.axes.get_yaxis().set_visible(False) ax.axes.get_xaxis().set_visible(False) plt.subplots_adjust(top=1.1) plt.tight_layout() #path_to_save = os.path.join(path_to_save, 'local_explanation_shap.png') path_to_save = os.path.join(path_to_save, 'local_explanation_shap_example_{}.png'.format(example_no)) plt.savefig(path_to_save, dpi=400) def generate_colorbar(admission_threshold, severe_condition_threshold, path_to_save): """ Generates and saves a colorbar with a gradient colors green-->yellow-->red-->black. The color transition values are defined by 'admission_threshold 'severe_condition_threshold: - the center of the yellow field is defined by 'admission threshold' value - the center of the red field is defined by 'severe_condition_threshold' value Saves the figure as 'colorbar_03score.png' Parameters: ----------- admission_threshold : float Defined by the researcher. severe_condition_threshold : float Defined by the researcher. path_to_save : str Directory to save the figure with the colorbar Return ------ """ thresh1 = admission_threshold / 3 thresh2 = severe_condition_threshold / 3 nodes = [0, thresh1, thresh2, 1.0] colors = ["green", "yellow", "red", "black"] cmap = LinearSegmentedColormap.from_list("", list(zip(nodes, colors))) gradient = np.linspace(0, 3, 300).reshape(1, -1) fig, ax = plt.subplots(figsize=(15, 3)) ax.imshow(gradient, extent=[-0.0, 3, -1, 1], aspect='auto', cmap=cmap) # 'RdYlGn_r') ax.set_yticks([]) ax.tick_params(axis='both', which='major', size=15) plt.xticks(fontsize=14) plt.tight_layout() path_to_save = os.path.join(path_to_save, 'colorbar_03score.png') plt.savefig(path_to_save + '', dpi=300) def generate_plot_all_patients(df, path_to_save): """ Generates a dot plot (i.e. swarmplot from seaborn library) for all the patients from the training set. The score for each patient is calculated according to LOOCV - model trained on all samples except one. Each dot represents a score predicted by the model for given patient. Dots are colored by the true outcome. It serves as an explanation 'where my patient is in the whole cohort in terms of predicted score' Saves the figure as 'plot_all_patients.png' Saves the figure as 'plot_all_patients_separated.png' Parameters: ----------- df : DataFrame Dataframe with 3 columns: index, y_pred, y_true, constant 0, 0.633, 0, 1 1, 0.635, 0, 1 2, 0.651, 0, 1 3, 0.652, 0, 1 ... ... ... ... path_to_save : str Returns ------- """ colors = ['green', 'gold', 'red', 'black'] sns.set_palette(sns.color_palette(colors)) fig, ax = plt.subplots(figsize=(13, 6)) g = sns.swarmplot(x='y_pred', y='constant', hue='y_true', data=df, ax=ax, zorder=1, orient='h', size=7) handles, labels = g.get_legend_handles_labels() g.legend(handles, ['Discharged', 'Admitted', 'Required O2', 'Death due to COVID'], title='Outcome', fontsize=12, framealpha=1, bbox_to_anchor=(1.01, 0.5), loc=6, borderaxespad=0., edgecolor='k') ax.set_xticks([1, 2, 3]) ax.set_xticklabels([1, 2, 3]) ax.set_ylabel(None) ax.set_xlabel(None) ax.set_yticks([]) ax.set_xlim([0.1, 3]) plt.tight_layout() path_to_save_joined = os.path.join(path_to_save, 'plot_all_patients.png') plt.savefig(path_to_save_joined, dpi=300) fig, ax = plt.subplots(figsize=(13, 7)) g = sns.swarmplot(x='y_pred', y='y_true', hue='y_true', data=df, ax=ax, zorder=1, orient='h', size=7) handles, labels = g.get_legend_handles_labels() g.legend(handles, ['Discharged', 'Admitted', 'Required O2', 'Death due to COVID'], title='Outcome', fontsize=12, framealpha=1, bbox_to_anchor=(1.01, 0.5), loc=6, borderaxespad=0., edgecolor='k') ax.set_xticks([1, 2, 3]) ax.set_xticklabels([1, 2, 3]) ax.set_ylabel(None) ax.set_xlabel(None) ax.set_yticks([]) ax.set_xlim([0.1, 3]) plt.tight_layout() path_to_save_separated = os.path.join(path_to_save, 'plot_all_patients_separated.png') plt.savefig(path_to_save_separated, dpi=300) def calculate_NEWS2(x): ''' Function to calculate NEWS2 score https://www.mdcalc.com/national-early-warning-score-news-2 INPUT: dictionary x = {'Respiratory Rate (bpm)': int, 'Hypercapnic respiratory failure': str, -string 'Yes' or 'No' 'SpO2 (%)': int, 'Supplemental O2': str, -string 'Yes' or 'No' 'Systolic BP (mmHg)': int, 'Heart Rate (bpm)':int 'Consciousness': 'str, -string 'Yes' or 'No' 'Temperature (degrees of C)': float, - decimal } OUTPUT: NEWS2_score - int ''' rr_score = 0 sat_score = 0 supp_o2_score = 0 Systolic_BP_score = 0 Heart_rate_score = 0 Consciousness_score = 0 Temperature_score = 0 # Respiratory Rate (bpm) if (11 >= x['Respiratory Rate (bpm)']) & (x['Respiratory Rate (bpm)'] >= 9): rr_score = 1 elif x['Respiratory Rate (bpm)'] <= 8: rr_score = 3 elif (24 >= x['Respiratory Rate (bpm)']) & (x['Respiratory Rate (bpm)'] >= 21): rr_score = 2 elif x['Respiratory Rate (bpm)'] >= 25: rr_score = 3 # SpO2 if x['Hypercapnic respiratory failure'] == 'No': # SpO2 (%) if x['SpO2 (%)'] <= 91: sat_score = 3 elif (92 <= x['SpO2 (%)']) & (x['SpO2 (%)'] <= 93): sat_score = 2 elif (94 <= x['SpO2 (%)']) & (x['SpO2 (%)'] <= 95): sat_score = 1 elif x['Hypercapnic respiratory failure'] == 'Yes': # SpO2 (%) if x['SpO2 (%)'] <= 83: sat_score = 3 elif (84 <= x['SpO2 (%)']) & (x['SpO2 (%)'] <= 85): sat_score = 2 elif (86 <= x['SpO2 (%)']) & (x['SpO2 (%)'] <= 87): sat_score = 1 elif (93 <= x['SpO2 (%)']) & (x['SpO2 (%)'] <= 94): sat_score = 1 elif (95<= x['SpO2 (%)']) & (x['SpO2 (%)'] <= 96): sat_score = 2 elif x['SpO2 (%)'] >= 97: sat_score = 3 # Supplemental O2 if x['Supplemental O2'] == 'Yes': supp_o2_score = 2 # Systolic BP (mmHg) if x['Systolic BP (mmHg)'] <= 90: Systolic_BP_score = 3 elif (91 <= x['Systolic BP (mmHg)']) & (x['Systolic BP (mmHg)'] <= 100): Systolic_BP_score = 2 elif (101 <= x['Systolic BP (mmHg)']) & (x['Systolic BP (mmHg)'] <= 110): Systolic_BP_score = 1 elif x['Systolic BP (mmHg)'] >= 220: Systolic_BP_score = 3 # Consciousness if x['Consciousness'] == 'No': Consciousness_score = 3 # Temperature if x['Temperature (degrees of C)'] <= 35: Temperature_score = 3 elif (35.1 <= x['Temperature (degrees of C)']) & (x['Temperature (degrees of C)'] <= 36): Temperature_score = 1 elif (38.1 <= x['Temperature (degrees of C)']) & (x['Temperature (degrees of C)'] <= 39): Temperature_score = 1 elif x['Temperature (degrees of C)'] >= 39.1: Temperature_score = 2 # Temperature if x['Heart Rate (bpm)'] <= 40: Heart_rate_score = 3 elif (41 <= x['Heart Rate (bpm)']) & (x['Heart Rate (bpm)'] <= 50): Heart_rate_score = 1 elif (91 <= x['Heart Rate (bpm)']) & (x['Heart Rate (bpm)'] <= 110): Heart_rate_score = 1 elif (111 <= x['Heart Rate (bpm)']) & (x['Heart Rate (bpm)'] <= 130): Heart_rate_score = 2 elif x['Heart Rate (bpm)'] >= 131: Heart_rate_score = 3 NEWS2_score = rr_score + sat_score + supp_o2_score + Systolic_BP_score + Consciousness_score + Temperature_score + Heart_rate_score #print(rr_score, sat_score, Systolic_BP_score, Consciousness_score, Temperature_score, Heart_rate_score) return NEWS2_score #------------------------------------------------------------------------------------------------------------------------ # Functions for finding N nearest patients from the dataset X based on input x def load_scaler(file_path): """ Loads StandardScaler stored in a .pkl file from 'file_path' required libraries: joblib :param: file_path: :return: ----------- scaler """ scaler = joblib.load(file_path) return scaler def prepare_input_x_to_KNearest(x, scaler, shap_weights, cols_from_shap_weights=None): ''' :param: x: dict scaler: scaler object from sklearn StandardScaler cols_from_shap_weights: optional, list of strings; used when defining custom search in the future shap_weigths: DataFrame, with absolute shap values normalized (all values divided by max) :return: x_to_query: dict, with scaled values using scaler and multiplied by shap weights ''' cols_in_scaler = ['NEWS2', 'CRP', 'Albumin', 'Age', 'Platelets', 'Neutrophil_log', 'Lymphocyte_log', 'Performance status', 'Total no. comorbidities', 'NLR_log' ] if cols_from_shap_weights==None: cols_from_shap_weights = cols_in_scaler else: pass x_trans = transform_x_values(x) x_trans['Neutrophil_log'] = np.log(x_trans['Neutrophil'] + 1) x_trans['Lymphocyte_log'] = np.log(x_trans['Lymphocyte'] + 1) x_trans['NLR_log'] = np.log(x_trans['NLR'] + 0.01) x_trans_df = pd.DataFrame.from_dict([x_trans])[cols_in_scaler] # scale x x_to_query = scaler.transform(x_trans_df.values) x_to_query = pd.DataFrame(x_to_query, columns = cols_in_scaler)[cols_from_shap_weights] # apply shap weights to scaled x x_to_query = x_to_query*shap_weights x_to_query =x_to_query[shap_weights.columns] return x_to_query def get_local_shap_importances(x_trans, explainer, cols_to_KNN=None): """ :param: x_trans: dict, patient data already transformed by transform_x_values function explainer: shap explainer object cols_to_KNN: list of strings, optional :return: shap_weigths: DataFrame with absolute shap values normalized (all values divided by max) """ if cols_to_KNN==None: cols_to_KNN = ['NEWS2', 'CRP', 'Albumin', 'Age', 'Platelets', 'Neutrophil', 'Lymphocyte', 'Performance status', 'Total no. comorbidities', 'NLR' ] else: pass shap_weigths = get_shap_values_for_x(x_trans, explainer, sort_explanation=False) shap_weigths = pd.DataFrame.from_dict([shap_weigths]) shap_weigths = shap_weigths.abs() shap_weigths = shap_weigths[cols_to_KNN] shap_weigths = shap_weigths / shap_weigths.max(axis=1).values[0] if 'Neutrophil' in shap_weigths.columns: shap_weigths['Neutrophil_log'] = shap_weigths['Neutrophil'] shap_weigths = shap_weigths.drop(columns='Neutrophil') if 'Lymphocyte' in shap_weigths.columns: shap_weigths['Lymphocyte_log'] = shap_weigths['Lymphocyte'] shap_weigths = shap_weigths.drop(columns = 'Lymphocyte') if 'NLR' in shap_weigths.columns: shap_weigths['NLR_log'] = shap_weigths['NLR'] shap_weigths = shap_weigths.drop(columns = 'NLR') return shap_weigths def transform_X_using_weigths(df_scaled, shap_weigths): """ :param: df_scaled: DataFrame, dataset with scaled values of 10 parameters shap_weigths: DataFrame, with absolute shap values normalized (all values divided by max) :return: X_weigthed: DataFrame, dataset with scaled values multiplied by shap weights """ cols_to_KNN = shap_weigths.columns X = df_scaled[cols_to_KNN].values X_weigthed = X * shap_weigths.values X_weigthed = pd.DataFrame(X_weigthed, columns=cols_to_KNN) return X_weigthed from sklearn.neighbors import BallTree def create_distance_BallTree(X): kdt = BallTree(X, leaf_size=30, metric='euclidean') return kdt def find_K_nearest(df_masked, x, kdt_BallTree, k=5, cols_to_show=None, index_filtered = None): """ :param: df_masked: DataFrame, dataset with masked patient data used to present details to the user x: dict, your patient data, prepared by using function 'prepare_input_x_to_KNearest' kdt_BallTree: BallTree object k: int, number of similar patients to look for cols_to_show: list of strings, optional index_filtered: optional, for the future... :return: df_k_nearest: DataFrame, with k columns, rows represent patients' parameters """ k_nearest_index = kdt_BallTree.query(x, k=100, return_distance=False)[0] if cols_to_show is None: cols_to_show = ['Outcome', 'Admitted/Discharged', 'Required O2', 'Death due to COVID19', 'CORONET score', 'CORONET recommendation', 'Biological Sex','Age', 'NEWS2','CRP','Albumin', 'Platelets', 'Lymphocyte', 'Neutrophil', 'NLR', 'LDH', 'Consciousness', 'Respiratory Rate', 'Oxygen saturation (SAT)', 'Chemotherapy', 'Immunotherapy', 'Targeted therapy', 'Radiotherapy', 'Total no. comorbidities', 'Performance status', 'Treatment intent', 'Early/advanced stage', 'Cancer type', 'Solid cancer stage'] #k_nearest_index = list(set.difference(set(np.where(index_filtered==True)[0]), set(k_nearest_index)))[:k] if index_filtered is None: k_nearest_index = k_nearest_index[:k] elif index_filtered.sum()>0: index_filtered=list(np.where(index_filtered==True)[0]) k_nearest_index = [i for i in k_nearest_index if i in index_filtered][:k] df_k_nearest = df_masked.loc[k_nearest_index, cols_to_show] df_k_nearest = df_k_nearest.T return df_k_nearest def prepate_df_to_show(x_trans, df_nearest): """ :param: x_trans: dict, patient data already transformed by transform_x_values function df_nearest: DataFrame, with k columns, rows represent patients' parameters :return: df_to_show: DataFrame, x_trans and df_nearest concatenated and with ordered index """ x_to_nearest = pd.DataFrame.from_dict([x_trans]).T x_to_nearest = x_to_nearest.rename(columns={0: 'Your patient'}) df_to_show = pd.concat((x_to_nearest, df_nearest), axis=1) df_to_show = df_to_show.fillna('-') df_to_show = df_to_show.loc[df_nearest.index, :] df_to_show.columns = ['Your patient','1st','2nd','3rd','4th','5th'] return df_to_show def mask_patients_in_df(df): df_masked = df.copy() # mask age, e.g.: 65 to '60s' df_masked['Age'] = ((df_masked['Age'] / 10).apply(np.floor) * 10).astype('int').astype('str') + 's' df_masked.loc[df_masked['Age']=='0s', 'Age'] = '<10' # crp_list = [list(np.arange(0, 10, 2)),list(np.arange(10, 100, 5)), list(np.arange(100, 1200, 20))] crp_list = [item for sublist in crp_list for item in sublist] df_masked['CRP'] = group_parameter(df['CRP'], groups_ranges=crp_list) df_masked['Platelets'] = group_parameter(df['Platelets'], groups_ranges=list(np.arange(0, 700, 10))) df_masked['Albumin'] = group_parameter(df['Albumin'], groups_ranges=list(np.arange(0, 120, 5))) neutro_list = [list(np.arange(0, 5, 0.5)),list(np.arange(5, 20, 1)), list(np.arange(20, 150, 2))] neutro_list = [float(item) for sublist in neutro_list for item in sublist] df_masked['Neutrophil'] = group_parameter(df['Neutrophil'], groups_ranges=neutro_list) lympho_list = [list(np.arange(0, 5, 0.5)),list(np.arange(5, 20, 1)), list(np.arange(20, 102, 2))] lympho_list = [item for sublist in lympho_list for item in sublist] df_masked['Lymphocyte'] = group_parameter(df['Lymphocyte'], groups_ranges=lympho_list) df_masked['NLR'] = group_parameter(df['NLR'], groups_ranges=list(np.arange(0, 120, 2))) df_masked['LDH'] = group_parameter(df['LDH'], groups_ranges=list(np.arange(0, 5000, 25))) df_masked['respiratory rate'] = group_parameter(df['respiratory rate'], groups_ranges=list(np.arange(0, 120, 2))) df_masked['SAT'] = group_parameter(df['SAT'], groups_ranges=list(np.arange(70, 106, 2))) yes_no_columns = ['Chemotherapy', 'Immunotherapy', 'Targetted_therapy', 'Radiotherapy','haematological_cancer'] for col in yes_no_columns: df_masked.loc[df_masked[col] == 0, col] = 'No' df_masked.loc[df_masked[col] == 1, col] = 'Yes' df_masked.loc[df_masked['Treatment_intent'] == 1, 'Treatment_intent'] = 'Curative' df_masked.loc[df_masked['Treatment_intent'] == 2, 'Treatment_intent'] = 'Palliative' df_masked['Outcome'] = '' df_masked.loc[df_masked['coronet_true_score'] == 0, 'Outcome'] = 'Discharged' df_masked.loc[df_masked['coronet_true_score'] == 1, 'Outcome'] = 'Admitted' df_masked.loc[df_masked['coronet_true_score'] == 2, 'Outcome'] = 'Admitted+O2' df_masked.loc[df_masked['coronet_true_score'] == 3, 'Outcome'] = 'Admitted+O2+died' df_masked.loc[df_masked['Total_nb_comorbidities']>4, 'Total_nb_comorbidities'] = '>4' df_masked['Cancer type'] = '-' df_masked.loc[df_masked['Cancer_stage'].isin([1,2,3,4]), 'Cancer type'] = 'Solid' df_masked.loc[df_masked['haematological_cancer']=='Yes', 'Cancer type'] = 'Haematological' df_masked.loc[df_masked['Cancer type'] == 'Haematological', 'Cancer_stage'] = '-' df_masked.loc[df_masked['Cancer_stage']==0, 'Cancer_stage'] = '-' df_masked.loc[df_masked['GCS_below15'] == 0, 'GCS_below15'] = 'fully conscious' df_masked.loc[df_masked['GCS_below15'] == 1, 'GCS_below15'] = 'GCS<15' # rename columns df_masked= df_masked.rename(columns = {'Sex':'Biological Sex', 'Targetted_therapy':'Targeted therapy', 'Treatment_intent':'Treatment intent', 'Total_nb_comorbidities':'Total no. comorbidities', 'Early_advanced_stage':'Early/advanced stage', 'Death_day_since_admission':'Death day since admission', 'respiratory rate': 'Respiratory Rate', 'SAT':'Oxygen saturation (SAT)', 'Cancer_stage':'Solid cancer stage', 'haematological_cancer':'Haematological cancer', '4C':'4C score', 'GCS_below15':'Consciousness', 'Performance_status':'Performance status' }) df_masked['Admitted/Discharged'] = 'Discharged' df_masked['Required O2'] = 'No' df_masked['Death due to COVID19'] = 'No' df_masked.loc[df_masked['coronet_true_score'] >0 , 'Admitted/Discharged'] = 'Admitted' df_masked.loc[df_masked['coronet_true_score'] >=2 ,'Required O2'] = 'Yes' df_masked.loc[df_masked['coronet_true_score'] ==3 ,'Death due to COVID19'] = 'Yes' df_masked = df_masked.fillna('-') return df_masked def group_parameter(series, groups_ranges=[0, 5, 15, 50, 150, 1000]): series_of_labels = series.copy() label = np.nan for i in range(len(groups_ranges) - 1): range_start = groups_ranges[i] range_end = groups_ranges[i + 1] index_ = (series >= range_start) & (series < range_end) if i == 0: label = '<' + str(range_end) elif i == (len(groups_ranges) - 2): label = '>' + str(range_start) else: label = str(range_start) + '-' + str(range_end) series_of_labels[index_] = label return series_of_labels def filter_cancer_type(df, haem_cancer=True, solid_cancer=True, missing_cancer=True): if haem_cancer: index_haem = df['Cancer type'] == 'Haematological' else: index_haem = df['Cancer type'].astype('bool') * False if solid_cancer: index_solid = df['Cancer type'] == 'Solid' else: index_solid = df['Cancer type'].astype('bool') * False if missing_cancer: index_missing = df['Cancer type'] == '-' else: index_missing = df['Cancer type'].astype('bool') * False index_ = index_haem | index_solid | index_missing return index_