aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py130
1 files changed, 103 insertions, 27 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c29763b..fddf942 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -255,7 +255,7 @@ class TosaTestGen:
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -293,7 +293,7 @@ class TosaTestGen:
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -333,7 +333,7 @@ class TosaTestGen:
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -378,7 +378,7 @@ class TosaTestGen:
input2=b,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -414,7 +414,7 @@ class TosaTestGen:
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -448,7 +448,7 @@ class TosaTestGen:
input_shape=a.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -487,7 +487,7 @@ class TosaTestGen:
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -523,7 +523,7 @@ class TosaTestGen:
input_dtype=a.dtype,
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -582,7 +582,7 @@ class TosaTestGen:
stride=stride,
pad=pad,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -938,7 +938,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -980,7 +980,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
output_dtype=result_tens.dtype,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1016,7 +1016,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1064,7 +1064,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1122,7 +1122,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1153,7 +1153,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1199,7 +1199,7 @@ class TosaTestGen:
input_dtype=a[0].dtype,
output_dtype=result_tens.dtype,
inputs=a,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1250,7 +1250,7 @@ class TosaTestGen:
output_dtype=result_tens.dtype,
pad=padding,
qinfo=qinfo,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1283,7 +1283,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1318,7 +1318,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1356,7 +1356,7 @@ class TosaTestGen:
perms=perms,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1391,7 +1391,7 @@ class TosaTestGen:
output_dtype=result_tens.dtype,
start=start,
size=size,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1425,7 +1425,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=a.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1474,7 +1474,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=values.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1519,7 +1519,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=values_in.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1580,7 +1580,7 @@ class TosaTestGen:
border=border,
input_list=input_list,
output_list=output_list,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
num_operands=num_operands,
):
return None
@@ -1628,7 +1628,7 @@ class TosaTestGen:
output_shape=result_tens.shape,
input_dtype=val.dtype,
output_dtype=result_tens.dtype,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1774,7 +1774,7 @@ class TosaTestGen:
double_round=double_round,
input_list=input_list,
output_list=output_list,
- result_tensor=result_tens,
+ result_tensors=[result_tens],
num_operands=num_operands,
):
return None
@@ -2083,6 +2083,38 @@ class TosaTestGen:
return acc_out
+ def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
+ results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
+
+ input_names = [val.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+
+ output_names = [res.name for res in results]
+ output_dtypes = [res.dtype for res in results]
+
+ input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_names, output_names
+ )
+
+ if not TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ input_shape=val.shape,
+ input_dtype=val.dtype,
+ output_dtype=output_dtypes,
+ result_tensors=results,
+ input_list=input_names,
+ output_list=output_names,
+ num_operands=num_operands,
+ ):
+ return None
+
+ self.ser.addOperator(op["op"], input_names, output_names)
+ return results
+
def create_filter_lists(
self, op, shapeFilter, rankFilter, dtypeFilter, testType, validator=None
):
@@ -3897,6 +3929,27 @@ class TosaTestGen:
TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
+ "rfft2d": {
+ "op": Op.RFFT2D,
+ "operands": (1, 0),
+ "rank": (3, 3),
+ "build_fcn": (
+ build_rfft2d,
+ TosaTensorGen.tgRFFT2d,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agNone,
+ ),
+ "types": [DType.FP32],
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evBatchMismatch,
+ TosaErrorValidator.evKernelNotPowerOfTwo,
+ ),
+ },
}
@@ -4717,3 +4770,26 @@ class OutputShaper:
out_dtype = rng.choice(wrong_dtypes)
return ser.addOutput(output_shape, out_dtype)
+
+ @staticmethod
+ def rfft2dOp(serializer, rng, value, error_name=None):
+ outputs = []
+
+ input_shape = value.shape
+ if error_name != ErrorIf.WrongRank:
+ assert len(input_shape) == 3
+
+ output_shape = [*input_shape[:-1], input_shape[-1] // 2 + 1]
+
+ output_dtype = value.dtype
+ if error_name == ErrorIf.WrongOutputType:
+ excludes = [DType.FP32]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
+ output_dtype = rng.choice(wrong_dtypes)
+ elif error_name == ErrorIf.BatchMismatch:
+ incorrect_batch = input_shape[0] + rng.integers(1, 10)
+ output_shape = [incorrect_batch, *input_shape[1:]]
+
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ return outputs