aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-13 18:25:39 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-12 15:31:44 +0000
commitaf09018205f476ab12e3ccfc25523f3f939a2aa3 (patch)
tree777ab4702b011abc48d99c0108e6c9510bf1893b
parent80fd9b8bf8d6def0a4ce6a3c59bdc598fecbd1d1 (diff)
downloadreference_model-af09018205f476ab12e3ccfc25523f3f939a2aa3.tar.gz
Improved test selection before test generation
Add test list output to tosa_verif_build_tests and test list capture to file for tosa_verif_conformance_generator Improve PAD & CONV2D test coverage for tosa-mi conformance Change to use logging for output to hide info from test lists Tweak verbosity levels of tosa_verif_conformance_generator Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ic29da5776b02e9ac610db6ee89d0ebfb4994e055
-rw-r--r--verif/conformance/README.md15
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json86
-rw-r--r--verif/conformance/tosa_verif_conformance_generator.py274
-rw-r--r--verif/generator/tosa_arg_gen.py27
-rw-r--r--verif/generator/tosa_error_if.py10
-rw-r--r--verif/generator/tosa_test_gen.py22
-rw-r--r--verif/generator/tosa_test_select.py348
-rw-r--r--verif/generator/tosa_verif_build_tests.py135
8 files changed, 707 insertions, 210 deletions
diff --git a/verif/conformance/README.md b/verif/conformance/README.md
index 0869ab6..5bededd 100644
--- a/verif/conformance/README.md
+++ b/verif/conformance/README.md
@@ -17,8 +17,10 @@ Each operator entry contains:
* "group" - name of the group this operator is in, in the spec
* "profile" - list of profiles that this operator covers
-* "support_for" - optional list of supported creation modes out of: lazy_data_gen (data generation just before test run)
* "gen_filter" - optional filter string for op to give to tosa_verif_build_tests - defaults to "^opname$"
+* "support_for" - optional list of supported creation modes out of:
+ * lazy_data_gen - data generation just before test run
+ * generator_select - use generator selector instead of conformance test_select
* "generation" - dictionary of test generation details - see below
* "selection" - dictionary of test selection details - see below
@@ -42,7 +44,16 @@ Each selection criteria is a dictionary that contains:
* "all": "true" - to select all tests (and not use test_select)
-or (more information for each entry in `test_select.py`):
+or for operators that have "support_for" "generator_select":
+
+* "permutes" - optional list of parameters whose values are to be permuted, the default is ["rank", "dtype"]
+* "maximum" - optional number - at most "maximum" tests (not including specific tests) will be captured per permuted "permutes" value, effects "full_params" as well
+* "full_params" - optional list of parameter names used to select tests covering a full range of values for these params up to "maximum"
+* "specifics" - optional dictionary of params with lists of values, tests that meet any of these "specifics" will be selected and kept (even using "post_sparsity")
+* "groups" - optional list of parameters that should be considered as a grouping of tests and treated as one test for "sparsity" and "specifics"
+* "num_errorifs" - optional value of error_if tests to keep per error_if case, the default is 1
+
+or for other operators it defaults to the old test select (more information for each entry in `test_select.py`):
* "params" - optional dictionary with mappings of parameter names to the values to select
* "permutes" - optional list of parameter names to be permuted
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 202e7b7..63a2a9c 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -723,7 +723,7 @@
"profile": [
"tosa-mi"
],
- "support_for": [ "lazy_data_gen" ],
+ "support_for": [ "lazy_data_gen", "generator_select" ],
"gen_filter": "^conv2d",
"generation": {
"standard": {
@@ -784,31 +784,11 @@
},
"selection": {
"default": {
- "params": {
- "shape": [],
- "type": [],
- "kernel": [],
- "stride": [],
- "pad": [],
- "dilation": []
- },
- "permutes": [
- "kernel",
- "shape",
- "type",
- "pad"
- ],
- "preselected": [
- {
- "shape": "1x34x19x27",
- "type": "f32xf32",
- "kernel": "3x1",
- "pad": "pad0000",
- "accum_type": "accf32",
- "stride": "st11",
- "dilation": "dilat11"
- }
- ]
+ "permutes": [ "rank", "dtype", "kernel", "acc_type" ],
+ "full_params": [ "stride", "dilation" ],
+ "specifics": { "pad": [ "(0, 0, 0, 0)" ] },
+ "groups": [ "s" ],
+ "maximum": 3
}
}
},
@@ -1901,7 +1881,7 @@
"profile": [
"tosa-mi"
],
- "support_for": [ "lazy_data_gen" ],
+ "support_for": [ "lazy_data_gen", "generator_select" ],
"generation": {
"standard": {
"generator_args": [
@@ -1925,23 +1905,21 @@
],
[
"--target-dtype",
+ "fp32",
+ "--target-dtype",
"fp16",
- "--fp-values-range",
- "-max,max",
- "--tensor-dim-range",
- "1,17",
- "--target-rank",
- "4"
- ],
- [
"--target-dtype",
"bf16",
"--fp-values-range",
"-max,max",
"--tensor-dim-range",
- "1,16",
+ "1,11",
"--target-rank",
- "5"
+ "4",
+ "--target-rank",
+ "5",
+ "--target-rank",
+ "6"
],
[
"--target-dtype",
@@ -1980,30 +1958,16 @@
},
"selection": {
"default": {
- "params": {},
- "permutes": [
- "shape",
- "type"
- ],
- "preselected": [
- {
- "shape": "50",
- "type": "bf16",
- "pad": "pad11"
- },
- {
- "shape": "63x46",
- "type": "bf16",
- "pad": "pad1010"
- },
- {
- "shape": "6",
- "type": "f16",
- "pad": "pad01"
- }
- ],
- "sparsity": {
- "pad": 21
+ "maximum": 5,
+ "specifics": {
+ "pad": [
+ "[[0 0]]",
+ "[[0 0], [0 0]]",
+ "[[0 0], [0 0], [0 0]]",
+ "[[0 0], [0 0], [0 0], [0 0]]",
+ "[[0 0], [0 0], [0 0], [0 0], [0 0]]",
+ "[[0 0], [0 0], [0 0], [0 0], [0 0], [0 0]]"
+ ]
}
}
}
diff --git a/verif/conformance/tosa_verif_conformance_generator.py b/verif/conformance/tosa_verif_conformance_generator.py
index 97aba13..7c82f31 100644
--- a/verif/conformance/tosa_verif_conformance_generator.py
+++ b/verif/conformance/tosa_verif_conformance_generator.py
@@ -70,9 +70,12 @@ class GenConformanceError(Exception):
def _run_sh_command(args, cwd, full_cmd):
"""Run an external command and capture stdout/stderr."""
# Quote the command line for printing
- full_cmd_esc = [shlex.quote(x) for x in full_cmd]
+ try:
+ full_cmd_esc = [shlex.quote(x) for x in full_cmd]
+ except Exception as e:
+ raise Exception(f"Error quoting command: {e}")
if args.capture_output:
- logger.debug(f"Command: {full_cmd_esc}")
+ logger.info(f"Command: {full_cmd_esc}")
rc = subprocess.run(
full_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
@@ -80,7 +83,7 @@ def _run_sh_command(args, cwd, full_cmd):
if args.capture_output:
stdout = rc.stdout.decode("utf-8")
- logger.debug(f"stdout: \n{stdout}")
+ logger.info(f"stdout: \n{stdout}")
if rc.returncode != 0:
raise Exception(
@@ -101,6 +104,7 @@ def build_op_tests(
gen_neg_dim_range,
supports=[],
gen_filter=None,
+ selector_info=None,
):
"""Build tests for a given operator.
@@ -126,9 +130,31 @@ def build_op_tests(
str(args.random_seed),
]
+ if args.verbosity:
+ build_cmd_base.append("-" + ("v" * args.verbosity))
+
+ if args.tests_list_file is not None:
+ build_cmd_base.append("--list-tests")
+
if "lazy_data_gen" in supports and args.lazy_data_generation:
build_cmd_base.append("--lazy-data-generation")
+ if "generator_select" in supports:
+ if selector_info is None:
+ logger.error(
+ "build_op_tests error: generator_select mode without selector information"
+ )
+ raise (GenConformanceError())
+ selector_config, selector_name = selector_info
+ build_cmd_base.extend(
+ [
+ "--test-selection-config",
+ str(selector_config),
+ "--test-selection-criteria",
+ selector_name,
+ ]
+ )
+
build_cmds_list = []
if test_type in ["positive", "both"]:
@@ -159,14 +185,19 @@ def build_op_tests(
build_cmd_neg_test.extend(target_dtypes_args)
build_cmds_list.append(build_cmd_neg_test)
- logger.debug(f"Creating {operator} tests with {len(build_cmds_list)} parameter(s)")
+ logger.info(f"Creating {operator} tests in {len(build_cmds_list)} batch(es)")
error = False
for i, cmd in enumerate(build_cmds_list):
try:
- _run_sh_command(args, args.ref_model_path.parent, cmd)
+ raw_stdout, _ = _run_sh_command(args, args.ref_model_path.parent, cmd)
logger.info(
f"{operator} test batch {(i+1)}/{len(build_cmds_list)} created successfully"
)
+
+ if args.tests_list_file is not None:
+ with args.tests_list_file.open("a") as fd:
+ fd.write(raw_stdout.decode("utf-8"))
+
except Exception as e:
logger.error(
f"{operator} test batch {(i+1)}/{len(build_cmds_list)} unsuccessful, skipping"
@@ -179,20 +210,20 @@ def build_op_tests(
return op_build_dir
-def _check_to_include_test(profile, test_name, exclude_negative_tests=False):
- """Check test name for exclusions, return False to indicate excluded."""
- excludes = ["ERRORIF"] if exclude_negative_tests else []
+def _check_to_include_test(test_type, test_name):
+ """Check test name for inclusion based on test_type, returns True to include."""
- for exclusion in excludes:
- if f"_{exclusion}_" in test_name:
- return False
- return True
+ if test_type == "both":
+ return True
+ else:
+ error_test = "_ERRORIF_" in test_name
+ return (error_test and test_type == "negative") or (
+ not error_test and test_type == "positive"
+ )
-def _get_all_tests_list(
- profile, test_root_dir, operator, exclude_negative_tests=False, include_all=False
-):
- """Create test list based on tests in the test_dir."""
+def _get_all_tests_list(test_type, test_root_dir, operator):
+ """Create test list from tests in the test_dir based on chosen type."""
test_dir = test_root_dir / operator
if not test_dir.is_dir():
# Tests are split into multiple dirs, for example: conv2d_1x1, conv2d_3x3
@@ -209,8 +240,7 @@ def _get_all_tests_list(
[
test
for test in tdir.glob("*")
- if include_all
- or _check_to_include_test(profile, test.name, exclude_negative_tests)
+ if _check_to_include_test(test_type, test.name)
]
)
return tests
@@ -240,16 +270,16 @@ def generate_results(args, profile, operator, op_build_dir, supports=[], tests=N
if not tests:
# Do not need to run ERRORIF tests as they don't have result files
- tests = _get_all_tests_list(
- profile, op_build_dir, operator, exclude_negative_tests=True
- )
+ tests = _get_all_tests_list("positive", op_build_dir, operator)
+ skipped = 0
for test in tests:
desc = test / "desc.json"
with desc.open("r") as fd:
test_desc = json.load(fd)
if "meta" in test_desc and "compliance" in test_desc["meta"]:
- logger.info(
+ skipped += 1
+ logger.debug(
f"Skipping generating results for new compliance test - {str(test)}"
)
continue
@@ -257,6 +287,9 @@ def generate_results(args, profile, operator, op_build_dir, supports=[], tests=N
ref_cmd.append(str(test.absolute()))
ref_cmds.append(ref_cmd)
+ if skipped:
+ logger.info(f"{skipped} new compliance tests skipped for results generation")
+
fail_string = "UNEXPECTED_FAILURE"
failed_counter = 0
@@ -272,7 +305,7 @@ def generate_results(args, profile, operator, op_build_dir, supports=[], tests=N
logger.error(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} failed.")
failed_counter += 1
else:
- logger.info(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} passed.")
+ logger.debug(f"Test {i+1}/{len(ref_cmds)}: {ref_cmds[i][-1]} passed.")
logger.info(f"{len(ref_cmds)-failed_counter}/{len(ref_cmds)} tests passed")
logger.info("Ran tests on model and saved results of passing tests")
@@ -280,6 +313,7 @@ def generate_results(args, profile, operator, op_build_dir, supports=[], tests=N
def convert_tests(
args,
+ test_type,
profile,
operator,
op_build_dir,
@@ -315,13 +349,13 @@ def convert_tests(
c2c_args_list = []
if not tests:
- tests = _get_all_tests_list(profile, op_build_dir, operator)
- logger.info(f"Converting all {profile} profile tests")
+ tests = _get_all_tests_list(test_type, op_build_dir, operator)
+ logger.info(f"Converting all {profile} profile tests of type {test_type}")
# Controls if we copy the tests in their operator sub-directory or not
output_dir_relative_pos = -1 if trim_op_subdir else -2
for test in tests:
- logger.info(f"Test chosen: {test}")
+ logger.debug(f"Test chosen: {test}")
c2c_args = c2c_args_base.copy()
full_output_directory = output_dir / test.relative_to(
*test.parts[:output_dir_relative_pos]
@@ -350,7 +384,7 @@ def convert_tests(
)
failed_counter += 1
else:
- logger.info(
+ logger.debug(
f"test {i+1}/{len(c2c_args_list)}: {c2c_args_list[i][-1]} converted"
)
logger.info(
@@ -394,7 +428,8 @@ def check_op_tests(args, profile, operator, output_dir):
"""Move test folders than contain files larger than 30MB to new directory."""
destination_dir = str(args.output_dir) + "_large_files"
- tests = _get_all_tests_list(profile, output_dir, operator, include_all=True)
+ # Include all tests - both positive and negative
+ tests = _get_all_tests_list("both", output_dir, operator)
if not tests:
logger.error(
f"Couldn't find any tests to size check for {operator} in {output_dir}"
@@ -617,6 +652,12 @@ def parse_args(argv=None):
help="Converts all tests instead of those picked by test_select",
)
parser.add_argument(
+ "--list-tests-to-file",
+ dest="tests_list_file",
+ type=Path,
+ help="Lists out the tests to be generated to a file instead of generating them",
+ )
+ parser.add_argument(
"--keep-large-files",
action="store_true",
help="Keeps tests that contain files larger than 30MB in output directory",
@@ -642,45 +683,6 @@ def parse_args(argv=None):
)
args = parser.parse_args(argv)
- return args
-
-
-def in_version(test_version, gen_dict):
- """Check if the selected test_version is compatible with the tests."""
-
- def version_string_to_numbers(verstr):
- # Turn the "vM.mm.pp" string into Major, Minor, Patch versions
- if verstr == TEST_VERSION_LATEST:
- return (TOSA_VERSION[0], TOSA_VERSION[1], TOSA_VERSION[2])
- else:
- match = re.match(REGEX_VERSION, verstr)
- if match is None:
- raise KeyError(f"Invalid version string {verstr}")
- return (int(v) for v in match.groups())
-
- if "from_version" in gen_dict:
- selected_version = version_string_to_numbers(test_version)
- from_version = version_string_to_numbers(gen_dict["from_version"])
-
- # Check the Major version is compatible, then Minor, and lastly Patch
- # Unless the versions match, we can exit early due to obvious precedence
- for sel, fro in zip(selected_version, from_version):
- if sel < fro:
- # From version is later than selected version
- return False
- elif sel > fro:
- # From version is earlier than selected version
- return True
- # If we get here, the version numbers match exactly
- return True
- else:
- # No specific version info
- return True
-
-
-def main():
- args = parse_args()
-
if args.ref_model_dir is not None:
# Assume the ref model exe path based on the ref model directory
args.ref_model_path = cmf.find_tosa_file(
@@ -690,7 +692,7 @@ def main():
logger.error(
f"Missing reference model binary (--ref-model-path): {args.ref_model_path}"
)
- return 2
+ return None
args.ref_model_path = args.ref_model_path.absolute()
if args.generate_lib_path is None:
@@ -701,7 +703,7 @@ def main():
logger.error(
f"Missing TOSA generate data library (--generate-lib-path): {args.generate_lib_path}"
)
- return 2
+ return None
args.generate_lib_path = args.generate_lib_path.absolute()
if args.schema_path is None:
@@ -712,7 +714,7 @@ def main():
logger.error(
f"Missing reference model schema (--schema-path): {args.schema_path}"
)
- return 2
+ return None
args.schema_path = args.schema_path.absolute()
if args.flatc_path is None:
@@ -721,9 +723,11 @@ def main():
)
if not args.flatc_path.is_file():
logger.error(f"Missing flatc binary (--flatc-path): {args.flatc_path}")
- return 2
+ return None
args.flatc_path = args.flatc_path.absolute()
+ args.param_json_dir = args.param_json_dir.absolute()
+
if args.unit_tests in ["framework", "both"]:
logger.warning(
"DEPRECATION - Framework tests are not part of TOSA conformance testing"
@@ -732,17 +736,65 @@ def main():
logger.error(
"Need to supply location of Framework flatbuffers schema via --framework-schema"
)
- return 2
+ return None
if not args.framework_tests_dir.is_dir():
logger.error(
f"Missing or invalid framework tests directory: {args.framework_tests_dir}"
)
- return 2
+ return None
+
+ return args
+
+
+def in_version(test_version, gen_dict):
+ """Check if the selected test_version is compatible with the tests."""
+
+ def version_string_to_numbers(verstr):
+ # Turn the "vM.mm.pp" string into Major, Minor, Patch versions
+ if verstr == TEST_VERSION_LATEST:
+ return (TOSA_VERSION[0], TOSA_VERSION[1], TOSA_VERSION[2])
+ else:
+ match = re.match(REGEX_VERSION, verstr)
+ if match is None:
+ raise KeyError(f"Invalid version string {verstr}")
+ return (int(v) for v in match.groups())
+
+ if "from_version" in gen_dict:
+ selected_version = version_string_to_numbers(test_version)
+ from_version = version_string_to_numbers(gen_dict["from_version"])
+
+ # Check the Major version is compatible, then Minor, and lastly Patch
+ # Unless the versions match, we can exit early due to obvious precedence
+ for sel, fro in zip(selected_version, from_version):
+ if sel < fro:
+ # From version is later than selected version
+ return False
+ elif sel > fro:
+ # From version is earlier than selected version
+ return True
+ # If we get here, the version numbers match exactly
+ return True
+ else:
+ # No specific version info
+ return True
+
+def _get_log_level(verbosity):
loglevels = (logging.WARNING, logging.INFO, logging.DEBUG)
- loglevel = loglevels[min(args.verbosity, len(loglevels) - 1)]
+ verbosity = max(verbosity, 0)
+ return loglevels[min(verbosity, len(loglevels) - 1)]
+
+
+def main():
+ args = parse_args()
+ if args is None:
+ # Argument processing error
+ return 2
+
+ loglevel = _get_log_level(args.verbosity)
logger.setLevel(loglevel)
- # Set other loggers the same
+ # Set other loggers to a quieter level
+ loglevel = _get_log_level(args.verbosity - 1)
logging.getLogger("test_select").setLevel(loglevel)
logging.getLogger("convert2conformance").setLevel(loglevel)
@@ -757,6 +809,11 @@ def main():
logger.debug(f"Creating build directory: {args.build_dir}")
args.build_dir.mkdir(parents=True, exist_ok=True)
+ if args.tests_list_file is not None:
+ # Try creating tests list file
+ with args.tests_list_file.open("w") as fd:
+ fd.write("")
+
# TODO: For tosa-mi should really generate tosa-bi profile as well
# - for now leave it as subset instead of as superset (for testing)
if args.profile == PROFILES_ALL:
@@ -822,6 +879,7 @@ def main():
)
convert_tests(
args,
+ "positive",
profile,
op,
framework_test_dir,
@@ -846,6 +904,7 @@ def main():
f"Couldn't load operator test params - {test_params_file}: {e}"
)
return 1
+ logger.debug(f"Using config file: {str(test_params_file)}")
operators = args.operators
if not operators:
@@ -913,23 +972,6 @@ def main():
else None
)
- ignore_missing = gen_name != STANDARD_GENERATOR_GROUP
- tags = (
- [gen_name] if gen_name != STANDARD_GENERATOR_GROUP else None
- )
-
- op_build_dir = build_op_tests(
- args,
- test_type,
- profile,
- op,
- gen_name,
- gen_dict["generator_args"],
- gen_neg_dim_range,
- supports=supports,
- gen_filter=gen_filter,
- )
-
# Work out which selection criteria we are using
if "selector" in gen_dict:
selector_name = gen_dict["selector"]
@@ -946,19 +988,38 @@ def main():
)
raise (GenConformanceError())
- # Selection criteria
- selection_config = test_params[op]["selection"][selector_name]
+ op_build_dir = build_op_tests(
+ args,
+ test_type,
+ profile,
+ op,
+ gen_name,
+ gen_dict["generator_args"],
+ gen_neg_dim_range,
+ supports=supports,
+ gen_filter=gen_filter,
+ selector_info=(test_params_file, selector_name),
+ )
+
+ if args.tests_list_file is not None:
+ logger.info("Tests list file extended")
+ continue
- if args.convert_all_tests:
- logger.debug(f"Running and converting all {op} tests")
- generate_results(
- args, profile, op, op_build_dir, supports=supports
- )
+ if args.convert_all_tests or "generator_select" in supports:
+ if test_type in ["positive", "both"]:
+ logger.info(f"Running and converting all {op} tests")
+ generate_results(
+ args, profile, op, op_build_dir, supports=supports
+ )
operator_test_list = None
else:
- logger.debug(
+ logger.info(
f"Running and converting selection of {op} tests"
)
+ # Selection criteria
+ selection_config = test_params[op]["selection"][
+ selector_name
+ ]
if test_type in ["positive", "both"]:
if (
"all" in selection_config
@@ -967,13 +1028,16 @@ def main():
# Just get all the positive tests
tests_gen, tests_gen2 = tee(
_get_all_tests_list(
- profile,
+ "positive",
op_build_dir,
op,
- exclude_negative_tests=True,
)
)
else:
+ ignore_missing = (
+ gen_name != STANDARD_GENERATOR_GROUP
+ )
+
# Get a selection of positive tests
tests_gen, tests_gen2 = tee(
get_op_tests_selection(
@@ -1007,8 +1071,14 @@ def main():
negative=True,
)
)
+
+ tags = (
+ [gen_name] if gen_name != STANDARD_GENERATOR_GROUP else None
+ )
+
output_dir = convert_tests(
args,
+ test_type,
profile,
op,
op_build_dir,
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index c596645..253e8ee 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1,8 +1,8 @@
# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import itertools
+import logging
import math
-import warnings
import generator.tosa_utils as gtu
import numpy as np
@@ -16,6 +16,9 @@ from tosa.ResizeMode import ResizeMode
# DTypeNames, DType, Op and ResizeMode are convenience variables to the
# flatc-generated types that should be enums, but aren't
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
class TosaQuantGen:
"""QuantizedInfo random generator helper functions.
@@ -131,8 +134,9 @@ class TosaQuantGen:
shift = shift + 1
shift = (-shift) + scaleBits
- # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(
- # scaleFp, scaleBits, m, multiplier, shift))
+ logger.debug(
+ f"computeMultiplierAndShift: scalefp={scaleFp} scaleBits={scaleBits} m={m} mult={multiplier} shift={shift}"
+ )
# Adjust multiplier such that shift is in allowed value range.
if shift == 0:
@@ -690,8 +694,9 @@ class TosaTensorValuesGen:
# Invalid data range from low to high created due to user
# constraints revert to using internal ranges as they are
# known to work
- msg = f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
- warnings.warn(msg)
+ logger.info(
+ f"Using safe data range ({low_val} to {high_val}) instead of supplied ({type_range[0]} to {type_range[1]})"
+ )
data_range = (low_val, high_val)
return data_range
return None
@@ -1856,7 +1861,7 @@ class TosaArgGen:
if "shape" in args_dict
else ""
)
- print(
+ logger.info(
f"Skipping {opName}{shape_info} dot product test as too few calculations {dot_products} < {testGen.TOSA_MI_DOT_PRODUCT_MIN}"
)
continue
@@ -2503,7 +2508,7 @@ class TosaArgGen:
arg_list.append((name, args_dict))
if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
- warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
+ logger.info(f"No ErrorIf test created for input shape: {shapeList[0]}")
arg_list = TosaArgGen._add_data_generators(
testGen,
@@ -2683,7 +2688,9 @@ class TosaArgGen:
remainder_w = partial_w % s[1]
output_h = partial_h // s[0] + 1
output_w = partial_w // s[1] + 1
- # debug print(shape, remainder_h, remainder_w, "/", output_h, output_w)
+ logger.debug(
+ f"agPooling: {shape} remainder=({remainder_h}, {remainder_w}) output=({output_h}, {output_w})"
+ )
if (
# the parameters must produce integer exact output
error_name != ErrorIf.PoolingOutputShapeNonInteger
@@ -2920,7 +2927,9 @@ class TosaArgGen:
# Cap the scaling at 2^15 - 1 for scale16
scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), 32767.0)
- # print('{} {} -> {}'.format(out_type_width, in_type_width, scale_arr))
+ logger.debug(
+ f"agRescale: {out_type_width} {in_type_width} -> {scale_arr}"
+ )
multiplier_arr = np.int32(np.zeros(shape=[nc]))
shift_arr = np.int32(np.zeros(shape=[nc]))
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 7a4d0d6..3972edd 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -1,5 +1,6 @@
# Copyright (c) 2021-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import logging
import math
import numpy as np
@@ -11,6 +12,9 @@ from tosa.DType import DType
from tosa.Op import Op
from tosa.ResizeMode import ResizeMode
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
class ErrorIf(object):
MaxDimExceeded = "MaxDimExceeded"
@@ -386,12 +390,12 @@ class TosaErrorValidator:
if expected_result and error_result:
serializer.setExpectedReturnCode(2, True, desc=error_reason)
elif error_result: # and not expected_result
- print(
+ logger.error(
f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
f" Expected: {error_name}, Got: {validator_name}"
)
elif not expected_result: # and not error_result
- print(
+ logger.error(
f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
f" Expected: {error_name}"
)
@@ -401,7 +405,7 @@ class TosaErrorValidator:
if k != "op":
if k.endswith("dtype"):
v = valueToName(DType, v)
- print(f" {k} = {v}")
+ logger.error(f" {k} = {v}")
return overall_result
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 0f68999..e7704f1 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1,6 +1,7 @@
# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import json
+import logging
import os
from copy import deepcopy
from datetime import datetime
@@ -27,6 +28,9 @@ TOSA_AUTOGENERATED_HEADER = f"""// Copyright (c) {datetime.today().year}, ARM Li
// AUTO-GENERATED FILE CREATED BY tosa_verif_build_tests
"""
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
class TosaTestGen:
# Maximum rank of tensor supported by test generator.
@@ -2134,6 +2138,7 @@ class TosaTestGen:
double_round = args_dict["double_round"]
per_channel = args_dict["per_channel"]
shift_arr = args_dict["shift"]
+ multiplier_arr = args_dict["multiplier"]
result_tensor = OutputShaper.typeConversionOp(
self.ser, self.rng, val, out_dtype, error_name
@@ -2203,7 +2208,9 @@ class TosaTestGen:
min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
- # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
+ logger.debug(
+ f"build_rescale: multiplier={multiplier_arr} shift={shift_arr} inzp={input_zp} outzp={output_zp}"
+ )
if scale32 and error_name is None:
# Make sure random values are within apply_scale_32 specification
# REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
@@ -2907,7 +2914,9 @@ class TosaTestGen:
cleanRankFilter = filterDict["rankFilter"]
cleanDtypeFilter = filterDict["dtypeFilter"]
cleanShapeFilter = filterDict["shapeFilter"]
- # print(f"Error: {error_name}, Filters: S {cleanShapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
+ logger.debug(
+ f"genOpTestList: Error={error_name}, Filters S={cleanShapeFilter}, R={cleanRankFilter}, T={cleanDtypeFilter}"
+ )
for r in cleanRankFilter:
for t in cleanDtypeFilter:
@@ -2981,8 +2990,7 @@ class TosaTestGen:
except KeyError:
raise Exception("Cannot find op with name {}".format(opName))
- if self.args.verbose:
- print(f"Creating {testStr}")
+ logger.info(f"Creating {testStr}")
# Create a serializer
self.createSerializer(opName, testStr)
@@ -3062,7 +3070,7 @@ class TosaTestGen:
self.serialize("test", tensMeta)
else:
# The test is not valid
- print(f"Invalid ERROR_IF test created: {opName} {testStr}")
+ logger.error(f"Invalid ERROR_IF test created: {opName} {testStr}")
def createDynamicOpLists(self):
@@ -3084,6 +3092,7 @@ class TosaTestGen:
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv2d_TEMPLATE"].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
+ self.TOSA_OP_LIST[testName]["real_name"] = "conv2d"
testName = "depthwise_conv2d_{}x{}".format(k[0], k[1])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
@@ -3091,6 +3100,7 @@ class TosaTestGen:
].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
+ self.TOSA_OP_LIST[testName]["real_name"] = "depthwise_conv2d"
testName = "transpose_conv2d_{}x{}".format(k[0], k[1])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST[
@@ -3098,12 +3108,14 @@ class TosaTestGen:
].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
+ self.TOSA_OP_LIST[testName]["real_name"] = "transpose_conv2d"
for k in KERNELS_3D:
testName = "conv3d_{}x{}x{}".format(k[0], k[1], k[2])
self.TOSA_OP_LIST[testName] = self.TOSA_OP_LIST["conv3d_TEMPLATE"].copy()
self.TOSA_OP_LIST[testName]["filter"] = k
self.TOSA_OP_LIST[testName]["template"] = False
+ self.TOSA_OP_LIST[testName]["real_name"] = "conv3d"
# Delete any templates after having created any dynamic ops
# This is a two-pass operation because it's bad practice to delete
diff --git a/verif/generator/tosa_test_select.py b/verif/generator/tosa_test_select.py
new file mode 100644
index 0000000..5a13178
--- /dev/null
+++ b/verif/generator/tosa_test_select.py
@@ -0,0 +1,348 @@
+# Copyright (c) 2024, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import copy
+import logging
+
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
+
+class Test:
+ """Test container to allow group and permute selection."""
+
+ def __init__(
+ self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None
+ ):
+ self.opName = opName
+ self.testStr = testStr
+ self.dtype = dtype
+ self.error = error
+ self.shapeList = shapeList
+ self.argsDict = argsDict
+ # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example
+ self.testOpName = testOpName if testOpName is not None else opName
+
+ self.key = None
+ self.groupKey = None
+ self.mark = False
+
+ def __str__(self):
+ return self.testStr
+
+ def __lt__(self, other):
+ return self.testStr < str(other)
+
+ def getArg(self, param):
+ # Get parameter values (arguments) for this test
+ if param == "rank":
+ return len(self.shapeList[0])
+ elif param == "dtype":
+ if isinstance(self.dtype, list):
+ return tuple(self.dtype)
+ return self.dtype
+ elif param == "shape" and "shape" not in self.argsDict:
+ return str(self.shapeList[0])
+
+ if param in self.argsDict:
+ # Turn other args into hashable string without newlines
+ val = str(self.argsDict[param])
+ return ",".join(str(val).splitlines())
+ else:
+ return None
+
+ def setKey(self, keyParams):
+ if self.error is None:
+ # Create the main key based on primary parameters
+ key = [self.getArg(param) for param in keyParams]
+ self.key = tuple(key)
+ else:
+ # Use the error as the key
+ self.key = self.error
+ return self.key
+
+ def getKey(self):
+ return self.key
+
+ def setGroupKey(self, groupParams):
+ # Create the group key based on arguments that do not define the group
+ # Therefore this test will match other tests that have the same arguments
+ # that are NOT the group arguments (group arguments like test set number)
+ paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys()))
+ key = []
+ for param in paramsList:
+ if param in groupParams:
+ continue
+ key.append(self.getArg(param))
+ self.groupKey = tuple(key)
+ return self.groupKey
+
+ def getGroupKey(self):
+ return self.groupKey
+
+ def inGroup(self, groupKey):
+ return self.groupKey == groupKey
+
+ def setMark(self):
+ # Marks the test as important
+ self.mark = True
+
+ def getMark(self):
+ return self.mark
+
+ def isError(self):
+ return self.error is not None
+
+
+def _get_selection_info_from_op(op, selectionCriteria, item, default):
+ # Get selection info from the op
+ if (
+ "selection" in op
+ and selectionCriteria in op["selection"]
+ and item in op["selection"][selectionCriteria]
+ ):
+ return op["selection"][selectionCriteria][item]
+ else:
+ return default
+
+
+def _get_tests_by_group(tests):
+ # Create simple structures to record the tests in groups
+ groups = []
+ group_tests = {}
+
+ for test in tests:
+ key = test.getGroupKey()
+ if key in group_tests:
+ group_tests[key].append(test)
+ else:
+ group_tests[key] = [test]
+ groups.append(key)
+
+ # Return list of test groups (group keys) and a dictionary with a list of tests
+ # associated with each group key
+ return groups, group_tests
+
+
+def _get_specific_op_info(opName, opSelectionInfo, testOpName):
+ # Get the op specific section from the selection config
+ name = opName if opName in opSelectionInfo else testOpName
+ if name not in opSelectionInfo:
+ logger.info(f"No op entry found for {opName} in test selection config")
+ return {}
+ return opSelectionInfo[name]
+
+
+class TestOpList:
+ """All the tests for one op grouped by permutations."""
+
+ def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName):
+ self.opName = opName
+ self.testOpName = testOpName
+ op = _get_specific_op_info(opName, opSelectionInfo, testOpName)
+
+ # See verif/conformance/README.md for more information on
+ # these selection arguments
+ self.permuteArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "permutes", ["rank", "dtype"]
+ )
+ self.paramArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "full_params", []
+ )
+ self.specificArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "specifics", {}
+ )
+ self.groupArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "groups", ["s"]
+ )
+ self.maximumPerPermute = _get_selection_info_from_op(
+ op, selectionCriteria, "maximum", None
+ )
+ self.numErrorIfs = _get_selection_info_from_op(
+ op, selectionCriteria, "num_errorifs", 1
+ )
+ self.selectAll = _get_selection_info_from_op(
+ op, selectionCriteria, "all", False
+ )
+
+ if self.paramArgs and self.maximumPerPermute > 1:
+ logger.warning(f"Unsupported - selection params AND maximum for {opName}")
+
+ self.tests = []
+ self.testStrings = set()
+ self.shapes = set()
+
+ self.permutes = set()
+ self.testsPerPermute = {}
+ self.paramsPerPermute = {}
+ self.specificsPerPermute = {}
+
+ self.selectionDone = False
+
+ def __len__(self):
+ return len(self.tests)
+
+ def add(self, test):
+ # Add a test to this op group and set up the permutations/group for it
+ assert test.opName.startswith(self.opName)
+ if str(test) in self.testStrings:
+ logger.info(f"Skipping duplicate test: {str(test)}")
+ return
+
+ self.tests.append(test)
+ self.testStrings.add(str(test))
+
+ self.shapes.add(test.getArg("shape"))
+
+ # Work out the permutation key for this test
+ permute = test.setKey(self.permuteArgs)
+ # Set up the group key for the test (for pulling out groups during selection)
+ test.setGroupKey(self.groupArgs)
+
+ if permute not in self.permutes:
+ # New permutation
+ self.permutes.add(permute)
+ # Set up area to record the selected tests
+ self.testsPerPermute[permute] = []
+ if self.paramArgs:
+ # Set up area to record the unique test params found
+ self.paramsPerPermute[permute] = {}
+ for param in self.paramArgs:
+ self.paramsPerPermute[permute][param] = set()
+ # Set up copy of the specific test args for selecting these
+ self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs)
+
+ def _init_select(self):
+ # Can only perform the selection process once as it alters the permute
+ # information set at init
+ assert not self.selectionDone
+
+ # Count of non-specific tests added to each permute (not error)
+ if not self.selectAll:
+ countPerPermute = {permute: 0 for permute in self.permutes}
+
+ # Go through each test looking for permutes, unique params & specifics
+ for test in self.tests:
+ permute = test.getKey()
+ append = False
+ possible_append = False
+
+ if test.isError():
+ # Error test, choose up to number of tests
+ if len(self.testsPerPermute[permute]) < self.numErrorIfs:
+ append = True
+ else:
+ if self.selectAll:
+ append = True
+ else:
+ # See if this is a specific test to add
+ for param, values in self.specificsPerPermute[permute].items():
+ arg = test.getArg(param)
+ # Iterate over a copy of the values, so we can remove them from the original
+ if arg in values.copy():
+ # Found a match, remove it, so we don't look for it later
+ values.remove(arg)
+ # Mark the test as special (and so shouldn't be removed)
+ test.setMark()
+ append = True
+
+ if self.paramArgs:
+ # See if this test contains any new params we should keep
+ # Perform this check even if we have already selected the test
+ # so we can record the params found
+ for param in self.paramArgs:
+ arg = test.getArg(param)
+ if arg not in self.paramsPerPermute[permute][param]:
+ # We have found a new value for this arg, record it
+ self.paramsPerPermute[permute][param].add(arg)
+ possible_append = True
+ else:
+ # No params set, so possible test to add up to maximum
+ possible_append = True
+
+ if (not append and possible_append) and (
+ self.maximumPerPermute is None
+ or countPerPermute[permute] < self.maximumPerPermute
+ ):
+ # Not selected but could be added and we have space left if
+ # a maximum is set.
+ append = True
+ countPerPermute[permute] += 1
+
+ # Check for grouping with chosen tests
+ if not append:
+ # We will keep any tests together than form a group
+ key = test.getGroupKey()
+ for t in self.testsPerPermute[permute]:
+ if t.getGroupKey() == key:
+ if t.getMark():
+ test.setMark()
+ append = True
+
+ if append:
+ self.testsPerPermute[permute].append(test)
+
+ self.selectionDone = True
+
+ def select(self, rng=None):
+ # Create selection of tests with optional shuffle
+ if not self.selectionDone:
+ if rng:
+ rng.shuffle(self.tests)
+
+ self._init_select()
+
+ # Now create the full list of selected tests per permute
+ selection = []
+
+ for permute, tests in self.testsPerPermute.items():
+ selection.extend(tests)
+
+ return selection
+
+ def all(self):
+ # Un-selected list of tests - i.e. all of them
+ return self.tests
+
+
+class TestList:
+ """List of all tests grouped by operator."""
+
+ def __init__(self, opSelectionInfo, selectionCriteria="default"):
+ self.opLists = {}
+ self.opSelectionInfo = opSelectionInfo
+ self.selectionCriteria = selectionCriteria
+
+ def __len__(self):
+ length = 0
+ for opName in self.opLists.keys():
+ length += len(self.opLists[opName])
+ return length
+
+ def add(self, test):
+ if test.opName not in self.opLists:
+ self.opLists[test.opName] = TestOpList(
+ test.opName,
+ self.opSelectionInfo,
+ self.selectionCriteria,
+ test.testOpName,
+ )
+ self.opLists[test.opName].add(test)
+
+ def _get_tests(self, selectMode, rng):
+ selection = []
+
+ for opList in self.opLists.values():
+ if selectMode:
+ tests = opList.select(rng=rng)
+ else:
+ tests = opList.all()
+ selection.extend(tests)
+
+ selection = sorted(selection)
+ return selection
+
+ def select(self, rng=None):
+ return self._get_tests(True, rng)
+
+ def all(self):
+ return self._get_tests(False, None)
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 8012d93..c32993a 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -1,17 +1,23 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
+import json
+import logging
import re
import sys
from pathlib import Path
import conformance.model_files as cmf
+import generator.tosa_test_select as tts
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
from serializer.tosa_serializer import DTypeNames
OPTION_FP_VALUES_RANGE = "--fp-values-range"
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
# Used for parsing a comma-separated list of integers/floats in a string
# to an actual list of integers/floats with special case max
@@ -58,6 +64,7 @@ def parseArgs(argv):
parser = argparse.ArgumentParser()
+ filter_group = parser.add_argument_group("test filter options")
ops_group = parser.add_argument_group("operator options")
tens_group = parser.add_argument_group("tensor options")
@@ -73,7 +80,7 @@ def parseArgs(argv):
help="Random seed for test generation",
)
- parser.add_argument(
+ filter_group.add_argument(
"--filter",
dest="filter",
default="",
@@ -82,7 +89,12 @@ def parseArgs(argv):
)
parser.add_argument(
- "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
+ "-v",
+ "--verbose",
+ dest="verbose",
+ action="count",
+ default=0,
+ help="Verbose operation",
)
parser.add_argument(
@@ -226,7 +238,7 @@ def parseArgs(argv):
help="Allow constant input tensors for concat operator",
)
- parser.add_argument(
+ filter_group.add_argument(
"--test-type",
dest="test_type",
choices=["positive", "negative", "both"],
@@ -235,6 +247,26 @@ def parseArgs(argv):
help="type of tests produced, positive, negative, or both",
)
+ filter_group.add_argument(
+ "--test-selection-config",
+ dest="selection_config",
+ type=Path,
+ help="enables test selection, this is the path to the JSON test selection config file, will use the default selection specified for each op unless --selection-criteria is supplied",
+ )
+
+ filter_group.add_argument(
+ "--test-selection-criteria",
+ dest="selection_criteria",
+ help="enables test selection, this is the selection criteria to use from the selection config",
+ )
+
+ parser.add_argument(
+ "--list-tests",
+ dest="list_tests",
+ action="store_true",
+ help="lists the tests that will be generated and then exits",
+ )
+
ops_group.add_argument(
"--allow-pooling-and-conv-oversizes",
dest="oversize",
@@ -281,6 +313,10 @@ def main(argv=None):
args = parseArgs(argv)
+ loglevels = (logging.WARNING, logging.INFO, logging.DEBUG)
+ loglevel = loglevels[min(args.verbose, len(loglevels) - 1)]
+ logger.setLevel(loglevel)
+
if not args.lazy_data_gen:
if args.generate_lib_path is None:
args.generate_lib_path = cmf.find_tosa_file(
@@ -290,55 +326,98 @@ def main(argv=None):
print(
f"Argument error: Generate library (--generate-lib-path) not found - {str(args.generate_lib_path)}"
)
- exit(2)
+ return 2
ttg = TosaTestGen(args)
+ # Determine if test selection mode is enabled or not
+ selectionMode = (
+ args.selection_config is not None or args.selection_criteria is not None
+ )
+ selectionCriteria = (
+ "default" if args.selection_criteria is None else args.selection_criteria
+ )
+ if args.selection_config is not None:
+ # Try loading the selection config
+ if not args.generate_lib_path.is_file():
+ print(
+ f"Argument error: Test selection config (--test-selection-config) not found {str(args.selection_config)}"
+ )
+ return 2
+ with args.selection_config.open("r") as fd:
+ selectionCfg = json.load(fd)
+ else:
+ # Fallback to using anything defined in the TosaTestGen list
+ # by default this will mean only selecting a tests using a
+ # permutation of rank by type for each op
+ selectionCfg = ttg.TOSA_OP_LIST
+
if args.test_type == "both":
testType = ["positive", "negative"]
else:
testType = [args.test_type]
+
results = []
for test_type in testType:
- testList = []
+ testList = tts.TestList(selectionCfg, selectionCriteria=selectionCriteria)
try:
for opName in ttg.TOSA_OP_LIST:
if re.match(args.filter + ".*", opName):
- testList.extend(
- ttg.genOpTestList(
- opName,
- shapeFilter=args.target_shapes,
- rankFilter=args.target_ranks,
- dtypeFilter=args.target_dtypes,
- testType=test_type,
- )
+ tests = ttg.genOpTestList(
+ opName,
+ shapeFilter=args.target_shapes,
+ rankFilter=args.target_ranks,
+ dtypeFilter=args.target_dtypes,
+ testType=test_type,
)
+ for testOpName, testStr, dtype, error, shapeList, argsDict in tests:
+ if "real_name" in ttg.TOSA_OP_LIST[testOpName]:
+ name = ttg.TOSA_OP_LIST[testOpName]["real_name"]
+ else:
+ name = testOpName
+ test = tts.Test(
+ name, testStr, dtype, error, shapeList, argsDict, testOpName
+ )
+ testList.add(test)
except Exception as e:
- print(f"INTERNAL ERROR: Failure generating test lists for {opName}")
+ logger.error(f"INTERNAL ERROR: Failure generating test lists for {opName}")
raise e
- print("{} matching {} tests".format(len(testList), test_type))
+ if not selectionMode:
+ # Allow all tests to be selected
+ tests = testList.all()
+ else:
+ # Use the random number generator to shuffle the test list
+ # and select the per op tests from it
+ tests = testList.select(ttg.rng)
- testStrings = []
- try:
- for opName, testStr, dtype, error, shapeList, argsDict in testList:
- # Check for and skip duplicate tests
- if testStr in testStrings:
- print(f"Skipping duplicate test: {testStr}")
- continue
- else:
- testStrings.append(testStr)
+ if args.list_tests:
+ for test in tests:
+ print(test)
+ continue
+
+ print(f"{len(tests)} matching {test_type} tests")
+ try:
+ for test in tests:
+ opName = test.testOpName
results.append(
ttg.serializeTest(
- opName, testStr, dtype, error, shapeList, argsDict
+ opName,
+ str(test),
+ test.dtype,
+ test.error,
+ test.shapeList,
+ test.argsDict,
)
)
except Exception as e:
- print(f"INTERNAL ERROR: Failure creating test output for {opName}")
+ logger.error(f"INTERNAL ERROR: Failure creating test output for {opName}")
raise e
- print(f"Done creating {len(results)} tests")
+ if not args.list_tests:
+ print(f"Done creating {len(results)} tests")
+ return 0
if __name__ == "__main__":