diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 111 |
1 files changed, 109 insertions, 2 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 5f9e2c1..2b762aa 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -213,6 +213,12 @@ class TosaTestGen: else: raise Exception(f"Unknown dtype, cannot determine width: {dtype}") + def constrictBatchSize(self, shape): + # Limit the batch size unless an explicit target shape set + if self.args.max_batch_size and not self.args.target_shapes: + shape[0] = min(shape[0], self.args.max_batch_size) + return shape + # Argument generators # Returns a list of tuples (stringDescriptor, [build_fcn_arg_list]) # Where the string descriptor is used to generate the test name and @@ -2081,6 +2087,48 @@ class TosaTestGen: return acc_out + def build_fft2d( + self, op, val1, val2, inverse, validator_fcns=None, error_name=None + ): + results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name) + + input_names = [val1.name, val2.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + + output_names = [res.name for res in results] + output_shapes = [res.shape for res in results] + output_dtypes = [res.dtype for res in results] + + input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_names, output_names + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + inverse=inverse, + input1=val1, + input2=val2, + input_shape=val1.shape, + input_dtype=val1.dtype, + output_shape=output_shapes, + output_dtype=output_dtypes, + result_tensors=results, + input_list=input_names, + output_list=output_names, + num_operands=num_operands, + ): + return None + + attr = ts.TosaSerializerAttribute() + attr.FFTAttribute(inverse) + + self.ser.addOperator(op["op"], input_names, output_names, attr) + return results + def build_rfft2d(self, op, val, validator_fcns=None, error_name=None): results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) @@ -2089,6 +2137,7 @@ class TosaTestGen: num_operands = pCount + cCount output_names = [res.name for res in results] + output_shapes = [res.shape for res in results] output_dtypes = [res.dtype for res in results] input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -2102,6 +2151,7 @@ class TosaTestGen: op=op, input_shape=val.shape, input_dtype=val.dtype, + output_shape=output_shapes, output_dtype=output_dtypes, result_tensors=results, input_list=input_names, @@ -3927,6 +3977,29 @@ class TosaTestGen: TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, + "fft2d": { + "op": Op.FFT2D, + "operands": (2, 0), + "rank": (3, 3), + "build_fcn": ( + build_fft2d, + TosaTensorGen.tgFFT2d, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agFFT2d, + ), + "types": [DType.FP32], + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evKernelNotPowerOfTwo, + TosaErrorValidator.evFFTInputShapeMismatch, + TosaErrorValidator.evFFTOutputShapeMismatch, + ), + }, "rfft2d": { "op": Op.RFFT2D, "operands": (1, 0), @@ -3946,6 +4019,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evBatchMismatch, TosaErrorValidator.evKernelNotPowerOfTwo, + TosaErrorValidator.evFFTOutputShapeMismatch, ), }, } @@ -4770,6 +4844,37 @@ class OutputShaper: return ser.addOutput(output_shape, out_dtype) @staticmethod + def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None): + outputs = [] + + assert ifm1.dtype == ifm2.dtype + input_dtype = ifm1.dtype + + if error_name != ErrorIf.FFTInputShapeMismatch: + assert ifm1.shape == ifm2.shape + + input_shape = ifm1.shape + if error_name != ErrorIf.WrongRank: + assert len(input_shape) == 3 + + output_shape = input_shape.copy() + output_dtype = input_dtype + + if error_name == ErrorIf.WrongOutputType: + excludes = [DType.FP32] + wrong_dtypes = list(usableDTypes(excludes=excludes)) + output_dtype = rng.choice(wrong_dtypes) + elif error_name == ErrorIf.BatchMismatch: + output_shape[0] += rng.integers(1, 10) + elif error_name == ErrorIf.FFTOutputShapeMismatch: + modify_dim = rng.choice([1, 2]) + output_shape[modify_dim] += rng.integers(1, 10) + + outputs.append(serializer.addOutput(output_shape, output_dtype)) + outputs.append(serializer.addOutput(output_shape, output_dtype)) + return outputs + + @staticmethod def rfft2dOp(serializer, rng, value, error_name=None): outputs = [] @@ -4785,8 +4890,10 @@ class OutputShaper: wrong_dtypes = list(usableDTypes(excludes=excludes)) output_dtype = rng.choice(wrong_dtypes) elif error_name == ErrorIf.BatchMismatch: - incorrect_batch = input_shape[0] + rng.integers(1, 10) - output_shape = [incorrect_batch, *input_shape[1:]] + output_shape[0] += rng.integers(1, 10) + elif error_name == ErrorIf.FFTOutputShapeMismatch: + modify_dim = rng.choice([1, 2]) + output_shape[modify_dim] += rng.integers(1, 10) outputs.append(serializer.addOutput(output_shape, output_dtype)) outputs.append(serializer.addOutput(output_shape, output_dtype)) |