aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/mlia/nn/tensorflow/utils.py')
-rw-r--r--src/mlia/nn/tensorflow/utils.py30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py
index 77ac529..b8d45c6 100644
--- a/src/mlia/nn/tensorflow/utils.py
+++ b/src/mlia/nn/tensorflow/utils.py
@@ -123,3 +123,33 @@ def get_tflite_converter(
converter.inference_output_type = tf.int8
return converter
+
+
+def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
+ """Get type map from tflite model."""
+ model_type_map: dict[str, Any] = {}
+ interpreter = tf.lite.Interpreter(str(model_filename))
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+ model_type_map = {
+ input_detail["name"]: input_detail["dtype"] for input_detail in input_details
+ }
+ return model_type_map
+
+
+def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
+ """Check if the model only has the given allowed datatypes."""
+ type_map = get_tflite_model_type_map(model_filename)
+ types = set(type_map.values())
+ allowed = set(allowed_types)
+ unexpected = types - allowed
+
+ def cls_to_str(types: set[type]) -> list[str]:
+ return [t.__name__ for t in types]
+
+ if len(unexpected) > 0:
+ raise TypeError(
+ f"Model {model_filename} has "
+ f"unexpected data types: {cls_to_str(unexpected)}. "
+ f"Only {cls_to_str(allowed)} are allowed."
+ )