aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-30 10:28:21 +0000
committerEric Kunze <eric.kunze@arm.com>2023-10-31 20:58:07 +0000
commit72dcab775c7a84037135bf365086ca976f3220ef (patch)
treeb184d731582734795e77492290c1595ec8ed13d5 /verif
parenta7f5b995d618d26724812fc27011f87600e958dc (diff)
downloadreference_model-72dcab775c7a84037135bf365086ca976f3220ef.tar.gz
Allow more integer types in result check.
Added int8, int16, uint8 and uint16. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I6b1174aa7369c34a613b322c4b05bf4fc586ee6e
Diffstat (limited to 'verif')
-rw-r--r--verif/checker/tosa_result_checker.py9
-rw-r--r--verif/tests/test_tosa_result_checker.py8
2 files changed, 12 insertions, 5 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 38ed510..4ba5d4c 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -186,7 +186,14 @@ def test_check(
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
# for quantized test, allow +-(quantize_tolerance) error
- if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
+ if reference_result.dtype in (
+ np.int8,
+ np.int16,
+ np.int32,
+ np.int64,
+ np.uint8,
+ np.uint16,
+ ):
if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
_print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py
index d78d158..f9bda39 100644
--- a/verif/tests/test_tosa_result_checker.py
+++ b/verif/tests/test_tosa_result_checker.py
@@ -32,12 +32,12 @@ def _delete_data_file(file: Path):
@pytest.mark.parametrize(
"data_type,expected",
[
- (np.int8, trc.TestResult.MISMATCH),
- (np.int16, trc.TestResult.MISMATCH),
+ (np.int8, trc.TestResult.PASS),
+ (np.int16, trc.TestResult.PASS),
(np.int32, trc.TestResult.PASS),
(np.int64, trc.TestResult.PASS),
- (np.uint8, trc.TestResult.MISMATCH),
- (np.uint16, trc.TestResult.MISMATCH),
+ (np.uint8, trc.TestResult.PASS),
+ (np.uint16, trc.TestResult.PASS),
(np.uint32, trc.TestResult.MISMATCH),
(np.uint64, trc.TestResult.MISMATCH),
(np.float16, trc.TestResult.PASS),