From 261b7b62b959a6c7312d810d9152069fdff69f3e Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 10 Jan 2023 14:50:31 +0000 Subject: Add RFFT2d to the reference model Includes: * RFFT2d reference implementation * TFLite framework tests * Basic TOSA tests * Serialization submodule upgrade with support for FFT/RFFT Signed-off-by: Luke Hutton Change-Id: I2a687e9cf87fb62a26160ea52439ba9830bea36e --- verif/generator/tosa_test_gen.py | 130 +++++++++++++++++++++++++++++++-------- 1 file changed, 103 insertions(+), 27 deletions(-) (limited to 'verif/generator/tosa_test_gen.py') diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c29763b..fddf942 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -255,7 +255,7 @@ class TosaTestGen: input_dtype=a.dtype, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -293,7 +293,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -333,7 +333,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -378,7 +378,7 @@ class TosaTestGen: input2=b, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -414,7 +414,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -448,7 +448,7 @@ class TosaTestGen: input_shape=a.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -487,7 +487,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -523,7 +523,7 @@ class TosaTestGen: input_dtype=a.dtype, output_shape=result_tens.shape, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -582,7 +582,7 @@ class TosaTestGen: stride=stride, pad=pad, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -938,7 +938,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -980,7 +980,7 @@ class TosaTestGen: output_shape=result_tens.shape, output_dtype=result_tens.dtype, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1016,7 +1016,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1064,7 +1064,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1122,7 +1122,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1153,7 +1153,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1199,7 +1199,7 @@ class TosaTestGen: input_dtype=a[0].dtype, output_dtype=result_tens.dtype, inputs=a, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1250,7 +1250,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, pad=padding, qinfo=qinfo, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1283,7 +1283,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1318,7 +1318,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1356,7 +1356,7 @@ class TosaTestGen: perms=perms, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1391,7 +1391,7 @@ class TosaTestGen: output_dtype=result_tens.dtype, start=start, size=size, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1425,7 +1425,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=a.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1474,7 +1474,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1519,7 +1519,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=values_in.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1580,7 +1580,7 @@ class TosaTestGen: border=border, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -1628,7 +1628,7 @@ class TosaTestGen: output_shape=result_tens.shape, input_dtype=val.dtype, output_dtype=result_tens.dtype, - result_tensor=result_tens, + result_tensors=[result_tens], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1774,7 +1774,7 @@ class TosaTestGen: double_round=double_round, input_list=input_list, output_list=output_list, - result_tensor=result_tens, + result_tensors=[result_tens], num_operands=num_operands, ): return None @@ -2083,6 +2083,38 @@ class TosaTestGen: return acc_out + def build_rfft2d(self, op, val, validator_fcns=None, error_name=None): + results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name) + + input_names = [val.name] + pCount, cCount = op["operands"] + num_operands = pCount + cCount + + output_names = [res.name for res in results] + output_dtypes = [res.dtype for res in results] + + input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList( + self, error_name, input_names, output_names + ) + + if not TosaErrorValidator.evValidateErrorIfs( + self.ser, + validator_fcns, + error_name, + op=op, + input_shape=val.shape, + input_dtype=val.dtype, + output_dtype=output_dtypes, + result_tensors=results, + input_list=input_names, + output_list=output_names, + num_operands=num_operands, + ): + return None + + self.ser.addOperator(op["op"], input_names, output_names) + return results + def create_filter_lists( self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None ): @@ -3897,6 +3929,27 @@ class TosaTestGen: TosaErrorValidator.evCondGraphOutputShapeNotSizeOne, ), }, + "rfft2d": { + "op": Op.RFFT2D, + "operands": (1, 0), + "rank": (3, 3), + "build_fcn": ( + build_rfft2d, + TosaTensorGen.tgRFFT2d, + TosaTensorValuesGen.tvgDefault, + TosaArgGen.agNone, + ), + "types": [DType.FP32], + "error_if_validators": ( + TosaErrorValidator.evWrongInputType, + TosaErrorValidator.evWrongOutputType, + TosaErrorValidator.evWrongInputList, + TosaErrorValidator.evWrongOutputList, + TosaErrorValidator.evWrongRank, + TosaErrorValidator.evBatchMismatch, + TosaErrorValidator.evKernelNotPowerOfTwo, + ), + }, } @@ -4717,3 +4770,26 @@ class OutputShaper: out_dtype = rng.choice(wrong_dtypes) return ser.addOutput(output_shape, out_dtype) + + @staticmethod + def rfft2dOp(serializer, rng, value, error_name=None): + outputs = [] + + input_shape = value.shape + if error_name != ErrorIf.WrongRank: + assert len(input_shape) == 3 + + output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1] + + output_dtype = value.dtype + if error_name == ErrorIf.WrongOutputType: + excludes = [DType.FP32] + wrong_dtypes = list(usableDTypes(excludes=excludes)) + output_dtype = rng.choice(wrong_dtypes) + elif error_name == ErrorIf.BatchMismatch: + incorrect_batch = input_shape[0] + rng.integers(1, 10) + output_shape = [incorrect_batch, *input_shape[1:]] + + outputs.append(serializer.addOutput(output_shape, output_dtype)) + outputs.append(serializer.addOutput(output_shape, output_dtype)) + return outputs -- cgit v1.2.1