aboutsummaryrefslogtreecommitdiff
path: root/verif/checker/tosa_result_checker.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/checker/tosa_result_checker.py')
-rw-r--r--verif/checker/tosa_result_checker.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
new file mode 100644
index 0000000..3a15de9
--- /dev/null
+++ b/verif/checker/tosa_result_checker.py
@@ -0,0 +1,187 @@
+"""TOSA result checker script."""
+# Copyright (c) 2020-2022, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import argparse
+import os
+from enum import Enum
+from enum import IntEnum
+from enum import unique
+from pathlib import Path
+
+import numpy as np
+
+##################################
+no_color_printing = False
+
+
+@unique
+class LogColors(Enum):
+ """Shell escape sequence colors for logging."""
+
+ NONE = "\u001b[0m"
+ GREEN = "\u001b[32;1m"
+ RED = "\u001b[31;1m"
+ YELLOW = "\u001b[33;1m"
+ BOLD_WHITE = "\u001b[1m"
+
+
+def print_color(color, msg):
+ """Print color status messages if enabled."""
+ if no_color_printing:
+ print(msg)
+ else:
+ print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
+
+
+@unique
+class TestResult(IntEnum):
+ """Test result values."""
+
+ # Note: PASS must be 0 for command line return success
+ PASS = 0
+ MISSING_FILE = 1
+ INCORRECT_FORMAT = 2
+ MISMATCH = 3
+ INTERNAL_ERROR = 4
+
+
+TestResultErrorStr = [
+ "",
+ "Missing file",
+ "Incorrect format",
+ "Mismatch",
+ "Internal error",
+]
+##################################
+
+
+def test_check(
+ reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+):
+ """Check if the result is the same as the expected reference."""
+ if not os.path.isfile(reference):
+ print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
+ msg = "Missing reference file: {}".format(reference)
+ return (TestResult.MISSING_FILE, 0.0, msg)
+ if not os.path.isfile(result):
+ print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
+ msg = "Missing result file: {}".format(result)
+ return (TestResult.MISSING_FILE, 0.0, msg)
+
+ try:
+ test_result = np.load(result)
+ except Exception as e:
+ print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
+ msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(result, e)
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+ try:
+ reference_result = np.load(reference)
+ except Exception as e:
+ print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
+ msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
+ reference, e
+ )
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
+ # Type comparison
+ if test_result.dtype != reference_result.dtype:
+ print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
+ msg = "Mismatch results type: Expected {}, got {}".format(
+ reference_result.dtype, test_result.dtype
+ )
+ return (TestResult.MISMATCH, 0.0, msg)
+
+ # Size comparison
+ # Size = 1 tensors can be equivalently represented as having rank 0 or rank
+ # >= 0, allow that special case
+ test_result = np.squeeze(test_result)
+ reference_result = np.squeeze(reference_result)
+
+ if np.shape(test_result) != np.shape(reference_result):
+ print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+ msg = "Shapes mismatch: Reference {} vs {}".format(
+ np.shape(test_result), np.shape(reference_result)
+ )
+ return (TestResult.MISMATCH, 0.0, msg)
+
+ # for quantized test, allow +-(quantize_tolerance) error
+ if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
+
+ if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
+ print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+ return (TestResult.PASS, 0.0, "")
+ else:
+ tolerance = quantize_tolerance + 1
+ while not np.all(
+ np.absolute(reference_result - test_result) <= quantize_tolerance
+ ):
+ tolerance = tolerance + 1
+ if tolerance > 10:
+ break
+
+ if tolerance > 10:
+ msg = "Integer result does not match and is greater than 10 difference"
+ else:
+ msg = (
+ "Integer result does not match but is within {} difference".format(
+ tolerance
+ )
+ )
+ # Fall-through to below to add failure values
+
+ elif reference_result.dtype == bool:
+ assert test_result.dtype == bool
+ # All boolean values must match, xor will show up differences
+ test = np.array_equal(reference_result, test_result)
+ if np.all(test):
+ print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+ return (TestResult.PASS, 0.0, "")
+ msg = "Boolean result does not match"
+ tolerance = 0.0
+ # Fall-through to below to add failure values
+
+ elif reference_result.dtype == np.float32:
+ tolerance = float_tolerance
+ if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
+ print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+ return (TestResult.PASS, tolerance, "")
+ msg = "Float result does not match within tolerance of {}".format(tolerance)
+ # Fall-through to below to add failure values
+
+ else:
+ print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
+ msg = "Unsupported results type: {}".format(reference_result.dtype)
+ return (TestResult.MISMATCH, 0.0, msg)
+
+ # Fall-through for mismatch failure to add values to msg
+ print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+ np.set_printoptions(threshold=128)
+ msg = "{}\ntest_result: {}\n{}".format(msg, test_result.shape, test_result)
+ msg = "{}\nreference_result: {}\n{}".format(
+ msg, reference_result.shape, reference_result
+ )
+ return (TestResult.MISMATCH, tolerance, msg)
+
+
+def main(argv=None):
+ """Check that the supplied reference and result files are the same."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "reference_path", type=Path, help="the path to the reference file to test"
+ )
+ parser.add_argument(
+ "result_path", type=Path, help="the path to the result file to test"
+ )
+ args = parser.parse_args(argv)
+ ref_path = args.reference_path
+ res_path = args.result_path
+
+ result, tolerance, msg = test_check(ref_path, res_path)
+ if result != TestResult.PASS:
+ print(msg)
+
+ return result
+
+
+if __name__ == "__main__":
+ exit(main())