aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLes Bell <les.bell@arm.com>2021-11-09 14:42:14 +0000
committerLes Bell <les.bell@arm.com>2021-11-18 08:12:45 +0000
commit0e027d43bf9964fb4c7e6187ccd3e8dfbcf9522f (patch)
tree6b13b8de5337285d99ed85bd3ba329ba1ef704c8
parent3aea3c36634c8ddfd65b17c4466caf154a721b2a (diff)
downloadreference_model-0e027d43bf9964fb4c7e6187ccd3e8dfbcf9522f.tar.gz
Convolutions ERROR_IF tests
Signed-off-by: Les Bell <les.bell@arm.com> Change-Id: I68a13e1b337b1afc2ab5e0edcffda2b4b0cecdda
-rw-r--r--verif/tosa_error_if.py2
-rw-r--r--verif/tosa_test_gen.py744
2 files changed, 554 insertions, 192 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index eb67ea8..7c162be 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -41,6 +41,7 @@ class ErrorIf(object):
ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
KernelSmallerOne = "KernelSmallerOne"
StrideSmallerOne = "StrideSmallerOne"
+ DilationSmallerOne = "DilationSmallerOne"
PadSmallerZero = "PadSmallerZero"
PadLargerEqualKernel = "PadLargerEqualKernel"
PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
@@ -68,4 +69,3 @@ class ErrorIf(object):
InputListBodyGraphOutputMismatch = "InputListBodyGraphOutputMismatch"
CondGraphOutputNotMatchingBool = "CondGraphOutputNotMatchingBool"
-
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index db44328..1bd1b5a 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -45,17 +45,82 @@ import tosa
from tosa_error_if import ErrorIf
# Convenience variables to the flatc-generated types that should be enums, but aren't
-DType = tosa.DType.DType()
-Op = tosa.Op.Op()
-ResizeMode = tosa.ResizeMode.ResizeMode()
+from tosa.DType import DType
+from tosa.Op import Op
+from tosa.ResizeMode import ResizeMode
+def valueToName(item, value):
+ """Get the name of an attribute with the given value.
+
+ This convenience function is needed to print meaningful names for
+ the values of the tosa.Op.Op and tosa.DType.DType classes.
+ This would not be necessary if they were subclasses of Enum, or
+ IntEnum, which, sadly, they are not.
+
+ Args:
+ item: The class, or object, to find the value in
+ value: The value to find
+
+ Example, to get the name of a DType value:
+
+ name = valueToName(DType, DType.INT8) # returns 'INT8'
+ name = valueToName(DType, 4) # returns 'INT8'
+
+ Returns:
+ The name of the first attribute found with a matching value,
+
+ Raises:
+ ValueError if the value is not found
+ """
+ for attr in dir(item):
+ if getattr(item, attr) == value:
+ return attr
+ raise ValueError(f'value ({value}) not found')
+
+def allDTypes(*, excludes=None):
+ """Get a set of all DType values, optionally excluding some values.
+
+ This convenience function is needed to provide a sequence of DType values.
+ This would be much easier if DType was a subclass of Enum, or IntEnum,
+ as we could then iterate over the values directly, instead of using
+ dir() to find the attributes and then check if they are what we want.
+
+ Args:
+ excludes: iterable of DTYPE values (e.g. [DType.INT8, DType.BOOL])
+
+ Returns:
+ A set of DType values
+ """
+ excludes = () if not excludes else excludes
+ return {getattr(DType, t) for t in dir(DType)
+ if not callable(getattr(DType, t)) and not t.startswith('__')
+ and getattr(DType, t) not in excludes}
+
+def usableDTypes(*, excludes=None):
+ """Get a set of usable DType values, optionally excluding some values.
+
+ Excludes (DType.UNKNOWN, DType.UINT8) in addition to the excludes
+ specified by the caller, as the serializer lib does not support them.
+ If you wish to include 'UNKNOWN' or 'UINT8' use allDTypes instead.
+
+ Args:
+ excludes: iterable of DType values (e.g. [DType.INT8, DType.BOOL])
+
+ Returns:
+ A set of DType values
+ """
+ omit = {DType.UNKNOWN, DType.UINT8}
+ omit.update(excludes if excludes else ())
+ return allDTypes(excludes=omit)
+
def product(shape):
value = 1
for n in shape:
value *= n
return value
+
class TosaQuantGen:
"""QuantizedInfo random generator helper functions. Specify with 'qgen': in the operator defintion"""
@@ -299,7 +364,8 @@ class TosaTensorGen:
def tgConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
- assert rank == 4
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
@@ -308,6 +374,10 @@ class TosaTensorGen:
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+
# Get the filter height/width from the operator parameters
filter_hw = op["filter"]
@@ -326,7 +396,8 @@ class TosaTensorGen:
def tgConv3D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
- assert rank == 5
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 5
# IFM dimensions are NDHWC
ifm_shape = testGen.makeShape(rank)
@@ -335,6 +406,10 @@ class TosaTensorGen:
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+
# Get the filter depth/height/width from the operator parameters
filter_dhw = op["filter"]
@@ -355,7 +430,8 @@ class TosaTensorGen:
def tgTransposeConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
- assert rank == 4
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
# IFM dimensions are NHWC
ifm_shape = testGen.makeShape(rank)
@@ -364,6 +440,10 @@ class TosaTensorGen:
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+
# Get the filter height/width from the operator parameters
filter_hw = op["filter"]
@@ -382,7 +462,8 @@ class TosaTensorGen:
def tgDepthwiseConv2D(testGen, op, rank, error_name=None):
pl, const = op["operands"]
- assert rank == 4
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 4
assert pl == 1 and const == 2
# IFM dimensions are NHWC
@@ -392,6 +473,10 @@ class TosaTensorGen:
if testGen.args.max_batch_size:
ifm_shape[0] = (ifm_shape[0] % testGen.args.max_batch_size) + 1
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape, max_dim=24, max_items=10000)
+
# Get the filter height/width from the operator parameters
# Filter is KH, HW, C, M
filter_hw = op["filter"]
@@ -421,7 +506,7 @@ class TosaTensorGen:
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
- shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+ input_shape = TosaErrorIfArgGen.eiRestrictDimensions(input_shape)
filter_oc = testGen.rng.integers(
low=testGen.args.tensor_shape_range[0],
@@ -446,7 +531,7 @@ class TosaTensorGen:
# Constrict the overall size of the shape when creating ERROR_IF tests
if error_name:
- shape = TosaErrorIfArgGen.eiRestrictDimensions(shape)
+ a_shape = TosaErrorIfArgGen.eiRestrictDimensions(a_shape)
# Get a random number for b_oc even if target shape is defined
b_oc = np.int32(
@@ -576,36 +661,54 @@ class TosaArgGen:
# Check the rank
rank = 5 if opName.startswith("conv3d") else 4
- assert len(ifm_shape) == rank
- assert len(filter_shape) == rank
+ if error_name != ErrorIf.WrongRank:
+ assert len(ifm_shape) == rank
+ assert len(filter_shape) == rank
# kernel rank omits batch and channels
k_rank = rank - 2
+ assert len(k) == k_rank
# Generate comprehensive argument lists
- p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
+ # - except for named errors, which use specific invalid value(s)
+ if error_name == ErrorIf.PadSmallerZero:
+ p_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
paddings = {x for x in itertools.product(*([p_vals] * k_rank * 2))}
- s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
+ if error_name == ErrorIf.StrideSmallerOne:
+ # Can't use stride=0, as it is used to derive output shape, as a divisor
+ s_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * k_rank))}
- d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
+ if error_name == ErrorIf.DilationSmallerOne:
+ d_vals = [testGen.rng.choice(range(-5, 1))]
+ else:
+ d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
dilations = {x for x in itertools.product(*([d_vals] * k_rank))}
- # add some oversize argument values
- if max(ifm_shape) < 64:
- bigPadding = 9
- paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
- bigStride = 8
- strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
- bigDilation = 7
- dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
-
- # There are too many parameter combinations, so generate them sparsely
- # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
- sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
+ if not error_name:
+ # add some oversize argument values
+ if max(ifm_shape) < 64:
+ bigPadding = 9
+ paddings.update({x for x in itertools.product(*([[0, bigPadding]] * (k_rank * 2)))})
+ bigStride = 8
+ strides.update({x for x in itertools.product(*([[1, bigStride]] * k_rank))})
+ bigDilation = 7
+ dilations.update({x for x in itertools.product(*([[1, bigDilation]] * k_rank))})
+
+ # There are too many parameter combinations, so generate them sparsely,
+ # very sparse for negative tests
+ sparsity_factor = 2 if error_name else 100
+ sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
+ # If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
+ # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
sparsity += 1
+
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
@@ -643,33 +746,50 @@ class TosaArgGen:
filter_shape = shapeList[1]
# Must be rank 4
- assert len(ifm_shape) == 4
- assert len(filter_shape) == 4
+ if error_name != ErrorIf.WrongRank:
+ assert len(ifm_shape) == 4
+ assert len(filter_shape) == 4
# Generate comprehensive argument lists
- p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
+ # - except for named errors, which use specific invalid value(s)
+ if error_name == ErrorIf.PadSmallerZero:
+ p_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ p_vals = [x for x in range(0, testGen.args.max_conv_padding + 1)]
paddings = {x for x in itertools.product(*([p_vals] * 2))}
- s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
+ if error_name == ErrorIf.StrideSmallerOne:
+ # Can't use stride=0, as it is used to derive output shape, as a divisor
+ s_vals = [testGen.rng.choice(range(-5, 0))]
+ else:
+ s_vals = [x for x in range(1, testGen.args.max_conv_stride + 1)]
strides = {x for x in itertools.product(*([s_vals] * 2))}
- d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
+ if error_name == ErrorIf.DilationSmallerOne:
+ d_vals = [testGen.rng.choice(range(-5, 1))]
+ else:
+ d_vals = [x for x in range(1, testGen.args.max_conv_dilation + 1)]
dilations = {x for x in itertools.product(*([d_vals] * 2))}
- # add some oversize argument values
- if max(ifm_shape) < 64:
- bigPadding = 9
- paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
- bigStride = 8
- strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
- bigDilation = 7
- dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
-
- # There are too many parameter combinations, so generate them sparsely
- # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
- sparsity = len(paddings) * len(strides) * len(dilations) // 100 + 1
+ if not error_name:
+ # add some oversize argument values
+ if max(ifm_shape) < 64:
+ bigPadding = 9
+ paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 2))})
+ bigStride = 8
+ strides.update({x for x in itertools.product(*([[1, bigStride]] * 2))})
+ bigDilation = 7
+ dilations.update({x for x in itertools.product(*([[1, bigDilation]] * 2))})
+
+ # There are too many parameter combinations, so generate them sparsely,
+ # very sparse for negative tests
+ sparsity_factor = 2 if error_name else 100
+ sparsity = len(paddings) * len(strides) * len(dilations) // sparsity_factor + 1
+ # If there are only a small number of tests, just select them all
if sparsity < 13:
sparsity = 1
+ # To get a variety of parameter combinations sparsity should not be a multiple of 2, 3 or 5
while sparsity % 2 == 0 or sparsity % 3 == 0 or sparsity % 5 == 0:
sparsity += 1
+
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
@@ -763,8 +883,11 @@ class TosaArgGen:
bigPadding = bigKernel - 1
paddings.update({x for x in itertools.product(*([[0, bigPadding]] * 4))})
- # There are too many parameter combinations, so generate them sparsely
- sparsity = len(paddings) * len(strides) * len(kernels) // 500 + 1
+ # There are too many parameter combinations, so generate them sparsely,
+ # very sparse for negative tests
+ sparsity_factor = 2 if error_name else 500
+ sparsity = len(paddings) * len(strides) * len(kernels) // sparsity_factor + 1
+
n = 0
for s in sorted(list(strides)):
for p in sorted(list(paddings)):
@@ -1414,7 +1537,7 @@ class TosaErrorIfArgGen:
input_list.append('eiDummyInput')
else:
input_list = input_list[:-1]
- if error_name == "WrongOutputList":
+ elif error_name == "WrongOutputList":
add_output = testGen.rng.choice([True, False])
if add_output:
output_list.append('eiDummyOutput')
@@ -1477,67 +1600,60 @@ class TosaErrorValidator:
@staticmethod
def evValidateErrorIfs(serializer, validator_fcns, error_name, **kwargs):
# Check ERROR_IF statements
-
for val_fcn in validator_fcns:
val_result = val_fcn(True, **kwargs)
-
validator_name = val_result['error_name']
error_result = val_result['error_result']
error_reason = val_result['error_reason']
- if error_result:
- if error_name == validator_name:
- serializer.setExpectedReturnCode(2, error_reason)
- else:
- print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
- return None # Return None to delete test if wrong ERROR_IF is hit
- else:
- if error_name == validator_name:
- print(f"No ERROR_IF hit for {error_name}")
- return None
+ # expect an error IFF the error_name and validator_name match
+ expected_result = error_result == (error_name == validator_name)
+
+ if expected_result and error_result:
+ serializer.setExpectedReturnCode(2, error_reason)
+ elif error_result: # and not expected_result
+ print(f"Unexpected ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
+ f" Expected: {error_name}, Got: {validator_name}")
+ elif not expected_result: # and not error_result
+ print(f"Missed ERROR_IF: Op: {valueToName(Op, kwargs['op']['op'])}"
+ f" Expected: {error_name}")
+
+ if not expected_result:
+ for k, v in sorted(kwargs.items()):
+ if k != 'op':
+ if k.endswith('dtype'):
+ v = valueToName(DType, v)
+ print(f' {k} = {v}')
@staticmethod
def evWrongInputType(check=False, **kwargs):
- all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
+ error_result = False
# Find the unsupported input data types
- assert 'op' in kwargs
op = kwargs['op']
input_dtypes = op['types']
-
allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
- wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
+ wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
if op['op'] == Op.CLAMP:
wrong_input_dtypes.remove(DType.INT48)
- error_name = ErrorIf.WrongInputType
- param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
- error_result = False
- error_reason = "Input data type not supported for this operator"
-
if check:
input_dtype = kwargs['input_dtype']
- if op['op'] == Op.FULLY_CONNECTED:
- if input_dtype not in allowed_input_dtypes:
- error_result = True
- elif input_dtype not in input_dtypes:
+ if input_dtype not in allowed_input_dtypes:
error_result = True
info_dict = {
- "error_name": error_name,
+ "error_name": ErrorIf.WrongInputType,
"error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs
+ "error_reason": f"Input data type not supported for this operator",
+ "param_reqs": {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
}
return info_dict
@staticmethod
def evWrongOutputType(check=False, **kwargs):
- error_name = ErrorIf.WrongOutputType
- param_reqs = {"rank": None, "dtype": None, "shape": None}
error_result = False
- error_reason = "Output data type not supported for this configuration of operator"
if check:
input_dtype = kwargs['input_dtype']
@@ -1607,15 +1723,24 @@ class TosaErrorValidator:
):
error_result = True
+ elif op['op'] in {Op.CONV2D, Op.CONV3D, Op.DEPTHWISE_CONV2D, Op.TRANSPOSE_CONV2D}:
+ if (
+ input_dtype == DType.INT8 and output_dtype != DType.INT32
+ or input_dtype == DType.INT16 and output_dtype != DType.INT48
+ or input_dtype == DType.FLOAT and output_dtype != DType.FLOAT
+ ):
+ error_result = True
+ # invalid input types are ignored, to avoid reporting multiple errors
+
else:
if output_dtype != input_dtype:
error_result = True
info_dict = {
- "error_name": error_name,
+ "error_name": ErrorIf.WrongOutputType,
"error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs
+ "error_reason": "Output data type not supported for this configuration of operator",
+ "param_reqs": {"rank": None, "dtype": None, "shape": None}
}
return info_dict
@@ -1634,8 +1759,10 @@ class TosaErrorValidator:
# Set minimum incorrect rank to 3 to avoid index error
if op['op'] in [Op.RESIZE]:
incorrect_ranks = [3, 5]
- if op['op'] in [Op.TRANSPOSE]:
+ elif op['op'] in [Op.TRANSPOSE]:
incorrect_ranks = [7, 8]
+ elif op['op'] in [Op.CONV3D]:
+ incorrect_ranks = [6, 7]
error_name = ErrorIf.WrongRank
param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -2056,24 +2183,17 @@ class TosaErrorValidator:
@staticmethod
def evInputZeroPointNotZero(check=False, **kwargs):
op = kwargs['op']
- inputDtypes = op['types'].copy()
- # If inputDtypes is a list then only the first two elements are INT8 inputs
- if isinstance(inputDtypes, list):
- inputDtypes = inputDtypes[2:]
+ error_result = False
- if DType.INT8 in inputDtypes:
- inputDtypes.remove(DType.INT8)
- if DType.UINT8 in inputDtypes:
- inputDtypes.remove(DType.UINT8)
+ # Quantizable types
+ qTypes = (DType.INT8, DType.UINT8)
- error_name = ErrorIf.InputZeroPointNotZero
- param_reqs = {
- "rank": None,
- "dtype": inputDtypes,
- "shape": None
- }
- error_result = False
- error_reason = "Input DType not INT8 and zero point not 0"
+ # This does not apply to quantizable types
+ inputDtypes = [
+ dtype for dtype in op['types']
+ if (isinstance(dtype, list) and dtype[0] not in qTypes) or
+ (not isinstance(dtype, list) and dtype not in qTypes)
+ ]
if check:
input_dtype = kwargs['input_dtype']
@@ -2086,22 +2206,22 @@ class TosaErrorValidator:
input_zero_point = qinfo[0][1]
if op['op'] == Op.MATMUL:
- input1_dtype = kwargs['input_dtype']
- input2_dtype = kwargs['input2_dtype']
qinfo = kwargs['qinfo'].ints
- input1_zero_point = qinfo[0][1]
- input2_zero_point = qinfo[1][1]
- if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
- error_result = True
+ for dtype, zp in (
+ (kwargs['input_dtype'], qinfo[0][1]),
+ (kwargs['input2_dtype'], qinfo[1][1]),
+ ):
+ if dtype not in qTypes and zp != 0:
+ error_result = True
+ break
else:
- if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
- error_result = True
+ error_result = input_dtype not in qTypes and input_zero_point != 0
info_dict = {
- "error_name": error_name,
+ "error_name": ErrorIf.InputZeroPointNotZero,
"error_result": error_result,
- "error_reason": error_reason,
- "param_reqs": param_reqs
+ "error_reason": "Input DType not INT8 and zero point not 0",
+ "param_reqs": {"rank": None, "dtype": inputDtypes, "shape": None}
}
return info_dict
@@ -2440,6 +2560,16 @@ class TosaErrorValidator:
return info_dict
@staticmethod
+ def evDilationSmallerOne(check=False, **kwargs):
+ error_result = check and min(kwargs['dilation']) < 1
+ return {
+ "error_name": ErrorIf.DilationSmallerOne,
+ "error_reason": "At least one dilation is smaller than one",
+ "param_reqs": {"rank": None, "dtype": None, "shape": None},
+ "error_result": error_result
+ }
+
+ @staticmethod
def evScaleTrue(check=False, **kwargs):
error_name = ErrorIf.ScaleTrue
param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
@@ -3079,65 +3209,87 @@ class TosaInvalidValidator:
return True
return False
-
@staticmethod
- def ivHeightWidthSmallerZero(**kwargs):
+ def ivHeightWidthInvalid(**kwargs):
opName = kwargs['opName']
inputShapes = kwargs['shapeList']
- input = inputShapes[0]
- if not opName.endswith("pool2d"):
- filter = inputShapes[1]
+ input_shape = inputShapes[0]
args = kwargs['args']
strides = args[0]
padding = args[1]
- dilations = args[2]
- if opName.endswith("pool2d"):
- kernel = args[2]
-
- if opName.startswith('conv2d'):
- h = (
- input[1]
- - filter[1]
- - (filter[1] - 1) * (dilations[0] - 1)
- + padding[0]
- + padding[1]
- ) // strides[0] + 1
-
- w = (
- input[2]
- - filter[2]
- - (filter[2] - 1) * (dilations[1] - 1)
- + padding[2]
- + padding[3]
- ) // strides[1] + 1
- elif opName.startswith("depthwise_conv2d"):
- h = (
- input[1]
- - filter[0]
- - (filter[0] - 1) * (dilations[0] - 1)
- + padding[0]
- + padding[1]
- ) // strides[0] + 1
-
- w = (
- input[2]
- - filter[1]
- - (filter[1] - 1) * (dilations[1] - 1)
- + padding[2]
- + padding[3]
- ) // strides[1] + 1
- elif opName.endswith("pool2d"):
- h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
- w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
- else:
- assert False, "Unrecognized Op"
- if h <= 0 or w <= 0:
- # Invalid parameter combination
+ if opName.endswith("pool2d"):
+ # avg_pool2d, max_pool2d
+ kernel_shape = args[2]
+ h = (input_shape[1] + padding[0] + padding[1] + strides[0] - kernel_shape[0]) // strides[0]
+ w = (input_shape[2] + padding[2] + padding[3] + strides[1] - kernel_shape[1]) // strides[1]
+ # return True if any dimension is < 1
+ return h < 1 or w < 1
+
+ if opName.startswith("transpose_conv2d"):
+ # transpose_conv2d
+ dilations = args[2]
+ output_shape = args[3]
+ filter_shape = inputShapes[1]
+ kernel_shape = filter_shape[1:-1]
+
+ def get_out_size(in_size, stride, kernel_size, dilation, out_pad, in_pad):
+ """Calculate the transpose_conv2d output size for a dimension.
+
+ Based on the keras function deconv_output_length, in
+ https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py
+
+ Args:
+ in_size: the input size - int
+ stride: the stride - int
+ kernel_size: the kernel size - int
+ dilation: the kernel dilation - int
+ out_pad: the output padding - int
+ in_pad: the input padding - int
+
+ Returns:
+ the output size
+ """
+ dilated_kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
+ return (in_size - 1) * stride + dilated_kernel_size - 2 * in_pad + out_pad
+
+ for pad_h, pad_w in (
+ (kernel_shape[0] - 1, kernel_shape[1] - 1), # FULL padding
+ (kernel_shape[0] // 2, kernel_shape[1] // 2), # SAME padding
+ (0, 0) # VALID padding
+ ):
+ h = get_out_size(input_shape[1], strides[0], kernel_shape[0], dilations[0],
+ padding[0], pad_h)
+ w = get_out_size(input_shape[2], strides[1], kernel_shape[1], dilations[1],
+ padding[1], pad_w)
+ if output_shape[1] == h and output_shape[2] == w:
+ return False
+
+ # output shape does not match the expected shape for any padding option
return True
- return False
+
+ if "conv2d" in opName or "conv3d" in opName:
+ # conv2d, conv3d, depthwise_conv2d
+ dilations = args[2]
+ filter_shape = inputShapes[1]
+ kernel_shape = filter_shape[0:2] if opName.startswith("depthwise_conv2d") else filter_shape[1:-1]
+
+ for i in range(len(kernel_shape)):
+ dim = (
+ input_shape[i + 1]
+ - kernel_shape[i]
+ - (kernel_shape[i] - 1) * (dilations[i] - 1)
+ + padding[i * 2 + 0]
+ + padding[i * 2 + 1]
+ ) // strides[i] + 1
+ # return True if any dimension is < 1
+ if dim < 1:
+ return True
+ return False
+
+ assert False, f"Unrecognized Op: {opName}"
@staticmethod
def ivNonPositiveOutputShape(**kwargs):
@@ -3149,7 +3301,6 @@ class TosaInvalidValidator:
return False
-
class TosaTestGen:
# Maximum rank of tensor supported by test generator.
TOSA_TENSOR_MAX_RANK = 6
@@ -3617,7 +3768,7 @@ class TosaTestGen:
if input.dtype not in [DType.INT8, DType.UINT8]:
qinfo = ts.TosaSerializerQuantInfo()
qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ TosaQuantGen.getQinfo(self, input.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
)
# Invalidate Input/Output list for error if checks.
@@ -3652,60 +3803,184 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list, attr, qinfo)
return result_tens
- def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
+ def build_conv2d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
assert len(padding) == 4
result_tens = OutputShaper.conv2dOp(
- self.ser, ifm, filter, strides, padding, dilations
+ self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+ )
+
+ # Ensure new output type has correct qinfo
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ qinfo = ts.TosaSerializerQuantInfo()
+ qinfo.ConvQuantInfo(
+ TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ )
+
+ # Invalidate Input/Output list for error_if checks.
+ input_list = [ifm.name, filter.name, bias.name]
+ output_list = [result_tens.name]
+ num_operands = sum(op["operands"])
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_dtype=ifm.dtype,
+ weight_dtype=filter.dtype,
+ output_dtype=result_tens.dtype,
+ qinfo=qinfo,
+ input_list=input_list,
+ num_operands=num_operands,
+ output_list=output_list,
+ pad=padding,
+ stride=strides,
+ dilation=dilations,
+ input_shape=ifm.shape,
)
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
self.ser.addOperator(
- op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
+ op['op'], input_list, output_list, attr, qinfo
)
return result_tens
- def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, qinfo):
+ def build_conv3d(self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None):
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
- self.ser, ifm, filter, strides, padding, dilations
+ self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+ )
+
+ # Ensure new output type has correct qinfo
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ qinfo = ts.TosaSerializerQuantInfo()
+ qinfo.ConvQuantInfo(
+ TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ )
+
+ # Invalidate Input/Output list for error_if checks.
+ input_list = [ifm.name, filter.name, bias.name]
+ output_list = [result_tens.name]
+ num_operands = sum(op["operands"])
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_dtype=ifm.dtype,
+ weight_dtype=filter.dtype,
+ output_dtype=result_tens.dtype,
+ qinfo=qinfo,
+ input_list=input_list,
+ num_operands=num_operands,
+ output_list=output_list,
+ pad=padding,
+ stride=strides,
+ dilation=dilations,
+ input_shape=ifm.shape,
)
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
self.ser.addOperator(
- op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
+ op['op'], input_list, output_list, attr, qinfo
)
return result_tens
def build_transpose_conv2d(
- self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, qinfo
+ self, op, ifm, filter, bias, stride, outpad, dilation, output_shape, validator_fcns=None, error_name=None, qinfo=None
):
assert len(outpad) == 2
- result_tens = OutputShaper.transposeConv2DOp(self.ser, ifm, output_shape)
+ result_tens = OutputShaper.transposeConv2DOp(self.ser, self.rng, ifm, output_shape, error_name)
+
+ # Ensure new output type has correct qinfo
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ qinfo = ts.TosaSerializerQuantInfo()
+ qinfo.ConvQuantInfo(
+ TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ )
+
+ # Invalidate Input/Output list for error_if checks.
+ input_list = [ifm.name, filter.name, bias.name]
+ output_list = [result_tens.name]
+ num_operands = sum(op["operands"])
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_dtype=ifm.dtype,
+ weight_dtype=filter.dtype,
+ output_dtype=result_tens.dtype,
+ qinfo=qinfo,
+ input_list=input_list,
+ num_operands=num_operands,
+ output_list=output_list,
+ pad=outpad,
+ stride=stride,
+ dilation=dilation,
+ input_shape=ifm.shape,
+ )
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(outpad, stride, dilation, output_shape)
self.ser.addOperator(
- op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
+ op['op'], input_list, output_list, attr, qinfo
)
return result_tens
def build_depthwise_conv2d(
- self, op, ifm, filter, bias, strides, padding, dilations, qinfo
+ self, op, ifm, filter, bias, strides, padding, dilations, validator_fcns=None, error_name=None, qinfo=None
):
result_tens = OutputShaper.depthwiseConv2dOp(
- self.ser, ifm, filter, strides, padding, dilations
+ self.ser, self.rng, ifm, filter, strides, padding, dilations, error_name
+ )
+
+ # Ensure new output type has correct qinfo
+ if error_name == ErrorIf.WrongInputType and ifm.dtype not in (DType.INT8, DType.UINT8):
+ qinfo = ts.TosaSerializerQuantInfo()
+ qinfo.ConvQuantInfo(
+ TosaQuantGen.getQinfo(self, ifm.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ )
+
+ # Invalidate Input/Output list for error_if checks.
+ input_list = [ifm.name, filter.name, bias.name]
+ output_list = [result_tens.name]
+ num_operands = sum(op["operands"])
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_dtype=ifm.dtype,
+ weight_dtype=filter.dtype,
+ output_dtype=result_tens.dtype,
+ qinfo=qinfo,
+ input_list=input_list,
+ num_operands=num_operands,
+ output_list=output_list,
+ pad=padding,
+ stride=strides,
+ dilation=dilations,
+ input_shape=ifm.shape,
)
attr = ts.TosaSerializerAttribute()
attr.ConvAttribute(padding, strides, dilations)
self.ser.addOperator(
- op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], attr, qinfo
+ op['op'], input_list, output_list, attr, qinfo
)
return result_tens
@@ -4795,8 +5070,6 @@ class TosaTestGen:
#print(f"Filters: S {shapeFilter}, R {cleanRankFilter}, T {cleanDtypeFilter}")
for r in cleanRankFilter:
- if opName.startswith("conv3d"):
- assert r == 5, "conv3d test must have input rank == 5"
for t in cleanDtypeFilter:
for shape in cleanShapeFilter:
# Filter out by rank
@@ -4914,11 +5187,7 @@ class TosaTestGen:
else:
resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
except TypeError as e:
- print(
- "build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
- build_fcn, tens, testArgs
- )
- )
+ print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
raise e
if resultName is None:
@@ -5287,7 +5556,7 @@ class TosaTestGen:
"build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_NARROW_INT_FP,
- "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
@@ -5301,7 +5570,19 @@ class TosaTestGen:
"build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
- "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evWeightZeroPointNotZero,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evStrideSmallerOne,
+ TosaErrorValidator.evDilationSmallerOne,
+ TosaErrorValidator.evWrongRank,
+ ),
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
@@ -5312,6 +5593,19 @@ class TosaTestGen:
"build_fcn": (build_conv3d, TosaTensorGen.tgConv3D, TosaArgGen.agConv),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evWeightZeroPointNotZero,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evStrideSmallerOne,
+ TosaErrorValidator.evDilationSmallerOne,
+ TosaErrorValidator.evWrongRank,
+ ),
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
@@ -5327,7 +5621,19 @@ class TosaTestGen:
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
- "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evWeightZeroPointNotZero,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evStrideSmallerOne,
+ TosaErrorValidator.evDilationSmallerOne,
+ TosaErrorValidator.evWrongRank,
+ ),
"template": True,
},
"fully_connected": {
@@ -5356,7 +5662,7 @@ class TosaTestGen:
"rank": (4, 4),
"build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
"types": TYPE_NARROW_INT_FP,
- "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (TosaErrorValidator.evKernelSmallerOne, TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evPadSmallerZero,
TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
TosaErrorValidator.evWrongOutputList, TosaErrorValidator.evPadLargerEqualKernel, TosaErrorValidator.evPoolingOutputShapeMismatch)
@@ -5373,7 +5679,22 @@ class TosaTestGen:
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
- "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
+ "invalid_test_validators": (
+ TosaInvalidValidator.ivHeightWidthInvalid,
+ TosaInvalidValidator.ivNonPositiveOutputShape,
+ ),
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evInputZeroPointNotZero,
+ TosaErrorValidator.evWeightZeroPointNotZero,
+ TosaErrorValidator.evPadSmallerZero,
+ TosaErrorValidator.evStrideSmallerOne,
+ TosaErrorValidator.evDilationSmallerOne,
+ TosaErrorValidator.evWrongRank,
+ ),
"template": True,
},
# Activation functions
@@ -6047,7 +6368,7 @@ class OutputShaper:
return ser.addOutput(shape, outputDType)
@staticmethod
- def conv2dOp(ser, ifm, filter, strides, padding, dilations):
+ def conv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
# IFM: NHWC
# Filter: OHWI
@@ -6074,6 +6395,10 @@ class OutputShaper:
+ padding[3]
) // strides[1] + 1
+ # Avoid illegal dimensions, which can be generated in error_if tests
+ h = max(h, 1)
+ w = max(w, 1)
+
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if ifm.dtype == DType.INT8:
@@ -6082,13 +6407,20 @@ class OutputShaper:
out_dtype = DType.INT48
elif ifm.dtype == DType.FLOAT:
out_dtype = DType.FLOAT
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ out_dtype = DType.INT32
else:
- raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
+ raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
- def conv3dOp(ser, ifm, filter, strides, padding, dilations):
+ def conv3dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
# IFM: NDHWC
# Filter: ODHWI
@@ -6118,6 +6450,11 @@ class OutputShaper:
+ padding[5]
) // strides[2] + 1
+ # Avoid illegal dimensions, which can be generated in error_if tests
+ d = max(d, 1)
+ h = max(h, 1)
+ w = max(w, 1)
+
ofm_shape = [ifm.shape[0], d, h, w, filter.shape[0]]
if ifm.dtype == DType.INT8:
@@ -6126,13 +6463,20 @@ class OutputShaper:
out_dtype = DType.INT48
elif ifm.dtype == DType.FLOAT:
out_dtype = DType.FLOAT
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ out_dtype = DType.INT32
else:
- raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
+ raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@staticmethod
- def depthwiseConv2dOp(ser, ifm, filter, strides, padding, dilations):
+ def depthwiseConv2dOp(ser, rng, ifm, filter, strides, padding, dilations, error_name=None):
# IFM: NHWC
# Filter: HWCM
# OFM: NHW C*M
@@ -6152,6 +6496,10 @@ class OutputShaper:
+ padding[3]
) // strides[1] + 1
+ # Avoid illegal dimensions, which can be generated in error_if tests
+ h = max(h, 1)
+ w = max(w, 1)
+
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
if ifm.dtype == DType.INT8:
@@ -6160,8 +6508,15 @@ class OutputShaper:
out_dtype = DType.INT48
elif ifm.dtype == DType.FLOAT:
out_dtype = DType.FLOAT
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ out_dtype = DType.INT32
else:
- raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
+ raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(ofm_shape, out_dtype)
@@ -6477,14 +6832,21 @@ class OutputShaper:
return ser.addOutput(val.shape, out_dtype)
@staticmethod
- def transposeConv2DOp(ser, ifm, output_shape):
+ def transposeConv2DOp(ser, rng, ifm, output_shape, error_name=None):
if ifm.dtype == DType.INT8:
out_dtype = DType.INT32
elif ifm.dtype == DType.INT16:
out_dtype = DType.INT48
elif ifm.dtype == DType.FLOAT:
out_dtype = DType.FLOAT
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ out_dtype = DType.INT32
else:
- raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
+ raise Exception(f"Unsupported input dtype: {ifm.dtype}")
+
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = list(usableDTypes(excludes=[out_dtype]))
+ out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(output_shape, out_dtype)