diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-08-17 12:54:59 +0100 |
---|---|---|
committer | Freddie Liardet <frederick.liardet@arm.com> | 2021-09-27 16:29:06 +0000 |
commit | a71711008dad9a786a66dcd734b19cb102d65ec5 (patch) | |
tree | d452f581cedde72a61867680be4d290adb03beba /python/scripts/utils/model_identification.py | |
parent | 93d6cf0028aea111f624b320027576a26354e998 (diff) | |
download | ComputeLibrary-a71711008dad9a786a66dcd734b19cb102d65ec5.tar.gz |
Generate an operator configuration file from a list of tflite models
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I1b13da6558bd11d49747162d66c81255ccec1498
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6166
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: SiCong Li <sicong.li@arm.com>
Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'python/scripts/utils/model_identification.py')
-rw-r--r-- | python/scripts/utils/model_identification.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/python/scripts/utils/model_identification.py b/python/scripts/utils/model_identification.py new file mode 100644 index 0000000000..43e7d20f61 --- /dev/null +++ b/python/scripts/utils/model_identification.py @@ -0,0 +1,76 @@ +# Copyright (c) 2021 Arm Limited. +# +# SPDX-License-Identifier: MIT +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import logging +import os + + +def is_tflite_model(model_path): + """Check if a model is of TfLite type + + Parameters: + ---------- + model_path: str + Path to model + + Returns + ---------- + bool: + True if given path is a valid TfLite model + """ + + try: + with open(model_path, "rb") as f: + hdr_bytes = f.read(8) + hdr_str = hdr_bytes[4:].decode("utf-8") + if hdr_str == "TFL3": + return True + else: + return False + except: + return False + + +def identify_model_type(model_path): + """Identify the type of a given deep learning model + + Parameters: + ---------- + model_path: str + Path to model + + Returns + ---------- + model_type: str + String representation of model type or 'None' if type could not be retrieved. + """ + + if not os.path.exists(model_path): + logging.warn(f"Provided model {model_path} does not exist!") + return None + + if is_tflite_model(model_path): + model_type = "tflite" + else: + logging.warn(logging.warn(f"Provided model {model_path} is not of supported type!")) + model_type = None + + return model_type |