From e4b08ffbe457c8932740e3171964cf2e7cd69b4f Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 15 Sep 2022 10:38:17 +0100 Subject: Initial set up of Main Inference conformance test gen tosa-verif-build-tests - option for setting FP values range - option for recursively finding tests - change from os.path to Path tosa_verif_result_check - option to supply FP tolerance - output difference and max tolerance on contents mismatch - change from os.path to Path MI conformance - contains examples of AVG_POOL2D and CONV2D tests Signed-off-by: Jeremy Johnson Change-Id: I8e1645cd8f10308604400ea53eef723ca163eed7 --- verif/checker/tosa_result_checker.py | 56 ++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 19 deletions(-) (limited to 'verif/checker/tosa_result_checker.py') diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index b7a76b6..1169a95 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -2,7 +2,6 @@ # Copyright (c) 2020-2022, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse -import os from enum import Enum from enum import IntEnum from enum import unique @@ -62,37 +61,41 @@ TestResultErrorStr = [ ] ################################## +DEFAULT_FP_TOLERANCE = 1e-3 + def test_check( - reference, - result, + reference_path, + result_path, test_name="test", quantize_tolerance=0, - float_tolerance=1e-3, + float_tolerance=DEFAULT_FP_TOLERANCE, misc_checks=[], ): """Check if the result is the same as the expected reference.""" - if not os.path.isfile(reference): + if not reference_path.is_file(): print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) - msg = "Missing reference file: {}".format(reference) + msg = "Missing reference file: {}".format(reference_path) return (TestResult.MISSING_FILE, 0.0, msg) - if not os.path.isfile(result): + if not result_path.is_file(): print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) - msg = "Missing result file: {}".format(result) + msg = "Missing result file: {}".format(result_path) return (TestResult.MISSING_FILE, 0.0, msg) try: - test_result = np.load(result) + test_result = np.load(result_path) except Exception as e: print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) - msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e) + msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( + result_path, e + ) return (TestResult.INCORRECT_FORMAT, 0.0, msg) try: - reference_result = np.load(reference) + reference_result = np.load(reference_path) except Exception as e: print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( - reference, e + reference_path, e ) return (TestResult.INCORRECT_FORMAT, 0.0, msg) @@ -109,6 +112,7 @@ def test_check( # >= 0, allow that special case test_result = np.squeeze(test_result) reference_result = np.squeeze(reference_result) + difference = None if np.shape(test_result) != np.shape(reference_result): print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) @@ -155,6 +159,7 @@ def test_check( ) ) # Fall-through to below to add failure values + difference = reference_result - test_result elif reference_result.dtype == bool: assert test_result.dtype == bool @@ -165,6 +170,7 @@ def test_check( return (TestResult.PASS, 0.0, "") msg = "Boolean result does not match" tolerance = 0.0 + difference = None # Fall-through to below to add failure values # TODO: update for fp16 tolerance @@ -174,6 +180,7 @@ def test_check( print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) return (TestResult.PASS, tolerance, "") msg = "Float result does not match within tolerance of {}".format(tolerance) + difference = reference_result - test_result # Fall-through to below to add failure values else: print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) @@ -182,16 +189,24 @@ def test_check( # Fall-through for mismatch failure to add values to msg print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) - np.set_printoptions(threshold=128) - msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result) - msg = "{}\nreference_result: {}\n{}".format( + np.set_printoptions(threshold=128, edgeitems=2) + + if difference is not None: + tolerance_needed = np.amax(np.absolute(difference)) + msg = "{}\n-- tolerance_needed: {}".format(msg, tolerance_needed) + + msg = "{}\n>> reference_result: {}\n{}".format( msg, reference_result.shape, reference_result ) + msg = "{}\n<< test_result: {}\n{}".format(msg, test_result.shape, test_result) + + if difference is not None: + msg = "{}\n!! difference_result: \n{}".format(msg, difference) return (TestResult.MISMATCH, tolerance, msg) def main(argv=None): - """Check that the supplied reference and result files are the same.""" + """Check that the supplied reference and result files have the same contents.""" parser = argparse.ArgumentParser() parser.add_argument( "reference_path", type=Path, help="the path to the reference file to test" @@ -199,11 +214,14 @@ def main(argv=None): parser.add_argument( "result_path", type=Path, help="the path to the result file to test" ) + parser.add_argument( + "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance" + ) args = parser.parse_args(argv) - ref_path = args.reference_path - res_path = args.result_path - result, tolerance, msg = test_check(ref_path, res_path) + result, tolerance, msg = test_check( + args.reference_path, args.result_path, float_tolerance=args.fp_tolerance + ) if result != TestResult.PASS: print(msg) -- cgit v1.2.1