aboutsummaryrefslogtreecommitdiff
path: root/verif/checker
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-09-14 17:02:09 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2023-10-02 12:04:44 +0100
commite2b5e87804e158cb3e5d06a131c317b3890b87b3 (patch)
treefd8b5a4d56dfcea4be4e6ced73f2d4d5b2e1d92d /verif/checker
parentbb0935f868a5ab09403cf3628848655b06ac1dec (diff)
downloadreference_model-e2b5e87804e158cb3e5d06a131c317b3890b87b3.tar.gz
Support for compliance checking testing
Updated to conformance generator to not generate tests with results for compliance tests. Updated test runner to run compliance mode version (precise & abs mode) of reference model to create test results to use against SUT results. Updated reference model to enable abs_mode on correct desc.json flags. Updated test checker to support compliance checking using verifier lib. Seperated color printing from test checker. Change-Id: I7e2fbfc6883916caa5d94d4ece122c48bf45f530 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'verif/checker')
-rw-r--r--verif/checker/color_print.py33
-rw-r--r--verif/checker/tosa_result_checker.py231
2 files changed, 197 insertions, 67 deletions
diff --git a/verif/checker/color_print.py b/verif/checker/color_print.py
new file mode 100644
index 0000000..1563b92
--- /dev/null
+++ b/verif/checker/color_print.py
@@ -0,0 +1,33 @@
+"""Color printing module."""
+# Copyright (c) 2020-2023, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+from enum import Enum
+from enum import unique
+
+color_printing = True
+
+
+@unique
+class LogColors(Enum):
+ """Shell escape sequence colors for logging."""
+
+ NONE = "\u001b[0m"
+ GREEN = "\u001b[32;1m"
+ RED = "\u001b[31;1m"
+ YELLOW = "\u001b[33;1m"
+ BOLD_WHITE = "\u001b[1m"
+
+
+def set_print_in_color(enabled):
+ """Set color printing to enabled or disabled."""
+ global color_printing
+ color_printing = enabled
+
+
+def print_color(color, msg):
+ """Print color status messages if enabled."""
+ global color_printing
+ if not color_printing:
+ print(msg)
+ else:
+ print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 1169a95..38ed510 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -1,43 +1,19 @@
"""TOSA result checker script."""
-# Copyright (c) 2020-2022, ARM Limited.
+# Copyright (c) 2020-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
-from enum import Enum
+import json
from enum import IntEnum
from enum import unique
from pathlib import Path
import numpy as np
+from checker.color_print import LogColors
+from checker.color_print import print_color
+from checker.verifier import VerifierError
+from checker.verifier import VerifierLibrary
from generator.tosa_utils import float32_is_valid_bfloat16
-
-##################################
-color_printing = True
-
-
-@unique
-class LogColors(Enum):
- """Shell escape sequence colors for logging."""
-
- NONE = "\u001b[0m"
- GREEN = "\u001b[32;1m"
- RED = "\u001b[31;1m"
- YELLOW = "\u001b[33;1m"
- BOLD_WHITE = "\u001b[1m"
-
-
-def set_print_in_color(enabled):
- """Set color printing to enabled or disabled."""
- global color_printing
- color_printing = enabled
-
-
-def print_color(color, msg):
- """Print color status messages if enabled."""
- global color_printing
- if not color_printing:
- print(msg)
- else:
- print("{}{}{}".format(color.value, msg, LogColors.NONE.value))
+from schemavalidation.schemavalidation import TestDescSchemaValidator
@unique
@@ -62,46 +38,120 @@ TestResultErrorStr = [
##################################
DEFAULT_FP_TOLERANCE = 1e-3
+result_printing = True
+
+
+def set_print_result(enabled):
+ """Set whether to print out or not."""
+ global result_printing
+ result_printing = enabled
+
+
+def _print_result(color, msg):
+ """Print out result."""
+ global result_printing
+ if result_printing:
+ print_color(color, msg)
+
+
+def compliance_check(
+ imp_result_path,
+ ref_result_path,
+ bnd_result_path,
+ test_name,
+ compliance_config,
+ ofm_name,
+ verify_lib_path,
+):
+ try:
+ vlib = VerifierLibrary(verify_lib_path)
+ except VerifierError as e:
+ _print_result(LogColors.RED, f"INTERNAL ERROR {test_name}")
+ msg = f"Could not load verfier library: {str(e)}"
+ return (TestResult.INTERNAL_ERROR, 0.0, msg)
+
+ success = vlib.verify_data(
+ ofm_name, compliance_config, imp_result_path, ref_result_path, bnd_result_path
+ )
+ if success:
+ _print_result(LogColors.GREEN, f"Results PASS {test_name}")
+ return (TestResult.PASS, 0.0, "")
+ else:
+ _print_result(LogColors.RED, f"Results NON-COMPLIANT {test_name}")
+ return (TestResult.MISMATCH, 0.0, "Non-compliance implementation results found")
def test_check(
- reference_path,
- result_path,
- test_name="test",
+ ref_result_path,
+ imp_result_path,
+ test_name=None,
quantize_tolerance=0,
float_tolerance=DEFAULT_FP_TOLERANCE,
misc_checks=[],
+ test_desc=None,
+ bnd_result_path=None,
+ ofm_name=None,
+ verify_lib_path=None,
):
"""Check if the result is the same as the expected reference."""
- if not reference_path.is_file():
- print_color(LogColors.RED, "Reference MISSING FILE {}".format(test_name))
- msg = "Missing reference file: {}".format(reference_path)
- return (TestResult.MISSING_FILE, 0.0, msg)
- if not result_path.is_file():
- print_color(LogColors.RED, "Results MISSING FILE {}".format(test_name))
- msg = "Missing result file: {}".format(result_path)
- return (TestResult.MISSING_FILE, 0.0, msg)
+ if test_desc:
+ # New compliance method - first get test details
+ try:
+ TestDescSchemaValidator().validate_config(test_desc)
+ except Exception as e:
+ _print_result(LogColors.RED, f"Test INCORRECT FORMAT {test_name}")
+ msg = f"Incorrect test format: {e}"
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
- try:
- test_result = np.load(result_path)
- except Exception as e:
- print_color(LogColors.RED, "Results INCORRECT FORMAT {}".format(test_name))
- msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
- result_path, e
- )
- return (TestResult.INCORRECT_FORMAT, 0.0, msg)
- try:
- reference_result = np.load(reference_path)
- except Exception as e:
- print_color(LogColors.RED, "Reference INCORRECT FORMAT {}".format(test_name))
- msg = "Incorrect numpy format of {}\nnumpy.load exception: {}".format(
- reference_path, e
+ if test_name is None:
+ test_name = "test"
+
+ paths = [imp_result_path, ref_result_path, bnd_result_path]
+ names = ["Implementation", "Reference", "Bounds"]
+ arrays = [None, None, None]
+
+ # Check the files exist and are in the right format
+ for idx, path in enumerate(paths):
+ name = names[idx]
+ if path is None and name == "Bounds":
+ # Bounds can be None - skip it
+ continue
+ if not path.is_file():
+ _print_result(LogColors.RED, f"{name} MISSING FILE {test_name}")
+ msg = f"Missing {name} file: {str(path)}"
+ return (TestResult.MISSING_FILE, 0.0, msg)
+ try:
+ arrays[idx] = np.load(path)
+ except Exception as e:
+ _print_result(LogColors.RED, f"{name} INCORRECT FORMAT {test_name}")
+ msg = f"Incorrect numpy format of {str(path)}\nnumpy.load exception: {e}"
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
+ if test_desc and "meta" in test_desc and "compliance" in test_desc["meta"]:
+ # Switch to using the verifier library for full compliance
+ if ofm_name is None:
+ ofm_name = test_desc["ofm_name"][0]
+ if len(test_desc["ofm_name"]) > 1:
+ _print_result(LogColors.RED, f"Output Name MISSING FILE {test_name}")
+ msg = "Must specify output name (ofm_name) to check as multiple found in desc.json"
+ return (TestResult.MISSING_FILE, 0.0, msg)
+
+ compliance_json = test_desc["meta"]["compliance"]
+
+ return compliance_check(
+ *arrays,
+ test_name,
+ compliance_json,
+ ofm_name,
+ verify_lib_path,
)
- return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
+ # Else continue with original checking method
+ test_result, reference_result, _ = arrays
# Type comparison
if test_result.dtype != reference_result.dtype:
- print_color(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
+ _print_result(LogColors.RED, "Results TYPE MISMATCH {}".format(test_name))
msg = "Mismatch results type: Expected {}, got {}".format(
reference_result.dtype, test_result.dtype
)
@@ -115,7 +165,7 @@ def test_check(
difference = None
if np.shape(test_result) != np.shape(reference_result):
- print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+ _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
msg = "Shapes mismatch: Reference {} vs {}".format(
np.shape(test_result), np.shape(reference_result)
)
@@ -139,7 +189,7 @@ def test_check(
if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
if np.all(np.absolute(reference_result - test_result) <= quantize_tolerance):
- print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+ _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, 0.0, "")
else:
tolerance = quantize_tolerance + 1
@@ -166,7 +216,7 @@ def test_check(
# All boolean values must match, xor will show up differences
test = np.array_equal(reference_result, test_result)
if np.all(test):
- print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+ _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, 0.0, "")
msg = "Boolean result does not match"
tolerance = 0.0
@@ -177,18 +227,18 @@ def test_check(
elif reference_result.dtype == np.float32 or reference_result.dtype == np.float16:
tolerance = float_tolerance
if np.allclose(reference_result, test_result, atol=tolerance, equal_nan=True):
- print_color(LogColors.GREEN, "Results PASS {}".format(test_name))
+ _print_result(LogColors.GREEN, "Results PASS {}".format(test_name))
return (TestResult.PASS, tolerance, "")
msg = "Float result does not match within tolerance of {}".format(tolerance)
difference = reference_result - test_result
# Fall-through to below to add failure values
else:
- print_color(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
+ _print_result(LogColors.RED, "Results UNSUPPORTED TYPE {}".format(test_name))
msg = "Unsupported results type: {}".format(reference_result.dtype)
return (TestResult.MISMATCH, 0.0, msg)
# Fall-through for mismatch failure to add values to msg
- print_color(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
+ _print_result(LogColors.RED, "Results MISCOMPARE {}".format(test_name))
np.set_printoptions(threshold=128, edgeitems=2)
if difference is not None:
@@ -209,18 +259,65 @@ def main(argv=None):
"""Check that the supplied reference and result files have the same contents."""
parser = argparse.ArgumentParser()
parser.add_argument(
- "reference_path", type=Path, help="the path to the reference file to test"
+ "ref_result_path",
+ type=Path,
+ help="path to the reference model result file to check",
)
parser.add_argument(
- "result_path", type=Path, help="the path to the result file to test"
+ "imp_result_path",
+ type=Path,
+ help="path to the implementation result file to check",
)
parser.add_argument(
"--fp-tolerance", type=float, default=DEFAULT_FP_TOLERANCE, help="FP tolerance"
)
+ parser.add_argument(
+ "--test_path", type=Path, help="path to the test that produced the results"
+ )
+ parser.add_argument(
+ "--bnd-result-path",
+ type=Path,
+ help="path to the reference model bounds result file for the dot product compliance check",
+ )
+ parser.add_argument(
+ "--ofm-name",
+ type=str,
+ help="name of the output tensor to check, defaults to the first ofm_name listed in the test",
+ )
+ parser.add_argument(
+ "--verify-lib-path",
+ type=Path,
+ help="path to TOSA verify library",
+ )
args = parser.parse_args(argv)
+ if args.test_path:
+ # Get details from the test path
+ test_desc_path = args.test_path / "desc.json"
+ if not args.test_path.is_dir() or not test_desc_path.is_file():
+ print(f"Invalid test directory {str(args.test_path)}")
+ return TestResult.MISSING_FILE
+
+ try:
+ with test_desc_path.open("r") as fd:
+ test_desc = json.load(fd)
+ except Exception as e:
+ print(f"Invalid test description file {str(test_desc_path)}: {e}")
+ return TestResult.INCORRECT_FORMAT
+ test_name = args.test_path.name
+ else:
+ test_desc = None
+ test_name = None
+
result, tolerance, msg = test_check(
- args.reference_path, args.result_path, float_tolerance=args.fp_tolerance
+ args.ref_result_path,
+ args.imp_result_path,
+ float_tolerance=args.fp_tolerance,
+ test_name=test_name,
+ test_desc=test_desc,
+ bnd_result_path=args.bnd_result_path,
+ ofm_name=args.ofm_name,
+ verify_lib_path=args.verify_lib_path,
)
if result != TestResult.PASS:
print(msg)