aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py240
1 files changed, 127 insertions, 113 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 83487a1..ffa3683 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1990,7 +1990,12 @@ class TosaArgGen:
# Shape: (OFM channels), (KD), KH, KW, IFM channels
filter_shape = shapeList[1]
- accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
+
+ if error_name == ErrorIf.WrongAccumulatorType:
+ accum_dtypes = (
+ [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
+ )
# Op type checks
conv3d = opName.startswith("conv3d")
@@ -2110,88 +2115,91 @@ class TosaArgGen:
sparsity = 1
n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for d in sorted(list(dilations)):
- if (
- n % sparsity == 0
- # the padded shape must exceed the dilation * kernel to get a positive
- # sized output shape
- and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
- and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
- and (
- k_rank < 3
- or (
- (ifm_shape[3] - 1 + p[4] + p[5])
- > d[2] * (k_shape[2] - 1)
- )
- )
- ):
- remainders = []
- outputs = []
- for index in range(k_rank):
- pad_offset = index * 2
- partial = (
- ifm_shape[index + 1]
- - 1
- + p[pad_offset]
- + p[pad_offset + 1]
- - (k_shape[index] - 1) * d[index]
- )
- remainders.append(partial % s[index])
- outputs.append((partial // s[index]) + 1)
-
+ for a in accum_dtypes:
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for d in sorted(list(dilations)):
if (
- # the parameters must produce integer exact output
- error_name != ErrorIf.ConvOutputShapeNonInteger
- and max(remainders) == 0
- ) or (
- error_name == ErrorIf.ConvOutputShapeNonInteger
- and max(remainders) > 0
+ n % sparsity == 0
+ # the padded shape must exceed the dilation * kernel to get a positive
+ # sized output shape
+ and (ifm_shape[1] - 1 + p[0] + p[1])
+ > d[0] * (k_shape[0] - 1)
+ and (ifm_shape[2] - 1 + p[2] + p[3])
+ > d[1] * (k_shape[1] - 1)
+ and (
+ k_rank < 3
+ or (
+ (ifm_shape[3] - 1 + p[4] + p[5])
+ > d[2] * (k_shape[2] - 1)
+ )
+ )
):
+ remainders = []
+ outputs = []
+ for index in range(k_rank):
+ pad_offset = index * 2
+ partial = (
+ ifm_shape[index + 1]
+ - 1
+ + p[pad_offset]
+ + p[pad_offset + 1]
+ - (k_shape[index] - 1) * d[index]
+ )
+ remainders.append(partial % s[index])
+ outputs.append((partial // s[index]) + 1)
+
if (
- max_dim_size is not None
- and max(outputs) >= max_dim_size
+ # the parameters must produce integer exact output
+ error_name != ErrorIf.ConvOutputShapeNonInteger
+ and max(remainders) == 0
+ ) or (
+ error_name == ErrorIf.ConvOutputShapeNonInteger
+ and max(remainders) > 0
):
- # Test will consume too much memory - skip it
- continue
-
- # Compliance - number of dot product calculations
- if depthwise:
- # N*OH*OW*C*M
- dots = gtu.product(
- (ifm_shape[0], *outputs, *filter_shape[2:])
- )
- else:
- # N*OH*OW*OC or N*OD*OH*OW*OC
- dots = gtu.product(
- (ifm_shape[0], *outputs, filter_shape[0])
- )
- args_dict = {
- "acc_type": accum_dtype,
- "stride": s,
- "pad": p,
- "dilation": d,
- "kernel": k_shape,
- "ks": k_size,
- "dot_products": dots,
- "shape": ifm_shape,
- }
-
- # Support for larger values than 9 needs different delimiter
- delim = "" if max(s + p + d) <= 9 else "x"
- arg_list.append(
- (
- "acc{}_st{}_pad{}_dilat{}".format(
- testGen.typeStr(accum_dtype),
- delim.join([str(x) for x in s]),
- delim.join([str(x) for x in p]),
- delim.join([str(x) for x in d]),
- ),
- args_dict,
+ if (
+ max_dim_size is not None
+ and max(outputs) >= max_dim_size
+ ):
+ # Test will consume too much memory - skip it
+ continue
+
+ # Compliance - number of dot product calculations
+ if depthwise:
+ # N*OH*OW*C*M
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, *filter_shape[2:])
+ )
+ else:
+ # N*OH*OW*OC or N*OD*OH*OW*OC
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, filter_shape[0])
+ )
+ args_dict = {
+ "acc_type": a,
+ "stride": s,
+ "pad": p,
+ "dilation": d,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ "shape": ifm_shape,
+ }
+
+ # Support for larger values than 9 needs different delimiter
+ delim = "" if max(s + p + d) <= 9 else "x"
+ arg_list.append(
+ (
+ "acc{}_st{}_pad{}_dilat{}".format(
+ testGen.typeStr(a),
+ delim.join([str(x) for x in s]),
+ delim.join([str(x) for x in p]),
+ delim.join([str(x) for x in d]),
+ ),
+ args_dict,
+ )
)
- )
- n += 1
+ n += 1
arg_list = TosaArgGen._add_data_generators(
testGen,
@@ -2216,7 +2224,7 @@ class TosaArgGen:
# Pick some potentially correct output dtype if input type is incorrect
accum_dtype = DType.INT32
else:
- accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtype = dtypes[-1] # use output dtype as accum_dtype
# Set up compliance info
args_dict = {
@@ -2303,7 +2311,12 @@ class TosaArgGen:
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
- accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
+
+ if error_name == ErrorIf.WrongAccumulatorType:
+ accum_dtypes = (
+ [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
+ )
# Must be rank 4
if error_name != ErrorIf.WrongRank:
@@ -2400,41 +2413,42 @@ class TosaArgGen:
sparsity = 1
n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- if n % sparsity == 0:
- # Determine the output shape
- oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
- ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
- os = [ifm_shape[0], oh, ow, filter_shape[0]]
-
- # N*OH*OW*OC
- dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
- args_dict = {
- "acc_type": accum_dtype,
- "stride": s,
- "pad": p,
- "kernel": k_shape,
- "ks": k_size,
- "dot_products": dots,
- "shape": ifm_shape,
- "out_shape": os,
- }
+ for a in accum_dtypes:
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ if n % sparsity == 0:
+ # Determine the output shape
+ oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
+ ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
+ os = [ifm_shape[0], oh, ow, filter_shape[0]]
+
+ # N*OH*OW*OC
+ dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
+ args_dict = {
+ "acc_type": a,
+ "stride": s,
+ "pad": p,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ "shape": ifm_shape,
+ "out_shape": os,
+ }
- # Support for larger values than 9 needs different delimiter
- delim = "" if max(s + p) <= 9 else "x"
- arg_list.append(
- (
- "acc{}_st{}_pad{}_os{}".format(
- testGen.typeStr(accum_dtype),
- delim.join([str(x) for x in s]),
- delim.join([str(x) for x in p]),
- "x".join([str(x) for x in os]),
- ),
- args_dict,
+ # Support for larger values than 9 needs different delimiter
+ delim = "" if max(s + p) <= 9 else "x"
+ arg_list.append(
+ (
+ "acc{}_st{}_pad{}_os{}".format(
+ testGen.typeStr(a),
+ delim.join([str(x) for x in s]),
+ delim.join([str(x) for x in p]),
+ "x".join([str(x) for x in os]),
+ ),
+ args_dict,
+ )
)
- )
- n += 1
+ n += 1
arg_list = TosaArgGen._add_data_generators(
testGen,