databricks/lib/repository/ground_truth/stats.py (441 lines of code) (raw):
from typing import Any, Dict, List, Optional, Sequence, Tuple
import lib.spark_helper.predictions as predictions_helper
from lib.spark_helper.ground_truth import GroundTruthFileStorage
from lib.spark_helper.predictions import Prediction
from sklearn.metrics import accuracy_score, precision_score, recall_score
from tabulate import tabulate
def replace_values(
input_list: List[Any], replacements: Dict[Any, Any]
) -> List[Any]:
return [replacements.get(item, item) for item in input_list]
def convert_classes_to_binary(
ground_truth: List[str], predicted: List[str]
) -> Tuple[List[int], List[int]]:
replacements = {"Compliant": 0, "Minor Issues": 1, "Red Flag": 1}
return replace_values(ground_truth, replacements), replace_values(
predicted, replacements
)
def get_annotation_from_revision_by_category(
revision: Dict[str, List[Dict[str, List[Dict[str, str]]]]], category: str
) -> Optional[str]:
for annotation in revision["pages"][0]["objs"]:
if annotation.get("category") == category:
return annotation.get("text")
return None
class StatsCalculator:
def __init__(
self,
temporary_storage: predictions_helper.TemporaryStorage,
ground_truth_storage: GroundTruthFileStorage,
) -> None:
self.temporary_storage = temporary_storage
self.ground_truth_storage = ground_truth_storage
self.predictions: List[Prediction] = []
def get_predictions(self, job_ids: Sequence[int]) -> None:
self.predictions = []
for job_id in job_ids:
self.predictions.extend(
self.temporary_storage.load_predictions(job_id=job_id)
)
def get_predictions_by_job_id(self, job_id: int) -> List[Prediction]:
return [
prediction
for prediction in self.predictions
if prediction.job_id == job_id
]
def calculate_accuracy_from_predictions(
self, predictions: List[Prediction]
) -> Any:
ground_truth, predicted = self.generate_prediction_and_truth_lists(
predictions
)
if not ground_truth or not predicted:
return None
ground_truth_binary, predicted_binary = convert_classes_to_binary(
ground_truth, predicted
)
return accuracy_score(ground_truth_binary, predicted_binary)
def calculate_precision_from_predictions(
self, predictions: List[Prediction]
) -> Any:
ground_truth, predicted = self.generate_prediction_and_truth_lists(
predictions
)
if not ground_truth or not predicted:
return None
ground_truth_binary, predicted_binary = convert_classes_to_binary(
ground_truth, predicted
)
return precision_score(
ground_truth_binary, predicted_binary, zero_division=0
)
def calculate_recall_from_predictions(
self, predictions: List[Prediction]
) -> Any:
ground_truth, predicted = self.generate_prediction_and_truth_lists(
predictions
)
if not ground_truth or not predicted:
return None
ground_truth_binary, predicted_binary = convert_classes_to_binary(
ground_truth, predicted
)
return recall_score(
ground_truth_binary, predicted_binary, zero_division=0
)
def generate_prediction_and_truth_lists(
self, predictions: List[Prediction]
) -> Tuple[List[str], List[str]]:
predicted: List[str] = []
ground_truth: List[str] = []
for prediction in predictions:
try:
revision = self.ground_truth_storage.read_revision_file(
file_id=prediction.file_id,
revision_id=(
prediction.ground_truth_revision_id
if prediction.ground_truth_revision_id
else ""
),
)
except Exception:
continue
for category in prediction.prediction_result:
annotation = get_annotation_from_revision_by_category(
revision, category
)
if not annotation:
continue
predicted.append(prediction.prediction_result[category])
ground_truth.append(annotation)
return ground_truth, predicted
def calculate_jobs_accuracy(self, job_ids: List[int]) -> Dict[int, Any]:
accuracy = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
accuracy[job_id] = self.calculate_accuracy_from_predictions(
job_predictions
)
return accuracy
def calculate_files_accuracy(
self, job_ids: List[int]
) -> Dict[int, Dict[int, Dict[str, Any]]]:
files_accuracy: Dict[int, Dict[int, Dict[str, Any]]] = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
for prediction in job_predictions:
file_accuracy = self.calculate_accuracy_from_predictions(
[prediction]
)
if not file_accuracy:
continue
if prediction.file_id not in files_accuracy:
files_accuracy[prediction.file_id] = {}
files_accuracy[prediction.file_id][prediction.job_id] = {}
files_accuracy[prediction.file_id][prediction.job_id][
"accuracy"
] = file_accuracy
files_accuracy[prediction.file_id][prediction.job_id][
"categories"
] = {}
try:
revision = self.ground_truth_storage.read_revision_file(
file_id=prediction.file_id,
revision_id=(
prediction.ground_truth_revision_id
if prediction.ground_truth_revision_id
else ""
),
)
except Exception:
continue
for category in prediction.prediction_result:
annotation = get_annotation_from_revision_by_category(
revision, category
)
if not annotation:
continue
predicted = prediction.prediction_result[category]
files_accuracy[prediction.file_id][prediction.job_id][
"categories"
][category] = int(predicted == annotation)
return files_accuracy
def get_accuracy_rows(
self, job_ids: List[int], include_files: bool, include_categories: bool
) -> List[List[Any]]:
jobs_accuracy = self.calculate_jobs_accuracy(job_ids)
rows: List[List[Any]] = []
rows.append(
["accuracy"] + [jobs_accuracy[job_id] for job_id in job_ids]
)
if not include_files:
return rows
files_accuracy = self.calculate_files_accuracy(job_ids)
for file_id in files_accuracy:
row = ["- " + str(file_id)]
for job_id in job_ids:
row.append(files_accuracy[file_id][job_id]["accuracy"])
rows.append(row)
if include_categories:
for category in files_accuracy[file_id][job_id]["categories"]:
row = ["-- " + category]
for job_id in job_ids:
row.append(
files_accuracy[file_id][job_id]["categories"][
category
]
)
rows.append(row)
return rows
def calculate_jobs_precision(self, job_ids: List[int]) -> Dict[int, Any]:
precision = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
precision[job_id] = self.calculate_precision_from_predictions(
job_predictions
)
return precision
def calculate_jobs_recall(self, job_ids: List[int]) -> Dict[int, Any]:
recall = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
recall[job_id] = self.calculate_recall_from_predictions(
job_predictions
)
return recall
def calculate_files_precision(
self, job_ids: List[int]
) -> Dict[int, Dict[int, Any]]:
files_precision: Dict[int, Dict[int, Optional[float]]] = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
for prediction in job_predictions:
precision = self.calculate_precision_from_predictions(
[prediction]
)
if not precision:
continue
if prediction.file_id not in files_precision:
files_precision[prediction.file_id] = {}
files_precision[prediction.file_id][
prediction.job_id
] = precision
return files_precision
def calculate_files_recall(
self, job_ids: List[int]
) -> Dict[int, Dict[int, Any]]:
files_recall: Dict[int, Dict[int, Optional[float]]] = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
for prediction in job_predictions:
recall = self.calculate_recall_from_predictions([prediction])
if not recall:
continue
if prediction.file_id not in files_recall:
files_recall[prediction.file_id] = {}
files_recall[prediction.file_id][prediction.job_id] = recall
return files_recall
def get_precision_rows(
self, job_ids: List[int], include_files: bool
) -> List[List[Any]]:
jobs_precision = self.calculate_jobs_precision(job_ids)
files_precision = self.calculate_files_precision(job_ids)
rows: List[List[Any]] = []
rows.append(
["precision"] + [jobs_precision[job_id] for job_id in job_ids]
)
if not include_files:
return rows
for file_id in files_precision:
row = ["- " + str(file_id)]
for job_id in job_ids:
row.append(files_precision[file_id][job_id])
rows.append(row)
return rows
def get_recall_rows(
self, job_ids: List[int], include_files: bool
) -> List[List[Any]]:
jobs_recall = self.calculate_jobs_recall(job_ids)
files_recall = self.calculate_files_recall(job_ids)
rows: List[List[Any]] = []
rows.append(["recall"] + [jobs_recall[job_id] for job_id in job_ids])
if not include_files:
return rows
for file_id in files_recall:
row = ["- " + str(file_id)]
for job_id in job_ids:
row.append(files_recall[file_id][job_id])
rows.append(row)
return rows
def get_precision_recall_rows(
self, job_ids: List[int], include_files: bool
) -> List[List[Any]]:
rows: List[List[Any]] = []
rows.extend(self.get_precision_rows(job_ids, include_files))
rows.extend(self.get_recall_rows(job_ids, include_files))
return rows
def calculate_stats_by_category(
self, job_ids: List[int]
) -> Dict[str, Dict[int, Dict[str, Any]]]:
stats: Dict[str, Dict[int, Dict[str, Any]]] = {}
for job_id in job_ids:
job_predictions = self.get_predictions_by_job_id(job_id)
for prediction in job_predictions:
try:
revision = self.ground_truth_storage.read_revision_file(
file_id=prediction.file_id,
revision_id=(
prediction.ground_truth_revision_id
if prediction.ground_truth_revision_id
else ""
),
)
except Exception:
continue
for category in prediction.prediction_result:
if category not in stats:
stats[category] = {}
if job_id not in stats[category]:
stats[category][job_id] = {
"ground_truth": [],
"predicted": [],
}
annotation = get_annotation_from_revision_by_category(
revision, category
)
if not annotation:
continue
stats[category][job_id]["ground_truth"].append(annotation)
stats[category][job_id]["predicted"].append(
prediction.prediction_result[category]
)
for category in stats:
for job_id in stats[category]:
ground_truth = stats[category][job_id]["ground_truth"]
predicted = stats[category][job_id]["predicted"]
ground_truth, predicted = convert_classes_to_binary(
ground_truth, predicted
)
stats[category][job_id]["precision"] = precision_score(
ground_truth, predicted, zero_division=0
)
stats[category][job_id]["recall"] = recall_score(
ground_truth, predicted, zero_division=0
)
stats[category][job_id]["accuracy"] = accuracy_score(
ground_truth, predicted
)
return stats
def get_category_rows(self, job_ids: List[int]) -> List[List[Any]]:
stats = self.calculate_stats_by_category(job_ids)
rows = []
rows.extend(self.get_precision_rows(job_ids, include_files=False))
for category in stats:
row = ["- " + category]
for job_id in job_ids:
row.append(stats[category][job_id]["precision"])
rows.append(row)
rows.extend(self.get_recall_rows(job_ids, include_files=False))
for category in stats:
row = ["- " + category]
for job_id in job_ids:
row.append(stats[category][job_id]["recall"])
rows.append(row)
rows.extend(
self.get_accuracy_rows(
job_ids, include_files=False, include_categories=False
)
)
for category in stats:
row = ["- " + category]
for job_id in job_ids:
row.append(stats[category][job_id]["accuracy"])
rows.append(row)
return rows
def avg_summary_by_jobs(self, job_ids: List[int]) -> None:
self.get_predictions(job_ids)
table_rows: List[List[Any]] = []
table_rows.append(
["job_id"] + [str(job_id) for job_id in job_ids]
) # header row
table_rows.extend(
self.get_precision_recall_rows(
job_ids,
include_files=False,
)
)
table_rows.extend(
self.get_accuracy_rows(
job_ids, include_files=False, include_categories=False
)
)
print(
tabulate(
table_rows,
headers="firstrow",
tablefmt="simple_grid",
floatfmt=".2f",
)
)
def avg_category_by_jobs(self, job_ids: List[int]) -> None:
self.get_predictions(job_ids)
table_rows: List[List[Any]] = []
table_rows.append(
["job_id"] + [str(job_id) for job_id in job_ids]
) # header row
table_rows.extend(self.get_category_rows(job_ids))
print(
tabulate(
table_rows,
headers="firstrow",
tablefmt="simple_grid",
floatfmt=".2f",
)
)
def avg_file_by_jobs(self, job_ids: List[int]) -> None:
self.get_predictions(job_ids)
table_rows: List[List[Any]] = []
table_rows.append(
["job_id"] + [str(job_id) for job_id in job_ids]
) # header row
table_rows.extend(
self.get_precision_recall_rows(
job_ids,
include_files=True,
)
)
table_rows.extend(
self.get_accuracy_rows(
job_ids, include_files=True, include_categories=False
)
)
print(
tabulate(
table_rows,
headers="firstrow",
tablefmt="simple_grid",
floatfmt=".2f",
)
)
def avg_file_and_category_by_jobs(self, job_ids: List[int]) -> None:
self.get_predictions(job_ids)
table_rows: List[List[Any]] = []
table_rows.append(
["job_id"] + [str(job_id) for job_id in job_ids]
) # header row
table_rows.extend(
self.get_precision_recall_rows(job_ids, include_files=True)
)
table_rows.extend(
self.get_accuracy_rows(
job_ids, include_files=True, include_categories=True
)
)
print(
tabulate(
table_rows,
headers="firstrow",
tablefmt="simple_grid",
floatfmt=".2f",
)
)