diff options
Diffstat (limited to 'verif/generator/tosa_test_select.py')
-rw-r--r-- | verif/generator/tosa_test_select.py | 348 |
1 files changed, 348 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_select.py b/verif/generator/tosa_test_select.py new file mode 100644 index 0000000..5a13178 --- /dev/null +++ b/verif/generator/tosa_test_select.py @@ -0,0 +1,348 @@ +# Copyright (c) 2024, ARM Limited. +# SPDX-License-Identifier: Apache-2.0 +import copy +import logging + +logging.basicConfig() +logger = logging.getLogger("tosa_verif_build_tests") + + +class Test: + """Test container to allow group and permute selection.""" + + def __init__( + self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None + ): + self.opName = opName + self.testStr = testStr + self.dtype = dtype + self.error = error + self.shapeList = shapeList + self.argsDict = argsDict + # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example + self.testOpName = testOpName if testOpName is not None else opName + + self.key = None + self.groupKey = None + self.mark = False + + def __str__(self): + return self.testStr + + def __lt__(self, other): + return self.testStr < str(other) + + def getArg(self, param): + # Get parameter values (arguments) for this test + if param == "rank": + return len(self.shapeList[0]) + elif param == "dtype": + if isinstance(self.dtype, list): + return tuple(self.dtype) + return self.dtype + elif param == "shape" and "shape" not in self.argsDict: + return str(self.shapeList[0]) + + if param in self.argsDict: + # Turn other args into hashable string without newlines + val = str(self.argsDict[param]) + return ",".join(str(val).splitlines()) + else: + return None + + def setKey(self, keyParams): + if self.error is None: + # Create the main key based on primary parameters + key = [self.getArg(param) for param in keyParams] + self.key = tuple(key) + else: + # Use the error as the key + self.key = self.error + return self.key + + def getKey(self): + return self.key + + def setGroupKey(self, groupParams): + # Create the group key based on arguments that do not define the group + # Therefore this test will match other tests that have the same arguments + # that are NOT the group arguments (group arguments like test set number) + paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys())) + key = [] + for param in paramsList: + if param in groupParams: + continue + key.append(self.getArg(param)) + self.groupKey = tuple(key) + return self.groupKey + + def getGroupKey(self): + return self.groupKey + + def inGroup(self, groupKey): + return self.groupKey == groupKey + + def setMark(self): + # Marks the test as important + self.mark = True + + def getMark(self): + return self.mark + + def isError(self): + return self.error is not None + + +def _get_selection_info_from_op(op, selectionCriteria, item, default): + # Get selection info from the op + if ( + "selection" in op + and selectionCriteria in op["selection"] + and item in op["selection"][selectionCriteria] + ): + return op["selection"][selectionCriteria][item] + else: + return default + + +def _get_tests_by_group(tests): + # Create simple structures to record the tests in groups + groups = [] + group_tests = {} + + for test in tests: + key = test.getGroupKey() + if key in group_tests: + group_tests[key].append(test) + else: + group_tests[key] = [test] + groups.append(key) + + # Return list of test groups (group keys) and a dictionary with a list of tests + # associated with each group key + return groups, group_tests + + +def _get_specific_op_info(opName, opSelectionInfo, testOpName): + # Get the op specific section from the selection config + name = opName if opName in opSelectionInfo else testOpName + if name not in opSelectionInfo: + logger.info(f"No op entry found for {opName} in test selection config") + return {} + return opSelectionInfo[name] + + +class TestOpList: + """All the tests for one op grouped by permutations.""" + + def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName): + self.opName = opName + self.testOpName = testOpName + op = _get_specific_op_info(opName, opSelectionInfo, testOpName) + + # See verif/conformance/README.md for more information on + # these selection arguments + self.permuteArgs = _get_selection_info_from_op( + op, selectionCriteria, "permutes", ["rank", "dtype"] + ) + self.paramArgs = _get_selection_info_from_op( + op, selectionCriteria, "full_params", [] + ) + self.specificArgs = _get_selection_info_from_op( + op, selectionCriteria, "specifics", {} + ) + self.groupArgs = _get_selection_info_from_op( + op, selectionCriteria, "groups", ["s"] + ) + self.maximumPerPermute = _get_selection_info_from_op( + op, selectionCriteria, "maximum", None + ) + self.numErrorIfs = _get_selection_info_from_op( + op, selectionCriteria, "num_errorifs", 1 + ) + self.selectAll = _get_selection_info_from_op( + op, selectionCriteria, "all", False + ) + + if self.paramArgs and self.maximumPerPermute > 1: + logger.warning(f"Unsupported - selection params AND maximum for {opName}") + + self.tests = [] + self.testStrings = set() + self.shapes = set() + + self.permutes = set() + self.testsPerPermute = {} + self.paramsPerPermute = {} + self.specificsPerPermute = {} + + self.selectionDone = False + + def __len__(self): + return len(self.tests) + + def add(self, test): + # Add a test to this op group and set up the permutations/group for it + assert test.opName.startswith(self.opName) + if str(test) in self.testStrings: + logger.info(f"Skipping duplicate test: {str(test)}") + return + + self.tests.append(test) + self.testStrings.add(str(test)) + + self.shapes.add(test.getArg("shape")) + + # Work out the permutation key for this test + permute = test.setKey(self.permuteArgs) + # Set up the group key for the test (for pulling out groups during selection) + test.setGroupKey(self.groupArgs) + + if permute not in self.permutes: + # New permutation + self.permutes.add(permute) + # Set up area to record the selected tests + self.testsPerPermute[permute] = [] + if self.paramArgs: + # Set up area to record the unique test params found + self.paramsPerPermute[permute] = {} + for param in self.paramArgs: + self.paramsPerPermute[permute][param] = set() + # Set up copy of the specific test args for selecting these + self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs) + + def _init_select(self): + # Can only perform the selection process once as it alters the permute + # information set at init + assert not self.selectionDone + + # Count of non-specific tests added to each permute (not error) + if not self.selectAll: + countPerPermute = {permute: 0 for permute in self.permutes} + + # Go through each test looking for permutes, unique params & specifics + for test in self.tests: + permute = test.getKey() + append = False + possible_append = False + + if test.isError(): + # Error test, choose up to number of tests + if len(self.testsPerPermute[permute]) < self.numErrorIfs: + append = True + else: + if self.selectAll: + append = True + else: + # See if this is a specific test to add + for param, values in self.specificsPerPermute[permute].items(): + arg = test.getArg(param) + # Iterate over a copy of the values, so we can remove them from the original + if arg in values.copy(): + # Found a match, remove it, so we don't look for it later + values.remove(arg) + # Mark the test as special (and so shouldn't be removed) + test.setMark() + append = True + + if self.paramArgs: + # See if this test contains any new params we should keep + # Perform this check even if we have already selected the test + # so we can record the params found + for param in self.paramArgs: + arg = test.getArg(param) + if arg not in self.paramsPerPermute[permute][param]: + # We have found a new value for this arg, record it + self.paramsPerPermute[permute][param].add(arg) + possible_append = True + else: + # No params set, so possible test to add up to maximum + possible_append = True + + if (not append and possible_append) and ( + self.maximumPerPermute is None + or countPerPermute[permute] < self.maximumPerPermute + ): + # Not selected but could be added and we have space left if + # a maximum is set. + append = True + countPerPermute[permute] += 1 + + # Check for grouping with chosen tests + if not append: + # We will keep any tests together than form a group + key = test.getGroupKey() + for t in self.testsPerPermute[permute]: + if t.getGroupKey() == key: + if t.getMark(): + test.setMark() + append = True + + if append: + self.testsPerPermute[permute].append(test) + + self.selectionDone = True + + def select(self, rng=None): + # Create selection of tests with optional shuffle + if not self.selectionDone: + if rng: + rng.shuffle(self.tests) + + self._init_select() + + # Now create the full list of selected tests per permute + selection = [] + + for permute, tests in self.testsPerPermute.items(): + selection.extend(tests) + + return selection + + def all(self): + # Un-selected list of tests - i.e. all of them + return self.tests + + +class TestList: + """List of all tests grouped by operator.""" + + def __init__(self, opSelectionInfo, selectionCriteria="default"): + self.opLists = {} + self.opSelectionInfo = opSelectionInfo + self.selectionCriteria = selectionCriteria + + def __len__(self): + length = 0 + for opName in self.opLists.keys(): + length += len(self.opLists[opName]) + return length + + def add(self, test): + if test.opName not in self.opLists: + self.opLists[test.opName] = TestOpList( + test.opName, + self.opSelectionInfo, + self.selectionCriteria, + test.testOpName, + ) + self.opLists[test.opName].add(test) + + def _get_tests(self, selectMode, rng): + selection = [] + + for opList in self.opLists.values(): + if selectMode: + tests = opList.select(rng=rng) + else: + tests = opList.all() + selection.extend(tests) + + selection = sorted(selection) + return selection + + def select(self, rng=None): + return self._get_tests(True, rng) + + def all(self): + return self._get_tests(False, None) |