diff options
author | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-09-08 14:24:39 +0100 |
---|---|---|
committer | Dmitrii Agibov <dmitrii.agibov@arm.com> | 2022-09-09 17:21:48 +0100 |
commit | f5b293d0927506c2a979a091bf0d07ecc78fa181 (patch) | |
tree | 4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/nn/tensorflow/config.py | |
parent | cde0c6ee140bd108849bff40467d8f18ffc332ef (diff) | |
download | mlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz |
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation
- Use builtin types for type hints whenever possible
- Use | syntax for union types
- Rename mlia.core._typing into mlia.core.typing
Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 36 |
1 files changed, 18 insertions, 18 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index d3235d7..6ee32e7 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -1,12 +1,12 @@ # SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Model configuration.""" +from __future__ import annotations + import logging from pathlib import Path from typing import cast -from typing import Dict from typing import List -from typing import Union import tensorflow as tf @@ -24,17 +24,17 @@ logger = logging.getLogger(__name__) class ModelConfiguration: """Base class for model configuration.""" - def __init__(self, model_path: Union[str, Path]) -> None: + def __init__(self, model_path: str | Path) -> None: """Init model configuration instance.""" self.model_path = str(model_path) def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" raise NotImplementedError() - def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel: """Convert model to Keras format.""" raise NotImplementedError() @@ -50,8 +50,8 @@ class KerasModel(ModelConfiguration): return tf.keras.models.load_model(self.model_path) def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" logger.info("Converting Keras to TFLite ...") @@ -65,7 +65,7 @@ class KerasModel(ModelConfiguration): return TFLiteModel(tflite_model_path) - def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel": + def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel: """Convert model to Keras format.""" return self @@ -73,14 +73,14 @@ class KerasModel(ModelConfiguration): class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method """TFLite model configuration.""" - def input_details(self) -> List[Dict]: + def input_details(self) -> list[dict]: """Get model's input details.""" interpreter = tf.lite.Interpreter(model_path=self.model_path) - return cast(List[Dict], interpreter.get_input_details()) + return cast(List[dict], interpreter.get_input_details()) def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" return self @@ -92,8 +92,8 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method """ def convert_to_tflite( - self, tflite_model_path: Union[str, Path], quantized: bool = False - ) -> "TFLiteModel": + self, tflite_model_path: str | Path, quantized: bool = False + ) -> TFLiteModel: """Convert model to TFLite format.""" converted_model = convert_tf_to_tflite(self.model_path, quantized) save_tflite_model(converted_model, tflite_model_path) @@ -101,7 +101,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method return TFLiteModel(tflite_model_path) -def get_model(model: Union[Path, str]) -> "ModelConfiguration": +def get_model(model: str | Path) -> ModelConfiguration: """Return the model object.""" if is_tflite_model(model): return TFLiteModel(model) @@ -118,7 +118,7 @@ def get_model(model: Union[Path, str]) -> "ModelConfiguration": ) -def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": +def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel: """Convert input model to TFLite and returns TFLiteModel object.""" tflite_model_path = ctx.get_model_path("converted_model.tflite") converted_model = get_model(model) @@ -126,7 +126,7 @@ def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel": return converted_model.convert_to_tflite(tflite_model_path, True) -def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel": +def get_keras_model(model: str | Path, ctx: Context) -> KerasModel: """Convert input model to Keras and returns KerasModel object.""" keras_model_path = ctx.get_model_path("converted_model.h5") converted_model = get_model(model) |