diff options
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index b94350a..0a17977 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -21,14 +21,13 @@ from mlia.nn.tensorflow.optimizations.quantization import dequantize from mlia.nn.tensorflow.optimizations.quantization import is_quantized from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters from mlia.nn.tensorflow.optimizations.quantization import quantize +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import save_tflite_model from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -67,9 +66,14 @@ class KerasModel(ModelConfiguration): ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" with log_action("Converting Keras to TensorFlow Lite ..."): - converted_model = convert_to_tflite(self.get_keras_model(), quantized) + convert_to_tflite( + self.get_keras_model(), + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + subprocess=True, + ) - save_tflite_model(converted_model, tflite_model_path) logger.debug( "Model %s converted and saved to %s", self.model_path, tflite_model_path ) @@ -270,8 +274,12 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_to_tflite(self.model_path, quantized) - save_tflite_model(converted_model, tflite_model_path) + convert_to_tflite( + self.model_path, + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + ) return TFLiteModel(tflite_model_path) |