aboutsummaryrefslogtreecommitdiff
path: root/tests/test_nn_tensorflow_tflite_metrics.py
diff options
context:
space:
mode:
authorBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-11 12:33:42 +0100
committerBenjamin Klimczak <benjamin.klimczak@arm.com>2022-07-26 14:08:21 +0100
commit5d81f37de09efe10f90512e50252be9c36925fcf (patch)
treeb4d7cdfd051da0a6e882bdfcf280fd7ca7b39e57 /tests/test_nn_tensorflow_tflite_metrics.py
parent7899b908c1fe6d86b92a80f3827ddd0ac05b674b (diff)
downloadmlia-5d81f37de09efe10f90512e50252be9c36925fcf.tar.gz
MLIA-551 Rework remains of AIET architecture
Re-factoring the code base to further merge the old AIET code into MLIA. - Remove last traces of the backend type 'tool' - Controlled systems removed, including SSH protocol, controller, RunningCommand, locks etc. - Build command / build dir and deploy functionality removed from Applications and Systems - Moving working_dir() - Replace module 'output_parser' with new module 'output_consumer' and merge Base64 parsing into it - Change the output consumption to optionally remove (i.e. actually consume) lines - Use Base64 parsing in GenericInferenceOutputParser, replacing the regex-based parsing and remove the now unused regex parsing - Remove AIET reporting - Pre-install applications by moving them to src/mlia/resources/backends - Rename aiet-config.json to backend-config.json - Move tests from tests/mlia/ to tests/ - Adapt unit tests to code changes - Dependencies removed: paramiko, filelock, psutil - Fix bug in corstone.py: The wrong resource directory was used which broke the functionality to download backends. - Use f-string formatting. - Use logging instead of print. Change-Id: I768bc3bb6b2eda57d219ad01be4a8e0a74167d76
Diffstat (limited to 'tests/test_nn_tensorflow_tflite_metrics.py')
-rw-r--r--tests/test_nn_tensorflow_tflite_metrics.py133
1 files changed, 133 insertions, 0 deletions
diff --git a/tests/test_nn_tensorflow_tflite_metrics.py b/tests/test_nn_tensorflow_tflite_metrics.py
new file mode 100644
index 0000000..00eacef
--- /dev/null
+++ b/tests/test_nn_tensorflow_tflite_metrics.py
@@ -0,0 +1,133 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Test for module utils/tflite_metrics."""
+import os
+import tempfile
+from math import isclose
+from pathlib import Path
+from typing import Generator
+from typing import List
+
+import numpy as np
+import pytest
+import tensorflow as tf
+
+from mlia.nn.tensorflow.tflite_metrics import ReportClusterMode
+from mlia.nn.tensorflow.tflite_metrics import TFLiteMetrics
+
+
+def _dummy_keras_model() -> tf.keras.Model:
+ # Create a dummy model
+ keras_model = tf.keras.Sequential(
+ [
+ tf.keras.Input(shape=(8, 8, 3)),
+ tf.keras.layers.Conv2D(4, 3),
+ tf.keras.layers.DepthwiseConv2D(3),
+ tf.keras.layers.Flatten(),
+ tf.keras.layers.Dense(8),
+ ]
+ )
+ return keras_model
+
+
+def _sparse_binary_keras_model() -> tf.keras.Model:
+ def get_sparse_weights(shape: List[int]) -> np.ndarray:
+ weights = np.zeros(shape)
+ with np.nditer(weights, op_flags=["writeonly"]) as weight_iterator:
+ for idx, value in enumerate(weight_iterator):
+ if idx % 2 == 0:
+ value[...] = 1.0
+ return weights
+
+ keras_model = _dummy_keras_model()
+ # Assign weights to have 0.5 sparsity
+ for layer in keras_model.layers:
+ if not isinstance(layer, tf.keras.layers.Flatten):
+ weight = layer.weights[0]
+ weight.assign(get_sparse_weights(weight.shape))
+ print(layer)
+ print(weight.numpy())
+ return keras_model
+
+
+@pytest.fixture(scope="class", name="tflite_file")
+def fixture_tflite_file() -> Generator:
+ """Generate temporary TFLite file for tests."""
+ converter = tf.lite.TFLiteConverter.from_keras_model(_sparse_binary_keras_model())
+ tflite_model = converter.convert()
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ file = os.path.join(tmp_dir, "test.tflite")
+ Path(file).write_bytes(tflite_model)
+ yield file
+
+
+@pytest.fixture(scope="function", name="metrics")
+def fixture_metrics(tflite_file: str) -> TFLiteMetrics:
+ """Generate metrics file for a given TFLite model."""
+ return TFLiteMetrics(tflite_file)
+
+
+class TestTFLiteMetrics:
+ """Tests for module TFLite_metrics."""
+
+ @staticmethod
+ def test_sparsity(metrics: TFLiteMetrics) -> None:
+ """Test sparsity."""
+ # Create new instance with a dummy TFLite file
+ # Check sparsity calculation
+ sparsity_per_layer = metrics.sparsity_per_layer()
+ for name, sparsity in sparsity_per_layer.items():
+ assert isclose(sparsity, 0.5), f"Layer '{name}' has incorrect sparsity."
+ assert isclose(metrics.sparsity_overall(), 0.5)
+
+ @staticmethod
+ def test_clusters(metrics: TFLiteMetrics) -> None:
+ """Test clusters."""
+ # NUM_CLUSTERS_PER_AXIS and NUM_CLUSTERS_MIN_MAX can be handled together
+ for mode in [
+ ReportClusterMode.NUM_CLUSTERS_PER_AXIS,
+ ReportClusterMode.NUM_CLUSTERS_MIN_MAX,
+ ]:
+ num_unique_weights = metrics.num_unique_weights(mode)
+ for name, num_unique_per_axis in num_unique_weights.items():
+ for num_unique in num_unique_per_axis:
+ assert (
+ num_unique == 2
+ ), f"Layer '{name}' has incorrect number of clusters."
+ # NUM_CLUSTERS_HISTOGRAM
+ hists = metrics.num_unique_weights(ReportClusterMode.NUM_CLUSTERS_HISTOGRAM)
+ assert hists
+ for name, hist in hists.items():
+ assert hist
+ for idx, num_axes in enumerate(hist):
+ # The histogram starts with the bin for for num_clusters == 1
+ num_clusters = idx + 1
+ msg = (
+ f"Histogram of layer '{name}': There are {num_axes} axes "
+ f"with {num_clusters} clusters"
+ )
+ if num_clusters == 2:
+ assert num_axes > 0, f"{msg}, but there should be at least one."
+ else:
+ assert num_axes == 0, f"{msg}, but there should be none."
+
+ @staticmethod
+ @pytest.mark.parametrize("report_sparsity", (False, True))
+ @pytest.mark.parametrize("report_cluster_mode", ReportClusterMode)
+ @pytest.mark.parametrize("max_num_clusters", (-1, 8))
+ @pytest.mark.parametrize("verbose", (False, True))
+ def test_summary(
+ tflite_file: str,
+ report_sparsity: bool,
+ report_cluster_mode: ReportClusterMode,
+ max_num_clusters: int,
+ verbose: bool,
+ ) -> None:
+ """Test the summary function."""
+ for metrics in [TFLiteMetrics(tflite_file), TFLiteMetrics(tflite_file, [])]:
+ metrics.summary(
+ report_sparsity=report_sparsity,
+ report_cluster_mode=report_cluster_mode,
+ max_num_clusters=max_num_clusters,
+ verbose=verbose,
+ )