aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-10-11 09:38:10 +0100
committerEric Kunze <eric.kunze@arm.com>2021-10-18 17:13:12 +0000
commitc4cf037c3944ca9f481c00cb5a7d2e96efe48d7c (patch)
tree5b22dfc335c00ceb0eab5db36a96841584aadd1f
parentc202521d6943a04e910e0daf5cca86dee536b5c0 (diff)
downloadreference_model-c4cf037c3944ca9f481c00cb5a7d2e96efe48d7c.tar.gz
Add negative testing support to fully_connected, matmul, argmax
Change-Id: I75f2a4ab6790dcbdfaec064f42f601d8f44da70b Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
-rw-r--r--verif/tosa_error_if.py3
-rw-r--r--verif/tosa_test_gen.py338
2 files changed, 308 insertions, 33 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 5e219cc..35a391e 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -31,9 +31,12 @@ class ErrorIf(object):
ChannelMismatch = "ChannelMismatch"
RankMismatch = "RankMismatch"
InputZeroPointNotZero = "InputZeroPointNotZero"
+ WeightZeroPointNotZero = "WeightZeroPointNotZero"
OutputZeroPointNotZero = "OutputZeroPointNotZero"
AxisSmallerZero = "AxisSmallerZero"
AxisLargerRank = "AxisLargerRank"
+ ArgmaxOutputShapeMismatch = "ArgmaxOutputShapeMismatch"
+ ArgmaxOutputRankMismatch = "ArgmaxOutputRankMismatch"
ShapeOfAxisNotOne = "ShapeOfAxisNotOne"
KernelSmallerOne = "KernelSmallerOne"
StrideSmallerOne = "StrideSmallerOne"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 6780aa7..1ec4a47 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -62,7 +62,7 @@ class TosaQuantGen:
return testGen.randInt(-128, 128)
elif dtype == DType.UINT8:
return testGen.randInt(0, 256)
- elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
+ elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.WeightZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
zero_point = testGen.randInt(-128, 128)
if zero_point == 0:
zero_point = 1
@@ -95,17 +95,31 @@ class TosaQuantGen:
else:
# an int, [input, weights, accumulator] dtypes are the same
dtypeList = [dtype_or_dtypeList] * 3
- input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
- weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0], error_name)
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+ elif error_name == ErrorIf.WeightZeroPointNotZero:
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1], error_name)
+ else:
+ input_zp = TosaQuantGen.getQinfo(testGen, dtypeList[0])
+ weights_zp = TosaQuantGen.getQinfo(testGen, dtypeList[1])
+
qinfo.ConvQuantInfo(input_zp, weights_zp)
return qinfo
@staticmethod
def qgMatmul(testGen, op, dtype, error_name=None):
qinfo = ts.TosaSerializerQuantInfo()
- qinfo.MatMulQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ qinfo.MatMulQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype, error_name)
)
+ else:
+ qinfo.MatMulQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ )
return qinfo
@staticmethod
@@ -196,9 +210,9 @@ class TosaTensorGen:
# Constrict the batch size?
if testGen.args.max_batch_size:
shape[0] = (shape[0] % testGen.args.max_batch_size) + 1
- # Constrict dimension size for large ranks
- if rank > 4:
- shape[4] = 1
+
+ # Constrict dimension size for large ranks when creating WrongRank tests
+ shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
shape_list = []
for i in range(pl + const):
@@ -383,9 +397,14 @@ class TosaTensorGen:
def tgFullyConnected(testGen, op, rank, error_name=None):
pl, const = op["operands"]
- assert rank == 2
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 2
input_shape = testGen.makeShape(rank)
+
+ # Constrict dimension size for large ranks when creating WrongRank tests
+ shape = TosaErrorIfArgGen.eiRestrictDimension(input_shape, error_name)
+
filter_oc = testGen.rng.integers(
low=testGen.args.tensor_shape_range[0],
high=testGen.args.tensor_shape_range[1],
@@ -401,10 +420,15 @@ class TosaTensorGen:
def tgMatmul(testGen, op, rank, error_name=None):
pl, const = op["operands"]
- assert rank == 3
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
assert pl == 2 and const == 0
a_shape = testGen.makeShape(rank)
+
+ # Constrict dimension size for large ranks when creating WrongRank tests
+ shape = TosaErrorIfArgGen.eiRestrictDimension(a_shape, error_name)
+
# Get a random number for b_oc even if target shape is defined
b_oc = np.int32(
testGen.rng.integers(
@@ -1312,13 +1336,15 @@ class TosaErrorValidator:
@staticmethod
def evWrongInputType(check=False, **kwargs):
- all_dtypes = (DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT)
+ all_dtypes = {DType.BOOL, DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
# Find the unsupported input data types
assert 'op' in kwargs
op = kwargs['op']
input_dtypes = op['types']
- wrong_input_dtypes = list(set(all_dtypes) - set(input_dtypes))
+
+ allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
+ wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
error_name = ErrorIf.WrongInputType
param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
@@ -1327,7 +1353,10 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs['input_dtype']
- if input_dtype not in input_dtypes:
+ if op['op'] == Op.FULLY_CONNECTED:
+ if input_dtype not in allowed_input_dtypes:
+ error_result = True
+ elif input_dtype not in input_dtypes:
error_result = True
info_dict = {
@@ -1373,6 +1402,16 @@ class TosaErrorValidator:
elif input_dtype == DType.UINT8:
if output_dtype != DType.INT8:
error_result = True
+ elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
+ if (
+ (input_dtype == DType.INT8 and output_dtype != DType.INT32) or
+ (input_dtype == DType.INT16 and output_dtype != DType.INT48) or
+ (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ ):
+ error_result = True
+ elif op['op'] == Op.ARGMAX:
+ if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
+ error_result = True
else:
if output_dtype != input_dtype:
error_result = True
@@ -1408,8 +1447,13 @@ class TosaErrorValidator:
if check:
input_shape = kwargs['input_shape']
+
if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
error_result = True
+ elif op['op'] == Op.FULLY_CONNECTED and len(input_shape) != 2:
+ error_result = True
+ elif op['op'] == Op.MATMUL and len(input_shape) != 3:
+ error_result = True
else:
if len(input_shape) not in rank_range:
error_result = True
@@ -1778,6 +1822,10 @@ class TosaErrorValidator:
def evInputZeroPointNotZero(check=False, **kwargs):
op = kwargs['op']
inputDtypes = op['types'].copy()
+ # If inputDtypes is a list then only the first two elements are INT8 inputs
+ if isinstance(inputDtypes, list):
+ inputDtypes = inputDtypes[2:]
+
if DType.INT8 in inputDtypes:
inputDtypes.remove(DType.INT8)
if DType.UINT8 in inputDtypes:
@@ -1802,7 +1850,50 @@ class TosaErrorValidator:
qinfo = kwargs['qinfo'].ints
input_zero_point = qinfo[0][1]
- if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
+ if op['op'] == Op.MATMUL:
+ input1_dtype = kwargs['input_dtype']
+ input2_dtype = kwargs['input2_dtype']
+ qinfo = kwargs['qinfo'].ints
+ input1_zero_point = qinfo[0][1]
+ input2_zero_point = qinfo[1][1]
+ if (input1_dtype != DType.INT8 and input1_zero_point != 0) or (input2_dtype != DType.INT8 and input2_zero_point != 0):
+ error_result = True
+ else:
+ if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evWeightZeroPointNotZero(check=False, **kwargs):
+ op = kwargs['op']
+
+ # exclude inputs with INT8 weights
+ inputDtypes = [t for t in op['types']
+ if not isinstance(t, list) or t[1] != DType.INT8]
+
+ error_name = ErrorIf.WeightZeroPointNotZero
+ param_reqs = {
+ "rank": None,
+ "dtype": inputDtypes,
+ "shape": None
+ }
+ error_result = False
+ error_reason = "Weight DType not INT8 and zero point not 0"
+
+ if check:
+ weight_dtype = kwargs['weight_dtype']
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = weight_zp
+ qinfo = kwargs['qinfo'].ints
+ weight_zero_point = qinfo[1][1]
+ if weight_dtype != DType.INT8 and weight_zero_point != 0:
error_result = True
info_dict = {
@@ -2007,6 +2098,65 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evArgmaxOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ArgmaxOutputShapeMismatch
+ param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Mismatch between output shape provided and expected output shape"
+
+ if check:
+ output_shape = kwargs['output_shape']
+ input_shape = kwargs['input_shape']
+ axis = kwargs['axis']
+
+ dimension_match = True
+ axis_shift = 0
+
+ # Check that rank is correct before trying to check dimensions
+ if (len(input_shape) - 1) == len(output_shape):
+ for i in range(len(input_shape)):
+ if i == axis:
+ axis_shift = 1
+ continue
+ if input_shape[i] != output_shape[i - axis_shift]:
+ dimension_match = False
+
+ if not dimension_match:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evArgmaxOutputRankMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ArgmaxOutputRankMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Mismatch between output shape provided and expected output shape"
+
+ if check:
+ output_shape = kwargs['output_shape']
+ input_shape = kwargs['input_shape']
+ axis = kwargs['axis']
+ valid_params = axis >= 0 and axis < len(input_shape)
+
+ if valid_params and (len(input_shape) - 1) != len(output_shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
@staticmethod
def evKernelSmallerOne(check=False, **kwargs):
@@ -2525,13 +2675,36 @@ class TosaTestGen:
self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
return result_tens
- def build_argmax(self, op, a, axis):
- result_tens = OutputShaper.argmaxOp(self.ser, a, axis)
+ def build_argmax(self, op, a, axis, validator_fcns, error_name):
+ result_tens = OutputShaper.argmaxOp(self.ser, self.rng, a, axis, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ axis=axis,
+ input_shape = a.shape,
+ input_dtype = a.dtype,
+ output_shape = result_tens.shape,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
def build_pool2d(self, op, input, stride, pad, kernel, validator_fcns=None, error_name=None, qinfo=None):
@@ -2634,17 +2807,67 @@ class TosaTestGen:
)
return result_tens
- def build_fully_connected(self, op, ifm, filter, bias, qinfo):
- result_tens = OutputShaper.fullyConnectedOp(self.ser, ifm, filter)
+ def build_fully_connected(self, op, ifm, filter, bias, validator_fcns=None, error_name=None, qinfo=None):
+ result_tens = OutputShaper.fullyConnectedOp(self.ser, self.rng, ifm, filter, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [ifm.name, filter.name, bias.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape=ifm.shape,
+ input_dtype=ifm.dtype,
+ weight_dtype=filter.dtype,
+ output_shape=result_tens.shape,
+ output_dtype=result_tens.dtype,
+ qinfo = qinfo,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
self.ser.addOperator(
- op['op'], [ifm.name, filter.name, bias.name], [result_tens.name], None, qinfo
+ op['op'], input_list, output_list, None, qinfo
)
return result_tens
- def build_matmul(self, op, a, b, qinfo):
- result_tens = OutputShaper.matmulOp(self.ser, a, b)
- self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], None, qinfo)
+ def build_matmul(self, op, a, b, validator_fcns=None, error_name=None, qinfo=None):
+ result_tens = OutputShaper.matmulOp(self.ser, self.rng, a, b, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape=a.shape,
+ input_dtype=a.dtype,
+ input2_shape=b.shape,
+ input2_dtype=b.dtype,
+ output_shape=result_tens.shape,
+ output_dtype=result_tens.dtype,
+ qinfo = qinfo,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
return result_tens
def build_reduce(self, op, a, axis, validator_fcns, error_name=None):
@@ -3246,7 +3469,6 @@ class TosaTestGen:
for validator in error_if_validators:
if validator is not None:
error_name = validator(check=False, op=op)['error_name']
- #print("error_name: ", error_name)
else:
error_name = None
@@ -3713,8 +3935,12 @@ class TosaTestGen:
"argmax": {
"op": Op.ARGMAX,
"operands": (1, 0),
+ "rank": (1, 4),
"build_fcn": (build_argmax, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_NARROW_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evArgmaxOutputRankMismatch,
+ TosaErrorValidator.evArgmaxOutputShapeMismatch, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"avg_pool2d": {
"op": Op.AVG_POOL2D,
@@ -3773,6 +3999,8 @@ class TosaTestGen:
"build_fcn": (build_fully_connected, TosaTensorGen.tgFullyConnected, None),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV,
+ "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWeightZeroPointNotZero, TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"matmul": {
"op": Op.MATMUL,
@@ -3781,6 +4009,8 @@ class TosaTestGen:
"build_fcn": (build_matmul, TosaTensorGen.tgMatmul, None),
"qgen": TosaQuantGen.qgMatmul,
"types": TYPE_NARROW_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"max_pool2d": {
"op": Op.MAX_POOL2D,
@@ -4386,10 +4616,30 @@ class OutputShaper:
return ser.addOutput(shape, outputDType)
@staticmethod
- def argmaxOp(ser, a, axis):
+ def argmaxOp(ser, rng, a, axis, error_name=None):
shape = a.shape.copy()
- del shape[axis]
- return ser.addOutput(shape, DType.INT32)
+
+ if error_name not in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank]:
+ del shape[axis]
+
+ if error_name == ErrorIf.ArgmaxOutputRankMismatch:
+ remove = rng.choice([True, False])
+ if remove and len(shape) > 1:
+ del shape[0]
+ else:
+ shape.append(1)
+ elif error_name == ErrorIf.ArgmaxOutputShapeMismatch:
+ for i in range(len(shape)):
+ shape[i] = shape[i] + rng.integers(1, 10)
+
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = DType.INT32
+
+ return ser.addOutput(shape, outputDType)
@staticmethod
def conv2dOp(ser, ifm, filter, strides, padding, dilations):
@@ -4514,7 +4764,7 @@ class OutputShaper:
def pool2dOp(ser, rng, ifm, kernel, stride, pad, error_name=None):
# input: NHWC
if stride[0] <= 0 or stride[1] <= 0 or min(pad) < 0:
- # If an incorrect stride is used set dimensions to 0, test is invalid anyway.
+ # If an incorrect stride is used set dimensions to 1, test is invalid anyway.
h = 1
w = 1
else:
@@ -4538,40 +4788,62 @@ class OutputShaper:
return ser.addOutput(ofm_shape, outputDType)
@staticmethod
- def fullyConnectedOp(ser, input, filter):
+ def fullyConnectedOp(ser, rng, input, filter, error_name=None):
# input: N, IC
# filter: OC, IC
# output: N, OC
output_shape = [input.shape[0], filter.shape[0]]
- if input.dtype == DType.INT8:
+ if error_name == ErrorIf.WrongOutputType:
+ if input.dtype == DType.INT8:
+ incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+ elif input.dtype == DType.INT16:
+ incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+ elif input.dtype == DType.FLOAT:
+ incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+ out_dtype = rng.choice(a=incorrect_types)
+ elif input.dtype == DType.INT8:
out_dtype = DType.INT32
elif input.dtype == DType.INT16:
out_dtype = DType.INT48
elif input.dtype == DType.FLOAT:
out_dtype = DType.FLOAT
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ out_dtype = DType.INT32
else:
raise Exception("Unsupported input dtype: {}".format(input.dtype))
return ser.addOutput(output_shape, out_dtype)
@staticmethod
- def matmulOp(ser, a, b):
+ def matmulOp(ser, rng, a, b, error_name=None):
# a: N, H, C
# b: N, C, W
# out: N, H, W
output_shape = [a.shape[0], a.shape[1], b.shape[2]]
- if a.dtype == DType.INT8:
+ if error_name == ErrorIf.WrongOutputType:
+ if a.dtype == DType.INT8:
+ incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT48, DType.FLOAT)
+ elif a.dtype == DType.INT16:
+ incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.FLOAT)
+ elif a.dtype == DType.FLOAT:
+ incorrect_types = (DType.INT4, DType.INT8, DType.INT16, DType.INT32, DType.INT48)
+ out_dtype = rng.choice(a=incorrect_types)
+ elif a.dtype == DType.INT8:
out_dtype = DType.INT32
elif a.dtype == DType.INT16:
out_dtype = DType.INT48
elif a.dtype == DType.FLOAT:
out_dtype = DType.FLOAT
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output dtype if input type is incorrect
+ out_dtype = DType.INT32
else:
- raise Exception("UNsupported input dtype for matmul: {}".format(a.dtype))
+ raise Exception("Unsupported input dtype for matmul: {}".format(a.dtype))
return ser.addOutput(output_shape, out_dtype)