diff options
Diffstat (limited to 'tests/test_nn_tensorflow_utils.py')
-rw-r--r-- | tests/test_nn_tensorflow_utils.py | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/tests/test_nn_tensorflow_utils.py b/tests/test_nn_tensorflow_utils.py index 14b06c4..dab8b4e 100644 --- a/tests/test_nn_tensorflow_utils.py +++ b/tests/test_nn_tensorflow_utils.py @@ -1,14 +1,17 @@ # SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Test for module utils/test_utils.""" +import re from pathlib import Path import numpy as np import pytest import tensorflow as tf +from mlia.nn.tensorflow.utils import check_tflite_datatypes from mlia.nn.tensorflow.utils import convert_to_tflite from mlia.nn.tensorflow.utils import get_tf_tensor_shape +from mlia.nn.tensorflow.utils import get_tflite_model_type_map from mlia.nn.tensorflow.utils import is_keras_model from mlia.nn.tensorflow.utils import is_tflite_model from mlia.nn.tensorflow.utils import representative_dataset @@ -109,3 +112,31 @@ def test_is_keras_model(model_path: Path, expected_result: bool) -> None: def test_get_tf_tensor_shape(test_tf_model: Path) -> None: """Test get_tf_tensor_shape with test model.""" assert get_tf_tensor_shape(str(test_tf_model)) == [1, 28, 28, 1] + + +def test_tflite_model_type_map( + test_tflite_model_fp32: Path, test_tflite_model: Path +) -> None: + """Test the model type map function.""" + assert get_tflite_model_type_map(test_tflite_model_fp32) == { + "serving_default_input:0": np.float32 + } + assert get_tflite_model_type_map(test_tflite_model) == { + "serving_default_input:0": np.int8 + } + + +def test_check_tflite_datatypes( + test_tflite_model_fp32: Path, test_tflite_model: Path +) -> None: + """Test the model type map function.""" + check_tflite_datatypes(test_tflite_model_fp32, np.float32) + check_tflite_datatypes(test_tflite_model, np.int8) + + with pytest.raises( + Exception, + match=re.escape( + "unexpected data types: ['float32']. Only ['int8'] are allowed" + ), + ): + check_tflite_datatypes(test_tflite_model_fp32, np.int8) |