diff options
Diffstat (limited to 'src/mlia/target/ethos_u/advice_generation.py')
-rw-r--r-- | src/mlia/target/ethos_u/advice_generation.py | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/src/mlia/target/ethos_u/advice_generation.py b/src/mlia/target/ethos_u/advice_generation.py index edd78fd..daae4f4 100644 --- a/src/mlia/target/ethos_u/advice_generation.py +++ b/src/mlia/target/ethos_u/advice_generation.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Ethos-U advice generation.""" from __future__ import annotations @@ -26,7 +26,7 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): """Produce advice.""" @produce_advice.register - @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_cpu_only_ops(self, data_item: HasCPUOnlyOperators) -> None: """Advice for CPU only operators.""" cpu_only_ops = ",".join(sorted(set(data_item.cpu_only_ops))) @@ -40,11 +40,10 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): "Using operators that are supported by the NPU will " "improve performance.", ] - + self.context.action_resolver.supported_operators_info() ) @produce_advice.register - @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_unsupported_operators( self, data_item: HasUnsupportedOnNPUOperators ) -> None: @@ -60,21 +59,25 @@ class EthosUAdviceProducer(FactBasedAdviceProducer): ) @produce_advice.register - @advice_category(AdviceCategory.OPERATORS, AdviceCategory.ALL) + @advice_category(AdviceCategory.COMPATIBILITY) def handle_all_operators_supported( self, _data_item: AllOperatorsSupportedOnNPU ) -> None: """Advice if all operators supported.""" - self.add_advice( - [ - "You don't have any unsupported operators, your model will " - "run completely on NPU." - ] - + self.context.action_resolver.check_performance() - ) + advice = [ + "You don't have any unsupported operators, your model will " + "run completely on NPU." + ] + if self.context.advice_category != ( + AdviceCategory.COMPATIBILITY, + AdviceCategory.PERFORMANCE, + ): + advice += self.context.action_resolver.check_performance() + + self.add_advice(advice) @produce_advice.register - @advice_category(AdviceCategory.OPTIMIZATION, AdviceCategory.ALL) + @advice_category(AdviceCategory.OPTIMIZATION) def handle_optimization_results(self, data_item: OptimizationResults) -> None: """Advice based on optimization results.""" if not data_item.diffs or len(data_item.diffs) != 1: @@ -202,5 +205,6 @@ class EthosUStaticAdviceProducer(ContextAwareAdviceProducer): ) ], } - - return advice_per_category.get(self.context.advice_category, []) + if len(self.context.advice_category) == 1: + return advice_per_category.get(list(self.context.advice_category)[0], []) + return [] |