aboutsummaryrefslogtreecommitdiff
path: root/tests/mlia/test_core_performance.py
blob: 0d28fe85e890e9c8b72db252c95cea4976e0265d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for the module performance."""
from pathlib import Path

from mlia.core.performance import estimate_performance
from mlia.core.performance import PerformanceEstimator


def test_estimate_performance(tmp_path: Path) -> None:
    """Test function estimate_performance."""
    model_path = tmp_path / "original.tflite"

    class SampleEstimator(PerformanceEstimator[Path, int]):
        """Sample estimator."""

        def estimate(self, model: Path) -> int:
            """Estimate performance."""
            if model.name == "original.tflite":
                return 1

            return 2

    def optimized_model(_original: Path) -> Path:
        """Return path to the 'optimized' model."""
        return tmp_path / "optimized.tflite"

    results = estimate_performance(model_path, SampleEstimator(), [optimized_model])
    assert results == [1, 2]