diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-30 10:28:21 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-10-31 20:58:07 +0000 |
commit | 72dcab775c7a84037135bf365086ca976f3220ef (patch) | |
tree | b184d731582734795e77492290c1595ec8ed13d5 /verif | |
parent | a7f5b995d618d26724812fc27011f87600e958dc (diff) | |
download | reference_model-72dcab775c7a84037135bf365086ca976f3220ef.tar.gz |
Allow more integer types in result check.
Added int8, int16, uint8 and uint16.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I6b1174aa7369c34a613b322c4b05bf4fc586ee6e
Diffstat (limited to 'verif')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 9 | ||||
-rw-r--r-- | verif/tests/test_tosa_result_checker.py | 8 |
2 files changed, 12 insertions, 5 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 38ed510..4ba5d4c 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -186,7 +186,14 @@ def test_check( return (TestResult.INCORRECT_FORMAT, 0.0, msg) # for quantized test, allow +-(quantize_tolerance) error - if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: + if reference_result.dtype in ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + ): if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): _print_result(LogColors.GREEN, "Results PASS {}".format(test_name)) diff --git a/verif/tests/test_tosa_result_checker.py b/verif/tests/test_tosa_result_checker.py index d78d158..f9bda39 100644 --- a/verif/tests/test_tosa_result_checker.py +++ b/verif/tests/test_tosa_result_checker.py @@ -32,12 +32,12 @@ def _delete_data_file(file: Path): @pytest.mark.parametrize( "data_type,expected", [ - (np.int8, trc.TestResult.MISMATCH), - (np.int16, trc.TestResult.MISMATCH), + (np.int8, trc.TestResult.PASS), + (np.int16, trc.TestResult.PASS), (np.int32, trc.TestResult.PASS), (np.int64, trc.TestResult.PASS), - (np.uint8, trc.TestResult.MISMATCH), - (np.uint16, trc.TestResult.MISMATCH), + (np.uint8, trc.TestResult.PASS), + (np.uint16, trc.TestResult.PASS), (np.uint32, trc.TestResult.MISMATCH), (np.uint64, trc.TestResult.MISMATCH), (np.float16, trc.TestResult.PASS), |