diff options
Diffstat (limited to 'verif/tosa_test_gen.py')
-rw-r--r-- | verif/tosa_test_gen.py | 103 |
1 files changed, 93 insertions, 10 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py index dc2d803..302e4f4 100644 --- a/verif/tosa_test_gen.py +++ b/verif/tosa_test_gen.py @@ -489,6 +489,30 @@ class TosaArgGen: return arg_list + @staticmethod + def agMul(testGen, opName, shapeList, dtype): + arg_list = [] + + if dtype is DType.INT32: + for p in range(testGen.args.num_rand_permutations): + + shift = testGen.randInt(0, 32) + + arg_list.append(('perm{}_shift{}'.format(p, shift), [shift])) + else: + arg_list.append(('shift0', [0])) + + return arg_list + + @staticmethod + def agArithmeticRightShift(testGen, opName, shapeList, dtype): + arg_list = [] + + arg_list.append(('roundTrue', [True])) + arg_list.append(('roundFalse', [False])) + + return arg_list + # Helper function for reshape. Gets some factors of a larger number. @staticmethod def getFactors(val, start=1): @@ -647,7 +671,7 @@ class TosaArgGen: arg_list.append(('mode{}_shift{}_odim{}x{}_out{}_st{}x{}_off{}x{}'.format(m, shift, output_dims[0], output_dims[1], testGen.typeStr(outputDType), stride[0], stride[1], offset[0], offset[1]), - [m, stride, offset, shift, output_dims, outputDType])) + [m, stride, offset, shift, output_dims, dtype, outputDType])) return arg_list @@ -850,7 +874,16 @@ class TosaTestGen: self.ser.addOperator(op, [a.name, b.name], [result_tens.name]) return result_tens - def build_mul(self, op, a, b): + def build_arithmetic_right_shift(self, op, a, b, round): + result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b) + + attr = ts.TosaSerializerAttribute() + attr.ArithmeticRightShiftAttribute(round) + + self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr) + return result_tens + + def build_mul(self, op, a, b, shift): result_tens = OutputShaper.binaryBroadcastOp(self.ser, a, b) # Special for multiply: @@ -858,7 +891,10 @@ class TosaTestGen: if a.dtype != DType.FLOAT: result_tens.setDtype(DType.INT32) - self.ser.addOperator(op, [a.name, b.name], [result_tens.name]) + attr = ts.TosaSerializerAttribute() + attr.MulAttribute(shift) + + self.ser.addOperator(op, [a.name, b.name], [result_tens.name], attr) return result_tens def build_table(self, op, a): @@ -1121,8 +1157,8 @@ class TosaTestGen: return result_tens - def build_resize(self, op, input, mode, stride, offset, shift, output_dims, output_dtype): - result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, output_dtype) + def build_resize(self, op, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype): + result_tens = OutputShaper.resizeOp(self.ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype) attr = ts.TosaSerializerAttribute() attr.ResizeAttribute(output_dims, stride, offset, shift, mode) @@ -1191,6 +1227,8 @@ class TosaTestGen: for i in range(nc): multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(scale_arr[i], scale32) + if shift_arr[i] < 2 or shift_arr[i] > 62: + self.ser.setExpectedFailure(True, 'OpRescale: invalid shift value') #print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) @@ -1413,8 +1451,30 @@ class TosaTestGen: # Build the random tensor operands and the test tens = [] - tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype)) - tens.extend(self.buildConstTensors(shapeList[pCount:], dtype)) + + # If test is ArithmeticRightShift, force value of operand[1] to be within [0, num_bits] + if op['op'] == Op.ARITHMETIC_RIGHT_SHIFT: + assert pCount == 2 and cCount == 0, 'Op.ArithmeticRightShift must have 2 placeholders, 0 consts' + + placeholders = [] + for idx, shape in enumerate(shapeList[:]): + if idx == 1: + if dtype == DType.INT8: + arr = np.int32(self.rng.integers(low=0, high=8, size=shape)) + elif dtype == DType.INT16: + arr = np.int32(self.rng.integers(low=0, high=16, size=shape)) + elif dtype == DType.INT32: + arr = np.int32(self.rng.integers(low=0, high=32, size=shape)) + else: + raise Exception('OpArithmeticRightShift: invalid input dtype') + else: + arr = self.getRandTensor(shapeList[0], dtype) + placeholders.append(self.ser.addPlaceholder(shape, dtype, Usage.ACTIVATION, [], arr)) + + tens.extend(placeholders) + else: + tens.extend(self.buildPlaceholderTensors(shapeList[0:pCount], dtype)) + tens.extend(self.buildConstTensors(shapeList[pCount:], dtype)) if qgen is not None: qinfo = qgen(self, op, dtype) @@ -1536,7 +1596,7 @@ class TosaTestGen: 'arithmetic_right_shift': { 'op': Op.ARITHMETIC_RIGHT_SHIFT, 'operands': (2, 0), - 'build_fcn': (build_binary_broadcast, TosaTensorGen.tgBroadcastFuzz, None), + 'build_fcn': (build_arithmetic_right_shift, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agArithmeticRightShift), 'types': TYPE_PURE_INT }, 'bitwise_and': @@ -1602,7 +1662,7 @@ class TosaTestGen: 'mul': { 'op': Op.MUL, 'operands': (2, 0), - 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, None), + 'build_fcn': (build_mul, TosaTensorGen.tgBroadcastFuzz, TosaArgGen.agMul), 'types': TYPE_PURE_INT_FP }, 'pow': @@ -2271,13 +2331,36 @@ class OutputShaper: return ser.addOutput(input.shape, DType.INT32, input.usage, input.dformat) @staticmethod - def resizeOp(ser, input, mode, stride, offset, shift, output_dims, output_dtype): + def resizeOp(ser, input, mode, stride, offset, shift, output_dims, input_dtype, output_dtype): output_dims = [input.shape[0], output_dims[0], output_dims[1], input.shape[3]] if stride[0] <= 0 or stride[1] <= 0: ser.setExpectedFailure(True, 'Negative or zero stride') + if mode == ResizeMode.BILINEAR: + if input_dtype == DType.INT8: + if output_dtype != DType.INT32: + ser.setExpectedFailure(True, 'Invalid output data type') + elif input_dtype == DType.INT16: + if output_dtype != DType.INT48: + ser.setexpectedfailure(true, 'Invalid output data type') + else: + ser.setexpectedfailure(true, 'Invalid input data type') + + elif mode == ResizeMode.NEAREST: + if input_dtype == DType.INT8: + if output_dtype != DType.INT8: + ser.setExpectedFailure(True, 'Invalid output data type') + elif input_dtype == DType.INT16: + if output_dtype != DType.INT16: + ser.setexpectedfailure(true, 'Invalid output data type') + else: + ser.setexpectedfailure(true, 'Invalid input data type') + + else: + ser.setexpectedfailure(true, 'Invalid resize mode') + return ser.addOutput(output_dims, output_dtype, input.usage, input.dformat) @staticmethod |