aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py230
1 files changed, 163 insertions, 67 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index a65e220..69968d3 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2,10 +2,13 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
import math
+import warnings
import numpy as np
from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
+from generator.tosa_utils import get_accum_dtype_from_tgTypes
+from generator.tosa_utils import get_wrong_output_type
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from serializer.tosa_serializer import DTypeNames
from tosa.DType import DType
@@ -773,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] == DType.FLOAT:
+ if dtypeList[0] in (DType.FP16, DType.FLOAT):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -982,7 +985,7 @@ class TosaArgGen:
return axes
@staticmethod
- def agConv(testGen, opName, shapeList, dtype, error_name=None):
+ def agConv(testGen, opName, shapeList, dtypes, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
@@ -990,6 +993,8 @@ class TosaArgGen:
# determine the kernel shape from operator name (e.g. "conv2d_3x3" => [3,3])
k = [int(x) for x in opName.split("_")[-1].split("x")]
+ accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
# Check the rank
rank = 5 if opName.startswith("conv3d") else 4
if error_name != ErrorIf.WrongRank:
@@ -1089,12 +1094,13 @@ class TosaArgGen:
):
arg_list.append(
(
- "st{}_pad{}_dilat{}".format(
+ "acc{}_st{}_pad{}_dilat{}".format(
+ testGen.typeStr(accum_dtype),
"".join([str(x) for x in s]),
"".join([str(x) for x in p]),
"".join([str(x) for x in d]),
),
- [s, p, d],
+ [accum_dtype, s, p, d],
)
)
n += 1
@@ -1102,12 +1108,55 @@ class TosaArgGen:
return arg_list
@staticmethod
- def agTransposeConv2D(testGen, opName, shapeList, dtype, error_name=None):
+ def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
+
+ if isinstance(dtypes, list) or isinstance(dtypes, tuple):
+ input_dtype = dtypes[0]
+ else:
+ input_dtype = dtypes
+
+ if error_name == ErrorIf.WrongOutputType:
+ accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ accum_dtype = DType.INT32
+ else:
+ accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
+ return [(f"acc{testGen.typeStr(accum_dtype)}", [accum_dtype])]
+
+ @staticmethod
+ def agMatMul(testGen, opName, shapeList, dtype, error_name=None):
+ # Get valid accumulate type(s)
+ if dtype == DType.INT8:
+ accum_dtypes = [DType.INT32]
+ elif dtype == DType.INT16:
+ accum_dtypes = [DType.INT48]
+ elif dtype == DType.FP16:
+ accum_dtypes = [DType.FP16, DType.FLOAT]
+ elif dtype == DType.FLOAT:
+ accum_dtypes = [DType.FLOAT]
+ elif error_name is None:
+ assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
+
+ if error_name == ErrorIf.WrongOutputType:
+ # Get incorrect output dtype for ErrorIf case
+ accum_dtypes = [get_wrong_output_type(opName, testGen.rng, dtype)]
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ accum_dtypes = [DType.INT32]
+
+ return [(f"acc{testGen.typeStr(a)}", [a]) for a in accum_dtypes]
+
+ @staticmethod
+ def agTransposeConv2D(testGen, opName, shapeList, dtypes, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
+ accum_dtype = get_accum_dtype_from_tgTypes(dtypes)
+
# Must be rank 4
if error_name != ErrorIf.WrongRank:
assert len(ifm_shape) == 4
@@ -1169,12 +1218,13 @@ class TosaArgGen:
os = [ifm_shape[0], oh, ow, filter_shape[0]]
arg_list.append(
(
- "st{}_pad{}_os{}".format(
+ "acc{}_st{}_pad{}_os{}".format(
+ testGen.typeStr(accum_dtype),
"".join([str(x) for x in s]),
"".join([str(x) for x in p]),
"x".join([str(x) for x in os]),
),
- [s, p, os],
+ [accum_dtype, s, p, os],
)
)
n += 1
@@ -1199,18 +1249,38 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype == DType.FLOAT:
+ elif dtype in (DType.FP16, DType.FLOAT):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
return []
for paddings in shape_pad_values:
- name = "pad"
- for r in range(rank):
- before, after = paddings[r]
- name = f"{name}{before}{after}"
- arg_list.append((name, [np.array(paddings), pad_const_int, pad_const_fp]))
+ paddings = list(paddings)
+ args_valid = True
+
+ if error_name == ErrorIf.PadSmallerZero:
+ # Prevent negative output shapes while ensuring still testing for negative padding
+ for i in range(rank):
+ dim_after_padding = (
+ paddings[i][0] + paddings[i][1] + shapeList[0][i]
+ )
+ if dim_after_padding < 1:
+ paddings[i] = (0, 0)
+ if all([p > -1 for p in paddings[i]]):
+ args_valid = False
+
+ if args_valid:
+ name = "pad"
+ for r in range(rank):
+ before, after = paddings[r]
+ name = f"{name}{before}{after}"
+ arg_list.append(
+ (name, [np.array(paddings), pad_const_int, pad_const_fp])
+ )
+
+ if error_name == ErrorIf.PadSmallerZero and len(arg_list) == 0:
+ warnings.warn(f"No ErrorIf test created for input shape: {shapeList[0]}")
return arg_list
@@ -1232,6 +1302,21 @@ class TosaArgGen:
k_vals = [x for x in range(2, testGen.args.max_pooling_kernel + 1)]
kernels = {x for x in itertools.product(*([k_vals] * 2))}
+ if opName == "max_pool2d":
+ accum_dtypes = [None] # max_pool has no accumulate dtype
+ elif dtype == DType.INT8 or dtype == DType.INT16:
+ accum_dtypes = [DType.INT32]
+ elif dtype == DType.FP16:
+ accum_dtypes = [DType.FP16, DType.FLOAT]
+ elif dtype == DType.FLOAT:
+ accum_dtypes = [DType.FLOAT]
+ elif error_name is None:
+ assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
+ else:
+ # Set to something for the ErrorIf case which has
+ # incorrect input data-type
+ accum_dtypes = [DType.INT32]
+
if testGen.args.oversize:
# add some oversize argument values
bigStride = 7
@@ -1252,63 +1337,70 @@ class TosaArgGen:
sparsity_factor = 2 if error_name else 500
sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
+ arg_str = (
+ "acc{}_st{}_kern{}_pad{}"
+ if accum_dtypes[0] is not None
+ else "st{}_kern{}_pad{}"
+ )
+
+ def get_arg_list_element(accum, stride, pad, kern):
+ # Return tuple containing the formatted argument string and
+ # the corresponding argument values
+ arg_str_elems = [
+ "".join([str(x) for x in stride]),
+ "".join([str(x) for x in kern]),
+ "".join([str(x) for x in pad]),
+ ]
+ # Note: different order to string
+ arg_val_elems = [stride, pad, kern]
+
+ if accum is not None:
+ arg_str_elems.insert(0, testGen.typeStr(accum))
+ arg_val_elems.insert(0, accum)
+ return (arg_str.format(*arg_str_elems), arg_val_elems)
+
n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for k in sorted(list(kernels)):
- if error_name in [
- ErrorIf.StrideSmallerOne,
- ErrorIf.KernelSmallerOne,
- ErrorIf.PadSmallerZero,
- ErrorIf.PadLargerEqualKernel,
- ]:
- sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
- testGen, error_name, s, p, k
- )
- if None not in [sNew, pNew, kNew] and n % sparsity == 0:
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in sNew]),
- "".join([str(x) for x in kNew]),
- "".join([str(x) for x in pNew]),
- ),
- [sNew, pNew, kNew],
- )
+ for a in accum_dtypes:
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for k in sorted(list(kernels)):
+ if error_name in [
+ ErrorIf.StrideSmallerOne,
+ ErrorIf.KernelSmallerOne,
+ ErrorIf.PadSmallerZero,
+ ErrorIf.PadLargerEqualKernel,
+ ]:
+ sNew, pNew, kNew = TosaErrorIfArgGen.eiPoolingErrorIf(
+ testGen, error_name, s, p, k
)
- elif (
- n % sparsity == 0
- # padding must not exceed the kernel size
- and p[0] < k[0]
- and p[1] < k[0]
- and p[2] < k[1]
- and p[3] < k[1]
- # the padded shape must exceed the kernel size
- and (shape[1] + p[0] + p[1]) > k[0]
- and (shape[2] + p[2] + p[3]) > k[1]
- ):
- remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
- remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
- if (
- # the parameters must produce integer exact output
- error_name != ErrorIf.PoolingOutputShapeNonInteger
- and remainder_h == 0
- and remainder_w == 0
- ) or (
- error_name == ErrorIf.PoolingOutputShapeNonInteger
- and (remainder_h != 0 or remainder_w != 0)
+ if None not in [sNew, pNew, kNew] and n % sparsity == 0:
+ arg_vals = [a, sNew, pNew, kNew]
+ arg_list.append(get_arg_list_element(*arg_vals))
+ elif (
+ n % sparsity == 0
+ # padding must not exceed the kernel size
+ and p[0] < k[0]
+ and p[1] < k[0]
+ and p[2] < k[1]
+ and p[3] < k[1]
+ # the padded shape must exceed the kernel size
+ and (shape[1] + p[0] + p[1]) > k[0]
+ and (shape[2] + p[2] + p[3]) > k[1]
):
- arg_list.append(
- (
- "st{}_kern{}_pad{}".format(
- "".join([str(x) for x in s]),
- "".join([str(x) for x in k]),
- "".join([str(x) for x in p]),
- ),
- [s, p, k],
- )
- )
- n += 1
+ remainder_h = (shape[1] + p[0] + p[1] - k[0]) % s[0]
+ remainder_w = (shape[2] + p[2] + p[3] - k[1]) % s[1]
+ if (
+ # the parameters must produce integer exact output
+ error_name != ErrorIf.PoolingOutputShapeNonInteger
+ and remainder_h == 0
+ and remainder_w == 0
+ ) or (
+ error_name == ErrorIf.PoolingOutputShapeNonInteger
+ and (remainder_h != 0 or remainder_w != 0)
+ ):
+ arg_vals = [a, s, p, k]
+ arg_list.append(get_arg_list_element(*arg_vals))
+ n += 1
return arg_list
@@ -1327,6 +1419,8 @@ class TosaArgGen:
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif inDtype == DType.FP16:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FLOAT:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
@@ -1734,6 +1828,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT32]
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
outputDTypeList = [DType.INT48]
+ elif dtype == DType.FP16:
+ outputDTypeList = [DType.FP16]
elif dtype == DType.FLOAT:
outputDTypeList = [DType.FLOAT]
elif error_name == ErrorIf.WrongInputType: