diff options
Diffstat (limited to 'tests/test_devices_cortexa_data_collection.py')
-rw-r--r-- | tests/test_devices_cortexa_data_collection.py | 30 |
1 files changed, 27 insertions, 3 deletions
diff --git a/tests/test_devices_cortexa_data_collection.py b/tests/test_devices_cortexa_data_collection.py index 7ea3e52..6d3b2ac 100644 --- a/tests/test_devices_cortexa_data_collection.py +++ b/tests/test_devices_cortexa_data_collection.py @@ -11,18 +11,42 @@ from mlia.devices.cortexa.data_collection import CortexAOperatorCompatibility from mlia.devices.cortexa.operators import CortexACompatibilityInfo -def test_cortex_a_data_collection( - monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str +def check_cortex_a_data_collection( + monkeypatch: pytest.MonkeyPatch, model: Path, tmpdir: str ) -> None: """Test Cortex-A data collection.""" + assert CortexAOperatorCompatibility.name() + monkeypatch.setattr( "mlia.devices.cortexa.data_collection.get_cortex_a_compatibility_info", MagicMock(return_value=CortexACompatibilityInfo(True, [])), ) + context = ExecutionContext(working_dir=tmpdir) - collector = CortexAOperatorCompatibility(test_tflite_model) + collector = CortexAOperatorCompatibility(model) collector.set_context(context) data_item = collector.collect_data() assert isinstance(data_item, CortexACompatibilityInfo) + + +def test_cortex_a_data_collection_tflite( + monkeypatch: pytest.MonkeyPatch, test_tflite_model: Path, tmpdir: str +) -> None: + """Test Cortex-A data collection with a TensorFlow Lite model.""" + check_cortex_a_data_collection(monkeypatch, test_tflite_model, tmpdir) + + +def test_cortex_a_data_collection_keras( + monkeypatch: pytest.MonkeyPatch, test_keras_model: Path, tmpdir: str +) -> None: + """Test Cortex-A data collection with a Keras model.""" + check_cortex_a_data_collection(monkeypatch, test_keras_model, tmpdir) + + +def test_cortex_a_data_collection_tf( + monkeypatch: pytest.MonkeyPatch, test_tf_model: Path, tmpdir: str +) -> None: + """Test Cortex-A data collection with a SavedModel.""" + check_cortex_a_data_collection(monkeypatch, test_tf_model, tmpdir) |