aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-09-28 11:38:21 +0100
committerMatthew Haddon <matthew.haddon@arm.com>2021-10-07 17:26:25 +0100
commite4ecdb2ee8471cc713e7562fbec4118820f81a72 (patch)
treedb730f692d8aff73168faca7700d61a721fe8fbc
parenteacff9ae50b645ec9a293fd58082bacfdbe1e868 (diff)
downloadreference_model-e4ecdb2ee8471cc713e7562fbec4118820f81a72.tar.gz
Add negative testing support for ew_unary operators
* Added negative testing support for the following operators: abs, bitwise_not, ceil, clz, exp, floor, log, logical_not, negate, reciprocal, rsqrt Change-Id: Icc6f146c6407502520330678420951749ba2a9ef Signed-off-by: Matthew Haddon <matthew.haddon@arm.com>
-rw-r--r--verif/tosa_error_if.py2
-rw-r--r--verif/tosa_test_gen.py187
2 files changed, 162 insertions, 27 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index c28591d..6656645 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -30,4 +30,6 @@ class ErrorIf(object):
BatchMismatch = "BatchMismatch"
ChannelMismatch = "ChannelMismatch"
RankMismatch = "RankMismatch"
+ InputZeroPointNotZero = "InputZeroPointNotZero"
+ OutputZeroPointNotZero = "OutputZeroPointNotZero"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index f5f7fff..2478331 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -56,23 +56,38 @@ class TosaQuantGen:
pass
@staticmethod
- def getQinfo(testGen, dtype):
+ def getQinfo(testGen, dtype, error_name=None):
+
if dtype == DType.INT8:
return testGen.randInt(-128, 128)
- if dtype == DType.UINT8:
+ elif dtype == DType.UINT8:
return testGen.randInt(0, 256)
+ elif error_name in [ErrorIf.InputZeroPointNotZero, ErrorIf.OutputZeroPointNotZero]:
+ zero_point = testGen.randInt(-128, 128)
+ if zero_point == 0:
+ zero_point = 1
+ return zero_point
return 0
@staticmethod
- def qgUnary(testGen, op, dtype):
+ def qgUnary(testGen, op, dtype, error_name=None):
qinfo = ts.TosaSerializerQuantInfo()
- qinfo.UnaryQuantInfo(
- TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
- )
+ if error_name == ErrorIf.InputZeroPointNotZero:
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype, error_name), TosaQuantGen.getQinfo(testGen, dtype)
+ )
+ elif error_name == ErrorIf.OutputZeroPointNotZero:
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype, error_name)
+ )
+ else:
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
+ )
return qinfo
@staticmethod
- def qgConv(testGen, op, dtype_or_dtypeList):
+ def qgConv(testGen, op, dtype_or_dtypeList, error_name=None):
qinfo = ts.TosaSerializerQuantInfo()
if isinstance(dtype_or_dtypeList, list):
# a list of [input, weights, accumulator] dtypes
@@ -86,7 +101,7 @@ class TosaQuantGen:
return qinfo
@staticmethod
- def qgMatmul(testGen, op, dtype):
+ def qgMatmul(testGen, op, dtype, error_name=None):
qinfo = ts.TosaSerializerQuantInfo()
qinfo.MatMulQuantInfo(
TosaQuantGen.getQinfo(testGen, dtype), TosaQuantGen.getQinfo(testGen, dtype)
@@ -94,7 +109,7 @@ class TosaQuantGen:
return qinfo
@staticmethod
- def qgPad(testGen, op, dtype):
+ def qgPad(testGen, op, dtype, error_name=None):
qinfo = ts.TosaSerializerQuantInfo()
qinfo.PadQuantInfo(TosaQuantGen.getQinfo(testGen, dtype))
return qinfo
@@ -1645,6 +1660,61 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evInputZeroPointNotZero(check=False, **kwargs):
+ error_name = ErrorIf.InputZeroPointNotZero
+ param_reqs = {
+ "rank": None,
+ "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
+ "shape": None
+ }
+ error_result = False
+ error_reason = "Input DType not INT8 and zero point not 0"
+
+ if check:
+ input_dtype = kwargs['input_dtype']
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ qinfo = kwargs['qinfo'].ints
+ input_zero_point = qinfo[0][1]
+ if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evOutputZeroPointNotZero(check=False, **kwargs):
+ error_name = ErrorIf.OutputZeroPointNotZero
+ param_reqs = {
+ "rank": None,
+ "dtype": [DType.INT16, DType.INT32, DType.FLOAT],
+ "shape": None
+ }
+ error_result = False
+ error_reason = "Output DType not INT8 and zero point not 0"
+
+ if check:
+ output_dtype = kwargs['output_dtype']
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ qinfo = kwargs['qinfo'].ints
+ output_zero_point = qinfo[1][1]
+ if output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
class TosaInvalidValidator:
@@ -1949,13 +2019,47 @@ class TosaTestGen:
# The build_fcn_arg_list is expanded and passed to the operator test
# build function
- def build_unary(self, op, a, qinfo=None):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ def build_unary(self, op, a, validator_fcns=None, error_name=None, qinfo=None):
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
+
# build_placeholder returns an int, ABS/other ops does not
if isinstance(op, int):
- self.ser.addOperator(op, [a.name], [result_tens.name], None, qinfo)
- else:
- self.ser.addOperator(op['op'], [a.name], [result_tens.name], None, qinfo)
+ self.ser.addOperator(op, a.name, result_tens.name, None, qinfo)
+ return result_tens
+ elif op['op'] == Op.IDENTITY:
+ self.ser.addOperator(op['op'], a.name, result_tens.name, None, qinfo)
+ return result_tens
+
+ # Ensure new output type has correct qinfo
+ if error_name == ErrorIf.WrongOutputType:
+ if result_tens.dtype not in [DType.INT8, DType.UINT8]:
+ qinfo = ts.TosaSerializerQuantInfo()
+ qinfo.UnaryQuantInfo(
+ TosaQuantGen.getQinfo(self, a.dtype), TosaQuantGen.getQinfo(self, result_tens.dtype)
+ )
+
+ # Invalidate Input/Output list for error if checks.
+ input_list = [a.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_dtype=a.dtype,
+ output_dtype=result_tens.dtype,
+ qinfo = qinfo,
+ result_tensor = result_tens,
+ input_list=input_list,
+ output_list=output_list,
+ num_operands=num_operands,
+ )
+
+ self.ser.addOperator(op['op'], input_list, output_list, None, qinfo)
return result_tens
def build_binary_broadcast(self, op, a, b, validator_fcns, error_name=None):
@@ -2139,7 +2243,7 @@ class TosaTestGen:
return result_tens
def build_clamp(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
attr = ts.TosaSerializerAttribute()
v = [self.getRandNumberDType(a.dtype), self.getRandNumberDType(a.dtype)]
@@ -2153,7 +2257,7 @@ class TosaTestGen:
return result_tens
def build_leaky_relu(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
attr = ts.TosaSerializerAttribute()
attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
@@ -2163,18 +2267,18 @@ class TosaTestGen:
# Needs an additional type/input
def build_prelu(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
self.ser.addOperator(op['op'], [a.name], [result_tens.name])
return result_tens
def build_sigmoid(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
self.ser.addOperator(op['op'], [a.name], [result_tens.name])
return result_tens
def build_tanh(self, op, a):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
self.ser.addOperator(op['op'], [a.name], [result_tens.name])
return result_tens
@@ -2220,7 +2324,7 @@ class TosaTestGen:
return result_tens
def build_reverse(self, op, a, axis):
- result_tens = OutputShaper.unaryOp(self.ser, a)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, None)
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(axis)
@@ -2365,9 +2469,8 @@ class TosaTestGen:
return result_tens
def build_identityn(self, op, val, val2):
-
- result_tens = OutputShaper.unaryOp(self.ser, val)
- result_tens2 = OutputShaper.unaryOp(self.ser, val2)
+ result_tens = OutputShaper.unaryOp(self.ser, self.rng, val, None)
+ result_tens2 = OutputShaper.unaryOp(self.ser, self.rng, val2, None)
self.ser.addOperator(
op, [val.name, val2.name], [result_tens.name, result_tens2.name]
)
@@ -2784,7 +2887,7 @@ class TosaTestGen:
tens = self.generate_tensors(op, dtypeList, shapeList, testArgs, error_name)
if qgen is not None:
- qinfo = qgen(self, op, dtype_or_dtypeList)
+ qinfo = qgen(self, op, dtype_or_dtypeList, error_name)
else:
qinfo = None
@@ -2796,7 +2899,7 @@ class TosaTestGen:
resultName = build_fcn(self, op, *tens, *testArgs)
else:
if qinfo is not None:
- resultName = build_fcn(self, op, *tens, *testArgs, qinfo, error_if_validators, error_name)
+ resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name, qinfo)
else:
resultName = build_fcn(self, op, *tens, *testArgs, error_if_validators, error_name)
except TypeError as e:
@@ -3388,48 +3491,64 @@ class TosaTestGen:
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FI32,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"bitwise_not": {
"op": Op.BITWISE_NOT,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_INT,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"ceil": {
"op": Op.CEIL,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"clz": {
"op": Op.CLZ,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": [DType.INT32],
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"exp": {
"op": Op.EXP,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"floor": {
"op": Op.FLOOR,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"log": {
"op": Op.LOG,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"logical_not": {
"op": Op.LOGICAL_NOT,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_BOOL,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"negate": {
"op": Op.NEGATE,
@@ -3437,18 +3556,25 @@ class TosaTestGen:
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_INT_FP,
+ "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero,
+ TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList)
},
"reciprocal": {
"op": Op.RECIPROCAL,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
"rsqrt": {
"op": Op.RSQRT,
"operands": (1, 0),
"build_fcn": (build_unary, TosaTensorGen.tgBasic, None),
"types": TYPE_FP,
+ "error_if_validators": (TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
# Elementwise Ternary operators
"select": {
@@ -3703,8 +3829,15 @@ class OutputShaper:
return ser.addOutput(shape, a.dtype)
@staticmethod
- def unaryOp(ser, a):
- return ser.addOutput(a.shape, a.dtype)
+ def unaryOp(ser, rng, a, error_name=None):
+ if error_name == ErrorIf.WrongOutputType:
+ all_dtypes = [DType.INT8, DType.INT16, DType.INT32, DType.INT48, DType.FLOAT]
+ wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
+ outputDType = rng.choice(wrong_dtypes)
+ else:
+ outputDType = a.dtype
+
+ return ser.addOutput(a.shape, outputDType)
@staticmethod
def selectOp(ser, cond, a, b):