aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-10 14:16:39 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-01-30 11:50:54 +0000
commit95a6710ffb8cadcb8658a967ab29cac1bffad930 (patch)
tree6320e5d34441626b1e7a956886bd1fee88dbf4a1 /verif/generator/tosa_arg_gen.py
parent4f931307a6319d9d99b3afce4ca6e1cd30d77f01 (diff)
downloadreference_model-95a6710ffb8cadcb8658a967ab29cac1bffad930.tar.gz
Main Compliance: TRANSPOSE_CONV2D support
Update data generator for main compliance values. Add test generation support. Fixed test set by including large 65k tests that were missing. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I8668c774e01c17e5d999aadf99c317e2dd893857
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py25
1 files changed, 24 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 8501caa..91d2d62 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2186,6 +2186,8 @@ class TosaArgGen:
assert len(filter_shape) == 4
k_shape = tuple(filter_shape[1:3])
+ # compliance size - KS
+ k_size = gtu.product((*k_shape, ifm_shape[3]))
if not testGen.args.level8k:
# Generate comprehensive argument lists
@@ -2283,6 +2285,19 @@ class TosaArgGen:
ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
os = [ifm_shape[0], oh, ow, filter_shape[0]]
+ # N*OH*OW*OC
+ dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
+ args_dict = {
+ "acc_type": accum_dtype,
+ "stride": s,
+ "pad": p,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ "shape": ifm_shape,
+ "out_shape": os,
+ }
+
# Support for larger values than 9 needs different delimiter
delim = "" if max(s + p) <= 9 else "x"
arg_list.append(
@@ -2293,11 +2308,19 @@ class TosaArgGen:
delim.join([str(x) for x in p]),
"x".join([str(x) for x in os]),
),
- [accum_dtype, s, p, os],
+ args_dict,
)
)
n += 1
+ arg_list = TosaArgGen._add_data_generators(
+ testGen,
+ opName,
+ dtypes[0],
+ arg_list,
+ error_name,
+ )
+ # Return list of tuples: (arg_str, args_dict)
return arg_list
@staticmethod