aboutsummaryrefslogtreecommitdiff
path: root/tests/test_backend_repo.py
blob: 50719893ed07578d4d64f129242abeed4e203cfd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# SPDX-FileCopyrightText: Copyright 2023, Arm Limited and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0
"""Tests for backend repository."""
from __future__ import annotations

import json
from pathlib import Path

import pytest

from mlia.backend.repo import BackendRepository
from mlia.backend.repo import get_backend_repository


def test_get_backend_repository(tmp_path: Path) -> None:
    """Test function get_backend_repository."""
    repo_path = tmp_path / "repo"
    repo = get_backend_repository(repo_path)

    assert isinstance(repo, BackendRepository)

    backends_dir = repo_path / "backends"
    assert backends_dir.is_dir()
    assert not list(backends_dir.iterdir())

    config_file = repo_path / "mlia_config.json"
    assert config_file.is_file()
    assert json.loads(config_file.read_text()) == {"backends": []}


def test_backend_repository_wrong_directory(tmp_path: Path) -> None:
    """Test that repository instance should throw error for the wrong directory."""
    with pytest.raises(
        Exception, match=f"Directory {tmp_path} could not be used as MLIA repository."
    ):
        BackendRepository(tmp_path)


def test_empty_backend_repository(tmp_path: Path) -> None:
    """Test empty backend repository."""
    repo_path = tmp_path / "repo"
    repo = get_backend_repository(repo_path)

    assert not repo.is_backend_installed("sample_backend")

    with pytest.raises(Exception, match="Backend sample_backend is not installed."):
        repo.remove_backend("sample_backend")

    with pytest.raises(Exception, match="Backend sample_backend is not installed."):
        repo.get_backend_settings("sample_backend")


def test_adding_backend(tmp_path: Path) -> None:
    """Test adding backend to the repository."""
    repo_path = tmp_path / "repo"
    repo = get_backend_repository(repo_path)

    backend_path = tmp_path.joinpath("backend")
    backend_path.mkdir()

    settings = {"param": "value"}
    repo.add_backend("sample_backend", backend_path, settings)

    backends_dir = repo_path / "backends"
    assert backends_dir.is_dir()
    assert not list(backends_dir.iterdir())

    assert repo.is_backend_installed("sample_backend")
    expected_settings = {
        "param": "value",
        "backend_path": backend_path.as_posix(),
    }
    assert repo.get_backend_settings("sample_backend") == (
        backend_path,
        expected_settings,
    )

    config_file = repo_path / "mlia_config.json"
    expected_content = {
        "backends": [
            {
                "name": "sample_backend",
                "settings": {
                    "backend_path": backend_path.as_posix(),
                    "param": "value",
                },
            }
        ]
    }
    assert json.loads(config_file.read_text()) == expected_content

    with pytest.raises(Exception, match="Backend sample_backend already installed"):
        repo.add_backend("sample_backend", backend_path, settings)

    repo.remove_backend("sample_backend")
    assert not repo.is_backend_installed("sample_backend")


def test_copy_backend(tmp_path: Path) -> None:
    """Test copying backend to the repository."""
    repo_path = tmp_path / "repo"
    repo = get_backend_repository(repo_path)

    backend_path = tmp_path.joinpath("backend")
    backend_path.mkdir()

    backend_path.joinpath("sample.txt").touch()

    settings = {"param": "value"}
    repo.copy_backend("sample_backend", backend_path, "sample_backend_dir", settings)

    repo_backend_path = repo_path.joinpath("backends", "sample_backend_dir")
    assert repo_backend_path.is_dir()
    assert repo_backend_path.joinpath("sample.txt").is_file()

    config_file = repo_path / "mlia_config.json"
    expected_content = {
        "backends": [
            {
                "name": "sample_backend",
                "settings": {
                    "backend_dir": "sample_backend_dir",
                    "param": "value",
                },
            }
        ]
    }
    assert json.loads(config_file.read_text()) == expected_content

    expected_settings = {
        "param": "value",
        "backend_dir": "sample_backend_dir",
    }
    assert repo.get_backend_settings("sample_backend") == (
        repo_backend_path,
        expected_settings,
    )

    repo.remove_backend("sample_backend")
    assert not repo_backend_path.exists()