aboutsummaryrefslogtreecommitdiff
path: root/verif/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r--verif/tosa_test_gen.py103
1 files changed, 93 insertions, 10 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index dc2d803..302e4f4 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -489,6 +489,30 @@ class TosaArgGen:
return arg_list
+ @staticmethod
+ def agMul(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ if dtype is DType.INT32:
+ for p in range(testGen.args.num_rand_permutations):
+
+ shift = testGen.randInt(0, 32)
+
+ arg_list.append(('perm{}_shift{}'.format(p, shift), [shift]))
+ else:
+ arg_list.append(('shift0', [0]))
+
+ return arg_list
+
+ @staticmethod
+ def agArithmeticRightShift(testGen, opName, shapeList, dtype):
+ arg_list = []
+
+ arg_list.append(('roundTrue', [True]))
+ arg_list.append(('roundFalse', [False]))
+
+ return arg_list
+
# Helper function for reshape. Gets some factors of a larger number.
@staticmethod
def getFactors(val, start=1):
@@ -647,7 +671,7 @@ class TosaArgGen:
arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1],
testGen.typeStr(outputDType), stride[0], stride[1],
offset[0], offset[1]),
- [m, stride, offset, shift, output_dims, outputDType]))
+ [m, stride, offset, shift, output_dims, dtype, outputDType]))
return arg_list
@@ -850,7 +874,16 @@ class TosaTestGen:
self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
return result_tens
- def build_mul(self, op, a, b):
+ def build_arithmetic_right_shift(self, op, a, b, round):
+ result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
+
+ attr = ts.TosaSerializerAttribute()
+ attr.ArithmeticRightShiftAttribute(round)
+
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
+ return result_tens
+
+ def build_mul(self, op, a, b, shift):
result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b)
# Special for multiply:
@@ -858,7 +891,10 @@ class TosaTestGen:
if a.dtype != DType.FLOAT:
result_tens.setDtype(DType.INT32)
- self.ser.addOperator(op, [a.name, b.name], [result_tens.name])
+ attr = ts.TosaSerializerAttribute()
+ attr.MulAttribute(shift)
+
+ self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr)
return result_tens
def build_table(self, op, a):
@@ -1121,8 +1157,8 @@ class TosaTestGen:
return result_tens
- def build_resize(self, op, input, mode, stride, offset, shift, output_dims, output_dtype):
- result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, output_dtype)
+ def build_resize(self, op, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
+ result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype)
attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(output_dims, stride, offset, shift, mode)
@@ -1191,6 +1227,8 @@ class TosaTestGen:
for i in range(nc):
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(scale_arr[i], scale32)
+ if shift_arr[i] < 2 or shift_arr[i] > 62:
+ self.ser.setExpectedFailure(True, 'OpRescale: invalid shift value')
#print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
@@ -1413,8 +1451,30 @@ class TosaTestGen:
# Build the random tensor operands and the test
tens = []
- tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype))
- tens.extend(self.buildConstTensors(shapeList[pCount:], dtype))
+
+ # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits]
+ if op['op'] == Op.ARITHMETIC_RIGHT_SHIFT:
+ assert pCount == 2 and cCount == 0, 'Op.ArithmeticRightShift must have 2 placeholders, 0 consts'
+
+ placeholders = []
+ for idx, shape in enumerate(shapeList[:]):
+ if idx == 1:
+ if dtype == DType.INT8:
+ arr = np.int32(self.rng.integers(low=0, high=8, size=shape))
+ elif dtype == DType.INT16:
+ arr = np.int32(self.rng.integers(low=0, high=16, size=shape))
+ elif dtype == DType.INT32:
+ arr = np.int32(self.rng.integers(low=0, high=32, size=shape))
+ else:
+ raise Exception('OpArithmeticRightShift: invalid input dtype')
+ else:
+ arr = self.getRandTensor(shapeList[0], dtype)
+ placeholders.append(self.ser.addPlaceholder(shape, dtype, Usage.ACTIVATION, [], arr))
+
+ tens.extend(placeholders)
+ else:
+ tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype))
+ tens.extend(self.buildConstTensors(shapeList[pCount:], dtype))
if qgen is not None:
qinfo = qgen(self, op, dtype)
@@ -1536,7 +1596,7 @@ class TosaTestGen:
'arithmetic_right_shift':
{ 'op': Op.ARITHMETIC_RIGHT_SHIFT,
'operands': (2, 0),
- 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None),
+ 'build_fcn': (build_arithmetic_right_shift, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agArithmeticRightShift),
'types': TYPE_PURE_INT },
'bitwise_and':
@@ -1602,7 +1662,7 @@ class TosaTestGen:
'mul':
{ 'op': Op.MUL,
'operands': (2, 0),
- 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, None),
+ 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul),
'types': TYPE_PURE_INT_FP },
'pow':
@@ -2271,13 +2331,36 @@ class OutputShaper:
return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat)
@staticmethod
- def resizeOp(ser, input, mode, stride, offset, shift, output_dims, output_dtype):
+ def resizeOp(ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype):
output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]]
if stride[0] <= 0 or stride[1] <= 0:
ser.setExpectedFailure(True, 'Negative or zero stride')
+ if mode == ResizeMode.BILINEAR:
+ if input_dtype == DType.INT8:
+ if output_dtype != DType.INT32:
+ ser.setExpectedFailure(True, 'Invalid output data type')
+ elif input_dtype == DType.INT16:
+ if output_dtype != DType.INT48:
+ ser.setexpectedfailure(true, 'Invalid output data type')
+ else:
+ ser.setexpectedfailure(true, 'Invalid input data type')
+
+ elif mode == ResizeMode.NEAREST:
+ if input_dtype == DType.INT8:
+ if output_dtype != DType.INT8:
+ ser.setExpectedFailure(True, 'Invalid output data type')
+ elif input_dtype == DType.INT16:
+ if output_dtype != DType.INT16:
+ ser.setexpectedfailure(true, 'Invalid output data type')
+ else:
+ ser.setexpectedfailure(true, 'Invalid input data type')
+
+ else:
+ ser.setexpectedfailure(true, 'Invalid resize mode')
+
return ser.addOutput(output_dims, output_dtype, input.usage, input.dformat)
@staticmethod