aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Haddon <matthew.haddon@arm.com>2021-08-25 16:40:29 +0100
committerMatthew Haddon <matthew.haddon@arm.com>2021-09-08 12:19:02 +0100
commitb724efc086b77b59944e9a69eecb260a3d3d2f26 (patch)
tree418060cb78045cdceb32d7ec47789b459e52c1f3
parent74567097e161ce9cdb0f09d474da6d9d36aa7476 (diff)
downloadreference_model-b724efc086b77b59944e9a69eecb260a3d3d2f26.tar.gz
Remove invalid tests from test generator
* Implemented InvalidValidator to remove existing invalid tests. * Removed invalid tests for resize, rescale, conv2d, depthwise_conv2d, transpose_conv2d, avg_pool2d, and max_pool2d (note default avg/max_pool never produced negative tests, but theoretically could). * Changed behaviour of computerMultiplierAndShift to produce the allowed range of shift values. Signed-off-by: Matthew Haddon <matthew.haddon@arm.com> Change-Id: I5e7b11030deb5322e2ca08fd4f4467fb02b7740d
-rw-r--r--verif/tosa_test_gen.py249
1 files changed, 155 insertions, 94 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 777c059..760ed06 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -123,10 +123,21 @@ class TosaQuantGen:
shift = shift + 1
shift = (-shift) + scaleBits
- # print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
+ #print('scalefp {} scaleBits {} m {} mult {} shift {}'.format(scaleFp, scaleBits, m, multiplier, shift))
+
+ # Adjust multiplier such that shift is in allowed value range.
+ if shift == 0:
+ multiplier = multiplier // 4
+ shift = shift + 2
+ elif shift == 1:
+ multiplier = multiplier // 2
+ shift = shift + 1
+ elif shift == 63:
+ multiplier = multiplier * 2
+ shift = shift - 1
assert multiplier <= (1 << scaleBits)
- assert shift >= 0 and shift <= 63
+ assert shift >= 2 and shift <= 62
return multiplier, shift
@@ -566,7 +577,7 @@ class TosaArgGen:
"st{}{}_kern{}{}_pad{}{}{}{}".format(
s[0], s[1], k[0], k[1], p[0], p[1], p[2], p[3]
),
- [k, s, p],
+ [s, p, k],
)
)
return arg_list
@@ -946,6 +957,126 @@ class TosaArgGen:
return arg_list
+class TosaInvalidValidator:
+
+ @staticmethod
+ def ivWrongDataTypeOrModeResize(**kwargs):
+ input_dtype = kwargs["input_dtype"]
+ args = kwargs["args"]
+ mode = args[0]
+ stride = args[1]
+ stride_fp = args[4]
+ output_dtype = args[8]
+
+ if mode == ResizeMode.BILINEAR:
+ # Invalid output data type / Invalid input datatype
+ return (
+ not (input_dtype == DType.INT8 and output_dtype == DType.INT32) or
+ not (input_dtype == DType.INT16 and output_dtype == DType.INT48) or
+ not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT) or
+ (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+ )
+ elif mode == ResizeMode.NEAREST:
+ # Invalid output data type / Invalid input datatype
+ return (
+ (input_dtype != output_dtype) or
+ (input_dtype not in [DType.INT8, DType.INT32, DType.FLOAT])
+ )
+ else:
+ # Invalid resize mode
+ return True
+
+ @staticmethod
+ def ivBadStride(**kwargs):
+ input_dtype = kwargs["input_dtype"]
+ args = kwargs["args"]
+ stride_x = args[1][0]
+ stride_y = args[1][1]
+ stride_fp_x = args[4][0]
+ stride_fp_y = args[4][1]
+
+ if input_dtype == DType.FLOAT:
+ if stride_fp_x <= 0 or stride_fp_y <= 0:
+ # Negative or zero stride
+ return True
+ else:
+ if stride_x <= 0 or stride_y <= 0:
+ # Negative or zero stride
+ return True
+ return False
+
+
+
+
+ @staticmethod
+ def ivHeightWidthSmallerZero(**kwargs):
+ opName = kwargs['opName']
+
+ inputShapes = kwargs['shapeList']
+ input = inputShapes[0]
+ if not opName.endswith("pool2d"):
+ filter = inputShapes[1]
+
+ args = kwargs['args']
+ strides = args[0]
+ padding = args[1]
+ dilations = args[2]
+ if opName.endswith("pool2d"):
+ kernel = args[2]
+
+ if opName.startswith('conv2d'):
+ h = (
+ input[1]
+ - filter[1]
+ - (filter[1] - 1) * (dilations[0] - 1)
+ + padding[0]
+ + padding[1]
+ ) // strides[0] + 1
+
+ w = (
+ input[2]
+ - filter[2]
+ - (filter[2] - 1) * (dilations[1] - 1)
+ + padding[2]
+ + padding[3]
+ ) // strides[1] + 1
+ elif opName.startswith("depthwise_conv2d"):
+ h = (
+ input[1]
+ - filter[0]
+ - (filter[0] - 1) * (dilations[0] - 1)
+ + padding[0]
+ + padding[1]
+ ) // strides[0] + 1
+
+ w = (
+ input[2]
+ - filter[1]
+ - (filter[1] - 1) * (dilations[1] - 1)
+ + padding[2]
+ + padding[3]
+ ) // strides[1] + 1
+ elif opName.endswith("pool2d"):
+ h = (input[1] + padding[0] + padding[1] + strides[0] - kernel[0]) // strides[0]
+ w = (input[2] + padding[2] + padding[3] + strides[1] - kernel[1]) // strides[1]
+ else:
+ assert False, "Unrecognized Op"
+
+ if h <= 0 or w <= 0:
+ # Invalid parameter combination
+ return True
+ return False
+
+ @staticmethod
+ def ivNonPositiveOutputShape(**kwargs):
+ args = kwargs['args']
+ output_shape = args[3]
+ if output_shape[1] <= 0 or output_shape[2] <= 0:
+ # Negative output shape
+ return True
+ return False
+
+
class TosaTestGen:
# Maximum rank of tensor supported by test generator.
@@ -1204,7 +1335,7 @@ class TosaTestGen:
self.ser.addOperator(op, [a.name], [result_tens.name], attr)
return result_tens
- def build_pool2d(self, op, input, kernel, stride, pad, qinfo=None):
+ def build_pool2d(self, op, input, stride, pad, kernel, qinfo=None):
result_tens = OutputShaper.pool2dOp(self.ser, input, kernel, stride, pad)
attr = ts.TosaSerializerAttribute()
@@ -1538,7 +1669,7 @@ class TosaTestGen:
if scale32:
pass
- # Cap the scaling at 2^15 - 1 for scale16
+ # Cap the scaling at 2^31 - 1 for scale32
scale_arr = np.clip(scale_arr, 1.0 / (1 << 31), (1 << 31) - 1)
else:
# Cap the scaling at 2^15 - 1 for scale16
@@ -1553,10 +1684,6 @@ class TosaTestGen:
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
scale_arr[i], scale32
)
- if shift_arr[i] < 2 or shift_arr[i] > 62:
- self.ser.setExpectedReturnCode(
- TosaReturnCode.UNPREDICTABLE, "OpRescale: invalid shift value"
- )
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
@@ -1780,6 +1907,19 @@ class TosaTestGen:
testList.append((opName, testStr, t, shapeList, args))
+ # Remove tests which are expected to fail but don't correlate to a ERROR_IF statement
+ if "invalid_test_validators" in op:
+ invalid_test_validators = op["invalid_test_validators"]
+ clean_testList = []
+ for test in testList:
+ for validator_fcn in invalid_test_validators:
+ remove_test = False
+ if validator_fcn(opName=test[0], input_dtype=test[2], shapeList=test[3], args=test[4]):
+ remove_test = True
+ if not remove_test:
+ clean_testList.append(test)
+ testList = clean_testList
+
# Reset RNG so both positive and negative tests are reproducible
self.resetRNG()
# Negative test loop
@@ -2112,6 +2252,7 @@ class TosaTestGen:
"build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
"qgen": TosaQuantGen.qgUnary,
"types": TYPE_NARROW_INT_FP,
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
},
# Templated operator. Filled in by createDynamicOpLists
"conv2d_TEMPLATE": {
@@ -2121,6 +2262,7 @@ class TosaTestGen:
"build_fcn": (build_conv2d, TosaTensorGen.tgConv2D, TosaArgGen.agConv2D),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV2D,
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
"template": True,
},
# Conv3d TBD
@@ -2137,6 +2279,7 @@ class TosaTestGen:
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV2D,
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,),
"template": True,
},
"fully_connected": {
@@ -2161,6 +2304,7 @@ class TosaTestGen:
"rank": (4, 4),
"build_fcn": (build_pool2d, TosaTensorGen.tgNHWC, TosaArgGen.agPooling),
"types": TYPE_NARROW_INT_FP,
+ "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthSmallerZero,)
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
@@ -2174,6 +2318,7 @@ class TosaTestGen:
),
"qgen": TosaQuantGen.qgConv,
"types": TYPE_CONV2D,
+ "invalid_test_validators": (TosaInvalidValidator.ivNonPositiveOutputShape,),
"template": True,
},
# Activation functions
@@ -2529,6 +2674,7 @@ class TosaTestGen:
"rank": (4, 4),
"build_fcn": (build_resize, TosaTensorGen.tgNHWC, TosaArgGen.agResize),
"types": [DType.INT8, DType.INT16, DType.FLOAT],
+ "invalid_test_validators": (TosaInvalidValidator.ivWrongDataTypeOrModeResize, TosaInvalidValidator.ivBadStride)
},
# Type conversion
"cast": {
@@ -2691,14 +2837,6 @@ class OutputShaper:
+ padding[3]
) // strides[1] + 1
- if h <= 0 or w <= 0:
- # Invalid test parameters?
- h = 0
- w = 0
- ser.setExpectedReturnCode(
- TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
- )
-
ofm_shape = [ifm.shape[0], h, w, filter.shape[0]]
if ifm.dtype == DType.INT8:
@@ -2733,14 +2871,6 @@ class OutputShaper:
+ padding[3]
) // strides[1] + 1
- if h <= 0 or w <= 0:
- # Invalid test parameters?
- h = 0
- w = 0
- ser.setExpectedReturnCode(
- TosaReturnCode.UNPREDICTABLE, "Invalid combination of conv2d parameters"
- )
-
ofm_shape = [ifm.shape[0], h, w, filter.shape[2] * filter.shape[3]]
if ifm.dtype == DType.INT8:
@@ -2760,14 +2890,6 @@ class OutputShaper:
h = (ifm.shape[1] + pad[0] + pad[1] + stride[0] - kernel[0]) // stride[0]
w = (ifm.shape[2] + pad[2] + pad[3] + stride[1] - kernel[1]) // stride[1]
- if h <= 0 or w <= 0:
- # Invalid test parameters?
- h = 0
- w = 0
- ser.setExpectedReturnCode(
- TosaReturnCode.UNPREDICTABLE, "Invalid combination of pool2d parameters"
- )
-
ofm_shape = [ifm.shape[0], h, w, ifm.shape[3]]
return ser.addOutput(ofm_shape, ifm.dtype)
@@ -2928,62 +3050,6 @@ class OutputShaper:
output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
- if input_dtype == DType.FLOAT:
- if stride_fp[0] <= 0 or stride_fp[1] <= 0:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Negative or zero stride"
- )
- else:
- if stride[0] <= 0 or stride[1] <= 0:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Negative or zero stride"
- )
-
- if mode == ResizeMode.BILINEAR:
- if input_dtype == DType.INT8:
- if output_dtype != DType.INT32:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid output data type"
- )
- elif input_dtype == DType.INT16:
- if output_dtype != DType.INT48:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid output data type"
- )
- elif input_dtype == DType.FLOAT:
- if output_dtype != DType.FLOAT:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid output data type"
- )
- else:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid input data type"
- )
-
- elif mode == ResizeMode.NEAREST:
- if input_dtype == DType.INT8:
- if output_dtype != DType.INT8:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid output data type"
- )
- elif input_dtype == DType.INT16:
- if output_dtype != DType.INT16:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid output data type"
- )
- elif input_dtype == DType.FLOAT:
- if output_dtype != DType.FLOAT:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid output data type"
- )
- else:
- ser.setExpectedReturnCode(
- TosaReturnCode.ERROR, "Invalid input data type"
- )
-
- else:
- ser.setExpectedReturnCode(TosaReturnCode.ERROR, "Invalid resize mode")
-
return ser.addOutput(output_dims, output_dtype)
@staticmethod
@@ -3001,9 +3067,4 @@ class OutputShaper:
else:
raise Exception("Unsupported input dtype: {}".format(ifm.dtype))
- if output_shape[1] <= 0 or output_shape[2] <= 0:
- ser.setExpectedReturnCode(
- TosaReturnCode.UNPREDICTABLE, "Negative output shape"
- )
-
return ser.addOutput(output_shape, out_dtype)