diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-09-14 17:02:09 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-02 12:04:44 +0100 |
commit | e2b5e87804e158cb3e5d06a131c317b3890b87b3 (patch) | |
tree | fd8b5a4d56dfcea4be4e6ced73f2d4d5b2e1d92d /verif/checker | |
parent | bb0935f868a5ab09403cf3628848655b06ac1dec (diff) | |
download | reference_model-e2b5e87804e158cb3e5d06a131c317b3890b87b3.tar.gz |
Support for compliance checking testing
Updated to conformance generator to not generate tests with results for
compliance tests.
Updated test runner to run compliance mode version (precise & abs mode)
of reference model to create test results to use against SUT results.
Updated reference model to enable abs_mode on correct desc.json flags.
Updated test checker to support compliance checking using verifier lib.
Seperated color printing from test checker.
Change-Id: I7e2fbfc6883916caa5d94d4ece122c48bf45f530
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'verif/checker')
-rw-r--r-- | verif/checker/color_print.py | 33 | ||||
-rw-r--r-- | verif/checker/tosa_result_checker.py | 231 |
2 files changed, 197 insertions, 67 deletions
diff --git a/verif/checker/color_print.py b/verif/checker/color_print.py new file mode 100644 index 0000000..1563b92 --- /dev/null +++ b/verif/checker/color_print.py @@ -0,0 +1,33 @@ +"""Color printing module.""" +# Copyright (c) 2020-2023, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 +from enum import Enum +from enum import unique + +color_printing = True + + +@unique +class LogColors(Enum): + """Shell escape sequence colors for logging.""" + + NONE = "\u001b[0m" + GREEN = "\u001b[32;1m" + RED = "\u001b[31;1m" + YELLOW = "\u001b[33;1m" + BOLD_WHITE = "\u001b[1m" + + +def set_print_in_color(enabled): + """Set color printing to enabled or disabled.""" + global color_printing + color_printing = enabled + + +def print_color(color, msg): + """Print color status messages if enabled.""" + global color_printing + if not color_printing: + print(msg) + else: + print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 1169a95..38ed510 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -1,43 +1,19 @@ """TOSA result checker script.""" -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse -from enum import Enum +import json from enum import IntEnum from enum import unique from pathlib import Path import numpy as np +from checker.color_print import LogColors +from checker.color_print import print_color +from checker.verifier import VerifierError +from checker.verifier import VerifierLibrary from generator.tosa_utils import float32_is_valid_bfloat16 - -################################## -color_printing = True - - -@unique -class LogColors(Enum): - """Shell escape sequence colors for logging.""" - - NONE = "\u001b[0m" - GREEN = "\u001b[32;1m" - RED = "\u001b[31;1m" - YELLOW = "\u001b[33;1m" - BOLD_WHITE = "\u001b[1m" - - -def set_print_in_color(enabled): - """Set color printing to enabled or disabled.""" - global color_printing - color_printing = enabled - - -def print_color(color, msg): - """Print color status messages if enabled.""" - global color_printing - if not color_printing: - print(msg) - else: - print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) +from schemavalidation.schemavalidation import TestDescSchemaValidator @unique @@ -62,46 +38,120 @@ TestResultErrorStr = [ ################################## DEFAULT_FP_TOLERANCE = 1e-3 +result_printing = True + + +def set_print_result(enabled): + """Set whether to print out or not.""" + global result_printing + result_printing = enabled + + +def _print_result(color, msg): + """Print out result.""" + global result_printing + if result_printing: + print_color(color, msg) + + +def compliance_check( + imp_result_path, + ref_result_path, + bnd_result_path, + test_name, + compliance_config, + ofm_name, + verify_lib_path, +): + try: + vlib = VerifierLibrary(verify_lib_path) + except VerifierError as e: + _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}") + msg = f"Could not load verfier library: {str(e)}" + return (TestResult.INTERNAL_ERROR, 0.0, msg) + + success = vlib.verify_data( + ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path + ) + if success: + _print_result(LogColors.GREEN, f"Results PASS {test_name}") + return (TestResult.PASS, 0.0, "") + else: + _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}") + return (TestResult.MISMATCH, 0.0, "Non-compliance implementation results found") def test_check( - reference_path, - result_path, - test_name="test", + ref_result_path, + imp_result_path, + test_name=None, quantize_tolerance=0, float_tolerance=DEFAULT_FP_TOLERANCE, misc_checks=[], + test_desc=None, + bnd_result_path=None, + ofm_name=None, + verify_lib_path=None, ): """Check if the result is the same as the expected reference.""" - if not reference_path.is_file(): - print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) - msg = "Missing reference file: {}".format(reference_path) - return (TestResult.MISSING_FILE, 0.0, msg) - if not result_path.is_file(): - print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) - msg = "Missing result file: {}".format(result_path) - return (TestResult.MISSING_FILE, 0.0, msg) + if test_desc: + # New compliance method - first get test details + try: + TestDescSchemaValidator().validate_config(test_desc) + except Exception as e: + _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}") + msg = f"Incorrect test format: {e}" + return (TestResult.INCORRECT_FORMAT, 0.0, msg) - try: - test_result = np.load(result_path) - except Exception as e: - print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) - msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( - result_path, e - ) - return (TestResult.INCORRECT_FORMAT, 0.0, msg) - try: - reference_result = np.load(reference_path) - except Exception as e: - print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) - msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( - reference_path, e + if test_name is None: + test_name = "test" + + paths = [imp_result_path, ref_result_path, bnd_result_path] + names = ["Implementation", "Reference", "Bounds"] + arrays = [None, None, None] + + # Check the files exist and are in the right format + for idx, path in enumerate(paths): + name = names[idx] + if path is None and name == "Bounds": + # Bounds can be None - skip it + continue + if not path.is_file(): + _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}") + msg = f"Missing {name} file: {str(path)}" + return (TestResult.MISSING_FILE, 0.0, msg) + try: + arrays[idx] = np.load(path) + except Exception as e: + _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}") + msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}" + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + + if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]: + # Switch to using the verifier library for full compliance + if ofm_name is None: + ofm_name = test_desc["ofm_name"][0] + if len(test_desc["ofm_name"]) > 1: + _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}") + msg = "Must specify output name (ofm_name) to check as multiple found in desc.json" + return (TestResult.MISSING_FILE, 0.0, msg) + + compliance_json = test_desc["meta"]["compliance"] + + return compliance_check( + *arrays, + test_name, + compliance_json, + ofm_name, + verify_lib_path, ) - return (TestResult.INCORRECT_FORMAT, 0.0, msg) + + # Else continue with original checking method + test_result, reference_result, _ = arrays # Type comparison if test_result.dtype != reference_result.dtype: - print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) + _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) msg = "Mismatch results type: Expected {}, got {}".format( reference_result.dtype, test_result.dtype ) @@ -115,7 +165,7 @@ def test_check( difference = None if np.shape(test_result) != np.shape(reference_result): - print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) + _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) msg = "Shapes mismatch: Reference {} vs {}".format( np.shape(test_result), np.shape(reference_result) ) @@ -139,7 +189,7 @@ def test_check( if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): - print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) + _print_result(LogColors.GREEN, "Results PASS {}".format(test_name)) return (TestResult.PASS, 0.0, "") else: tolerance = quantize_tolerance + 1 @@ -166,7 +216,7 @@ def test_check( # All boolean values must match, xor will show up differences test = np.array_equal(reference_result, test_result) if np.all(test): - print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) + _print_result(LogColors.GREEN, "Results PASS {}".format(test_name)) return (TestResult.PASS, 0.0, "") msg = "Boolean result does not match" tolerance = 0.0 @@ -177,18 +227,18 @@ def test_check( elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16: tolerance = float_tolerance if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): - print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) + _print_result(LogColors.GREEN, "Results PASS {}".format(test_name)) return (TestResult.PASS, tolerance, "") msg = "Float result does not match within tolerance of {}".format(tolerance) difference = reference_result - test_result # Fall-through to below to add failure values else: - print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) + _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) msg = "Unsupported results type: {}".format(reference_result.dtype) return (TestResult.MISMATCH, 0.0, msg) # Fall-through for mismatch failure to add values to msg - print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) + _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) np.set_printoptions(threshold=128, edgeitems=2) if difference is not None: @@ -209,18 +259,65 @@ def main(argv=None): """Check that the supplied reference and result files have the same contents.""" parser = argparse.ArgumentParser() parser.add_argument( - "reference_path", type=Path, help="the path to the reference file to test" + "ref_result_path", + type=Path, + help="path to the reference model result file to check", ) parser.add_argument( - "result_path", type=Path, help="the path to the result file to test" + "imp_result_path", + type=Path, + help="path to the implementation result file to check", ) parser.add_argument( "--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance" ) + parser.add_argument( + "--test_path", type=Path, help="path to the test that produced the results" + ) + parser.add_argument( + "--bnd-result-path", + type=Path, + help="path to the reference model bounds result file for the dot product compliance check", + ) + parser.add_argument( + "--ofm-name", + type=str, + help="name of the output tensor to check, defaults to the first ofm_name listed in the test", + ) + parser.add_argument( + "--verify-lib-path", + type=Path, + help="path to TOSA verify library", + ) args = parser.parse_args(argv) + if args.test_path: + # Get details from the test path + test_desc_path = args.test_path / "desc.json" + if not args.test_path.is_dir() or not test_desc_path.is_file(): + print(f"Invalid test directory {str(args.test_path)}") + return TestResult.MISSING_FILE + + try: + with test_desc_path.open("r") as fd: + test_desc = json.load(fd) + except Exception as e: + print(f"Invalid test description file {str(test_desc_path)}: {e}") + return TestResult.INCORRECT_FORMAT + test_name = args.test_path.name + else: + test_desc = None + test_name = None + result, tolerance, msg = test_check( - args.reference_path, args.result_path, float_tolerance=args.fp_tolerance + args.ref_result_path, + args.imp_result_path, + float_tolerance=args.fp_tolerance, + test_name=test_name, + test_desc=test_desc, + bnd_result_path=args.bnd_result_path, + ofm_name=args.ofm_name, + verify_lib_path=args.verify_lib_path, ) if result != TestResult.PASS: print(msg) |