aboutsummaryrefslogtreecommitdiff
path: root/tests/test_backend_repo.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_backend_repo.py')
-rw-r--r--tests/test_backend_repo.py140
1 files changed, 140 insertions, 0 deletions
diff --git a/tests/test_backend_repo.py b/tests/test_backend_repo.py
new file mode 100644
index 0000000..5071989
--- /dev/null
+++ b/tests/test_backend_repo.py
@@ -0,0 +1,140 @@
+# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
+# SPDX-License-Identifier: Apache-2.0
+"""Tests for backend repository."""
+from __future__ import annotations
+
+import json
+from pathlib import Path
+
+import pytest
+
+from mlia.backend.repo import BackendRepository
+from mlia.backend.repo import get_backend_repository
+
+
+def test_get_backend_repository(tmp_path: Path) -> None:
+ """Test function get_backend_repository."""
+ repo_path = tmp_path / "repo"
+ repo = get_backend_repository(repo_path)
+
+ assert isinstance(repo, BackendRepository)
+
+ backends_dir = repo_path / "backends"
+ assert backends_dir.is_dir()
+ assert not list(backends_dir.iterdir())
+
+ config_file = repo_path / "mlia_config.json"
+ assert config_file.is_file()
+ assert json.loads(config_file.read_text()) == {"backends": []}
+
+
+def test_backend_repository_wrong_directory(tmp_path: Path) -> None:
+ """Test that repository instance should throw error for the wrong directory."""
+ with pytest.raises(
+ Exception, match=f"Directory {tmp_path} could not be used as MLIA repository."
+ ):
+ BackendRepository(tmp_path)
+
+
+def test_empty_backend_repository(tmp_path: Path) -> None:
+ """Test empty backend repository."""
+ repo_path = tmp_path / "repo"
+ repo = get_backend_repository(repo_path)
+
+ assert not repo.is_backend_installed("sample_backend")
+
+ with pytest.raises(Exception, match="Backend sample_backend is not installed."):
+ repo.remove_backend("sample_backend")
+
+ with pytest.raises(Exception, match="Backend sample_backend is not installed."):
+ repo.get_backend_settings("sample_backend")
+
+
+def test_adding_backend(tmp_path: Path) -> None:
+ """Test adding backend to the repository."""
+ repo_path = tmp_path / "repo"
+ repo = get_backend_repository(repo_path)
+
+ backend_path = tmp_path.joinpath("backend")
+ backend_path.mkdir()
+
+ settings = {"param": "value"}
+ repo.add_backend("sample_backend", backend_path, settings)
+
+ backends_dir = repo_path / "backends"
+ assert backends_dir.is_dir()
+ assert not list(backends_dir.iterdir())
+
+ assert repo.is_backend_installed("sample_backend")
+ expected_settings = {
+ "param": "value",
+ "backend_path": backend_path.as_posix(),
+ }
+ assert repo.get_backend_settings("sample_backend") == (
+ backend_path,
+ expected_settings,
+ )
+
+ config_file = repo_path / "mlia_config.json"
+ expected_content = {
+ "backends": [
+ {
+ "name": "sample_backend",
+ "settings": {
+ "backend_path": backend_path.as_posix(),
+ "param": "value",
+ },
+ }
+ ]
+ }
+ assert json.loads(config_file.read_text()) == expected_content
+
+ with pytest.raises(Exception, match="Backend sample_backend already installed"):
+ repo.add_backend("sample_backend", backend_path, settings)
+
+ repo.remove_backend("sample_backend")
+ assert not repo.is_backend_installed("sample_backend")
+
+
+def test_copy_backend(tmp_path: Path) -> None:
+ """Test copying backend to the repository."""
+ repo_path = tmp_path / "repo"
+ repo = get_backend_repository(repo_path)
+
+ backend_path = tmp_path.joinpath("backend")
+ backend_path.mkdir()
+
+ backend_path.joinpath("sample.txt").touch()
+
+ settings = {"param": "value"}
+ repo.copy_backend("sample_backend", backend_path, "sample_backend_dir", settings)
+
+ repo_backend_path = repo_path.joinpath("backends", "sample_backend_dir")
+ assert repo_backend_path.is_dir()
+ assert repo_backend_path.joinpath("sample.txt").is_file()
+
+ config_file = repo_path / "mlia_config.json"
+ expected_content = {
+ "backends": [
+ {
+ "name": "sample_backend",
+ "settings": {
+ "backend_dir": "sample_backend_dir",
+ "param": "value",
+ },
+ }
+ ]
+ }
+ assert json.loads(config_file.read_text()) == expected_content
+
+ expected_settings = {
+ "param": "value",
+ "backend_dir": "sample_backend_dir",
+ }
+ assert repo.get_backend_settings("sample_backend") == (
+ repo_backend_path,
+ expected_settings,
+ )
+
+ repo.remove_backend("sample_backend")
+ assert not repo_backend_path.exists()