diff options
Diffstat (limited to 'src/mlia/target/cortex_a/data_analysis.py')
-rw-r--r-- | src/mlia/target/cortex_a/data_analysis.py | 13 |
1 files changed, 6 insertions, 7 deletions
diff --git a/src/mlia/target/cortex_a/data_analysis.py b/src/mlia/target/cortex_a/data_analysis.py index 4a3a068..089c1a2 100644 --- a/src/mlia/target/cortex_a/data_analysis.py +++ b/src/mlia/target/cortex_a/data_analysis.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates. +# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates. # SPDX-License-Identifier: Apache-2.0 """Cortex-A data analysis module.""" from __future__ import annotations @@ -13,7 +13,6 @@ from mlia.core.data_analysis import Fact from mlia.core.data_analysis import FactExtractor from mlia.nn.tensorflow.tflite_compat import TFLiteCompatibilityInfo from mlia.target.cortex_a.operators import CortexACompatibilityInfo -from mlia.target.cortex_a.operators import Operator class CortexADataAnalyzer(FactExtractor): @@ -28,7 +27,7 @@ class CortexADataAnalyzer(FactExtractor): self, data_item: CortexACompatibilityInfo ) -> None: """Analyse operator compatibility information.""" - if data_item.cortex_a_compatible: + if data_item.is_cortex_a_compatible: self.add_fact(ModelIsCortexACompatible(data_item.backend_info)) else: unsupported_ops = set() @@ -36,17 +35,17 @@ class CortexADataAnalyzer(FactExtractor): str, ModelIsNotCortexACompatible.ActivationFunctionSupport ] = defaultdict(ModelIsNotCortexACompatible.ActivationFunctionSupport) for oper in data_item.operators: - if oper.support_type == Operator.SupportType.OP_NOT_SUPPORTED: + support_type = data_item.get_support_type(oper) + if support_type == data_item.SupportType.OP_NOT_SUPPORTED: unsupported_ops.add(oper.full_name) - - if oper.support_type == Operator.SupportType.ACTIVATION_NOT_SUPPORTED: + elif support_type == data_item.SupportType.ACTIVATION_NOT_SUPPORTED: # Add used but unsupported actication functions activation_func_support[oper.full_name].used_unsupported.add( oper.activation_func.name ) # Add supported activation functions activation_func_support[oper.full_name].supported.update( - oper.supported_activation_functions + data_item.supported_activation_functions(oper) ) assert ( |