aboutsummaryrefslogtreecommitdiff
path: root/verif/tests/test_tosa_result_checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tests/test_tosa_result_checker.py')
-rw-r--r--verif/tests/test_tosa_result_checker.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py
new file mode 100644
index 0000000..bc8a2fc
--- /dev/null
+++ b/verif/tests/test_tosa_result_checker.py
@@ -0,0 +1,197 @@
+"""Tests for tosa_result_checker.py."""
+# Copyright (c) 2021-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from pathlib import Path
+
+import numpy as np
+import pytest
+
+import checker.tosa_result_checker as trc
+
+
+def _create_data_file(name, npy_data):
+ """Create numpy data file."""
+ file = Path(__file__).parent / name
+ with open(file, "wb") as f:
+ np.save(f, npy_data)
+ return file
+
+
+def _create_empty_file(name):
+ """Create numpy data file."""
+ file = Path(__file__).parent / name
+ f = open(file, "wb")
+ f.close()
+ return file
+
+
+def _delete_data_file(file: Path):
+ """Delete numpy data file."""
+ file.unlink()
+
+
+@pytest.mark.parametrize(
+ "data_type,expected",
+ [
+ (np.int8, trc.TestResult.MISMATCH),
+ (np.int16, trc.TestResult.MISMATCH),
+ (np.int32, trc.TestResult.PASS),
+ (np.int64, trc.TestResult.PASS),
+ (np.uint8, trc.TestResult.MISMATCH),
+ (np.uint16, trc.TestResult.MISMATCH),
+ (np.uint32, trc.TestResult.MISMATCH),
+ (np.uint64, trc.TestResult.MISMATCH),
+ (np.float16, trc.TestResult.MISMATCH),
+ (np.float32, trc.TestResult.PASS),
+ (np.float64, trc.TestResult.MISMATCH),
+ (bool, trc.TestResult.PASS),
+ ],
+)
+def test_supported_types(data_type, expected):
+ """Check which data types are supported."""
+ # Generate data
+ npy_data = np.ndarray(shape=(2, 3), dtype=data_type)
+
+ # Save data as reference and result files to compare.
+ reference_file = _create_data_file("reference.npy", npy_data)
+ result_file = _create_data_file("result.npy", npy_data)
+
+ args = [str(reference_file), str(result_file)]
+ """Compares reference and result npy files, returns zero if it passes."""
+ assert trc.main(args) == expected
+
+ # Remove files created
+ _delete_data_file(reference_file)
+ _delete_data_file(result_file)
+
+
+@pytest.mark.parametrize(
+ "data_type,expected",
+ [
+ (np.int32, trc.TestResult.MISMATCH),
+ (np.int64, trc.TestResult.MISMATCH),
+ (np.float32, trc.TestResult.MISMATCH),
+ (bool, trc.TestResult.MISMATCH),
+ ],
+)
+def test_shape_mismatch(data_type, expected):
+ """Check that mismatch shapes do not pass."""
+ # Generate and save data as reference and result files to compare.
+ npy_data = np.ones(shape=(3, 2), dtype=data_type)
+ reference_file = _create_data_file("reference.npy", npy_data)
+ npy_data = np.ones(shape=(2, 3), dtype=data_type)
+ result_file = _create_data_file("result.npy", npy_data)
+
+ args = [str(reference_file), str(result_file)]
+ """Compares reference and result npy files, returns zero if it passes."""
+ assert trc.main(args) == expected
+
+ # Remove files created
+ _delete_data_file(reference_file)
+ _delete_data_file(result_file)
+
+
+@pytest.mark.parametrize(
+ "data_type,expected",
+ [
+ (np.int32, trc.TestResult.MISMATCH),
+ (np.int64, trc.TestResult.MISMATCH),
+ (np.float32, trc.TestResult.MISMATCH),
+ (bool, trc.TestResult.MISMATCH),
+ ],
+)
+def test_results_mismatch(data_type, expected):
+ """Check that different results do not pass."""
+ # Generate and save data as reference and result files to compare.
+ npy_data = np.zeros(shape=(2, 3), dtype=data_type)
+ reference_file = _create_data_file("reference.npy", npy_data)
+ npy_data = np.ones(shape=(2, 3), dtype=data_type)
+ result_file = _create_data_file("result.npy", npy_data)
+
+ args = [str(reference_file), str(result_file)]
+ """Compares reference and result npy files, returns zero if it passes."""
+ assert trc.main(args) == expected
+
+ # Remove files created
+ _delete_data_file(reference_file)
+ _delete_data_file(result_file)
+
+
+@pytest.mark.parametrize(
+ "data_type1,data_type2,expected",
+ [ # Pairwise testing of all supported types
+ (np.int32, np.int64, trc.TestResult.MISMATCH),
+ (bool, np.float32, trc.TestResult.MISMATCH),
+ ],
+)
+def test_types_mismatch(data_type1, data_type2, expected):
+ """Check that different types in results do not pass."""
+ # Generate and save data as reference and result files to compare.
+ npy_data = np.ones(shape=(3, 2), dtype=data_type1)
+ reference_file = _create_data_file("reference.npy", npy_data)
+ npy_data = np.ones(shape=(3, 2), dtype=data_type2)
+ result_file = _create_data_file("result.npy", npy_data)
+
+ args = [str(reference_file), str(result_file)]
+ """Compares reference and result npy files, returns zero if it passes."""
+ assert trc.main(args) == expected
+
+ # Remove files created
+ _delete_data_file(reference_file)
+ _delete_data_file(result_file)
+
+
+@pytest.mark.parametrize(
+ "reference_exists,result_exists,expected",
+ [
+ (True, False, trc.TestResult.MISSING_FILE),
+ (False, True, trc.TestResult.MISSING_FILE),
+ ],
+)
+def test_missing_files(reference_exists, result_exists, expected):
+ """Check that missing files are caught."""
+ # Generate and save data
+ npy_data = np.ndarray(shape=(2, 3), dtype=bool)
+ reference_file = _create_data_file("reference.npy", npy_data)
+ result_file = _create_data_file("result.npy", npy_data)
+ if not reference_exists:
+ _delete_data_file(reference_file)
+ if not result_exists:
+ _delete_data_file(result_file)
+
+ args = [str(reference_file), str(result_file)]
+ assert trc.main(args) == expected
+
+ if reference_exists:
+ _delete_data_file(reference_file)
+ if result_exists:
+ _delete_data_file(result_file)
+
+
+@pytest.mark.parametrize(
+ "reference_numpy,result_numpy,expected",
+ [
+ (True, False, trc.TestResult.INCORRECT_FORMAT),
+ (False, True, trc.TestResult.INCORRECT_FORMAT),
+ ],
+)
+def test_incorrect_format_files(reference_numpy, result_numpy, expected):
+ """Check that incorrect format files are caught."""
+ # Generate and save data
+ npy_data = np.ndarray(shape=(2, 3), dtype=bool)
+ reference_file = (
+ _create_data_file("reference.npy", npy_data)
+ if reference_numpy
+ else _create_empty_file("empty.npy")
+ )
+ result_file = (
+ _create_data_file("result.npy", npy_data)
+ if result_numpy
+ else _create_empty_file("empty.npy")
+ )
+
+ args = [str(reference_file), str(result_file)]
+ assert trc.main(args) == expected
+
+ _delete_data_file(reference_file)
+ _delete_data_file(result_file)