From 0efca3cadbad5517a59884576ddb90cfe7ac30f8 Mon Sep 17 00:00:00 2001 From: Diego Russo Date: Mon, 30 May 2022 13:34:14 +0100 Subject: Add MLIA codebase Add MLIA codebase including sources and tests. Change-Id: Id41707559bd721edd114793618d12ccd188d8dbd --- src/mlia/nn/tensorflow/config.py | 134 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 src/mlia/nn/tensorflow/config.py (limited to 'src/mlia/nn/tensorflow/config.py') diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py new file mode 100644 index 0000000..d3235d7 --- /dev/null +++ b/src/mlia/nn/tensorflow/config.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-License-Identifier: Apache-2.0 +"""Model configuration.""" +import logging +from pathlib import Path +from typing import cast +from typing import Dict +from typing import List +from typing import Union + +import tensorflow as tf + +from mlia.core.context import Context +from mlia.nn.tensorflow.utils import convert_tf_to_tflite +from mlia.nn.tensorflow.utils import convert_to_tflite +from mlia.nn.tensorflow.utils import is_keras_model +from mlia.nn.tensorflow.utils import is_tf_model +from mlia.nn.tensorflow.utils import is_tflite_model +from mlia.nn.tensorflow.utils import save_tflite_model + +logger = logging.getLogger(__name__) + + +class ModelConfiguration: + """Base class for model configuration.""" + + def __init__(self, model_path: Union[str, Path]) -> None: + """Init model configuration instance.""" + self.model_path = str(model_path) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + raise NotImplementedError() + + def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + """Convert model to Keras format.""" + raise NotImplementedError() + + +class KerasModel(ModelConfiguration): + """Keras model configuration. + + Supports all models supported by Keras API: saved model, H5, HDF5 + """ + + def get_keras_model(self) -> tf.keras.Model: + """Return associated Keras model.""" + return tf.keras.models.load_model(self.model_path) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + logger.info("Converting Keras to TFLite ...") + + converted_model = convert_to_tflite(self.get_keras_model(), quantized) + logger.info("Done\n") + + save_tflite_model(converted_model, tflite_model_path) + logger.debug( + "Model %s converted and saved to %s", self.model_path, tflite_model_path + ) + + return TFLiteModel(tflite_model_path) + + def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + """Convert model to Keras format.""" + return self + + +class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method + """TFLite model configuration.""" + + def input_details(self) -> List[Dict]: + """Get model's input details.""" + interpreter = tf.lite.Interpreter(model_path=self.model_path) + return cast(List[Dict], interpreter.get_input_details()) + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + return self + + +class TfModel(ModelConfiguration): # pylint: disable=abstract-method + """TensorFlow model configuration. + + Supports models supported by TensorFlow API (not Keras) + """ + + def convert_to_tflite( + self, tflite_model_path: Union[str, Path], quantized: bool = False + ) -> "TFLiteModel": + """Convert model to TFLite format.""" + converted_model = convert_tf_to_tflite(self.model_path, quantized) + save_tflite_model(converted_model, tflite_model_path) + + return TFLiteModel(tflite_model_path) + + +def get_model(model: Union[Path, str]) -> "ModelConfiguration": + """Return the model object.""" + if is_tflite_model(model): + return TFLiteModel(model) + + if is_keras_model(model): + return KerasModel(model) + + if is_tf_model(model): + return TfModel(model) + + raise Exception( + "The input model format is not supported" + "(supported formats: TFLite, Keras, TensorFlow saved model)!" + ) + + +def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": + """Convert input model to TFLite and returns TFLiteModel object.""" + tflite_model_path = ctx.get_model_path("converted_model.tflite") + converted_model = get_model(model) + + return converted_model.convert_to_tflite(tflite_model_path, True) + + +def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel": + """Convert input model to Keras and returns KerasModel object.""" + keras_model_path = ctx.get_model_path("converted_model.h5") + converted_model = get_model(model) + + return converted_model.convert_to_keras(keras_model_path) -- cgit v1.2.1