in basic_pitch/inference.py [0:0]
def __init__(self, model_path: Union[pathlib.Path, str]):
present = []
if TF_PRESENT:
present.append("TensorFlow")
try:
self.model_type = Model.MODEL_TYPES.TENSORFLOW
self.model = tf.saved_model.load(str(model_path))
return
except Exception as e:
if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.listdir(model_path)):
logging.warning(
"Could not load TensorFlow saved model %s even "
"though it looks like a saved model file with error %s. "
"Are you sure it's a TensorFlow saved model?",
model_path,
e.__repr__(),
)
if CT_PRESENT:
present.append("CoreML")
try:
self.model_type = Model.MODEL_TYPES.COREML
self.model = ct.models.MLModel(str(model_path), compute_units=ct.ComputeUnit.CPU_ONLY)
return
except Exception as e:
if str(model_path).endswith(".mlpackage"):
logging.warning(
"Could not load CoreML file %s even "
"though it looks like a CoreML file with error %s. "
"Are you sure it's a CoreML file?",
model_path,
e.__repr__(),
)
if TFLITE_PRESENT or TF_PRESENT:
present.append("TensorFlowLite")
try:
self.model_type = Model.MODEL_TYPES.TFLITE
self.interpreter = tflite.Interpreter(str(model_path))
self.model = self.interpreter.get_signature_runner()
return
except Exception as e:
if str(model_path).endswith(".tflite"):
logging.warning(
"Could not load TensorFlowLite file %s even "
"though it looks like a TFLite file with error %s. "
"Are you sure it's a TFLite file?",
model_path,
e.__repr__(),
)
if ONNX_PRESENT:
present.append("ONNX")
try:
self.model_type = Model.MODEL_TYPES.ONNX
providers = ["CPUExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers():
providers.insert(0, "CUDAExecutionProvider")
self.model = ort.InferenceSession(str(model_path), providers=providers)
return
except Exception as e:
if str(model_path).endswith(".onnx"):
logging.warning(
"Could not load ONNX file %s even "
"though it looks like a ONNX file with error %s. "
"Are you sure it's a ONNX file?",
model_path,
e.__repr__(),
)
raise ValueError(
f"File {model_path} cannot be loaded into either "
"TensorFlow, CoreML, TFLite or ONNX. "
"Please check if it is a supported and valid serialized model "
"and that one of these packages are installed. On this system, "
f"{present} is installed."
)