aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/nn/tensorflow/config.py
diff options
context:
space:
mode:
authorDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-08 14:24:39 +0100
committerDmitrii Agibov <dmitrii.agibov@arm.com>2022-09-09 17:21:48 +0100
commitf5b293d0927506c2a979a091bf0d07ecc78fa181 (patch)
tree4de585b7cb6ed34da8237063752270189a730a41 /src/mlia/nn/tensorflow/config.py
parentcde0c6ee140bd108849bff40467d8f18ffc332ef (diff)
downloadmlia-f5b293d0927506c2a979a091bf0d07ecc78fa181.tar.gz
MLIA-386 Simplify typing in the source code
- Enable deferred annotations evaluation - Use builtin types for type hints whenever possible - Use | syntax for union types - Rename mlia.core._typing into mlia.core.typing Change-Id: I3f6ffc02fa069c589bdd9e8bddbccd504285427a
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r--src/mlia/nn/tensorflow/config.py36
1 files changed, 18 insertions, 18 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py
index d3235d7..6ee32e7 100644
--- a/src/mlia/nn/tensorflow/config.py
+++ b/src/mlia/nn/tensorflow/config.py
@@ -1,12 +1,12 @@
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Model configuration."""
+from __future__ import annotations
+
import logging
from pathlib import Path
from typing import cast
-from typing import Dict
from typing import List
-from typing import Union
import tensorflow as tf
@@ -24,17 +24,17 @@ logger = logging.getLogger(__name__)
class ModelConfiguration:
"""Base class for model configuration."""
- def __init__(self, model_path: Union[str, Path]) -> None:
+ def __init__(self, model_path: str | Path) -> None:
"""Init model configuration instance."""
self.model_path = str(model_path)
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
raise NotImplementedError()
- def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
"""Convert model to Keras format."""
raise NotImplementedError()
@@ -50,8 +50,8 @@ class KerasModel(ModelConfiguration):
return tf.keras.models.load_model(self.model_path)
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
logger.info("Converting Keras to TFLite ...")
@@ -65,7 +65,7 @@ class KerasModel(ModelConfiguration):
return TFLiteModel(tflite_model_path)
- def convert_to_keras(self, keras_model_path: Union[str, Path]) -> "KerasModel":
+ def convert_to_keras(self, keras_model_path: str | Path) -> KerasModel:
"""Convert model to Keras format."""
return self
@@ -73,14 +73,14 @@ class KerasModel(ModelConfiguration):
class TFLiteModel(ModelConfiguration): # pylint: disable=abstract-method
"""TFLite model configuration."""
- def input_details(self) -> List[Dict]:
+ def input_details(self) -> list[dict]:
"""Get model's input details."""
interpreter = tf.lite.Interpreter(model_path=self.model_path)
- return cast(List[Dict], interpreter.get_input_details())
+ return cast(List[dict], interpreter.get_input_details())
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
return self
@@ -92,8 +92,8 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
"""
def convert_to_tflite(
- self, tflite_model_path: Union[str, Path], quantized: bool = False
- ) -> "TFLiteModel":
+ self, tflite_model_path: str | Path, quantized: bool = False
+ ) -> TFLiteModel:
"""Convert model to TFLite format."""
converted_model = convert_tf_to_tflite(self.model_path, quantized)
save_tflite_model(converted_model, tflite_model_path)
@@ -101,7 +101,7 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method
return TFLiteModel(tflite_model_path)
-def get_model(model: Union[Path, str]) -> "ModelConfiguration":
+def get_model(model: str | Path) -> ModelConfiguration:
"""Return the model object."""
if is_tflite_model(model):
return TFLiteModel(model)
@@ -118,7 +118,7 @@ def get_model(model: Union[Path, str]) -> "ModelConfiguration":
)
-def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
+def get_tflite_model(model: str | Path, ctx: Context) -> TFLiteModel:
"""Convert input model to TFLite and returns TFLiteModel object."""
tflite_model_path = ctx.get_model_path("converted_model.tflite")
converted_model = get_model(model)
@@ -126,7 +126,7 @@ def get_tflite_model(model: Union[str, Path], ctx: Context) -> "TFLiteModel":
return converted_model.convert_to_tflite(tflite_model_path, True)
-def get_keras_model(model: Union[str, Path], ctx: Context) -> "KerasModel":
+def get_keras_model(model: str | Path, ctx: Context) -> KerasModel:
"""Convert input model to Keras and returns KerasModel object."""
keras_model_path = ctx.get_model_path("converted_model.h5")
converted_model = get_model(model)