blob: cb12918cbfd55199930c8b93993e4f76abceb25f (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
|
# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Module for performance estimation."""
from __future__ import annotations
from abc import abstractmethod
from typing import Callable
from typing import Generic
from typing import TypeVar
M = TypeVar("M") # model type
P = TypeVar("P") # performance metrics
class PerformanceEstimator(Generic[M, P]):
"""Base class for the performance estimation."""
@abstractmethod
def estimate(self, model: M) -> P:
"""Estimate performance."""
def estimate_performance(
original_model: M,
estimator: PerformanceEstimator[M, P],
model_transformations: list[Callable[[M], M]],
) -> list[P]:
"""Estimate performance impact.
This function estimates performance impact on model performance after
applying provided transformations/optimizations.
:param original_model: object that represents a model, could be
instance of the model or path to the model. This depends on
provided performance estimator.
:param estimator: performance estimator
:param model_transformations: list of the callables each of those
returns object that represents optimized model
"""
original_metrics = estimator.estimate(original_model)
optimized_metrics = [
estimator.estimate(transform(original_model))
for transform in model_transformations
]
return [original_metrics, *optimized_metrics]
|