diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/utils.py')
-rw-r--r-- | src/mlia/nn/tensorflow/utils.py | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/src/mlia/nn/tensorflow/utils.py b/src/mlia/nn/tensorflow/utils.py index 77ac529..b8d45c6 100644 --- a/src/mlia/nn/tensorflow/utils.py +++ b/src/mlia/nn/tensorflow/utils.py @@ -123,3 +123,33 @@ def get_tflite_converter( converter.inference_output_type = tf.int8 return converter + + +def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]: + """Get type map from tflite model.""" + model_type_map: dict[str, Any] = {} + interpreter = tf.lite.Interpreter(str(model_filename)) + interpreter.allocate_tensors() + input_details = interpreter.get_input_details() + model_type_map = { + input_detail["name"]: input_detail["dtype"] for input_detail in input_details + } + return model_type_map + + +def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None: + """Check if the model only has the given allowed datatypes.""" + type_map = get_tflite_model_type_map(model_filename) + types = set(type_map.values()) + allowed = set(allowed_types) + unexpected = types - allowed + + def cls_to_str(types: set[type]) -> list[str]: + return [t.__name__ for t in types] + + if len(unexpected) > 0: + raise TypeError( + f"Model {model_filename} has " + f"unexpected data types: {cls_to_str(unexpected)}. " + f"Only {cls_to_str(allowed)} are allowed." + ) |