databricks/lib/spark_helper/predictions.py (112 lines of code) (raw):
import json
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from lib.spark_helper.db_service import SparkDBService
from lib.spark_helper.storage_service import SparkStorageService
@dataclass
class ModelParams:
run_name: Optional[str] = None
model_name: Optional[str] = None
model_version: Optional[str] = None
temperature: Optional[float] = None
prompt: Optional[Any] = None
json_schema: Optional[Any] = None
@dataclass
class Prediction:
job_id: int
file_id: int
ground_truth_revision_id: Optional[str]
model_params: ModelParams
prediction_result: Dict[str, Any]
created_date: datetime
class TemporaryStorage:
VOLUME_NAME = "predictions"
STORAGE_PATH = VOLUME_NAME + "/{job_id}/{file_id}.json"
def __init__(self, storage_service: SparkStorageService) -> None:
self.storage_service = storage_service
self.storage_service.create_volume_if_not_exists(self.VOLUME_NAME)
def store(self, predictions: List[Prediction]) -> None:
for prediction in predictions:
prediction_dict = asdict(prediction)
prediction_dict["created_date"] = (
prediction.created_date.isoformat()
)
self.storage_service.write_text(
data=json.dumps(prediction_dict, indent=4),
file_path=Path(
self.STORAGE_PATH.format(
job_id=prediction.job_id, file_id=prediction.file_id
)
),
)
def load_predictions(self, job_id: int) -> List[Prediction]:
prediction_file_paths = self.storage_service.list_files(
Path(f"{self.VOLUME_NAME}/{job_id}")
)
predictions = []
for file_path in prediction_file_paths:
prediction = json.loads(self.storage_service.read_text(file_path))
predictions.append(
Prediction(
job_id=prediction["job_id"],
file_id=prediction["file_id"],
ground_truth_revision_id=prediction[
"ground_truth_revision_id"
],
model_params=ModelParams(
model_name=prediction["model_params"]["model_name"],
model_version=prediction["model_params"][
"model_version"
],
temperature=prediction["model_params"]["temperature"],
prompt=prediction["model_params"]["prompt"],
json_schema=(
json.loads(
prediction["model_params"]["json_schema"]
)
if prediction["model_params"]["json_schema"]
else None
),
),
prediction_result=prediction["prediction_result"],
created_date=datetime.fromisoformat(
prediction["created_date"]
),
)
)
return predictions
class PermanentStorage:
TABLE = "predictions"
COLUMNS = {
"job_id": "INT",
"file_id": "INT",
"revision_id": "STRING",
"model_parameters": "STRING",
"prediction_results": "STRING",
"create_date": "TIMESTAMP",
}
def __init__(self, db_service: SparkDBService):
self.db_service = db_service
self.db_service.create_table_if_not_exists(self.TABLE, self.COLUMNS)
def store(self, predictions: List[Prediction]) -> None:
for prediction in predictions:
self.db_service.insert_table(
self.TABLE,
[
prediction.job_id,
prediction.file_id,
prediction.ground_truth_revision_id,
json.dumps(asdict(prediction.model_params)),
json.dumps(prediction.prediction_result),
prediction.created_date,
],
)
def load_by_job_id(
self,
job_id: str,
file_id: Optional[str] = None,
model_params_run_name: Optional[str] = None,
) -> None:
# return ordered by created_date ascending
pass