aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_core_performance.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/mlia/test_core_performance.py')
-rw-r--r--tests/mlia/test_core_performance.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/tests/mlia/test_core_performance.py b/tests/mlia/test_core_performance.py
new file mode 100644
index 0000000..0d28fe8
--- /dev/null
+++ b/tests/mlia/test_core_performance.py
@@ -0,0 +1,29 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for the module performance."""
+from pathlib import Path
+
+from mlia.core.performance import estimate_performance
+from mlia.core.performance import PerformanceEstimator
+
+
+def test_estimate_performance(tmp_path: Path) -> None:
+ """Test function estimate_performance."""
+ model_path = tmp_path / "original.tflite"
+
+ class SampleEstimator(PerformanceEstimator[Path, int]):
+ """Sample estimator."""
+
+ def estimate(self, model: Path) -> int:
+ """Estimate performance."""
+ if model.name == "original.tflite":
+ return 1
+
+ return 2
+
+ def optimized_model(_original: Path) -> Path:
+ """Return path to the 'optimized' model."""
+ return tmp_path / "optimized.tflite"
+
+ results = estimate_performance(model_path, SampleEstimator(), [optimized_model])
+ assert results == [1, 2]