aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-10-13 11:30:30 +0100
committerEric Kunze <eric.kunze@arm.com>2021-11-09 15:11:29 +0000
commitbb5676f55df0d14be7e07981c39645971a587ed2 (patch)
tree504be59a83169347331acac32afcf17ab59ca656
parent25fbe521a1d64b4edc985386fc493683f1a08e60 (diff)
downloadreference_model-bb5676f55df0d14be7e07981c39645971a587ed2.tar.gz
Add ERROR_IF checks to operators without specific ERROR_IFs
* Operators implemented: sigmoid, tanh, arthmetic_right_shift, mul, table, select, equal, greater, greater_equal, concat, reverse, tile, scatter, gather, case * Note that over the course of implementation some specific ERROR_IF checks have been added for some of the above operators Change-Id: I80595e6eb9a3e5efd1cc6fd7aa28bbc2dd614980 Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Signed-off-by: Les Bell <les.bell@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r--verif/tosa_error_if.py3
-rw-r--r--verif/tosa_test_gen.py751
2 files changed, 666 insertions, 88 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 93a35b3..9fcc374 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -53,5 +53,8 @@ class ErrorIf(object):
InputSizeStartLengthMismatch = "InputSizeStartLengthMismatch"
IndexOutsideBounds = "IndexOutsideBounds"
IndexUsedTwice = "IndexUsedTwice"
+ MaxSmallerMin = "MaxSmallerMin"
+ ConcatInputRankMismatch = "ConcatInputRankMismatch"
+ ConcatInputDimMismatch = "ConcatInputDimMismatch"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 3702142..cd59898 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import numpy as np
import argparse
import sys
@@ -457,15 +456,36 @@ class TosaTensorGen:
num_tensors = testGen.randInt(0, 4)
shape_list = []
for i in range(pl + const + num_tensors):
- shape_list.append(shape.copy())
+ if error_name == ErrorIf.ConcatInputRankMismatch and i != 0:
+ remove = testGen.rng.choice([True, False])
+ wrongShape = shape.copy()
+
+ if remove and len(shape) > 1:
+ wrongShape = wrongShape[1:]
+ else:
+ wrongShape = list(wrongShape)
+ wrongShape.append(testGen.rng.integers(1, 10))
+
+ shape_list.append(wrongShape)
+ else:
+ shape_list.append(shape.copy())
return shape_list
@staticmethod
def tgConcatConstInput(testGen, shapeList, axis, error_name=None):
+ if error_name in [ErrorIf.AxisSmallerZero, ErrorIf.AxisLargerRank, ErrorIf.ConcatInputRankMismatch]:
+ return shapeList
+
# Split concat shape along axis to allow for multiple const inputs
# without making too many large tensors
if len(shapeList) == 2 or shapeList[0][axis] < len(shapeList):
+ # If axis can't be split we still need to invalidate other dimensions
+ if error_name == ErrorIf.ConcatInputDimMismatch:
+ for shape in shapeList[1:]:
+ # Negative test shapeLists are created individually for each test,
+ # so no need to copy the shape before altering it.
+ shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
return shapeList
# Create copy of shape we are going to split (so we don't alter shapeList)
@@ -482,7 +502,13 @@ class TosaTensorGen:
# Append new shape, and set remaining shape
shape[axis] = split_shape_val
new_shapeList.append(shape.copy())
- shape[axis] = remaining_length
+
+ # invalidate dimensions
+ if error_name == ErrorIf.ConcatInputDimMismatch:
+ shape[(axis + 1) % len(shape)] += testGen.rng.integers(5, 10)
+ else:
+ shape[axis] = remaining_length
+
if i == len(shapeList) - 3:
new_shapeList.append(shape.copy())
@@ -764,7 +790,9 @@ class TosaArgGen:
arg_list = []
# Enumerate the output types here
- if inDtype == DType.INT8:
+ if error_name == ErrorIf.WrongOutputType:
+ dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
+ elif inDtype == DType.INT8:
dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
elif inDtype == DType.INT16:
dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
@@ -774,6 +802,9 @@ class TosaArgGen:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FLOAT:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif error_name == ErrorIf.WrongInputType:
+ # Pick some potentially correct output type for incorrect input type
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
else:
raise Exception("Unexpected input dtype: {}".format(inDtype))
@@ -1417,6 +1448,17 @@ class TosaErrorIfArgGen:
else:
return start, size
+ @staticmethod
+ def eiCastErrorIf(testGen, input_dtype):
+ if input_dtype in [DType.BOOL, DType.FLOAT]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FLOAT]
+ elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
+ outputDType = [DType.INT48]
+ else:
+ assert True, f"input_dtype ({input_dtype}) not supported"
+ return outputDType
+
+
class TosaErrorValidator:
@staticmethod
@@ -1453,6 +1495,9 @@ class TosaErrorValidator:
allowed_input_dtypes = {t[0] if isinstance(t, list) else t for t in input_dtypes}
wrong_input_dtypes = list(all_dtypes - allowed_input_dtypes)
+ if op['op'] == Op.CLAMP:
+ wrong_input_dtypes.remove(DType.INT48)
+
error_name = ErrorIf.WrongInputType
param_reqs = {"rank": None, "dtype": wrong_input_dtypes, "shape": None}
error_result = False
@@ -1496,6 +1541,7 @@ class TosaErrorValidator:
(input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
+
elif op['op'] == Op.RESCALE:
if input_dtype == DType.INT8:
if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
@@ -1509,6 +1555,7 @@ class TosaErrorValidator:
elif input_dtype == DType.UINT8:
if output_dtype != DType.INT8:
error_result = True
+
elif op['op'] in [Op.FULLY_CONNECTED, Op.MATMUL]:
if (
(input_dtype == DType.INT8 and output_dtype != DType.INT32) or
@@ -1516,9 +1563,37 @@ class TosaErrorValidator:
(input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
+
elif op['op'] == Op.ARGMAX:
if input_dtype in [DType.INT8, DType.INT16, DType.FLOAT] and output_dtype != DType.INT32:
error_result = True
+
+ elif op['op'] == Op.MUL:
+ if input_dtype != DType.FLOAT and output_dtype != DType.INT32:
+ error_result = True
+ elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
+ error_result = True
+
+ elif op['op'] == Op.TABLE:
+ if input_dtype == DType.INT8 and output_dtype != DType.INT8:
+ error_result = True
+ elif input_dtype == DType.INT16 and output_dtype != DType.INT32:
+ error_result = True
+
+ elif op['op'] in [Op.EQUAL, Op.GREATER_EQUAL, Op.GREATER]:
+ if output_dtype != DType.BOOL:
+ error_result = True
+
+ elif op['op'] == Op.CAST:
+ if (
+ (input_dtype == DType.BOOL and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
+ or (input_dtype == DType.INT8 and output_dtype not in [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT])
+ or (input_dtype == DType.INT16 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT])
+ or (input_dtype == DType.INT32 and output_dtype not in [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT])
+ or (input_dtype == DType.FLOAT and output_dtype not in [DType.INT8, DType.INT16, DType.INT32])
+ ):
+ error_result = True
+
else:
if output_dtype != input_dtype:
error_result = True
@@ -1584,6 +1659,9 @@ class TosaErrorValidator:
op = kwargs['op']
input_list = kwargs['input_list']
num_operands = kwargs['num_operands']
+ if op['op'] in [Op.SCATTER, Op.GATHER]:
+ # SCATTER/GATHER add an indices input tensor in their build functions
+ num_operands += 1
if len(input_list) != num_operands:
error_result = True
@@ -2551,6 +2629,82 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evMaxSmallerMin(check=False, **kwargs):
+ error_name = ErrorIf.MaxSmallerMin
+ param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Max value smaller than min value"
+
+ if check:
+ max_val = kwargs['max_val']
+ min_val = kwargs['min_val']
+ if max_val < min_val:
+ error_result = True
+
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evConcatInputRankMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatInputRankMismatch
+ param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input ranks are not identical"
+
+ if check:
+ inputs = kwargs['inputs']
+ input_shape = kwargs['input_shape']
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evConcatInputDimMismatch(check=False, **kwargs):
+ error_name = ErrorIf.ConcatInputDimMismatch
+ param_reqs = {"rank": [2,4], "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Input dimensions differ on too many axes"
+
+ if check:
+ inputs = kwargs['inputs']
+ input_shape = kwargs['input_shape']
+ axis = kwargs['axis']
+
+ # Ensure rank is valid before checking dims.
+ valid_rank = True
+ for input in inputs:
+ if len(input.shape) != len(input_shape):
+ valid_rank = False
+
+ if valid_rank:
+ for input in inputs:
+ for i, dim in enumerate(input.shape):
+ if dim != input_shape[i] and axis != i:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
class TosaInvalidValidator:
@@ -2931,52 +3085,166 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
- def build_binary_nonbroadcast(self, op, a, b):
+ def build_binary_nonbroadcast(self, op, a, b, validator_fcns=None, error_name=None):
result_tens = OutputShaper.binaryNonBroadcastOp(self.ser, a, b)
self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
return result_tens
- def build_arithmetic_right_shift(self, op, a, b, round):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
+ def build_arithmetic_right_shift(self, op, a, b, round, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input1 = a,
+ input2 = b,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
attr = ts.TosaSerializerAttribute()
attr.ArithmeticRightShiftAttribute(round)
- self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_mul(self, op, a, b, shift):
- result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b)
+ def build_mul(self, op, a, b, shift, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, self.rng, a, b, error_name)
# Special for multiply:
# Force the result to INT32 for INT types
if a.dtype != DType.FLOAT:
result_tens.setDtype(DType.INT32)
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
+ outputDType = self.rng.choice(all_dtypes)
+ result_tens.setDtype(outputDType)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input1 = a,
+ input2 = b,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift)
- self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_table(self, op, a, table):
- result_tens = OutputShaper.tableOp(self.ser, a)
+ def build_table(self, op, a, table, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.tableOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
attr.TableAttribute(table)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = a.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_select(self, op, cond, a, b):
- result_tens = OutputShaper.selectOp(self.ser, cond, a, b)
- self.ser.addOperator(op['op'], [cond.name, a.name, b.name], [result_tens.name])
+ def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [cond.name, a.name, b.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = a.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list,)
return result_tens
- def build_comparison(self, op, a, b):
- result_tens = OutputShaper.binaryComparisonOp(self.ser, a, b)
- self.ser.addOperator(op['op'], [a.name, b.name], [result_tens.name])
+ def build_comparison(self, op, a, b, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.binaryComparisonOp(self.ser, self.rng, a, b, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name, b.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = a.shape,
+ input_dtype = a.dtype,
+ output_shape = result_tens.shape,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list,)
return result_tens
def build_argmax(self, op, a, axis, validator_fcns, error_name):
@@ -3206,22 +3474,56 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_clamp(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
+ def build_clamp(self, op, a, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
- attr = ts.TosaSerializerAttribute()
v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
+ if error_name == ErrorIf.MaxSmallerMin:
+ # Make sure the numbers are different to invoke this error
+ while v[0] == v[1]:
+ v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
+ max_val = min(v)
+ min_val = max(v)
+ else:
+ max_val = max(v)
+ min_val = min(v)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ max_val=max_val,
+ min_val=min_val,
+ input_shape = a.shape,
+ output_shape = result_tens.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ attr = ts.TosaSerializerAttribute()
if a.dtype == DType.FLOAT:
- attr.ClampAttribute(0, 0, min(v), max(v))
+ attr.ClampAttribute(0, 0, min_val, max_val)
else:
- attr.ClampAttribute(min(v), max(v), 0, 0)
+ attr.ClampAttribute(min_val, max_val, 0, 0)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_leaky_relu(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
+ def build_leaky_relu(self, op, a, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
@@ -3230,39 +3532,111 @@ class TosaTestGen:
return result_tens
# Needs an additional type/input
- def build_prelu(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
+ def build_prelu(self, op, a, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
self.ser.addOperator(op['op'], [a.name], [result_tens.name])
return result_tens
- def build_sigmoid(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name])
+ def build_sigmoid(self, op, a, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = a.shape,
+ output_shape = result_tens.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
- def build_tanh(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name])
+ def build_tanh(self, op, a, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = a.shape,
+ output_shape = result_tens.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
- def build_concat(self, op, *a):
- assert type(a[-1]) == int
+ def build_concat(self, op, *a, validator_fcns=None, error_name=None):
+ if error_name != ErrorIf.WrongInputType:
+ assert type(a[-1]) == int
# To store variable length list of input tensors we need to store axis along with it
axis = a[-1]
a = a[:-1]
- result_tens = OutputShaper.concatOp(self.ser, axis, *a)
-
- attr = ts.TosaSerializerAttribute()
- attr.AxisAttribute(axis)
+ result_tens = OutputShaper.concatOp(self.ser, self.rng, axis, *a, error_name=error_name)
input_tensor_names = []
for tensor in a:
input_tensor_names.append(tensor.name)
- self.ser.addOperator(op['op'], input_tensor_names, [result_tens.name], attr)
+ # Invalidate Input/Output list for error if checks.
+ input_list = input_tensor_names
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ axis=axis,
+ input_shape = a[0].shape,
+ output_shape = result_tens.shape,
+ input_dtype = a[0].dtype,
+ output_dtype = result_tens.dtype,
+ inputs=a,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ attr = ts.TosaSerializerAttribute()
+ attr.AxisAttribute(axis)
+
+
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
def build_pad(self, op, a, padding, pad_const_int, pad_const_float, validator_fcns=None, error_name=None, qinfo=None):
@@ -3331,13 +3705,36 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_reverse(self, op, a, axis):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
+ def build_reverse(self, op, a, axis, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ axis=axis,
+ input_shape = a.shape,
+ output_shape = result_tens.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
def build_transpose(self, op, a, perms, validator_fcns=None, error_name=None):
@@ -3406,16 +3803,38 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_tile(self, op, a, multiples):
- result_tens = OutputShaper.tileOp(self.ser, a, multiples)
+ def build_tile(self, op, a, multiples, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.tileOp(self.ser, self.rng, a, multiples, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = a.shape,
+ output_shape = result_tens.shape,
+ input_dtype = a.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
attr = ts.TosaSerializerAttribute()
attr.TileAttribute(multiples)
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_gather(self, op, values):
+ def build_gather(self, op, values, validator_fcns=None, error_name=None):
# Create a new indicies tensor
# here with data that doesn't exceed the dimensions of the values tensor
@@ -3429,13 +3848,35 @@ class TosaTestGen:
) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
- result_tens = OutputShaper.gatherOp(self.ser, values, indicies)
+ result_tens = OutputShaper.gatherOp(self.ser, self.rng, values, indicies, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [values.name, indicies.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = values.shape,
+ output_shape = result_tens.shape,
+ input_dtype = values.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
- self.ser.addOperator(op['op'], [values.name, indicies.name], [result_tens.name])
+ self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
- def build_scatter(self, op, values_in, input):
+ def build_scatter(self, op, values_in, input, validator_fcns=None, error_name=None):
# Create a new indicies tensor
# here with data that doesn't exceed the dimensions of the values_in tensor
@@ -3447,12 +3888,32 @@ class TosaTestGen:
) # (N, W)
indicies = self.ser.addConst(indicies_arr.shape, DType.INT32, indicies_arr)
- result_tens = OutputShaper.scatterOp(self.ser, values_in, indicies, input)
+ result_tens = OutputShaper.scatterOp(self.ser, self.rng, values_in, indicies, input, error_name)
- self.ser.addOperator(
- op['op'], [values_in.name, indicies.name, input.name], [result_tens.name]
+ # Invalidate Input/Output list for error if checks.
+ input_list = [values_in.name, indicies.name, input.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = input.shape,
+ output_shape = result_tens.shape,
+ input_dtype = input.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
)
+ self.ser.addOperator(op['op'], input_list, output_list)
+
return result_tens
@@ -3525,26 +3986,49 @@ class TosaTestGen:
self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
- def build_identityn(self, op, val, val2):
- result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
- result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
+ def build_identityn(self, op, val, val2, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, error_name)
+ result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, error_name)
self.ser.addOperator(
op, [val.name, val2.name], [result_tens.name, result_tens2.name]
)
return result_tens
- def build_const(self, op, val):
+ def build_const(self, op, val, validator_fcns=None, error_name=None):
self.ser.addOutputTensor(val)
return val
# Type Conversion
- def build_cast(self, op, val, out_dtype):
- result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
- self.ser.addOperator(op['op'], [val.name], [result_tens.name])
+ def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None):
+ result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [val.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape = val.shape,
+ output_shape = result_tens.shape,
+ input_dtype = val.dtype,
+ output_dtype = result_tens.dtype,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list)
return result_tens
def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
- result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
+ result_tens = OutputShaper.typeConversionOp(self.ser, self.rng, val, out_dtype, error_name)
if per_channel:
nc = val.shape[-1]
@@ -3997,9 +4481,9 @@ class TosaTestGen:
resultName = build_fcn(self, op, *tens, *testArgs)
else:
if qinfo is not None:
- resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
+ resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name, qinfo=qinfo)
else:
- resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
+ resultName = build_fcn(self, op, *tens, *testArgs, validator_fcns=error_if_validators, error_name=error_name)
except TypeError as e:
print(
"build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
@@ -4105,6 +4589,8 @@ class TosaTestGen:
arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
elif dtypeList[idx] == DType.INT32:
arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
+ elif error_name == ErrorIf.WrongInputType:
+ arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
else:
raise Exception("OpArithmeticRightShift: invalid input dtype")
else:
@@ -4167,6 +4653,8 @@ class TosaTestGen:
num_bits = 16
elif dtypeList[0] == DType.INT32:
num_bits = 32
+ elif error_name == ErrorIf.WrongInputType:
+ num_bits = 8
else:
raise Exception("OpMul: invalid input dtype")
@@ -4217,7 +4705,11 @@ class TosaTestGen:
if self.args.num_const_inputs_concat == 0:
count = len(shapeList)
- shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0])
+ # Ensure axis is an int
+ testArgs[0] = int(testArgs[0])
+
+ shapeList = TosaTensorGen.tgConcatConstInput(self, shapeList, testArgs[0], error_name)
+
tens.extend(
self.buildPlaceholderTensors(shapeList[0:count], dtypeList[0:count])
)
@@ -4461,18 +4953,24 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_clamp, TosaTensorGen.tgBasic, None),
"types": TYPE_NARROW_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evMaxSmallerMin, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"sigmoid": {
"op": Op.SIGMOID,
"operands": (1, 0),
"build_fcn": (build_sigmoid, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList)
},
"tanh": {
"op": Op.TANH,
"operands": (1, 0),
"build_fcn": (build_tanh, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList)
},
# Elementwise Binary Operators
"add": {
@@ -4492,6 +4990,8 @@ class TosaTestGen:
TosaArgGen.agArithmeticRightShift,
),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList)
},
"bitwise_and": {
"op": Op.BITWISE_AND,
@@ -4586,6 +5086,8 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
"types": TYPE_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList)
},
"pow": {
"op": Op.POW,
@@ -4611,6 +5113,8 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_table, TosaTensorGen.tgBasic, TosaArgGen.agTable),
"types": [DType.INT8, DType.INT16],
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList)
},
# Elementwise Unary operators
"abs": {
@@ -4709,6 +5213,8 @@ class TosaTestGen:
"operands": (3, 0),
"build_fcn": (build_select, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FIB,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
# Comparison operators
"equal": {
@@ -4716,18 +5222,24 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"greater_equal": {
"op": Op.GREATER_EQUAL,
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"greater": {
"op": Op.GREATER,
"operands": (2, 0),
"build_fcn": (build_comparison, TosaTensorGen.tgBroadcastFuzz, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
# Reduction operators
"reduce_all": {
@@ -4790,6 +5302,8 @@ class TosaTestGen:
"operands": (2, 0),
"build_fcn": (build_concat, TosaTensorGen.tgConcat, TosaArgGen.agAxis),
"types": TYPE_FIB,
+ "error_if_validators": (TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evConcatInputRankMismatch,
+ TosaErrorValidator.evConcatInputDimMismatch, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongOutputList)
},
"pad": {
"op": Op.PAD,
@@ -4814,6 +5328,8 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_reverse, TosaTensorGen.tgBasic, TosaArgGen.agAxis),
"types": TYPE_FIB,
+ "error_if_validators": (TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"slice": {
"op": Op.SLICE,
@@ -4830,6 +5346,8 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_tile, TosaTensorGen.tgBasic, TosaArgGen.agTile),
"types": TYPE_FIB,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"transpose": {
"op": Op.TRANSPOSE,
@@ -4865,6 +5383,8 @@ class TosaTestGen:
"rank": (3, 3),
"build_fcn": (build_gather, TosaTensorGen.tgBasic, None),
"types": TYPE_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"scatter": {
"op": Op.SCATTER,
@@ -4874,6 +5394,8 @@ class TosaTestGen:
"rank": (3, 3),
"build_fcn": (build_scatter, TosaTensorGen.tgScatter, None),
"types": TYPE_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
# Image operations
"resize": {
@@ -4895,6 +5417,8 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_cast, TosaTensorGen.tgBasic, TosaArgGen.agCast),
"types": [DType.FLOAT, DType.INT8, DType.INT16, DType.INT32, DType.BOOL],
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"rescale": {
"op": Op.RESCALE,
@@ -5000,7 +5524,7 @@ class OutputShaper:
return ser.addOutput(a.shape, outputDType)
@staticmethod
- def selectOp(ser, cond, a, b):
+ def selectOp(ser, rng, cond, a, b, error_name=None):
assert len(a.shape) == len(b.shape) and len(a.shape) == len(cond.shape)
assert a.dtype == b.dtype
@@ -5008,10 +5532,17 @@ class OutputShaper:
for i in range(len(a.shape)):
shape.append(max(cond.shape[i], a.shape[i], b.shape[i]))
- return ser.addOutput(shape, a.dtype)
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = a.dtype
+
+ return ser.addOutput(shape, outputDType)
@staticmethod
- def binaryComparisonOp(ser, a, b):
+ def binaryComparisonOp(ser, rng, a, b , error_name=None):
assert len(a.shape) == len(b.shape)
assert a.dtype == b.dtype
@@ -5023,8 +5554,13 @@ class OutputShaper:
else:
shape.append(a.shape[i])
- # Force the output type to bool
- return ser.addOutput(shape, DType.BOOL)
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = DType.BOOL
+
+ return ser.addOutput(shape, outputDType)
@staticmethod
def reduceOp(ser, rng, a, axis, error_name=None):
@@ -5276,18 +5812,31 @@ class OutputShaper:
return ser.addOutput(output_shape, out_dtype)
@staticmethod
- def concatOp(ser, axis, *a):
+ def concatOp(ser, rng, axis, *a, error_name=None):
input1 = a[0]
remaining_inputs = a[1:]
+ # calculate the output shape, if possible, otherwise just use the first input shape
output_shape = input1.shape.copy()
+ if not (
+ # unable to concat tensors of different ranks
+ error_name == ErrorIf.ConcatInputRankMismatch
+ # unable to concat tensors along an invalid axis
+ or error_name in [ErrorIf.AxisLargerRank, ErrorIf.AxisSmallerZero]
+ # unable to concat tensors of different dimensions
+ or error_name == ErrorIf.ConcatInputDimMismatch
+ ):
+ for tensor in remaining_inputs:
+ output_shape[axis] += tensor.shape[axis]
- output_shape[axis] = input1.shape[axis]
-
- for tensor in remaining_inputs:
- output_shape[axis] += tensor.shape[axis]
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = {DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT}
+ wrong_dtypes = list(all_dtypes - set([input1.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = input1.dtype
- return ser.addOutput(output_shape, input1.dtype)
+ return ser.addOutput(output_shape, outputDType)
@staticmethod
def padOp(ser, rng, a, padding, error_name=None):
@@ -5365,7 +5914,7 @@ class OutputShaper:
return ser.addOutput(output_shape, outputDType)
@staticmethod
- def tileOp(ser, a, multiples):
+ def tileOp(ser, rng, a, multiples, error_name=None):
output_shape = a.shape.copy()
assert len(multiples) == len(output_shape)
@@ -5373,7 +5922,14 @@ class OutputShaper:
for i in range(len(output_shape)):
output_shape[i] = a.shape[i] * multiples[i]
- return ser.addOutput(output_shape, a.dtype)
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = a.dtype
+
+ return ser.addOutput(output_shape, outputDType)
@staticmethod
def transposeOp(ser, rng, a, perms, error_name=None):
@@ -5398,17 +5954,24 @@ class OutputShaper:
return ser.addOutput(output_shape, outputDType)
@staticmethod
- def gatherOp(ser, values, indices):
+ def gatherOp(ser, rng, values, indices, error_name=None):
assert len(values.shape) == 3
assert len(indices.shape) == 2
assert values.shape[0] == indices.shape[0]
output_shape = [values.shape[0], indices.shape[1], values.shape[2]]
- return ser.addOutput(output_shape, values.dtype)
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = values.dtype
+
+ return ser.addOutput(output_shape, outputDType)
@staticmethod
- def scatterOp(ser, values_in, indices, input):
+ def scatterOp(ser, rng, values_in, indices, input, error_name=None):
assert len(values_in.shape) == 3
assert len(indices.shape) == 2
assert len(input.shape) == 3
@@ -5418,13 +5981,25 @@ class OutputShaper:
output_shape = values_in.shape
- return ser.addOutput(output_shape, values_in.dtype)
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = values_in.dtype
+
+ return ser.addOutput(output_shape, outputDType)
@staticmethod
- def tableOp(ser, input):
- # Same shape as the input, but dtype dependent on table dtype
- assert input.dtype == DType.INT16 or input.dtype == DType.INT8
+ def tableOp(ser, rng, input, error_name=None):
+ # Same shape as the input, dtype dependent on input dtype
+ if error_name != ErrorIf.WrongInputType:
+ assert input.dtype == DType.INT16 or input.dtype == DType.INT8
output_dtype = DType.INT32 if input.dtype == DType.INT16 else DType.INT8
+ if error_name == ErrorIf.WrongOutputType:
+ wrong_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes.remove(output_dtype)
+ output_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(input.shape, output_dtype)
@staticmethod
@@ -5456,7 +6031,7 @@ class OutputShaper:
return serializer.addOutput(output_dims, output_dtype)
@staticmethod
- def typeConversionOp(ser, val, out_dtype):
+ def typeConversionOp(ser, rng, val, out_dtype, error_name=None):
return ser.addOutput(val.shape, out_dtype)
@staticmethod