aboutsummaryrefslogtreecommitdiff
path: root/src/mlia/core/performance.py
blob: 5433d5c8626097c265427938be5cc131444bba04 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for performance estimation."""
from abc import abstractmethod
from typing import Callable
from typing import Generic
from typing import List
from typing import TypeVar


ModelType = TypeVar("ModelType")  # pylint: disable=invalid-name
PerfMetricsType = TypeVar("PerfMetricsType")  # pylint: disable=invalid-name


class PerformanceEstimator(Generic[ModelType, PerfMetricsType]):
    """Base class for the performance estimation."""

    @abstractmethod
    def estimate(self, model: ModelType) -> PerfMetricsType:
        """Estimate performance."""


def estimate_performance(
    original_model: ModelType,
    estimator: PerformanceEstimator[ModelType, PerfMetricsType],
    model_transformations: List[Callable[[ModelType], ModelType]],
) -> List[PerfMetricsType]:
    """Estimate performance impact.

    This function estimates performance impact on model performance after
    applying provided transformations/optimizations.

    :param original_model: object that represents a model, could be
           instance of the model or path to the model. This depends on
           provided performance estimator.
    :param estimator: performance estimator
    :param model_transformations: list of the callables each of those
           returns object that represents optimized model
    """
    original_metrics = estimator.estimate(original_model)

    optimized_metrics = [
        estimator.estimate(transform(original_model))
        for transform in model_transformations
    ]

    return [original_metrics, *optimized_metrics]