diff options
Diffstat (limited to 'src/mlia/target/ethos_u/performance.py')
-rw-r--r-- | src/mlia/target/ethos_u/performance.py | 26 |
1 files changed, 10 insertions, 16 deletions
diff --git a/src/mlia/target/ethos_u/performance.py b/src/mlia/target/ethos_u/performance.py index 0d791a1..be1a287 100644 --- a/src/mlia/target/ethos_u/performance.py +++ b/src/mlia/target/ethos_u/performance.py @@ -11,20 +11,17 @@ from typing import Union import mlia.backend.vela.compiler as vela_comp import mlia.backend.vela.performance as vela_perf -from mlia.backend.corstone.performance import DeviceInfo +from mlia.backend.corstone import is_corstone_backend from mlia.backend.corstone.performance import estimate_performance -from mlia.backend.corstone.performance import ModelInfo -from mlia.backend.registry import get_supported_backends from mlia.core.context import Context from mlia.core.performance import PerformanceEstimator from mlia.nn.tensorflow.config import get_tflite_model from mlia.nn.tensorflow.config import ModelConfiguration from mlia.nn.tensorflow.optimizations.select import OptimizationSettings from mlia.target.ethos_u.config import EthosUConfiguration -from mlia.target.registry import is_supported +from mlia.target.registry import supported_backends from mlia.utils.logging import log_action - logger = logging.getLogger(__name__) @@ -186,14 +183,11 @@ class CorstonePerformanceEstimator( model_path, self.device.compiler_options, optimized_model_path ) - model_info = ModelInfo(model_path=optimized_model_path) - device_info = DeviceInfo( - device_type=self.device.target, # type: ignore - mac=self.device.mac, - ) - corstone_perf_metrics = estimate_performance( - model_info, device_info, self.backend + self.device.target, + self.device.mac, + optimized_model_path, + self.backend, ) return NPUCycles( @@ -222,11 +216,12 @@ class EthosUPerformanceEstimator( self.device = device if backends is None: backends = ["Vela"] # Only Vela is always available as default + ethos_u_backends = supported_backends(device.target) for backend in backends: - if backend != "Vela" and not is_supported(backend): + if backend != "Vela" and backend not in ethos_u_backends: raise ValueError( f"Unsupported backend '{backend}'. " - f"Only 'Vela' and {get_supported_backends()} " + f"Only 'Vela' and {ethos_u_backends} " "are supported." ) self.backends = set(backends) @@ -241,12 +236,11 @@ class EthosUPerformanceEstimator( memory_usage = None npu_cycles = None - for backend in self.backends: if backend == "Vela": vela_estimator = VelaPerformanceEstimator(self.context, self.device) memory_usage = vela_estimator.estimate(tflite_model) - elif backend in get_supported_backends(): + elif is_corstone_backend(backend): corstone_estimator = CorstonePerformanceEstimator( self.context, self.device, backend ) |