diff options
Diffstat (limited to 'python/scripts/utils/model_identification.py')
-rw-r--r-- | python/scripts/utils/model_identification.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/python/scripts/utils/model_identification.py b/python/scripts/utils/model_identification.py new file mode 100644 index 0000000000..43e7d20f61 --- /dev/null +++ b/python/scripts/utils/model_identification.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 Arm Limited. +# +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import logging +import os + + +def is_tflite_model(model_path): + """Check if a model is of TfLite type + + Parameters: + ---------- + model_path: str + Path to model + + Returns + ---------- + bool: + True if given path is a valid TfLite model + """ + + try: + with open(model_path, "rb") as f: + hdr_bytes = f.read(8) + hdr_str = hdr_bytes[4:].decode("utf-8") + if hdr_str == "TFL3": + return True + else: + return False + except: + return False + + +def identify_model_type(model_path): + """Identify the type of a given deep learning model + + Parameters: + ---------- + model_path: str + Path to model + + Returns + ---------- + model_type: str + String representation of model type or 'None' if type could not be retrieved. + """ + + if not os.path.exists(model_path): + logging.warn(f"Provided model {model_path} does not exist!") + return None + + if is_tflite_model(model_path): + model_type = "tflite" + else: + logging.warn(logging.warn(f"Provided model {model_path} is not of supported type!")) + model_type = None + + return model_type |