aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-10-08 21:21:05 +0100
committerEric Kunze <eric.kunze@arm.com>2021-10-18 17:13:12 +0000
commitc202521d6943a04e910e0daf5cca86dee536b5c0 (patch)
treec1af41b4c7311ca2e81716fe4074b67d3d9f5b79
parent8a0a663e7a8160674deeec17c6f8ad04c0391313 (diff)
downloadreference_model-c202521d6943a04e910e0daf5cca86dee536b5c0.tar.gz
Add negative testing support to RESCALE
* Negative tests for rescale op added Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I70aead1c6a67f159c7b7c9a05f7d5f0b92521584
-rw-r--r--verif/tosa_error_if.py2
-rw-r--r--verif/tosa_test_gen.py183
2 files changed, 170 insertions, 15 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
index 2daeb9d..5e219cc 100644
--- a/verif/tosa_error_if.py
+++ b/verif/tosa_error_if.py
@@ -40,5 +40,7 @@ class ErrorIf(object):
PadSmallerZero = "PadSmallerZero"
PadLargerEqualKernel = "PadLargerEqualKernel"
PoolingOutputShapeMismatch = "PoolingOutputShapeMismatch"
+ ScaleNotTrue = "ScaleNotTrue"
+ ScaleTrue = "ScaleTrue"
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index a03c66f..6780aa7 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -169,6 +169,9 @@ class TosaTensorGen:
pl, const = opName["operands"]
shape = testGen.makeShape(rank)
+ # Constrict dimension size for large ranks when creating WrongRank tests
+ shape = TosaErrorIfArgGen.eiRestrictDimension(shape, error_name)
+
shape_list = []
for i in range(pl + const):
shape_list.append(shape.copy())
@@ -754,21 +757,31 @@ class TosaArgGen:
# Enumerate the output types here
for dtype in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
- if inDtype == DType.UINT8 and dtype != DType.INT8:
+ if dtype in [DType.UINT8, DType.INT8] and error_name == ErrorIf.OutputZeroPointNotZero:
+ continue
+ if inDtype == DType.UINT8 and dtype != DType.INT8 and error_name != ErrorIf.WrongOutputType:
# The only output dtype for UINT8 is INT8, skip all other combinations
continue
- if inDtype != DType.INT8 and dtype == DType.UINT8:
+ if inDtype != DType.INT8 and dtype == DType.UINT8 and error_name != ErrorIf.WrongOutputType:
# The only input dtype for UINT8 is INT8, skip all other combinations
continue
+ if error_name == ErrorIf.WrongOutputType and not TosaErrorIfArgGen.eiRescaleWrongOutputType(inDtype, dtype):
+ continue
for scale32 in [False, True]:
+ if error_name == ErrorIf.ScaleTrue and scale32 == False:
+ continue
+ elif error_name == ErrorIf.ScaleNotTrue and scale32 == True:
+ continue
for double_round in [False, True]:
+ if error_name == ErrorIf.ScaleNotTrue and double_round == False:
+ continue
for per_channel in [False, True]:
- if inDtype == DType.INT48 and scale32:
+ if inDtype == DType.INT48 and scale32 and error_name != ErrorIf.ScaleTrue:
# Illegal condition. Must be scale32=False
continue
- if double_round and not scale32:
+ if double_round and not scale32 and error_name != ErrorIf.ScaleNotTrue:
# Illegal condition. ERROR_IF(!scale32 && double_round)
continue
@@ -1229,6 +1242,22 @@ class TosaErrorIfArgGen:
else:
return None, None, None
+ @staticmethod
+ def eiRescaleWrongOutputType(input_dtype, output_dtype):
+ if input_dtype == DType.INT8:
+ if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ return True
+ if input_dtype in [DType.INT16, DType.INT32]:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ return True
+ elif input_dtype == DType.INT48:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ return True
+ elif input_dtype == DType.UINT8:
+ if output_dtype != DType.INT8:
+ return True
+ return False
+
@staticmethod
def eiInvalidateInputOutputList(testGen, error_name, input_list, output_list):
@@ -1247,6 +1276,16 @@ class TosaErrorIfArgGen:
output_list = []
return input_list, output_list
+ @staticmethod
+ def eiRestrictDimension(shape, error_name):
+ # Restrict dimension size if rank is large for WrongRank Error_If
+ # This will keep the test sizes reasonably small
+ if error_name == ErrorIf.WrongRank:
+ if len(shape) > 4:
+ shape[4] = 1
+
+ return shape
+
class TosaErrorValidator:
@staticmethod
@@ -1321,6 +1360,19 @@ class TosaErrorValidator:
(input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
):
error_result = True
+ elif op['op'] == Op.RESCALE:
+ if input_dtype == DType.INT8:
+ if output_dtype not in [DType.UINT8, DType.INT8, DType.INT16, DType.INT32]:
+ error_result = True
+ if input_dtype in [DType.INT16, DType.INT32]:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ error_result = True
+ elif input_dtype == DType.INT48:
+ if output_dtype not in [DType.INT8, DType.INT16, DType.INT32]:
+ error_result = True
+ elif input_dtype == DType.UINT8:
+ if output_dtype != DType.INT8:
+ error_result = True
else:
if output_dtype != input_dtype:
error_result = True
@@ -1343,11 +1395,11 @@ class TosaErrorValidator:
rmin, rmax = op['rank']
rank_range = range(rmin, rmax + 1)
incorrect_ranks = list(set(all_ranks) - set(rank_range))
+ # Remove small incorrect ranks to avoid index errors
+ incorrect_ranks = [rank for rank in incorrect_ranks if rank > rmin]
# Set minimum incorrect rank to 3 to avoid index error
if op['op'] in [Op.RESIZE]:
incorrect_ranks = [3, 5]
- elif op['op'] in [Op.AVG_POOL2D, Op.MAX_POOL2D]:
- incorrect_ranks = [5]
error_name = ErrorIf.WrongRank
param_reqs = {"rank": incorrect_ranks, "dtype": None, "shape": None}
@@ -1358,6 +1410,9 @@ class TosaErrorValidator:
input_shape = kwargs['input_shape']
if op['op'] in [Op.RESIZE, Op.AVG_POOL2D, Op.MAX_POOL2D] and len(input_shape) != 4:
error_result = True
+ else:
+ if len(input_shape) not in rank_range:
+ error_result = True
info_dict = {
"error_name": error_name,
@@ -1739,9 +1794,14 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs['input_dtype']
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs['qinfo'].ints
- input_zero_point = qinfo[0][1]
+ if isinstance(kwargs['qinfo'], tuple):
+ qinfo = kwargs['qinfo']
+ input_zero_point = qinfo[0]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ qinfo = kwargs['qinfo'].ints
+ input_zero_point = qinfo[0][1]
+
if input_dtype not in [DType.INT8, DType.UINT8] and input_zero_point != 0:
error_result = True
@@ -1774,10 +1834,18 @@ class TosaErrorValidator:
if check:
input_dtype = kwargs['input_dtype']
- # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
- qinfo = kwargs['qinfo'].ints
- output_zero_point = qinfo[1][1]
- if input_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
+ output_dtype = kwargs['output_dtype']
+ if isinstance(kwargs['qinfo'], tuple):
+ qinfo = kwargs['qinfo']
+ output_zero_point = qinfo[1]
+ else:
+ # For use: qinfo.ints[0][1] = input_zp, qinfo.ints[1][1] = output_zp
+ qinfo = kwargs['qinfo'].ints
+ output_zero_point = qinfo[1][1]
+ if op['op'] == Op.AVG_POOL2D:
+ if input_dtype != DType.INT8 and output_zero_point != 0:
+ error_result = True
+ elif output_dtype not in [DType.INT8, DType.UINT8] and output_zero_point != 0:
error_result = True
info_dict = {
@@ -1980,6 +2048,48 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evScaleTrue(check=False, **kwargs):
+ error_name = ErrorIf.ScaleTrue
+ param_reqs = {"rank": None, "dtype": [DType.INT48], "shape": None}
+ error_result = False
+ error_reason = "Scale set to true but input type is INT48"
+
+ if check:
+ input_dtype = kwargs['input_dtype']
+ scale32 = kwargs['scale32']
+ if scale32 and input_dtype == DType.INT48:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evScaleNotTrue(check=False, **kwargs):
+ error_name = ErrorIf.ScaleNotTrue
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Scale set to false but double round set to true"
+
+ if check:
+ scale32 = kwargs['scale32']
+ double_round = kwargs['double_round']
+ if not scale32 and double_round:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
class TosaInvalidValidator:
@@ -2276,6 +2386,10 @@ class TosaTestGen:
return 32
elif t == DType.INT48:
return 48
+ elif t == DType.FLOAT:
+ return 32
+ elif t == DType.BOOL:
+ return 1
else:
raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
@@ -2809,7 +2923,7 @@ class TosaTestGen:
self.ser.addOperator(op['op'], [val.name], [result_tens.name])
return result_tens
- def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel):
+ def build_rescale(self, op, val, out_dtype, scale32, double_round, per_channel, validator_fcns, error_name):
result_tens = OutputShaper.typeConversionOp(self.ser, val, out_dtype)
if per_channel:
@@ -2826,6 +2940,11 @@ class TosaTestGen:
elif val.dtype == DType.UINT8:
input_zp = self.randInt(0, 256)
in_type_width = in_type_width + 1
+ elif error_name == ErrorIf.InputZeroPointNotZero:
+ input_zp = self.randInt(-128, 128)
+ if input_zp == 0:
+ input_zp = input_zp + self.rng.integers(1, 10)
+ in_type_width = in_type_width + 1
else:
input_zp = 0
@@ -2835,6 +2954,11 @@ class TosaTestGen:
elif out_dtype == DType.UINT8:
output_zp = self.randInt(0, 256)
out_type_width = out_type_width + 1
+ elif error_name == ErrorIf.OutputZeroPointNotZero:
+ output_zp = self.randInt(-128, 128)
+ if output_zp == 0:
+ output_zp = output_zp + self.rng.integers(1, 10)
+ out_type_width = out_type_width + 1
else:
output_zp = 0
@@ -2864,6 +2988,31 @@ class TosaTestGen:
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
+ # Invalidate Input/Output list for error if checks.
+ input_list = [val.name]
+ output_list = [result_tens.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+ input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(self, error_name, input_list, output_list)
+
+ qinfo = (input_zp, output_zp)
+ TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_dtype=val.dtype,
+ output_dtype=out_dtype,
+ input_shape=val.shape,
+ qinfo=qinfo,
+ scale32 = scale32,
+ double_round = double_round,
+ input_list=input_list,
+ output_list=output_list,
+ result_tensor=result_tens,
+ num_operands=num_operands,
+ )
+
attr = ts.TosaSerializerAttribute()
attr.RescaleAttribute(
input_zp,
@@ -2875,7 +3024,7 @@ class TosaTestGen:
per_channel,
)
- self.ser.addOperator(op['op'], [val.name], [result_tens.name], attr)
+ self.ser.addOperator(op['op'], input_list, output_list, attr)
return result_tens
def build_cond_if_const(self, op, then_tens, else_tens, cond):
@@ -4092,8 +4241,12 @@ class TosaTestGen:
"rescale": {
"op": Op.RESCALE,
"operands": (1, 0),
+ "rank": (1,4),
"build_fcn": (build_rescale, TosaTensorGen.tgBasic, TosaArgGen.agRescale),
"types": [DType.UINT8, DType.INT8, DType.INT16, DType.INT32, DType.INT48],
+ "error_if_validators": (TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evOutputZeroPointNotZero, TosaErrorValidator.evScaleTrue,
+ TosaErrorValidator.evScaleNotTrue, TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList)
},
# Custom
# Not implemented.