aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py112
1 files changed, 87 insertions, 25 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index ffa3683..4878708 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1976,6 +1976,9 @@ class TosaArgGen:
sparsity += 1
return sparsity
+ # Maximum number of error_if variants to produce
+ MAX_CONV_ERROR_IFS = 3
+
@staticmethod
def agConv(testGen, rng, opName, shapeList, dtypes, error_name=None):
# Used by CONV2D, CONV3D and DEPTHWISE_CONV2D
@@ -2017,17 +2020,60 @@ class TosaArgGen:
k_size *= ifm_shape[-1]
if not testGen.args.level8k:
- # Generate comprehensive argument lists
- # - except for named errors, which use specific invalid value(s)
- if error_name == ErrorIf.PadSmallerZero:
- p_vals = [rng.choice(range(-5, 0))]
+ if error_name in (
+ ErrorIf.PadSmallerZero,
+ ErrorIf.StrideSmallerOne,
+ ErrorIf.DilationSmallerOne,
+ ):
+ # Use specific invalid value(s)
+ if error_name == ErrorIf.PadSmallerZero:
+ # Create negative paddings but with positive opposite paddings
+ neg_pad = rng.choice(range(-5, 0))
+ p_vals = [neg_pad, abs(neg_pad)]
+ else:
+ p_vals = [0, 0]
+ if error_name == ErrorIf.StrideSmallerOne:
+ # Can't use stride=0, as it is used to derive output shape, as a divisor
+ s_vals = [rng.choice(range(-5, 0))]
+ else:
+ s_vals = [1]
+ if error_name == ErrorIf.DilationSmallerOne:
+ d_vals = [rng.choice(range(-5, 1))]
+ else:
+ d_vals = [1]
+ p = p_vals * k_rank
+ s = s_vals * k_rank
+ d = d_vals * k_rank
+
+ # Fix values to produce valid error_if
+ for index in range(k_rank):
+ pad_offset = index * 2
+ fixed = False
+ while not fixed:
+ partial = (
+ ifm_shape[index + 1]
+ - 1
+ + p[pad_offset]
+ + p[pad_offset + 1]
+ - (k_shape[index] - 1) * d[index]
+ )
+ remainder = partial % s[index]
+ if partial <= 0:
+ p[pad_offset + 1] += abs(partial) + 1
+ elif remainder:
+ # Stride will be negative for StrideSmallerOne
+ assert remainder < 0
+ p[pad_offset + 1] += abs(remainder)
+ else:
+ fixed = True
+ paddings = {tuple(p)}
+ strides = {tuple(s)}
+ dilations = {tuple(d)}
+ logger.debug(f"agConv: error pad={p} stride={s} dilation={d}")
else:
+ # Generate comprehensive argument lists
p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
- paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
- if error_name == ErrorIf.StrideSmallerOne:
- # Can't use stride=0, as it is used to derive output shape, as a divisor
- s_vals = [rng.choice(range(-5, 0))]
- else:
+ paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
# Stride must be greater than 1 to force non-integer error
startStride = (
1 if error_name != ErrorIf.ConvOutputShapeNonInteger else 2
@@ -2035,12 +2081,10 @@ class TosaArgGen:
s_vals = [
x for x in range(startStride, testGen.args.max_conv_stride + 1)
]
- strides = {x for x in itertools.product(*([s_vals] * k_rank))}
- if error_name == ErrorIf.DilationSmallerOne:
- d_vals = [rng.choice(range(-5, 1))]
- else:
d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
- dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
+
+ strides = {x for x in itertools.product(*([s_vals] * k_rank))}
+ dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
if not error_name and testGen.args.oversize:
# add some oversize argument values
@@ -2064,12 +2108,15 @@ class TosaArgGen:
)
max_dim_size = None
- # There are too many parameter combinations, so generate them sparsely,
- # very sparse for negative tests
- sparsity_factor = 2 if error_name else 120
- sparsity = TosaArgGen._calculate_sparsity(
- len(paddings) * len(strides) * len(dilations), sparsity_factor
- )
+ if error_name:
+ # Cycle through all error_if tests but we only keep the first few
+ sparsity = 1
+ else:
+ # There are too many parameter combinations, so generate them sparsely,
+ sparsity_factor = 120
+ sparsity = TosaArgGen._calculate_sparsity(
+ len(paddings) * len(strides) * len(dilations), sparsity_factor
+ )
else:
# Only test 8k levels boundaries
bigStride = testGen.TOSA_8K_LEVEL_MAX_STRIDE
@@ -2114,13 +2161,15 @@ class TosaArgGen:
# Currently allow all combinations that are reasonable size
sparsity = 1
+ more_tests = True
n = 0
for a in accum_dtypes:
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
for d in sorted(list(dilations)):
if (
- n % sparsity == 0
+ more_tests
+ and n % sparsity == 0
# the padded shape must exceed the dilation * kernel to get a positive
# sized output shape
and (ifm_shape[1] - 1 + p[0] + p[1])
@@ -2199,6 +2248,15 @@ class TosaArgGen:
args_dict,
)
)
+ if (
+ error_name
+ and len(arg_list) >= TosaArgGen.MAX_CONV_ERROR_IFS
+ ):
+ # Found enough errors
+ logger.debug(
+ f"Skipping creating more conv error tests for {error_name}"
+ )
+ more_tests = False
n += 1
arg_list = TosaArgGen._add_data_generators(
@@ -2482,7 +2540,9 @@ class TosaArgGen:
pad_const_int = 0
pad_const_fp = rng.randNumberDType(dtype)
else:
- return []
+ assert error_name == ErrorIf.WrongInputType
+ pad_const_int = 0
+ pad_const_fp = 0
list_shape_pad_values = list(shape_pad_values)
# If we are producing tests for rank 6 or greater use sparsity
@@ -2523,7 +2583,9 @@ class TosaArgGen:
arg_list.append((name, args_dict))
if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
- logger.info(f"No ErrorIf test created for input shape: {shapeList[0]}")
+ logger.debug(
+ f"agPad: No PadSmallerZero ErrorIf test created for input shape: {shapeList[0]}"
+ )
arg_list = TosaArgGen._add_data_generators(
testGen,
@@ -3106,8 +3168,8 @@ class TosaArgGen:
# Find new shapes up to the number of permutations asked for
# This code is NOT fast. Fortunately, the numbers are fairly small.
for p in range(testGen.args.num_rand_permutations):
- # Rank from 1 to TOSA_TENSOR_MAX_RANK
- newRank = rng.randInt(1, (testGen.TOSA_TENSOR_MAX_RANK + 1))
+ # Rank from 1 to MAX_TENSOR_RANK
+ newRank = rng.randInt(1, (gtu.MAX_TENSOR_RANK + 1))
if len(factors) < newRank:
continue