diff options
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r-- | verif/checker/tosa_result_checker.py | 187 |
1 files changed, 187 insertions, 0 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py new file mode 100644 index 0000000..3a15de9 --- /dev/null +++ b/verif/checker/tosa_result_checker.py @@ -0,0 +1,187 @@ +"""TOSA result checker script.""" +# Copyright (c) 2020-2022, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 +import argparse +import os +from enum import Enum +from enum import IntEnum +from enum import unique +from pathlib import Path + +import numpy as np + +################################## +no_color_printing = False + + +@unique +class LogColors(Enum): + """Shell escape sequence colors for logging.""" + + NONE = "\u001b[0m" + GREEN = "\u001b[32;1m" + RED = "\u001b[31;1m" + YELLOW = "\u001b[33;1m" + BOLD_WHITE = "\u001b[1m" + + +def print_color(color, msg): + """Print color status messages if enabled.""" + if no_color_printing: + print(msg) + else: + print("{}{}{}".format(color.value, msg, LogColors.NONE.value)) + + +@unique +class TestResult(IntEnum): + """Test result values.""" + + # Note: PASS must be 0 for command line return success + PASS = 0 + MISSING_FILE = 1 + INCORRECT_FORMAT = 2 + MISMATCH = 3 + INTERNAL_ERROR = 4 + + +TestResultErrorStr = [ + "", + "Missing file", + "Incorrect format", + "Mismatch", + "Internal error", +] +################################## + + +def test_check( + reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 +): + """Check if the result is the same as the expected reference.""" + if not os.path.isfile(reference): + print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name)) + msg = "Missing reference file: {}".format(reference) + return (TestResult.MISSING_FILE, 0.0, msg) + if not os.path.isfile(result): + print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name)) + msg = "Missing result file: {}".format(result) + return (TestResult.MISSING_FILE, 0.0, msg) + + try: + test_result = np.load(result) + except Exception as e: + print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name)) + msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e) + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + try: + reference_result = np.load(reference) + except Exception as e: + print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name)) + msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format( + reference, e + ) + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + + # Type comparison + if test_result.dtype != reference_result.dtype: + print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name)) + msg = "Mismatch results type: Expected {}, got {}".format( + reference_result.dtype, test_result.dtype + ) + return (TestResult.MISMATCH, 0.0, msg) + + # Size comparison + # Size = 1 tensors can be equivalently represented as having rank 0 or rank + # >= 0, allow that special case + test_result = np.squeeze(test_result) + reference_result = np.squeeze(reference_result) + + if np.shape(test_result) != np.shape(reference_result): + print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) + msg = "Shapes mismatch: Reference {} vs {}".format( + np.shape(test_result), np.shape(reference_result) + ) + return (TestResult.MISMATCH, 0.0, msg) + + # for quantized test, allow +-(quantize_tolerance) error + if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: + + if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance): + print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) + return (TestResult.PASS, 0.0, "") + else: + tolerance = quantize_tolerance + 1 + while not np.all( + np.absolute(reference_result - test_result) <= quantize_tolerance + ): + tolerance = tolerance + 1 + if tolerance > 10: + break + + if tolerance > 10: + msg = "Integer result does not match and is greater than 10 difference" + else: + msg = ( + "Integer result does not match but is within {} difference".format( + tolerance + ) + ) + # Fall-through to below to add failure values + + elif reference_result.dtype == bool: + assert test_result.dtype == bool + # All boolean values must match, xor will show up differences + test = np.array_equal(reference_result, test_result) + if np.all(test): + print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) + return (TestResult.PASS, 0.0, "") + msg = "Boolean result does not match" + tolerance = 0.0 + # Fall-through to below to add failure values + + elif reference_result.dtype == np.float32: + tolerance = float_tolerance + if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True): + print_color(LogColors.GREEN, "Results PASS {}".format(test_name)) + return (TestResult.PASS, tolerance, "") + msg = "Float result does not match within tolerance of {}".format(tolerance) + # Fall-through to below to add failure values + + else: + print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name)) + msg = "Unsupported results type: {}".format(reference_result.dtype) + return (TestResult.MISMATCH, 0.0, msg) + + # Fall-through for mismatch failure to add values to msg + print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name)) + np.set_printoptions(threshold=128) + msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result) + msg = "{}\nreference_result: {}\n{}".format( + msg, reference_result.shape, reference_result + ) + return (TestResult.MISMATCH, tolerance, msg) + + +def main(argv=None): + """Check that the supplied reference and result files are the same.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "reference_path", type=Path, help="the path to the reference file to test" + ) + parser.add_argument( + "result_path", type=Path, help="the path to the result file to test" + ) + args = parser.parse_args(argv) + ref_path = args.reference_path + res_path = args.result_path + + result, tolerance, msg = test_check(ref_path, res_path) + if result != TestResult.PASS: + print(msg) + + return result + + +if __name__ == "__main__": + exit(main()) |