aboutsummaryrefslogtreecommitdiff
path: root/tests/test_target_cortex_a_data_collection.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_target_cortex_a_data_collection.py')
-rw-r--r--tests/test_target_cortex_a_data_collection.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/tests/test_target_cortex_a_data_collection.py b/tests/test_target_cortex_a_data_collection.py
new file mode 100644
index 0000000..7504166
--- /dev/null
+++ b/tests/test_target_cortex_a_data_collection.py
@@ -0,0 +1,52 @@
+# SPDX-FileCopyrightText: Copyright 2022, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for Cortex-A data collection module."""
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from mlia.core.context import ExecutionContext
+from mlia.target.cortex_a.data_collection import CortexAOperatorCompatibility
+from mlia.target.cortex_a.operators import CortexACompatibilityInfo
+
+
+def check_cortex_a_data_collection(
+ monkeypatch: pytest.MonkeyPatch, model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection."""
+ assert CortexAOperatorCompatibility.name()
+
+ monkeypatch.setattr(
+ "mlia.target.cortex_a.data_collection.get_cortex_a_compatibility_info",
+ MagicMock(return_value=CortexACompatibilityInfo(True, [])),
+ )
+
+ context = ExecutionContext(working_dir=tmpdir)
+ collector = CortexAOperatorCompatibility(model)
+ collector.set_context(context)
+
+ data_item = collector.collect_data()
+
+ assert isinstance(data_item, CortexACompatibilityInfo)
+
+
+def test_cortex_a_data_collection_tflite(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a TensorFlow Lite model."""
+ check_cortex_a_data_collection(monkeypatch, test_tflite_model, tmpdir)
+
+
+def test_cortex_a_data_collection_keras(
+ monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a Keras model."""
+ check_cortex_a_data_collection(monkeypatch, test_keras_model, tmpdir)
+
+
+def test_cortex_a_data_collection_tf(
+ monkeypatch: pytest.MonkeyPatch, test_tf_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a SavedModel."""
+ check_cortex_a_data_collection(monkeypatch, test_tf_model, tmpdir)