aboutsummaryrefslogtreecommitdiff
path: root/python/scripts/utils/model_identification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/scripts/utils/model_identification.py')
-rw-r--r--python/scripts/utils/model_identification.py76
1 files changed, 76 insertions, 0 deletions
diff --git a/python/scripts/utils/model_identification.py b/python/scripts/utils/model_identification.py
new file mode 100644
index 0000000000..84a6e1a097
--- /dev/null
+++ b/python/scripts/utils/model_identification.py
@@ -0,0 +1,76 @@
+# Copyright (c) 2021 Arm Limited.
+#
+# SPDX-License-Identifier: MIT
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to
+# deal in the Software without restriction, including without limitation the
+# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+# sell copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+import logging
+import os
+
+
+def is_tflite_model(model_path):
+ """Check if a model is of TFLite type
+
+ Parameters:
+ ----------
+ model_path: str
+ Path to model
+
+ Returns
+ ----------
+ bool:
+ True if given path is a valid TFLite model
+ """
+
+ try:
+ with open(model_path, "rb") as f:
+ hdr_bytes = f.read(8)
+ hdr_str = hdr_bytes[4:].decode("utf-8")
+ if hdr_str == "TFL3":
+ return True
+ else:
+ return False
+ except:
+ return False
+
+
+def identify_model_type(model_path):
+ """Identify the type of a given deep learning model
+
+ Parameters:
+ ----------
+ model_path: str
+ Path to model
+
+ Returns
+ ----------
+ model_type: str
+ String representation of model type or 'None' if type could not be retrieved.
+ """
+
+ if not os.path.exists(model_path):
+ logging.warn(f"Provided model {model_path} does not exist!")
+ return None
+
+ if is_tflite_model(model_path):
+ model_type = "tflite"
+ else:
+ logging.warn(logging.warn(f"Provided model {model_path} is not of supported type!"))
+ model_type = None
+
+ return model_type