aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r--verif/checker/tosa_result_checker.py56
1 files changed, 37 insertions, 19 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index b7a76b6..1169a95 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -2,7 +2,6 @@
# Copyright (c) 2020-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
-import os
from enum import Enum
from enum import IntEnum
from enum import unique
@@ -62,37 +61,41 @@ TestResultErrorStr = [
]
##################################
+DEFAULT_FP_TOLERANCE = 1e-3
+
def test_check(
- reference,
- result,
+ reference_path,
+ result_path,
test_name="test",
quantize_tolerance=0,
- float_tolerance=1e-3,
+ float_tolerance=DEFAULT_FP_TOLERANCE,
misc_checks=[],
):
"""Check if the result is the same as the expected reference."""
- if not os.path.isfile(reference):
+ if not reference_path.is_file():
print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
- msg = "Missing reference file: {}".format(reference)
+ msg = "Missing reference file: {}".format(reference_path)
return (TestResult.MISSING_FILE, 0.0, msg)
- if not os.path.isfile(result):
+ if not result_path.is_file():
print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
- msg = "Missing result file: {}".format(result)
+ msg = "Missing result file: {}".format(result_path)
return (TestResult.MISSING_FILE, 0.0, msg)
try:
- test_result = np.load(result)
+ test_result = np.load(result_path)
except Exception as e:
print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
- msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
+ msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
+ result_path, e
+ )
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
try:
- reference_result = np.load(reference)
+ reference_result = np.load(reference_path)
except Exception as e:
print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
- reference, e
+ reference_path, e
)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
@@ -109,6 +112,7 @@ def test_check(
# >= 0, allow that special case
test_result = np.squeeze(test_result)
reference_result = np.squeeze(reference_result)
+ difference = None
if np.shape(test_result) != np.shape(reference_result):
print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
@@ -155,6 +159,7 @@ def test_check(
)
)
# Fall-through to below to add failure values
+ difference = reference_result - test_result
elif reference_result.dtype == bool:
assert test_result.dtype == bool
@@ -165,6 +170,7 @@ def test_check(
return (TestResult.PASS, 0.0, "")
msg = "Boolean result does not match"
tolerance = 0.0
+ difference = None
# Fall-through to below to add failure values
# TODO: update for fp16 tolerance
@@ -174,6 +180,7 @@ def test_check(
print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, tolerance, "")
msg = "Float result does not match within tolerance of {}".format(tolerance)
+ difference = reference_result - test_result
# Fall-through to below to add failure values
else:
print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
@@ -182,16 +189,24 @@ def test_check(
# Fall-through for mismatch failure to add values to msg
print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
- np.set_printoptions(threshold=128)
- msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
- msg = "{}\nreference_result: {}\n{}".format(
+ np.set_printoptions(threshold=128, edgeitems=2)
+
+ if difference is not None:
+ tolerance_needed = np.amax(np.absolute(difference))
+ msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed)
+
+ msg = "{}\n>> reference_result: {}\n{}".format(
msg, reference_result.shape, reference_result
)
+ msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result)
+
+ if difference is not None:
+ msg = "{}\n!! difference_result: \n{}".format(msg, difference)
return (TestResult.MISMATCH, tolerance, msg)
def main(argv=None):
- """Check that the supplied reference and result files are the same."""
+ """Check that the supplied reference and result files have the same contents."""
parser = argparse.ArgumentParser()
parser.add_argument(
"reference_path", type=Path, help="the path to the reference file to test"
@@ -199,11 +214,14 @@ def main(argv=None):
parser.add_argument(
"result_path", type=Path, help="the path to the result file to test"
)
+ parser.add_argument(
+ "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
+ )
args = parser.parse_args(argv)
- ref_path = args.reference_path
- res_path = args.result_path
- result, tolerance, msg = test_check(ref_path, res_path)
+ result, tolerance, msg = test_check(
+ args.reference_path, args.result_path, float_tolerance=args.fp_tolerance
+ )
if result != TestResult.PASS:
print(msg)