diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 91 |
1 files changed, 89 insertions, 2 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index b5e71ac..8c18e67 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -88,7 +88,9 @@ class TosaTestGen: return np.int32(self.rng.integers(low=-32768, high=32768, size=shape)) elif dtype == DType.UINT16: return np.int32(self.rng.integers(low=0, high=65536, size=shape)) - elif dtype == DType.INT32: + elif ( + dtype == DType.INT32 or dtype == DType.SHAPE + ): # restricting too large value for SHAPE return np.int32( self.rng.integers(low=-(1 << 31), high=(1 << 31), size=shape) ) @@ -181,7 +183,9 @@ class TosaTestGen: low, high = (-128, 128) elif dtype == DType.INT16: low, high = (-32768, 32768) - elif dtype == DType.INT32: + elif ( + dtype == DType.INT32 or dtype == DType.SHAPE + ): # restricting too large value for SHAPE low, high = (-(1 << 31), (1 << 31)) elif dtype == DType.INT48: low, high = (-(1 << 47), (1 << 47)) @@ -1310,6 +1314,49 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens + def build_dim( + self, + op, + a, + axis, + validator_fcns=None, + error_name=None, + qinfo=None, + ): + result_tens = OutputShaper.dimOp(self.ser, self.rng, a, axis, error_name) + + # Invalidate Input/Output list for error if checks. + input_list = [a.name] + output_list = [result_tens.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_list, output_list + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + axis=axis, + input_shape=a.shape, + input_dtype=a.dtype, + output_shape=result_tens.shape, + output_dtype=result_tens.dtype, + result_tensors=[result_tens], + input_list=input_list, + output_list=output_list, + num_operands=num_operands, + ): + return None + + attr = ts.TosaSerializerAttribute() + attr.AxisAttribute(axis) + + self.ser.addOperator(op["op"], input_list, output_list, attr) + return result_tens + def build_reshape(self, op, a, newShape, validator_fcns=None, error_name=None): result_tens = OutputShaper.reshapeOp( self.ser, self.rng, a, newShape, error_name @@ -3749,6 +3796,25 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, ), }, + "dim": { + "op": Op.DIM, + "operands": (1, 0), + "build_fcn": ( + build_dim, + TosaTensorGen.tgBasic, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agAxis, + ), + "types": TYPE_FIB, + "error_if_validators": ( + TosaErrorValidator.evAxisLargerRank, + TosaErrorValidator.evAxisSmallerZero, + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + ), + }, "reshape": { "op": Op.RESHAPE, "operands": (1, 0), @@ -4665,6 +4731,27 @@ class OutputShaper: return ser.addOutput(output_shape, outputDType) @staticmethod + def dimOp(ser, rng, a, axis, error_name=None): + output_shape = [1] + + if error_name == ErrorIf.WrongOutputType: + all_dtypes = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.FP16, + DType.BF16, + ] + wrong_dtypes = list(set(all_dtypes)) + outputDType = rng.choice(wrong_dtypes) + else: + outputDType = DType.SHAPE + + return ser.addOutput(output_shape, outputDType) + + @staticmethod def reshapeOp(ser, rng, a, shape, error_name=None): output_shape = shape.copy() |