aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_verif_build_tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_verif_build_tests.py')
-rw-r--r--verif/generator/tosa_verif_build_tests.py135
1 files changed, 107 insertions, 28 deletions
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 8012d93..c32993a 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -1,17 +1,23 @@
-# Copyright (c) 2020-2023, ARM Limited.
+# Copyright (c) 2020-2024, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import argparse
+import json
+import logging
import re
import sys
from pathlib import Path
import conformance.model_files as cmf
+import generator.tosa_test_select as tts
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
from serializer.tosa_serializer import DTypeNames
OPTION_FP_VALUES_RANGE = "--fp-values-range"
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
# Used for parsing a comma-separated list of integers/floats in a string
# to an actual list of integers/floats with special case max
@@ -58,6 +64,7 @@ def parseArgs(argv):
parser = argparse.ArgumentParser()
+ filter_group = parser.add_argument_group("test filter options")
ops_group = parser.add_argument_group("operator options")
tens_group = parser.add_argument_group("tensor options")
@@ -73,7 +80,7 @@ def parseArgs(argv):
help="Random seed for test generation",
)
- parser.add_argument(
+ filter_group.add_argument(
"--filter",
dest="filter",
default="",
@@ -82,7 +89,12 @@ def parseArgs(argv):
)
parser.add_argument(
- "-v", "--verbose", dest="verbose", action="count", help="Verbose operation"
+ "-v",
+ "--verbose",
+ dest="verbose",
+ action="count",
+ default=0,
+ help="Verbose operation",
)
parser.add_argument(
@@ -226,7 +238,7 @@ def parseArgs(argv):
help="Allow constant input tensors for concat operator",
)
- parser.add_argument(
+ filter_group.add_argument(
"--test-type",
dest="test_type",
choices=["positive", "negative", "both"],
@@ -235,6 +247,26 @@ def parseArgs(argv):
help="type of tests produced, positive, negative, or both",
)
+ filter_group.add_argument(
+ "--test-selection-config",
+ dest="selection_config",
+ type=Path,
+ help="enables test selection, this is the path to the JSON test selection config file, will use the default selection specified for each op unless --selection-criteria is supplied",
+ )
+
+ filter_group.add_argument(
+ "--test-selection-criteria",
+ dest="selection_criteria",
+ help="enables test selection, this is the selection criteria to use from the selection config",
+ )
+
+ parser.add_argument(
+ "--list-tests",
+ dest="list_tests",
+ action="store_true",
+ help="lists the tests that will be generated and then exits",
+ )
+
ops_group.add_argument(
"--allow-pooling-and-conv-oversizes",
dest="oversize",
@@ -281,6 +313,10 @@ def main(argv=None):
args = parseArgs(argv)
+ loglevels = (logging.WARNING, logging.INFO, logging.DEBUG)
+ loglevel = loglevels[min(args.verbose, len(loglevels) - 1)]
+ logger.setLevel(loglevel)
+
if not args.lazy_data_gen:
if args.generate_lib_path is None:
args.generate_lib_path = cmf.find_tosa_file(
@@ -290,55 +326,98 @@ def main(argv=None):
print(
f"Argument error: Generate library (--generate-lib-path) not found - {str(args.generate_lib_path)}"
)
- exit(2)
+ return 2
ttg = TosaTestGen(args)
+ # Determine if test selection mode is enabled or not
+ selectionMode = (
+ args.selection_config is not None or args.selection_criteria is not None
+ )
+ selectionCriteria = (
+ "default" if args.selection_criteria is None else args.selection_criteria
+ )
+ if args.selection_config is not None:
+ # Try loading the selection config
+ if not args.generate_lib_path.is_file():
+ print(
+ f"Argument error: Test selection config (--test-selection-config) not found {str(args.selection_config)}"
+ )
+ return 2
+ with args.selection_config.open("r") as fd:
+ selectionCfg = json.load(fd)
+ else:
+ # Fallback to using anything defined in the TosaTestGen list
+ # by default this will mean only selecting a tests using a
+ # permutation of rank by type for each op
+ selectionCfg = ttg.TOSA_OP_LIST
+
if args.test_type == "both":
testType = ["positive", "negative"]
else:
testType = [args.test_type]
+
results = []
for test_type in testType:
- testList = []
+ testList = tts.TestList(selectionCfg, selectionCriteria=selectionCriteria)
try:
for opName in ttg.TOSA_OP_LIST:
if re.match(args.filter + ".*", opName):
- testList.extend(
- ttg.genOpTestList(
- opName,
- shapeFilter=args.target_shapes,
- rankFilter=args.target_ranks,
- dtypeFilter=args.target_dtypes,
- testType=test_type,
- )
+ tests = ttg.genOpTestList(
+ opName,
+ shapeFilter=args.target_shapes,
+ rankFilter=args.target_ranks,
+ dtypeFilter=args.target_dtypes,
+ testType=test_type,
)
+ for testOpName, testStr, dtype, error, shapeList, argsDict in tests:
+ if "real_name" in ttg.TOSA_OP_LIST[testOpName]:
+ name = ttg.TOSA_OP_LIST[testOpName]["real_name"]
+ else:
+ name = testOpName
+ test = tts.Test(
+ name, testStr, dtype, error, shapeList, argsDict, testOpName
+ )
+ testList.add(test)
except Exception as e:
- print(f"INTERNAL ERROR: Failure generating test lists for {opName}")
+ logger.error(f"INTERNAL ERROR: Failure generating test lists for {opName}")
raise e
- print("{} matching {} tests".format(len(testList), test_type))
+ if not selectionMode:
+ # Allow all tests to be selected
+ tests = testList.all()
+ else:
+ # Use the random number generator to shuffle the test list
+ # and select the per op tests from it
+ tests = testList.select(ttg.rng)
- testStrings = []
- try:
- for opName, testStr, dtype, error, shapeList, argsDict in testList:
- # Check for and skip duplicate tests
- if testStr in testStrings:
- print(f"Skipping duplicate test: {testStr}")
- continue
- else:
- testStrings.append(testStr)
+ if args.list_tests:
+ for test in tests:
+ print(test)
+ continue
+
+ print(f"{len(tests)} matching {test_type} tests")
+ try:
+ for test in tests:
+ opName = test.testOpName
results.append(
ttg.serializeTest(
- opName, testStr, dtype, error, shapeList, argsDict
+ opName,
+ str(test),
+ test.dtype,
+ test.error,
+ test.shapeList,
+ test.argsDict,
)
)
except Exception as e:
- print(f"INTERNAL ERROR: Failure creating test output for {opName}")
+ logger.error(f"INTERNAL ERROR: Failure creating test output for {opName}")
raise e
- print(f"Done creating {len(results)} tests")
+ if not args.list_tests:
+ print(f"Done creating {len(results)} tests")
+ return 0
if __name__ == "__main__":