diff options
Diffstat (limited to 'verif/generator/tosa_verif_build_tests.py')
-rw-r--r-- | verif/generator/tosa_verif_build_tests.py | 135 |
1 files changed, 107 insertions, 28 deletions
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py index 8012d93..c32993a 100644 --- a/verif/generator/tosa_verif_build_tests.py +++ b/verif/generator/tosa_verif_build_tests.py @@ -1,17 +1,23 @@ -# Copyright (c) 2020-2023, ARM Limited. +# Copyright (c) 2020-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import argparse +import json +import logging import re import sys from pathlib import Path import conformance.model_files as cmf +import generator.tosa_test_select as tts from generator.tosa_test_gen import TosaTestGen from serializer.tosa_serializer import dtype_str_to_val from serializer.tosa_serializer import DTypeNames OPTION_FP_VALUES_RANGE = "--fp-values-range" +logging.basicConfig() +logger = logging.getLogger("tosa_verif_build_tests") + # Used for parsing a comma-separated list of integers/floats in a string # to an actual list of integers/floats with special case max @@ -58,6 +64,7 @@ def parseArgs(argv): parser = argparse.ArgumentParser() + filter_group = parser.add_argument_group("test filter options") ops_group = parser.add_argument_group("operator options") tens_group = parser.add_argument_group("tensor options") @@ -73,7 +80,7 @@ def parseArgs(argv): help="Random seed for test generation", ) - parser.add_argument( + filter_group.add_argument( "--filter", dest="filter", default="", @@ -82,7 +89,12 @@ def parseArgs(argv): ) parser.add_argument( - "-v", "--verbose", dest="verbose", action="count", help="Verbose operation" + "-v", + "--verbose", + dest="verbose", + action="count", + default=0, + help="Verbose operation", ) parser.add_argument( @@ -226,7 +238,7 @@ def parseArgs(argv): help="Allow constant input tensors for concat operator", ) - parser.add_argument( + filter_group.add_argument( "--test-type", dest="test_type", choices=["positive", "negative", "both"], @@ -235,6 +247,26 @@ def parseArgs(argv): help="type of tests produced, positive, negative, or both", ) + filter_group.add_argument( + "--test-selection-config", + dest="selection_config", + type=Path, + help="enables test selection, this is the path to the JSON test selection config file, will use the default selection specified for each op unless --selection-criteria is supplied", + ) + + filter_group.add_argument( + "--test-selection-criteria", + dest="selection_criteria", + help="enables test selection, this is the selection criteria to use from the selection config", + ) + + parser.add_argument( + "--list-tests", + dest="list_tests", + action="store_true", + help="lists the tests that will be generated and then exits", + ) + ops_group.add_argument( "--allow-pooling-and-conv-oversizes", dest="oversize", @@ -281,6 +313,10 @@ def main(argv=None): args = parseArgs(argv) + loglevels = (logging.WARNING, logging.INFO, logging.DEBUG) + loglevel = loglevels[min(args.verbose, len(loglevels) - 1)] + logger.setLevel(loglevel) + if not args.lazy_data_gen: if args.generate_lib_path is None: args.generate_lib_path = cmf.find_tosa_file( @@ -290,55 +326,98 @@ def main(argv=None): print( f"Argument error: Generate library (--generate-lib-path) not found - {str(args.generate_lib_path)}" ) - exit(2) + return 2 ttg = TosaTestGen(args) + # Determine if test selection mode is enabled or not + selectionMode = ( + args.selection_config is not None or args.selection_criteria is not None + ) + selectionCriteria = ( + "default" if args.selection_criteria is None else args.selection_criteria + ) + if args.selection_config is not None: + # Try loading the selection config + if not args.generate_lib_path.is_file(): + print( + f"Argument error: Test selection config (--test-selection-config) not found {str(args.selection_config)}" + ) + return 2 + with args.selection_config.open("r") as fd: + selectionCfg = json.load(fd) + else: + # Fallback to using anything defined in the TosaTestGen list + # by default this will mean only selecting a tests using a + # permutation of rank by type for each op + selectionCfg = ttg.TOSA_OP_LIST + if args.test_type == "both": testType = ["positive", "negative"] else: testType = [args.test_type] + results = [] for test_type in testType: - testList = [] + testList = tts.TestList(selectionCfg, selectionCriteria=selectionCriteria) try: for opName in ttg.TOSA_OP_LIST: if re.match(args.filter + ".*", opName): - testList.extend( - ttg.genOpTestList( - opName, - shapeFilter=args.target_shapes, - rankFilter=args.target_ranks, - dtypeFilter=args.target_dtypes, - testType=test_type, - ) + tests = ttg.genOpTestList( + opName, + shapeFilter=args.target_shapes, + rankFilter=args.target_ranks, + dtypeFilter=args.target_dtypes, + testType=test_type, ) + for testOpName, testStr, dtype, error, shapeList, argsDict in tests: + if "real_name" in ttg.TOSA_OP_LIST[testOpName]: + name = ttg.TOSA_OP_LIST[testOpName]["real_name"] + else: + name = testOpName + test = tts.Test( + name, testStr, dtype, error, shapeList, argsDict, testOpName + ) + testList.add(test) except Exception as e: - print(f"INTERNAL ERROR: Failure generating test lists for {opName}") + logger.error(f"INTERNAL ERROR: Failure generating test lists for {opName}") raise e - print("{} matching {} tests".format(len(testList), test_type)) + if not selectionMode: + # Allow all tests to be selected + tests = testList.all() + else: + # Use the random number generator to shuffle the test list + # and select the per op tests from it + tests = testList.select(ttg.rng) - testStrings = [] - try: - for opName, testStr, dtype, error, shapeList, argsDict in testList: - # Check for and skip duplicate tests - if testStr in testStrings: - print(f"Skipping duplicate test: {testStr}") - continue - else: - testStrings.append(testStr) + if args.list_tests: + for test in tests: + print(test) + continue + + print(f"{len(tests)} matching {test_type} tests") + try: + for test in tests: + opName = test.testOpName results.append( ttg.serializeTest( - opName, testStr, dtype, error, shapeList, argsDict + opName, + str(test), + test.dtype, + test.error, + test.shapeList, + test.argsDict, ) ) except Exception as e: - print(f"INTERNAL ERROR: Failure creating test output for {opName}") + logger.error(f"INTERNAL ERROR: Failure creating test output for {opName}") raise e - print(f"Done creating {len(results)} tests") + if not args.list_tests: + print(f"Done creating {len(results)} tests") + return 0 if __name__ == "__main__": |