databricks/notebooks/gpt_prediction.py (118 lines of code) (raw):

# Databricks notebook source import json from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from pathlib import Path from typing import Any, Dict, List import lib.spark_helper.predictions as predictions_helper from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import SecretStr from langchain_openai import AzureChatOpenAI from lib.repository.configs.service import load_config from lib.repository.ground_truth.helpers import GroundTruthHelper from lib.spark_helper.files import FilesStorage from lib.spark_helper.storage_service import SparkStorageService from databricks.sdk.runtime import dbutils # COMMAND ---------- system_prompt = [ ( "system", "You are an assistant that extracts product information", ), ("human", "{text}"), ] json_schema = { "title": "ProductInfo", "description": "Information about a product", "type": "object", "properties": { "price": { "type": "string", "description": "Price of a product in USD", }, }, } # COMMAND ---------- def get_model( credentials: Dict[str, str], parameters: predictions_helper.ModelParams ) -> AzureChatOpenAI: return AzureChatOpenAI( azure_endpoint=credentials["azure_endpoint"], api_key=SecretStr(credentials["api_key"]), azure_deployment=parameters.model_name, api_version=parameters.model_version, temperature=parameters.temperature if parameters.temperature else 1.0, ) def predict( model: AzureChatOpenAI, text: str, output_schema: Dict[Any, Any] ) -> Any: prompt = ChatPromptTemplate.from_messages(system_prompt) runnable = prompt | model.with_structured_output(schema=output_schema) predictions = runnable.invoke({"text": text}) return predictions # COMMAND ---------- configs = load_config(project_name=dbutils.widgets.get("project_name")) storage_service = SparkStorageService(configs) def predict_file( model: AzureChatOpenAI, file: Dict[Any, Any], output_schema: Dict[Any, Any] ) -> Any: file_id = file["file_id"] print(f"Predicting file: {file_id}") text = storage_service.read_text( Path(FilesStorage.TXT_STORAGE_PATH.format(file_id=file_id)) ) prediction = predict(model, text, output_schema) return prediction def predict_files_parallel( model: AzureChatOpenAI, files: List[Dict[Any, Any]], output_schema: Dict[Any, Any], ) -> List[Dict[Any, Any]]: predictions = [] with ThreadPoolExecutor(max_workers=20) as executor: future_to_file = { executor.submit(predict_file, model, file, output_schema): file for file in files } for future in as_completed(future_to_file): file = future_to_file[future] prediction = future.result() predictions.append( { "file_id": int(file["file_id"]), "prediction": prediction, } ) return predictions # COMMAND ---------- secrets_scope = dbutils.widgets.get("secrets_scope") model_credentials = { "azure_endpoint": dbutils.secrets.get( scope=secrets_scope, key="gpt_endpoint" ), "api_key": dbutils.secrets.get( scope=secrets_scope, key="azure_openai_api_key" ), } model_parameters = predictions_helper.ModelParams( model_name="gpt-4", model_version="2023-12-01-preview", temperature=0, prompt=system_prompt, json_schema=json.dumps(json_schema), ) # COMMAND ---------- helper = GroundTruthHelper(configs) job_parameters = json.loads(dbutils.widgets.get("badgerdoc_job_parameters")) files = job_parameters["files_data"] model = get_model(model_credentials, model_parameters) predicted_values = predict_files_parallel(model, files, json_schema) predictions: list[predictions_helper.Prediction] = [] for predicted_value in predicted_values: file_id = predicted_value["file_id"] predictions.append( predictions_helper.Prediction( job_id=int(job_parameters["job_id"]), file_id=file_id, ground_truth_revision_id=helper.get_latest_revision_id(file_id), model_params=model_parameters, prediction_result=predicted_value["prediction"], created_date=datetime.now(), ) ) temporary_storage = predictions_helper.TemporaryStorage(storage_service) temporary_storage.store(predictions)