aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_select.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_select.py')
-rw-r--r--verif/generator/tosa_test_select.py348
1 files changed, 348 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_select.py b/verif/generator/tosa_test_select.py
new file mode 100644
index 0000000..5a13178
--- /dev/null
+++ b/verif/generator/tosa_test_select.py
@@ -0,0 +1,348 @@
+# Copyright (c) 2024, ARM Limited.
+# SPDX-License-Identifier: Apache-2.0
+import copy
+import logging
+
+logging.basicConfig()
+logger = logging.getLogger("tosa_verif_build_tests")
+
+
+class Test:
+ """Test container to allow group and permute selection."""
+
+ def __init__(
+ self, opName, testStr, dtype, error, shapeList, argsDict, testOpName=None
+ ):
+ self.opName = opName
+ self.testStr = testStr
+ self.dtype = dtype
+ self.error = error
+ self.shapeList = shapeList
+ self.argsDict = argsDict
+ # Given test op name used for look up in TOSA_OP_LIST for "conv2d_1x1" for example
+ self.testOpName = testOpName if testOpName is not None else opName
+
+ self.key = None
+ self.groupKey = None
+ self.mark = False
+
+ def __str__(self):
+ return self.testStr
+
+ def __lt__(self, other):
+ return self.testStr < str(other)
+
+ def getArg(self, param):
+ # Get parameter values (arguments) for this test
+ if param == "rank":
+ return len(self.shapeList[0])
+ elif param == "dtype":
+ if isinstance(self.dtype, list):
+ return tuple(self.dtype)
+ return self.dtype
+ elif param == "shape" and "shape" not in self.argsDict:
+ return str(self.shapeList[0])
+
+ if param in self.argsDict:
+ # Turn other args into hashable string without newlines
+ val = str(self.argsDict[param])
+ return ",".join(str(val).splitlines())
+ else:
+ return None
+
+ def setKey(self, keyParams):
+ if self.error is None:
+ # Create the main key based on primary parameters
+ key = [self.getArg(param) for param in keyParams]
+ self.key = tuple(key)
+ else:
+ # Use the error as the key
+ self.key = self.error
+ return self.key
+
+ def getKey(self):
+ return self.key
+
+ def setGroupKey(self, groupParams):
+ # Create the group key based on arguments that do not define the group
+ # Therefore this test will match other tests that have the same arguments
+ # that are NOT the group arguments (group arguments like test set number)
+ paramsList = sorted(["shape", "dtype"] + list(self.argsDict.keys()))
+ key = []
+ for param in paramsList:
+ if param in groupParams:
+ continue
+ key.append(self.getArg(param))
+ self.groupKey = tuple(key)
+ return self.groupKey
+
+ def getGroupKey(self):
+ return self.groupKey
+
+ def inGroup(self, groupKey):
+ return self.groupKey == groupKey
+
+ def setMark(self):
+ # Marks the test as important
+ self.mark = True
+
+ def getMark(self):
+ return self.mark
+
+ def isError(self):
+ return self.error is not None
+
+
+def _get_selection_info_from_op(op, selectionCriteria, item, default):
+ # Get selection info from the op
+ if (
+ "selection" in op
+ and selectionCriteria in op["selection"]
+ and item in op["selection"][selectionCriteria]
+ ):
+ return op["selection"][selectionCriteria][item]
+ else:
+ return default
+
+
+def _get_tests_by_group(tests):
+ # Create simple structures to record the tests in groups
+ groups = []
+ group_tests = {}
+
+ for test in tests:
+ key = test.getGroupKey()
+ if key in group_tests:
+ group_tests[key].append(test)
+ else:
+ group_tests[key] = [test]
+ groups.append(key)
+
+ # Return list of test groups (group keys) and a dictionary with a list of tests
+ # associated with each group key
+ return groups, group_tests
+
+
+def _get_specific_op_info(opName, opSelectionInfo, testOpName):
+ # Get the op specific section from the selection config
+ name = opName if opName in opSelectionInfo else testOpName
+ if name not in opSelectionInfo:
+ logger.info(f"No op entry found for {opName} in test selection config")
+ return {}
+ return opSelectionInfo[name]
+
+
+class TestOpList:
+ """All the tests for one op grouped by permutations."""
+
+ def __init__(self, opName, opSelectionInfo, selectionCriteria, testOpName):
+ self.opName = opName
+ self.testOpName = testOpName
+ op = _get_specific_op_info(opName, opSelectionInfo, testOpName)
+
+ # See verif/conformance/README.md for more information on
+ # these selection arguments
+ self.permuteArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "permutes", ["rank", "dtype"]
+ )
+ self.paramArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "full_params", []
+ )
+ self.specificArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "specifics", {}
+ )
+ self.groupArgs = _get_selection_info_from_op(
+ op, selectionCriteria, "groups", ["s"]
+ )
+ self.maximumPerPermute = _get_selection_info_from_op(
+ op, selectionCriteria, "maximum", None
+ )
+ self.numErrorIfs = _get_selection_info_from_op(
+ op, selectionCriteria, "num_errorifs", 1
+ )
+ self.selectAll = _get_selection_info_from_op(
+ op, selectionCriteria, "all", False
+ )
+
+ if self.paramArgs and self.maximumPerPermute > 1:
+ logger.warning(f"Unsupported - selection params AND maximum for {opName}")
+
+ self.tests = []
+ self.testStrings = set()
+ self.shapes = set()
+
+ self.permutes = set()
+ self.testsPerPermute = {}
+ self.paramsPerPermute = {}
+ self.specificsPerPermute = {}
+
+ self.selectionDone = False
+
+ def __len__(self):
+ return len(self.tests)
+
+ def add(self, test):
+ # Add a test to this op group and set up the permutations/group for it
+ assert test.opName.startswith(self.opName)
+ if str(test) in self.testStrings:
+ logger.info(f"Skipping duplicate test: {str(test)}")
+ return
+
+ self.tests.append(test)
+ self.testStrings.add(str(test))
+
+ self.shapes.add(test.getArg("shape"))
+
+ # Work out the permutation key for this test
+ permute = test.setKey(self.permuteArgs)
+ # Set up the group key for the test (for pulling out groups during selection)
+ test.setGroupKey(self.groupArgs)
+
+ if permute not in self.permutes:
+ # New permutation
+ self.permutes.add(permute)
+ # Set up area to record the selected tests
+ self.testsPerPermute[permute] = []
+ if self.paramArgs:
+ # Set up area to record the unique test params found
+ self.paramsPerPermute[permute] = {}
+ for param in self.paramArgs:
+ self.paramsPerPermute[permute][param] = set()
+ # Set up copy of the specific test args for selecting these
+ self.specificsPerPermute[permute] = copy.deepcopy(self.specificArgs)
+
+ def _init_select(self):
+ # Can only perform the selection process once as it alters the permute
+ # information set at init
+ assert not self.selectionDone
+
+ # Count of non-specific tests added to each permute (not error)
+ if not self.selectAll:
+ countPerPermute = {permute: 0 for permute in self.permutes}
+
+ # Go through each test looking for permutes, unique params & specifics
+ for test in self.tests:
+ permute = test.getKey()
+ append = False
+ possible_append = False
+
+ if test.isError():
+ # Error test, choose up to number of tests
+ if len(self.testsPerPermute[permute]) < self.numErrorIfs:
+ append = True
+ else:
+ if self.selectAll:
+ append = True
+ else:
+ # See if this is a specific test to add
+ for param, values in self.specificsPerPermute[permute].items():
+ arg = test.getArg(param)
+ # Iterate over a copy of the values, so we can remove them from the original
+ if arg in values.copy():
+ # Found a match, remove it, so we don't look for it later
+ values.remove(arg)
+ # Mark the test as special (and so shouldn't be removed)
+ test.setMark()
+ append = True
+
+ if self.paramArgs:
+ # See if this test contains any new params we should keep
+ # Perform this check even if we have already selected the test
+ # so we can record the params found
+ for param in self.paramArgs:
+ arg = test.getArg(param)
+ if arg not in self.paramsPerPermute[permute][param]:
+ # We have found a new value for this arg, record it
+ self.paramsPerPermute[permute][param].add(arg)
+ possible_append = True
+ else:
+ # No params set, so possible test to add up to maximum
+ possible_append = True
+
+ if (not append and possible_append) and (
+ self.maximumPerPermute is None
+ or countPerPermute[permute] < self.maximumPerPermute
+ ):
+ # Not selected but could be added and we have space left if
+ # a maximum is set.
+ append = True
+ countPerPermute[permute] += 1
+
+ # Check for grouping with chosen tests
+ if not append:
+ # We will keep any tests together than form a group
+ key = test.getGroupKey()
+ for t in self.testsPerPermute[permute]:
+ if t.getGroupKey() == key:
+ if t.getMark():
+ test.setMark()
+ append = True
+
+ if append:
+ self.testsPerPermute[permute].append(test)
+
+ self.selectionDone = True
+
+ def select(self, rng=None):
+ # Create selection of tests with optional shuffle
+ if not self.selectionDone:
+ if rng:
+ rng.shuffle(self.tests)
+
+ self._init_select()
+
+ # Now create the full list of selected tests per permute
+ selection = []
+
+ for permute, tests in self.testsPerPermute.items():
+ selection.extend(tests)
+
+ return selection
+
+ def all(self):
+ # Un-selected list of tests - i.e. all of them
+ return self.tests
+
+
+class TestList:
+ """List of all tests grouped by operator."""
+
+ def __init__(self, opSelectionInfo, selectionCriteria="default"):
+ self.opLists = {}
+ self.opSelectionInfo = opSelectionInfo
+ self.selectionCriteria = selectionCriteria
+
+ def __len__(self):
+ length = 0
+ for opName in self.opLists.keys():
+ length += len(self.opLists[opName])
+ return length
+
+ def add(self, test):
+ if test.opName not in self.opLists:
+ self.opLists[test.opName] = TestOpList(
+ test.opName,
+ self.opSelectionInfo,
+ self.selectionCriteria,
+ test.testOpName,
+ )
+ self.opLists[test.opName].add(test)
+
+ def _get_tests(self, selectMode, rng):
+ selection = []
+
+ for opList in self.opLists.values():
+ if selectMode:
+ tests = opList.select(rng=rng)
+ else:
+ tests = opList.all()
+ selection.extend(tests)
+
+ selection = sorted(selection)
+ return selection
+
+ def select(self, rng=None):
+ return self._get_tests(True, rng)
+
+ def all(self):
+ return self._get_tests(False, None)