aboutsummaryrefslogtreecommitdiff
path: root/tests/test_target_cortex_a_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_target_cortex_a_operators.py')
-rw-r--r--tests/test_target_cortex_a_operators.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/tests/test_target_cortex_a_operators.py b/tests/test_target_cortex_a_operators.py
index 262ebc8..8bc48e6 100644
--- a/tests/test_target_cortex_a_operators.py
+++ b/tests/test_target_cortex_a_operators.py
@@ -1,7 +1,8 @@
-# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for Cortex-A operator compatibility."""
from pathlib import Path
+from typing import cast
import pytest
import tensorflow as tf
@@ -9,18 +10,19 @@ import tensorflow as tf
from mlia.backend.armnn_tflite_delegate import compat
from mlia.nn.tensorflow.tflite_graph import TFL_OP
from mlia.nn.tensorflow.utils import convert_to_tflite
+from mlia.target.cortex_a.config import CortexAConfiguration
from mlia.target.cortex_a.operators import CortexACompatibilityInfo
from mlia.target.cortex_a.operators import get_cortex_a_compatibility_info
-from mlia.target.cortex_a.operators import Operator
def test_compat_data() -> None:
"""Make sure all data contains the necessary items."""
builtin_tfl_ops = {op.name for op in TFL_OP}
- for data in [compat.ARMNN_TFLITE_DELEGATE]:
- assert "metadata" in data
- assert "backend" in data["metadata"]
- assert "version" in data["metadata"]
+ assert "backend" in compat.ARMNN_TFLITE_DELEGATE
+ assert "ops" in compat.ARMNN_TFLITE_DELEGATE
+
+ ops = cast(dict, compat.ARMNN_TFLITE_DELEGATE["ops"])
+ for data in ops.values():
assert "builtin_ops" in data
for comp in data["builtin_ops"]:
assert comp in builtin_tfl_ops
@@ -32,14 +34,18 @@ def check_get_cortex_a_compatibility_info(
expected_success: bool,
) -> None:
"""Check the function 'get_cortex_a_compatibility_info'."""
- compat_info = get_cortex_a_compatibility_info(model_path)
+ compat_info = get_cortex_a_compatibility_info(
+ model_path, CortexAConfiguration.load_profile("cortex-a")
+ )
assert isinstance(compat_info, CortexACompatibilityInfo)
- assert expected_success == compat_info.cortex_a_compatible
+ assert expected_success == compat_info.is_cortex_a_compatible
assert compat_info.operators
for oper in compat_info.operators:
assert oper.name
assert oper.location
- assert oper.support_type in Operator.SupportType
+ assert (
+ compat_info.get_support_type(oper) in CortexACompatibilityInfo.SupportType
+ )
def test_get_cortex_a_compatibility_info_compatible(