aboutsummaryrefslogtreecommitdiff
path: root/tests/test_devices_cortexa_data_collection.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_devices_cortexa_data_collection.py')
-rw-r--r--tests/test_devices_cortexa_data_collection.py30
1 files changed, 27 insertions, 3 deletions
diff --git a/tests/test_devices_cortexa_data_collection.py b/tests/test_devices_cortexa_data_collection.py
index 7ea3e52..6d3b2ac 100644
--- a/tests/test_devices_cortexa_data_collection.py
+++ b/tests/test_devices_cortexa_data_collection.py
@@ -11,18 +11,42 @@ from mlia.devices.cortexa.data_collection import CortexAOperatorCompatibility
from mlia.devices.cortexa.operators import CortexACompatibilityInfo
-def test_cortex_a_data_collection(
- monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+def check_cortex_a_data_collection(
+ monkeypatch: pytest.MonkeyPatch, model: Path, tmpdir: str
) -> None:
"""Test Cortex-A data collection."""
+ assert CortexAOperatorCompatibility.name()
+
monkeypatch.setattr(
"mlia.devices.cortexa.data_collection.get_cortex_a_compatibility_info",
MagicMock(return_value=CortexACompatibilityInfo(True, [])),
)
+
context = ExecutionContext(working_dir=tmpdir)
- collector = CortexAOperatorCompatibility(test_tflite_model)
+ collector = CortexAOperatorCompatibility(model)
collector.set_context(context)
data_item = collector.collect_data()
assert isinstance(data_item, CortexACompatibilityInfo)
+
+
+def test_cortex_a_data_collection_tflite(
+ monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a TensorFlow Lite model."""
+ check_cortex_a_data_collection(monkeypatch, test_tflite_model, tmpdir)
+
+
+def test_cortex_a_data_collection_keras(
+ monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a Keras model."""
+ check_cortex_a_data_collection(monkeypatch, test_keras_model, tmpdir)
+
+
+def test_cortex_a_data_collection_tf(
+ monkeypatch: pytest.MonkeyPatch, test_tf_model: Path, tmpdir: str
+) -> None:
+ """Test Cortex-A data collection with a SavedModel."""
+ check_cortex_a_data_collection(monkeypatch, test_tf_model, tmpdir)