aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-09-07 16:12:21 +0100
committerMatthew Haddon <matthew.haddon@arm.com>2021-09-27 12:03:33 +0100
commite86fd34cb3881d5a9c65c1efdbda437314fb83cb (patch)
treed14633fa4c8a93f8e0ea2af461a36b90bc499450
parent7aa69f4bfb91ff662a3d7fceaf81aa215c8e40d2 (diff)
downloadreference_model-e86fd34cb3881d5a9c65c1efdbda437314fb83cb.tar.gz
Add ERROR_IF support for RESIZE
* TosaErrorValidator implemented to produce and test for ERROR_IF conditions * RESIZE specific ERROR_IF test support added * Set rank and type parameters before test generation loop to avoid multiple checks for valid parameters * Increase output dimensions if IFM/OFM ratio smaller than 1/16 Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I430e13383d99c2e25354f53d3703fb9be973f6d4
-rw-r--r--verif/tosa_error_if.py25
-rw-r--r--verif/tosa_test_gen.py507
-rwxr-xr-xverif/tosa_verif_build_tests.py4
3 files changed, 487 insertions, 49 deletions
diff --git a/verif/tosa_error_if.py b/verif/tosa_error_if.py
new file mode 100644
index 0000000..e310804
--- /dev/null
+++ b/verif/tosa_error_if.py
@@ -0,0 +1,25 @@
+# Copyright (c) 2021, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+class ErrorIf(object):
+ MaxDimExceeded = "MaxDimExceeded"
+ StrideSmallerEqualZero = "StrideSmallerEqualZero"
+ StrideLargerEqualMax = "StrideLargerEqualMax"
+ StrideLargerDimension = "StrideLargerDimension"
+ OffsetSmallerEqualMin = "OffsetSmallerEqualMin"
+ OffsetLargerEqualMax = "OffsetLargerEqualMax"
+ ShiftNotZero = "ShiftNotZero"
+ ShiftSmallerOne = "ShiftSmallerOne"
+ ShiftLargerEleven = "ShiftLargerEleven"
+
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 1f28000..f55d892 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -42,6 +42,7 @@ sys.path.append(
import tosa_serializer as ts
from tosa_serializer import *
import tosa
+from tosa_error_if import ErrorIf
# Convenience variables to the flatc-generated types that should be enums, but aren't
DType = tosa.DType.DType()
@@ -161,7 +162,7 @@ class TosaTensorGen:
return shape_list
@staticmethod
- def tgNHWC(testGen, opName, rank):
+ def tgNHWC(testGen, opName, rank, error_name=None):
pl, const = opName["operands"]
assert rank == 4
@@ -892,7 +893,7 @@ class TosaArgGen:
return arg_list
@staticmethod
- def agResize(testGen, opName, shapeList, dtype):
+ def agResize(testGen, opName, shapeList, dtype, error_name=None):
arg_list = []
ifm_shape = shapeList[0]
@@ -918,7 +919,14 @@ class TosaArgGen:
# Randomly generate legal output dimensions and shift
# and then compute the stride and offset based on them
- output_dims = [testGen.randInt(1), testGen.randInt(1)]
+ # A output_dim of 1 will cause offset to exceed allowed range
+ # so minimum value 2 produced below
+ output_dims = [testGen.randInt(1) + 1, testGen.randInt(1) + 1]
+ while ((float(ifm_shape[1]) / float(output_dims[0])) >= 16):
+ output_dims[0] += 1
+ while ((float(ifm_shape[2]) / float(output_dims[1])) >= 16):
+ output_dims[1] += 1
+
in_center_h = (ifm_shape[1] - 1) / 2.0
in_center_w = (ifm_shape[2] - 1) / 2.0
out_center_h = (output_dims[0] - 1) / 2.0
@@ -935,6 +943,20 @@ class TosaArgGen:
offset = [0, 0]
stride_fp = [fp_stride_y, fp_stride_x]
offset_fp = [fp_offset_y, fp_offset_x]
+
+ if error_name is not None:
+ shift, stride, stride_fp, offset, offset_fp = TosaErrorIfArgGen.eiResizeErrorIf(
+ testGen,
+ error_name,
+ shapeList,
+ outputDType,
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp
+ )
+
arg_list.append(
(
"mode{}_odim{}x{}_out{}_st{:.2f}x{:.2f}_off{:.2f}x{:.2f}".format(
@@ -969,12 +991,12 @@ class TosaArgGen:
offset_x = int(round(fp_offset_x * unit))
while (
- stride_y >= 32768
- or stride_x >= 32768
- or offset_y >= 32768
- or offset_x >= 32768
- or offset_y < -32768
- or offset_x < -32768
+ stride_y >= (16 << shift)
+ or stride_x >= (16 << shift)
+ or offset_y >= (16 << shift)
+ or offset_x >= (16 << shift)
+ or offset_y <= (-16 << shift)
+ or offset_x <= (-16 << shift)
):
shift = shift - 1
unit = float(1 << shift)
@@ -989,6 +1011,19 @@ class TosaArgGen:
stride_fp = [0.0, 0.0]
offset_fp = [0.0, 0.0]
+ if error_name is not None:
+ shift, stride, stride_fp, offset, offset_fp = TosaErrorIfArgGen.eiResizeErrorIf(
+ testGen,
+ error_name,
+ shapeList,
+ outputDType,
+ shift,
+ stride,
+ stride_fp,
+ offset,
+ offset_fp
+ )
+
arg_list.append(
(
"mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}".format(
@@ -1038,6 +1073,279 @@ class TosaArgGen:
return arg_list
+class TosaErrorIfArgGen:
+
+ @staticmethod
+ def eiResizeErrorIf(testGen, error_name, shapeList, outputDType, shift, stride, stride_fp, offset, offset_fp):
+
+ if outputDType == DType.FLOAT:
+ if error_name == ErrorIf.StrideSmallerEqualZero:
+ stride_fp = testGen.rng.random(size=[2]) - 2
+ elif error_name == ErrorIf.ShiftNotZero:
+ shift = testGen.rng.integers(1, 5)
+ elif error_name == ErrorIf.StrideLargerDimension:
+ shape = shapeList[0]
+ transform_height = testGen.rng.choice([False, True])
+ if transform_height:
+ stride_fp[0] = shape[1] + testGen.rng.integers(1, 10)
+ else:
+ stride_fp[1] = shape[2] + testGen.rng.integers(1, 10)
+ else:
+ if error_name == ErrorIf.StrideSmallerEqualZero:
+ stride = np.int16(testGen.rng.integers(-1, 1, size=[2]))
+ elif error_name == ErrorIf.ShiftSmallerOne:
+ shift = testGen.rng.integers(-3, 1)
+ if shift <= 0:
+ stride = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
+ offset = [(16 >> -shift) - 1, (16 >> -shift) - 1] # avoids other ERROR_IF checks
+ else:
+ stride = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
+ offset = [(16 << shift) - 1, (16 << shift) - 1] # avoids other ERROR_IF checks
+ elif error_name == ErrorIf.ShiftLargerEleven:
+ shift = np.int16(testGen.rng.integers(12, 15))
+ elif error_name == ErrorIf.StrideLargerDimension:
+ shape = shapeList[0]
+ transform_height = testGen.rng.choice([False, True])
+ if transform_height:
+ stride[0] = shape[1] + testGen.rng.integers(1, 10)
+ else:
+ stride[1] = shape[2] + testGen.rng.integers(1, 10)
+ elif error_name == ErrorIf.StrideLargerEqualMax:
+ stride = [(16 << shift) + 1, (16 << shift) + 1]
+ elif error_name == ErrorIf.OffsetLargerEqualMax:
+ offset = [(16 << shift) + 1, (16 << shift) + 1]
+ elif error_name == ErrorIf.OffsetSmallerEqualMin:
+ offset = [(-16 << shift) - 1, (-16 << shift) - 1]
+
+ return shift, stride, stride_fp, offset, offset_fp
+
+
+class TosaErrorValidator:
+
+
+ @staticmethod
+ def evMaxDimExceeded(check=False, **kwargs):
+ error_name = ErrorIf.MaxDimExceeded
+ param_reqs = {"rank": [4,4], "dtype": [DType.INT8], "shape": [[1, 16584, 5, 1]]}
+ error_result = False
+ error_reason = "At least one maximum dimension is larger than 16384"
+
+ if check:
+ input_shape = kwargs['input_shape'].shape
+ output_shape = kwargs['output_shape'] # Note this is just (OH, OW)
+ if ((input_shape[1] > 16384) or
+ (input_shape[2] > 16384) or
+ (output_shape[0] > 16384) or
+ (output_shape[1] > 16384)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evStrideSmallerEqualZero(check=False, **kwargs):
+ error_name = ErrorIf.StrideSmallerEqualZero
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Stride value smaller than or equal zero"
+
+ if check:
+ input_dtype = kwargs['input_dtype']
+ if input_dtype == DType.FLOAT:
+ stride = kwargs['stride_fp']
+ else:
+ stride = kwargs['stride']
+
+ if min(stride) <= 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evStrideLargerEqualMax(check=False, **kwargs):
+ error_name = ErrorIf.StrideLargerEqualMax
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Stride value larger than or equal to maximum value"
+
+ if check:
+ shift = kwargs['shift']
+ input_dtype = kwargs['input_dtype']
+ stride = kwargs['stride']
+ if input_dtype in [DType.INT8, DType.INT16]:
+ if shift >= 0 and (stride[0] >= (16 << shift) or stride[1] >= (16 << shift)):
+ error_result = True
+ elif shift < 0 and (stride[0] >= (16 >> -shift) or stride[1] >= (16 >> -shift)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evStrideLargerDimension(check=False, **kwargs):
+ error_name = ErrorIf.StrideLargerDimension
+ param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
+ error_result = False
+ error_reason = "Stride value larger than or equal to H/W dimension"
+
+ if check:
+ shape = kwargs['input_shape'].shape
+ input_dtype = kwargs['input_dtype']
+ stride = kwargs['stride_fp']
+
+ if input_dtype == DType.FLOAT and (stride[0] > shape[1]) or (stride[1] > shape[2]):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evOffsetSmallerEqualMin(check=False, **kwargs):
+ error_name = ErrorIf.OffsetSmallerEqualMin
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Offset value smaller than or equal to minimum value"
+
+ if check:
+ shift = kwargs['shift']
+ input_dtype = kwargs['input_dtype']
+ if input_dtype == DType.FLOAT:
+ offset = kwargs['offset_fp']
+ else:
+ offset = kwargs['offset']
+
+ if shift >= 0 and (offset[0] <= (-16 << shift) or offset[1] <= (-16 << shift)):
+ error_result = True
+ elif shift < 0 and (offset[0] <= (-16 >> -shift) or offset[1] <= (-16 >> -shift)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evOffsetLargerEqualMax(check=False, **kwargs):
+ error_name = ErrorIf.OffsetLargerEqualMax
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Offset value larger than or equal to maximum value"
+
+ if check:
+ shift = kwargs['shift']
+ input_dtype = kwargs['input_dtype']
+ if input_dtype == DType.FLOAT:
+ offset = kwargs['offset_fp']
+ else:
+ offset = kwargs['offset']
+
+ if shift >= 0:
+ if offset[0] >= (16 << shift) or offset[1] >= (16 << shift):
+ error_result = True
+
+ if shift >= 0 and (offset[0] >= (16 << shift) or offset[1] >= (16 << shift)):
+ error_result = True
+ elif shift < 0 and (offset[0] >= (16 >> -shift) or offset[1] >= (16 >> -shift)):
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evShiftNotZero(check=False, **kwargs):
+ error_name = ErrorIf.ShiftNotZero
+ param_reqs = {"rank": None, "dtype": [DType.FLOAT], "shape": None}
+ error_result = False
+ error_reason = "Shift value must be zero for float input"
+
+ if check:
+ shift = kwargs['shift']
+ input_dtype = kwargs['input_dtype']
+ if input_dtype == DType.FLOAT and shift != 0:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
+ @staticmethod
+ def evShiftSmallerOne(check=False, **kwargs):
+ error_name = ErrorIf.ShiftSmallerOne
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Shift value smaller than one"
+
+ if check:
+ shift = kwargs['shift']
+ input_dtype = kwargs['input_dtype']
+ if shift < 1 and input_dtype != DType.FLOAT:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+ @staticmethod
+ def evShiftLargerEleven(check=False, **kwargs):
+ error_name = ErrorIf.ShiftLargerEleven
+ param_reqs = {"rank": None, "dtype": [DType.INT8, DType.INT16], "shape": None}
+ error_result = False
+ error_reason = "Shift value larger than eleven"
+
+ if check:
+ shift = kwargs['shift']
+ if shift > 11:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs
+ }
+ return info_dict
+
+
class TosaInvalidValidator:
@staticmethod
@@ -1586,6 +1894,7 @@ class TosaTestGen:
self.ser.addOperator(
op, [a.name, padding_tens.name], [result_tens.name], None, qinfo
)
+ return result_tens
def build_reshape(self, op, a, newShape):
result_tens = OutputShaper.reshapeOp(self.ser, a, newShape)
@@ -1684,6 +1993,8 @@ class TosaTestGen:
output_dims,
input_dtype,
output_dtype,
+ validator_fcns,
+ error_name = None,
):
result_tens = OutputShaper.resizeOp(
self.ser,
@@ -1697,8 +2008,34 @@ class TosaTestGen:
output_dims,
input_dtype,
output_dtype,
+ error_name
)
+ # Check ERROR_IF statements
+ for val_fcn in validator_fcns:
+ val_result = val_fcn(
+ check=True,
+ shift=shift,
+ input_dtype=input_dtype,
+ input_shape=input,
+ output_shape=output_dims,
+ offset=offset,
+ offset_fp=offset_fp,
+ stride=stride,
+ stride_fp=stride_fp)
+
+ validator_name = val_result['error_name']
+ error_result = val_result['error_result']
+ error_reason = val_result['error_reason']
+
+ if error_result:
+ if error_name == validator_name:
+ self.ser.setExpectedReturnCode(2, error_reason)
+ else:
+ print(f"Multiple ERROR_IF checks hit \nError required: {error_name}, Error_produced: {validator_name}")
+ return None # Return None to delete test if wrong ERROR_IF is hit
+
+
attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
@@ -1935,45 +2272,52 @@ class TosaTestGen:
build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
- # Generate the lists of arguments
- rmin, rmax = op["rank"]
-
# Create a default testing rank range, 1-4 inclusive to keep test sizes reasonably small.
default_test_rank_range = range(1, 5)
+ if not shapeFilter:
+ shapeFilter = [None]
+
+ # Generate the lists of arguments
+ rmin, rmax = op["rank"]
+ if rankFilter is not None:
+ cleanRankFilter = []
+ # Ensure rankFilter values are allowed by operator
+ for rank in rankFilter:
+ if rank >= rmin and rank <= rmax:
+ cleanRankFilter.append(rank)
+ rankFilter = cleanRankFilter
+ elif rankFilter is None and shapeFilter[0] is None:
+ cleanRankFilter = []
+ # Ensure default behaviour is bounded by default range or by operator, whichever is smaller.
+ rankRange = range(rmin, rmax + 1)
+ for rank in rankRange:
+ if rank >= min(default_test_rank_range) and rank <= max(default_test_rank_range):
+ cleanRankFilter.append(rank)
+ rankFilter = cleanRankFilter
+ else:
+ rankFilter = range(rmin, rmax + 1)
+
+ dtypes = op["types"]
+ if dtypeFilter is not None:
+ cleanDtypeFilter = []
+ # Ensure filtered dtypes are allowed by operator
+ for dtype in dtypeFilter:
+ if dtype in dtypes:
+ cleanDtypeFilter.append(dtype)
+ dtypeFilter = cleanDtypeFilter
+ else:
+ dtypeFilter = dtypes
# Test list consists of a tuple of:
# (opName, testNameStr, dtype, shapeList, argumentsList)
testList = []
- if not shapeFilter:
- shapeFilter = [None]
-
# Positive test loop
if testType in ['positive', 'both']:
- for r in range(rmin, rmax + 1):
-
- # Filter out the rank?
- if rankFilter is not None and r not in rankFilter:
- continue
+ for r in rankFilter:
if opName.startswith("conv3d"):
assert r == 5, "conv3d test must have input rank == 5"
- elif (
- rankFilter is None
- and shapeFilter[0] is None
- and r not in default_test_rank_range
- ):
- continue
-
- for t in op["types"]:
-
- # Filter tests based on dtype?
- if dtypeFilter is not None:
- if not (
- t in dtypeFilter
- or (isinstance(t, list) and t[0] in dtypeFilter)
- ):
- continue
-
+ for t in dtypeFilter:
# Create the placeholder and const tensors
for shape in shapeFilter:
# A None shape chooses a random shape of a given rank
@@ -1981,7 +2325,6 @@ class TosaTestGen:
# Filter out by rank
if shape is not None and len(shape) != r:
continue
-
self.setTargetShape(shape)
shapeList = tgen_fcn(self, op, r)
@@ -2003,7 +2346,7 @@ class TosaTestGen:
else:
testStr = "{}_{}_{}".format(opName, shapeStr, typeStr)
- testList.append((opName, testStr, t, shapeList, args))
+ testList.append((opName, testStr, t, None, shapeList, args))
# Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
if "invalid_test_validators" in op:
@@ -2012,21 +2355,73 @@ class TosaTestGen:
for test in testList:
for validator_fcn in invalid_test_validators:
remove_test = False
- if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[3], args=test[4]):
+ if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[4], args=test[5]):
remove_test = True
if not remove_test:
clean_testList.append(test)
testList = clean_testList
+ # Store the original filters so they can be reused if required
+ base_rankFilter = rankFilter
+ base_dtypeFilter = dtypeFilter
+ base_shapeFilter = shapeFilter
# Reset RNG so both positive and negative tests are reproducible
self.resetRNG()
+
# Negative test loop
- if testType in ['negative', 'both']:
- print("Negative tests unsupported")
+ if testType in ['negative', 'both'] and "error_if_validators" in op:
+ error_if_validators = op["error_if_validators"]
+ for validator in error_if_validators:
+ validator_info = validator()
+ error_name = validator_info['error_name']
+ error_arguments = validator_info['param_reqs']
+
+ #Set parameters as required
+ if error_arguments['rank'] != None:
+ rmin, rmax = error_arguments['rank']
+ rankFilter = range(rmin, rmax + 1)
+ else:
+ rankFilter = base_rankFilter
+ if error_arguments['dtype'] != None:
+ dtypeFilter = error_arguments['dtype']
+ else:
+ dtypeFilter = base_dtypeFilter
+ if error_arguments['shape'] != None:
+ shapes = error_arguments['shape']
+ else:
+ shapes = base_shapeFilter[:2] # Reduce number of shapes to keep test numbers small
+
+ for r in range(rmin, rmax + 1):
+ for t in dtypeFilter:
+ # Create the placeholder and const tensors
+ for shape in shapes:
+ # A None shape chooses a random shape of a given rank
+ # Filter out by rank
+ if shape is not None and len(shape) != r:
+ continue
+ self.setTargetShape(shape)
+ shapeList = tgen_fcn(self, op, r, error_name)
+ shapeStr = self.shapeStr(shapeList[0])
+ typeStr = self.typeStr(t)
+ # Argument lists consists of tuples of the (str, []) string representation and the build function argument list
+ argList = []
+ if agen_fcn:
+ argList = agen_fcn(self, opName, shapeList, t, error_name)
+ else:
+ argList = [("", [])]
+ for argStr, args in argList:
+ if argStr:
+ testStr = "{}_ERRORIF_{}_{}_{}_{}".format(
+ opName, error_name, shapeStr, typeStr, argStr
+ )
+ else:
+ testStr = "{}_ERRORIF_{}_{}_{}".format(opName, error_name, shapeStr, typeStr)
+ testList.append((opName, testStr, t, error_name, shapeList, args))
return testList
- def serializeTest(self, opName, testStr, dtype_or_dtypeList, shapeList, testArgs):
+
+ def serializeTest(self, opName, testStr, dtype_or_dtypeList, error_name, shapeList, testArgs):
try:
op = self.TOSA_OP_LIST[opName]
except KeyError as e:
@@ -2036,6 +2431,11 @@ class TosaTestGen:
self.createSerializer(opName, testStr)
build_fcn, tgen_fcn, agen_fcn = op["build_fcn"]
+ if "error_if_validators" in op:
+ error_if_validators = op["error_if_validators"]
+ else:
+ error_if_validators = None
+
pCount, cCount = op["operands"]
num_operands = pCount + cCount
@@ -2268,10 +2668,16 @@ class TosaTestGen:
qinfo = None
try:
- if qinfo is not None:
- resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
+ if error_if_validators is None:
+ if qinfo is not None:
+ resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo)
+ else:
+ resultName = build_fcn(self, op["op"], *tens, *testArgs)
else:
- resultName = build_fcn(self, op["op"], *tens, *testArgs)
+ if qinfo is not None:
+ resultName = build_fcn(self, op["op"], *tens, *testArgs, qinfo, error_if_validators, error_name)
+ else:
+ resultName = build_fcn(self, op["op"], *tens, *testArgs, error_if_validators, error_name)
except TypeError as e:
print(
"build_fcn: {}\nTensors: {}\nArgs: {}\n".format(
@@ -2280,6 +2686,9 @@ class TosaTestGen:
)
raise e
+ if resultName is None:
+ print("Invalid ERROR_IF tests created")
+
# Save the serialized test
self.serialize("test")
@@ -2846,7 +3255,10 @@ class TosaTestGen:
"rank": (4, 4),
"build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
"types": [DType.INT8, DType.INT16, DType.FLOAT],
- "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride)
+ "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride),
+ "error_if_validators": (TosaErrorValidator.evMaxDimExceeded, TosaErrorValidator.evStrideSmallerEqualZero, TosaErrorValidator.evStrideLargerDimension,
+ TosaErrorValidator.evStrideLargerEqualMax, TosaErrorValidator.evOffsetSmallerEqualMin, TosaErrorValidator.evOffsetLargerEqualMax,
+ TosaErrorValidator.evShiftNotZero, TosaErrorValidator.evShiftSmallerOne, TosaErrorValidator.evShiftLargerEleven)
},
# Type conversion
"cast": {
@@ -3262,6 +3674,7 @@ class OutputShaper:
output_dims,
input_dtype,
output_dtype,
+ error_name = None
):
output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
diff --git a/verif/tosa_verif_build_tests.py b/verif/tosa_verif_build_tests.py
index 02d1934..5b84314 100755
--- a/verif/tosa_verif_build_tests.py
+++ b/verif/tosa_verif_build_tests.py
@@ -235,10 +235,10 @@ def main():
print("{} matching tests".format(len(testList)))
results = []
- for opName, testStr, dtype, shapeList, testArgs in testList:
+ for opName, testStr, dtype, error, shapeList, testArgs in testList:
if args.verbose:
print(testStr)
- results.append(ttg.serializeTest(opName, testStr, dtype, shapeList, testArgs))
+ results.append(ttg.serializeTest(opName, testStr, dtype, error, shapeList, testArgs))
print(f"Done creating {len(results)} tests")