aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py111
1 files changed, 109 insertions, 2 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 5f9e2c1..2b762aa 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -213,6 +213,12 @@ class TosaTestGen:
else:
raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
+ def constrictBatchSize(self, shape):
+ # Limit the batch size unless an explicit target shape set
+ if self.args.max_batch_size and not self.args.target_shapes:
+ shape[0] = min(shape[0], self.args.max_batch_size)
+ return shape
+
# Argument generators
# Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
# Where the string descriptor is used to generate the test name and
@@ -2081,6 +2087,48 @@ class TosaTestGen:
return acc_out
+ def build_fft2d(
+ self, op, val1, val2, inverse, validator_fcns=None, error_name=None
+ ):
+ results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
+
+ input_names = [val1.name, val2.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+
+ output_names = [res.name for res in results]
+ output_shapes = [res.shape for res in results]
+ output_dtypes = [res.dtype for res in results]
+
+ input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_names, output_names
+ )
+
+ if not TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ inverse=inverse,
+ input1=val1,
+ input2=val2,
+ input_shape=val1.shape,
+ input_dtype=val1.dtype,
+ output_shape=output_shapes,
+ output_dtype=output_dtypes,
+ result_tensors=results,
+ input_list=input_names,
+ output_list=output_names,
+ num_operands=num_operands,
+ ):
+ return None
+
+ attr = ts.TosaSerializerAttribute()
+ attr.FFTAttribute(inverse)
+
+ self.ser.addOperator(op["op"], input_names, output_names, attr)
+ return results
+
def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
@@ -2089,6 +2137,7 @@ class TosaTestGen:
num_operands = pCount + cCount
output_names = [res.name for res in results]
+ output_shapes = [res.shape for res in results]
output_dtypes = [res.dtype for res in results]
input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -2102,6 +2151,7 @@ class TosaTestGen:
op=op,
input_shape=val.shape,
input_dtype=val.dtype,
+ output_shape=output_shapes,
output_dtype=output_dtypes,
result_tensors=results,
input_list=input_names,
@@ -3927,6 +3977,29 @@ class TosaTestGen:
TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
+ "fft2d": {
+ "op": Op.FFT2D,
+ "operands": (2, 0),
+ "rank": (3, 3),
+ "build_fcn": (
+ build_fft2d,
+ TosaTensorGen.tgFFT2d,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agFFT2d,
+ ),
+ "types": [DType.FP32],
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evBatchMismatch,
+ TosaErrorValidator.evKernelNotPowerOfTwo,
+ TosaErrorValidator.evFFTInputShapeMismatch,
+ TosaErrorValidator.evFFTOutputShapeMismatch,
+ ),
+ },
"rfft2d": {
"op": Op.RFFT2D,
"operands": (1, 0),
@@ -3946,6 +4019,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evKernelNotPowerOfTwo,
+ TosaErrorValidator.evFFTOutputShapeMismatch,
),
},
}
@@ -4770,6 +4844,37 @@ class OutputShaper:
return ser.addOutput(output_shape, out_dtype)
@staticmethod
+ def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
+ outputs = []
+
+ assert ifm1.dtype == ifm2.dtype
+ input_dtype = ifm1.dtype
+
+ if error_name != ErrorIf.FFTInputShapeMismatch:
+ assert ifm1.shape == ifm2.shape
+
+ input_shape = ifm1.shape
+ if error_name != ErrorIf.WrongRank:
+ assert len(input_shape) == 3
+
+ output_shape = input_shape.copy()
+ output_dtype = input_dtype
+
+ if error_name == ErrorIf.WrongOutputType:
+ excludes = [DType.FP32]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
+ output_dtype = rng.choice(wrong_dtypes)
+ elif error_name == ErrorIf.BatchMismatch:
+ output_shape[0] += rng.integers(1, 10)
+ elif error_name == ErrorIf.FFTOutputShapeMismatch:
+ modify_dim = rng.choice([1, 2])
+ output_shape[modify_dim] += rng.integers(1, 10)
+
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ return outputs
+
+ @staticmethod
def rfft2dOp(serializer, rng, value, error_name=None):
outputs = []
@@ -4785,8 +4890,10 @@ class OutputShaper:
wrong_dtypes = list(usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
- incorrect_batch = input_shape[0] + rng.integers(1, 10)
- output_shape = [incorrect_batch, *input_shape[1:]]
+ output_shape[0] += rng.integers(1, 10)
+ elif error_name == ErrorIf.FFTOutputShapeMismatch:
+ modify_dim = rng.choice([1, 2])
+ output_shape[modify_dim] += rng.integers(1, 10)
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))