diff options
author | Gergely Nagy <gergely.nagy@arm.com> | 2023-11-21 12:29:38 +0000 |
---|---|---|
committer | Gergely Nagy <gergely.nagy@arm.com> | 2023-12-07 17:09:31 +0000 |
commit | 54eec806272b7574a0757c77a913a369a9ecdc70 (patch) | |
tree | 2e6484b857b2a68279a2707dbb21e5c26685f4e2 /src/mlia/nn/tensorflow/config.py | |
parent | 7c50f1d6367186c03a282ac7ecb8fca0f905ba30 (diff) | |
download | mlia-54eec806272b7574a0757c77a913a369a9ecdc70.tar.gz |
MLIA-835 Invalid JSON output
TFLiteConverter was producing log messages in the output that was not
possible to capture and redirect to logging.
The solution/workaround is to run it as a subprocess.
This change required some refactoring around existing invocations of
the converter.
Change-Id: I394bd0d49d36e6686cfcb9d658e4aad05326cb87
Signed-off-by: Gergely Nagy <gergely.nagy@arm.com>
Diffstat (limited to 'src/mlia/nn/tensorflow/config.py')
-rw-r--r-- | src/mlia/nn/tensorflow/config.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/src/mlia/nn/tensorflow/config.py b/src/mlia/nn/tensorflow/config.py index b94350a..0a17977 100644 --- a/src/mlia/nn/tensorflow/config.py +++ b/src/mlia/nn/tensorflow/config.py @@ -21,14 +21,13 @@ from mlia.nn.tensorflow.optimizations.quantization import dequantize from mlia.nn.tensorflow.optimizations.quantization import is_quantized from mlia.nn.tensorflow.optimizations.quantization import QuantizationParameters from mlia.nn.tensorflow.optimizations.quantization import quantize +from mlia.nn.tensorflow.tflite_convert import convert_to_tflite from mlia.nn.tensorflow.tflite_graph import load_fb from mlia.nn.tensorflow.tflite_graph import save_fb from mlia.nn.tensorflow.utils import check_tflite_datatypes -from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_saved_model from mlia.nn.tensorflow.utils import is_tflite_model -from mlia.nn.tensorflow.utils import save_tflite_model from mlia.utils.logging import log_action logger = logging.getLogger(__name__) @@ -67,9 +66,14 @@ class KerasModel(ModelConfiguration): ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" with log_action("Converting Keras to TensorFlow Lite ..."): - converted_model = convert_to_tflite(self.get_keras_model(), quantized) + convert_to_tflite( + self.get_keras_model(), + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + subprocess=True, + ) - save_tflite_model(converted_model, tflite_model_path) logger.debug( "Model %s converted and saved to %s", self.model_path, tflite_model_path ) @@ -270,8 +274,12 @@ class TfModel(ModelConfiguration): # pylint: disable=abstract-method self, tflite_model_path: str | Path, quantized: bool = False ) -> TFLiteModel: """Convert model to TensorFlow Lite format.""" - converted_model = convert_to_tflite(self.model_path, quantized) - save_tflite_model(converted_model, tflite_model_path) + convert_to_tflite( + self.model_path, + quantized, + input_path=Path(self.model_path), + output_path=Path(tflite_model_path), + ) return TFLiteModel(tflite_model_path) |