aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-03-25 09:46:02 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-25 16:42:46 +0000
commit8746026555c37c8d208fcbedeb04d0ae6d25d53e (patch)
treef0142ed228bb54477c7d2f4a78485115d383607c
parentc44483694a6b3476d89968f6ee9463886d433211 (diff)
downloadreference_model-8746026555c37c8d208fcbedeb04d0ae6d25d53e.tar.gz
Fix missing Pooling ERRORIFs
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I749b1b9fbfa32954d8748a860280c86087c08e7f
-rw-r--r--verif/generator/tosa_arg_gen.py33
1 files changed, 24 insertions, 9 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 504bfa9..41ef4df 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1979,7 +1979,7 @@ class TosaArgGen:
return sparsity
# Maximum number of error_if variants to produce
- MAX_CONV_ERROR_IFS = 3
+ MAX_TESTS_ERROR_IFS = 3
@staticmethod
def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
@@ -2252,7 +2252,7 @@ class TosaArgGen:
)
if (
error_name
- and len(arg_list) >= TosaArgGen.MAX_CONV_ERROR_IFS
+ and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
):
# Found enough errors
logger.debug(
@@ -2691,12 +2691,15 @@ class TosaArgGen:
{x for x in itertools.product(*([[startPad, bigPadding]] * 4))}
)
- # There are too many parameter combinations, so generate them sparsely,
- # very sparse for negative tests
- sparsity_factor = 2 if error_name else 500
- sparsity = (
- len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
- )
+ if error_name:
+ # Cycle through all error_if tests but we only keep the first few
+ sparsity = 1
+ else:
+ # There are too many parameter combinations, so generate them sparsely
+ sparsity_factor = 500
+ sparsity = (
+ len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
+ )
else:
# We have already limited test output combinations for 8k tests
sparsity = 1
@@ -2732,6 +2735,7 @@ class TosaArgGen:
args_dict["acc_type"] = accum
return (arg_str.format(*arg_str_elems), args_dict)
+ more_tests = True
n = 0
for a in accum_dtypes:
for s in sorted(list(strides)):
@@ -2751,7 +2755,8 @@ class TosaArgGen:
get_arg_list_element(a, sNew, pNew, kNew, shape)
)
elif (
- n % sparsity == 0
+ more_tests
+ and n % sparsity == 0
# padding must not exceed the kernel size
and p[0] < k[0]
and p[1] < k[0]
@@ -2792,6 +2797,16 @@ class TosaArgGen:
arg_list.append(
get_arg_list_element(a, s, p, k, dp, shape)
)
+ if (
+ error_name
+ and len(arg_list) >= TosaArgGen.MAX_TESTS_ERROR_IFS
+ ):
+ # Found enough errors
+ logger.debug(
+ f"Skipping creating more pooling error tests for {error_name}"
+ )
+ more_tests = False
+
n += 1
# Now add data generator types