aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py36
1 files changed, 22 insertions, 14 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 3a85961..79d4e78 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2545,13 +2545,17 @@ class TosaArgGen:
def agPad(testGen, rng, opName, shapeList, dtype, error_name=None):
rank = len(shapeList[0])
- # Exhaustively test combinations of padding on each side of each dimension
- # - the range of padding values is defined by pad_min and pad_max
- # - for padding >9, the name format needs to be more distinctive
- pad_min, pad_max = 0, 1
- pad_values = [x for x in range(pad_min, pad_max + 1)]
- if error_name == ErrorIf.PadSmallerZero:
+ if error_name is None and testGen.args.oversize:
+ pad_values = [6, 7, 10, 13]
+ elif error_name == ErrorIf.PadSmallerZero:
pad_values = [x for x in range(-2, 0)]
+ else:
+ # Exhaustively test combinations of padding on each side of each dimension
+ # - the range of padding values is defined by pad_min and pad_max
+ pad_min, pad_max = 0, 1
+ pad_values = [x for x in range(pad_min, pad_max + 1)]
+
+ # Calculate pad combinations
axis_pad_values = [x for x in itertools.product(pad_values, pad_values)]
shape_pad_values = itertools.product(*([axis_pad_values] * rank))
@@ -2593,15 +2597,19 @@ class TosaArgGen:
if all([p > -1 for p in paddings[i]]):
args_valid = False
if args_valid and n % sparsity == 0:
- name = "pad"
+ # Work out name
+ pad_list = []
for r in range(rank):
- before, after = paddings[r]
- name = f"{name}{before}{after}"
- args_dict = {
- "pad": np.array(paddings),
- "pad_const_int": pad_const_int,
- "pad_const_fp": pad_const_fp,
- }
+ pad_list.extend(paddings[r])
+
+ delim = "" if max(pad_list) <= 9 else "x"
+ name = "pad{}".format(delim.join([str(x) for x in pad_list]))
+
+ args_dict = {
+ "pad": np.array(paddings),
+ "pad_const_int": pad_const_int,
+ "pad_const_fp": pad_const_fp,
+ }
arg_list.append((name, args_dict))
if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0: