aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Hutton <luke.hutton@arm.com>2023-02-06 14:54:18 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-10 20:01:04 +0000
commit5728713fca4f6e2dff60dad3689e471545e563d2 (patch)
tree848421100f82a33ff57ee3205c369ad75737f7d3
parentc1e25f5755997e65ac1a360ec1e875db06040d8d (diff)
downloadreference_model-5728713fca4f6e2dff60dad3689e471545e563d2.tar.gz
Add FFT2d to the reference model
Includes: * FFT2d reference implementation * Basic TOSA tests Change-Id: Ie79fcb713542345d550ec013646810c1e890e388 Signed-off-by: Luke Hutton <luke.hutton@arm.com>
-rw-r--r--reference_model/src/ops/op_factory.cc3
-rw-r--r--reference_model/src/ops/tensor_ops.cc243
-rw-r--r--reference_model/src/ops/tensor_ops.h23
-rw-r--r--verif/generator/tosa_arg_gen.py48
-rw-r--r--verif/generator/tosa_error_if.py64
-rw-r--r--verif/generator/tosa_test_gen.py111
6 files changed, 450 insertions, 42 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index b1a405a..8d84135 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -89,6 +89,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
break;
+ case Op_FFT2D:
+ DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
+ break;
case Op_FULLY_CONNECTED:
DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 4663c47..af808e8 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -238,6 +238,86 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute,
return 0;
}
+int check_fft_shape(const std::vector<int32_t>& in_real,
+ const std::vector<int32_t>& in_imag,
+ const std::vector<int32_t>& out_real,
+ const std::vector<int32_t>& out_imag,
+ std::string& msg) {
+ const bool is_rfft = in_imag.empty();
+ auto is_power_of_two = [](int32_t n) -> bool
+ {
+ return (n & (n-1)) == 0 && n > 0;
+ };
+
+ if (!is_power_of_two(in_real[1]) || !is_power_of_two(in_real[2]))
+ {
+ msg = "Input height and width must be a power of two";
+ return 1;
+ }
+
+ // RFFT does not have a second input
+ if (!is_rfft)
+ {
+ bool input_check = true;
+ for (size_t i = 0; i < in_real.size(); i++)
+ {
+ if (in_real[i] != in_imag[i])
+ {
+ input_check = false;
+ break;
+ }
+ }
+ if (!input_check)
+ {
+ msg = "Mismatch between real input shape and imaginary input shape";
+ return 1;
+ }
+ }
+
+ bool output_check = true;
+ for (size_t i = 0; i < out_real.size(); i++)
+ {
+ if (out_real[i] != out_imag[i])
+ {
+ output_check = false;
+ break;
+ }
+ }
+ if (!output_check)
+ {
+ msg = "Mismatch between real output shape and imaginary output shape";
+ return 1;
+ }
+
+ if (in_real[0] != out_real[0])
+ {
+ msg = "Input and output batch size don't match";
+ return 1;
+ }
+ if (in_real[1] != out_real[1])
+ {
+ msg = "Input and output height don't match";
+ return 1;
+ }
+
+ if (is_rfft)
+ {
+ if (in_real[2] / 2 + 1 != out_real[2])
+ {
+ msg = "Output width is expected to match input width / 2 + 1";
+ return 1;
+ }
+ } else {
+ if (in_real[2] != out_real[2])
+ {
+ msg = "Input and output width don't match";
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
template <int Rank, DType Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
@@ -1448,82 +1528,167 @@ int OpMaxPool2d<Dtype>::eval()
}
template <DType Dtype>
-OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
- : GraphNode(sgt_, Op_RFFT2D, id_)
+OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_FFT2D, id_)
{
- setRequiredOperands(1, 2);
+ setRequiredOperands(2, 2);
setRequiredRank(3);
+
+ INIT_ATTRIBUTE(FFT);
}
template <DType Dtype>
-OpRFFT2d<Dtype>::~OpRFFT2d() {}
+OpFFT2d<Dtype>::~OpFFT2d() {
+ if (attribute)
+ delete attribute;
+}
template <DType Dtype>
-int OpRFFT2d<Dtype>::checkTensorAttributes()
+int OpFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
- if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) ||
- validateRequiredRank(outputs[1]))
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) ||
+ validateRequiredRank(outputs[0]) || validateRequiredRank(outputs[1]))
{
return 1;
}
- if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
+ if (inputs[0]->matchType(*outputs[0]) || inputs[1]->matchType(*outputs[1]) ||
+ inputs[0]->matchType(*inputs[1]))
{
- printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
+ printNodeValidationError("OpFFT2d: input and output tensor type mismatch");
return 1;
}
- in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ in_real = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ in_imag = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
- ASSERT_MEM(in && out_real && out_imag);
+ ASSERT_MEM(in_real && in_imag && out_real && out_imag);
- auto is_power_of_two = [](int32_t n) -> bool
- {
- return (n & (n-1)) == 0 && n > 0;
- };
-
- // Input shape: [N, H, W]
- if (!is_power_of_two(in->getShape()[1]) || !is_power_of_two(in->getShape()[2]))
+ std::string msg;
+ if (check_fft_shape(in_real->getShape(), in_imag->getShape(),
+ out_real->getShape(), out_imag->getShape(), msg))
{
- printNodeValidationError("OpRFFT2d: input height and width must be a power of two");
+ msg = "OpFFT2d: " + msg;
+ printNodeValidationError(msg.c_str());
return 1;
}
- // Output shape: [N, H, W / 2 + 1]
- bool output_check = true;
- for (int32_t i = 0; i < out_real->getRank(); i++)
+ return 0;
+}
+
+template <DType Dtype>
+int OpFFT2d<Dtype>::eval()
+{
+ int in_real_batch = this->in_real->getShape()[0];
+ int in_real_height = this->in_real->getShape()[1];
+ int in_real_width = this->in_real->getShape()[2];
+
+ int in_imag_batch = this->in_imag->getShape()[0];
+ int in_imag_height = this->in_imag->getShape()[1];
+ int in_imag_width = this->in_imag->getShape()[2];
+
+ int out_real_batch = this->out_real->getShape()[0];
+ int out_real_height = this->out_real->getShape()[1];
+ int out_real_width = this->out_real->getShape()[2];
+
+ int out_imag_batch = this->out_imag->getShape()[0];
+ int out_imag_height = this->out_imag->getShape()[1];
+ int out_imag_width = this->out_imag->getShape()[2];
+
+ DEBUG_INFO(OP,
+ "perform OpFFT2d, input.shapes=[[%d,%d,%d],[%d,%d,%d]], output.shapes=[[%d,%d,%d],[%d,%d,%d]]",
+ in_real_batch, in_real_height, in_real_width,
+ in_imag_batch, in_imag_height, in_imag_width,
+ out_real_batch, out_real_height, out_real_width,
+ out_imag_batch, out_imag_height, out_imag_width);
+
+ OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
+
+ if (attribute->inverse()) {
+ sign_val = -1.0;
+ }
+
+ for (int n = 0; n < in_real_batch; n++)
{
- if (out_real->getShape()[i] != out_imag->getShape()[i])
+ for (int oy = 0; oy < out_real_height; oy++)
{
- output_check = false;
- break;
+ for (int ox = 0; ox < out_real_width; ox++)
+ {
+ sum_real = 0.0;
+ sum_imag = 0.0;
+ for (int iy = 0; iy < in_real_height; iy++)
+ {
+ for (int ix = 0; ix < in_real_width; ix++)
+ {
+ OutEigenType val_real = this->in_real->getTensor()(n, iy, ix);
+ OutEigenType val_imag = this->in_imag->getTensor()(n, iy, ix);
+ // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
+ a = sign_val * 2 * M_PI * ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
+ sum_real += val_real * cos(a) + val_imag * sin(a);
+ sum_imag += -val_real * sin(a) + val_imag * cos(a);
+ }
+ }
+ this->out_real->getTensor()(n, oy, ox) = sum_real;
+ this->out_imag->getTensor()(n, oy, ox) = sum_imag;
+ }
}
}
- if (!output_check)
- {
- printNodeValidationError(
- "OpRFFT2d: Mismatch between real output shape and imaginary output shape");
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
+ : GraphNode(sgt_, Op_RFFT2D, id_)
+{
+ setRequiredOperands(1, 2);
+ setRequiredRank(3);
+}
+
+template <DType Dtype>
+OpRFFT2d<Dtype>::~OpRFFT2d() {}
+
+
+template <DType Dtype>
+int OpRFFT2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
return 1;
- }
- if (in->getShape()[0] != out_real->getShape()[0]) {
- printNodeValidationError("OpRFFT2d: input and output batch size don't match");
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]) ||
+ validateRequiredRank(outputs[1]))
+ {
return 1;
}
- if (in->getShape()[1] != out_real->getShape()[1]) {
- printNodeValidationError("OpRFFT2d: input and output height don't match");
+
+ if (inputs[0]->matchType(*outputs[0]) || inputs[0]->matchType(*outputs[1]))
+ {
+ printNodeValidationError("OpRFFT2d: input and output tensor type mismatch");
return 1;
}
- if (in->getShape()[2] / 2 + 1 != out_real->getShape()[2]) {
- printNodeValidationError("OpRFFT2d: output width is expected to match input width / 2 + 1");
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out_real = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ out_imag = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[1]);
+
+ ASSERT_MEM(in && out_real && out_imag);
+
+ std::string msg;
+ if (check_fft_shape(in->getShape(), {},
+ out_real->getShape(), out_imag->getShape(), msg))
+ {
+ msg = "OpRFFT2d: " + msg;
+ printNodeValidationError(msg.c_str());
return 1;
}
@@ -1843,6 +2008,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
+
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 0d2b3eb..9ef4a58 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -249,6 +249,29 @@ protected:
};
template <DType Dtype>
+class OpFFT2d : public GraphNode
+{
+public:
+ OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpFFT2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 3>;
+ using TOut = Eigen::Tensor<OutEigenType, 3>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in_real;
+ TosaReference::TensorTemplate<TIn>* in_imag;
+ TosaReference::TensorTemplate<TOut>* out_real;
+ TosaReference::TensorTemplate<TOut>* out_imag;
+ tosa::TosaFFTAttribute* attribute;
+};
+
+template <DType Dtype>
class OpRFFT2d : public GraphNode
{
public:
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 05a7d2b..370570c 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -417,6 +417,45 @@ class TosaTensorGen:
return [ifm_shape, filter_shape, bias_shape]
@staticmethod
+ def tgFFT2d(testGen, op, rank, error_name=None):
+ pl, const = op["operands"]
+
+ if error_name != ErrorIf.WrongRank:
+ assert rank == 3
+ assert pl == 2 and const == 0
+
+ # IFM dimensions are NHW
+ ifm_shape = testGen.makeShape(rank)
+
+ # Select nearest lower power of two from input height and width
+ ifm_shape[1] = 2 ** int(math.log(ifm_shape[1], 2))
+ ifm_shape[2] = 2 ** int(math.log(ifm_shape[2], 2))
+
+ # Constrict the overall size of the shape when creating ERROR_IF tests
+ if error_name:
+ ifm_shape = TosaErrorIfArgGen.eiRestrictDimensions(ifm_shape)
+
+ # Generate an invalid kernel that is not a power of two
+ if error_name == ErrorIf.KernelNotPowerOfTwo:
+ inc_h = 2 if ifm_shape[1] == 1 else 1
+ inc_w = 2 if ifm_shape[2] == 1 else 1
+ inc_choices = [(inc_h, 0), (0, inc_w), (inc_h, inc_w)]
+ selected_inc = testGen.rng.choice(inc_choices)
+ ifm_shape[1] += selected_inc[0]
+ ifm_shape[2] += selected_inc[1]
+
+ ifm_shape = testGen.constrictBatchSize(ifm_shape)
+
+ ifm_shapes = [ifm_shape.copy(), ifm_shape.copy()]
+ if error_name == ErrorIf.FFTInputShapeMismatch:
+ modify_shape = testGen.rng.choice([0, 1])
+ # Only modify kernel (H, W)
+ modify_dim = testGen.rng.choice([1, 2])
+ ifm_shapes[modify_shape][modify_dim] *= 2
+
+ return [ifm_shapes[0], ifm_shapes[1]]
+
+ @staticmethod
def tgRFFT2d(testGen, op, rank, error_name=None):
pl, const = op["operands"]
@@ -1613,6 +1652,15 @@ class TosaArgGen:
return arg_list
+ @staticmethod
+ def agFFT2d(testGen, opName, shapeList, dtype, error_name=None):
+ arg_list = []
+
+ arg_list.append(("inverseTrue", [True]))
+ arg_list.append(("inverseFalse", [False]))
+
+ return arg_list
+
# Helper function for reshape. Gets some factors of a larger number.
@staticmethod
def getFactors(val, start=1):
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 93f975d..ee227b3 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -79,6 +79,8 @@ class ErrorIf(object):
CondIfCondShapeNotSizeOne = "CondIfCondShapeNotSizeOne"
CondGraphOutputShapeNotSizeOne = "CondGraphOutputShapeNotSizeOne"
KernelNotPowerOfTwo = "KernelNotPowerOfTwo"
+ FFTInputShapeMismatch = "FFTInputShapeMismatch"
+ FFTOutputShapeMismatch = "FFTOutputShapeMismatch"
class TosaErrorIfArgGen:
@@ -562,7 +564,7 @@ class TosaErrorValidator:
):
error_result = True
- elif op["op"] == Op.RFFT2D:
+ elif op["op"] in [Op.FFT2D, Op.RFFT2D]:
if not all([ty == input_dtype for ty in output_dtype]):
error_result = True
@@ -686,7 +688,7 @@ class TosaErrorValidator:
op = kwargs["op"]
output_list = kwargs["output_list"]
expected_length = 1
- if op["op"] == Op.RFFT2D:
+ if op["op"] in [Op.FFT2D, Op.RFFT2D]:
expected_length = 2
if len(output_list) != expected_length:
@@ -2446,6 +2448,64 @@ class TosaErrorValidator:
}
return info_dict
+ @staticmethod
+ def evFFTInputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.FFTInputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = "Mismatch between real and imaginary input shapes"
+
+ if check:
+ input1 = kwargs["input1"]
+ input2 = kwargs["input2"]
+
+ if input1.shape != input2.shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
+ @staticmethod
+ def evFFTOutputShapeMismatch(check=False, **kwargs):
+ error_name = ErrorIf.FFTOutputShapeMismatch
+ param_reqs = {"rank": None, "dtype": None, "shape": None}
+ error_result = False
+ error_reason = (
+ "Mismatch between provided and expected output kernel (H, W) shape"
+ )
+
+ if check:
+ op = kwargs["op"]
+ input_shape = kwargs["input_shape"]
+
+ if len(input_shape) == 3:
+ output_shapes = kwargs["output_shape"]
+
+ # Ignoring batch size (N) from input shape
+ expected_shape = input_shape[1:]
+ if op["op"] == Op.RFFT2D:
+ expected_shape[1] = expected_shape[1] // 2 + 1
+
+ # Ignoring batch size (N) from output shapes
+ output_shape_0 = output_shapes[0][1:]
+ output_shape_1 = output_shapes[1][1:]
+ # Ensure sure the kernel sizes (H, W) of both outputs match the expected
+ if output_shape_0 != output_shape_1 or output_shape_0 != expected_shape:
+ error_result = True
+
+ info_dict = {
+ "error_name": error_name,
+ "error_result": error_result,
+ "error_reason": error_reason,
+ "param_reqs": param_reqs,
+ }
+ return info_dict
+
class TosaInvalidValidator:
@staticmethod
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 5f9e2c1..2b762aa 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -213,6 +213,12 @@ class TosaTestGen:
else:
raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
+ def constrictBatchSize(self, shape):
+ # Limit the batch size unless an explicit target shape set
+ if self.args.max_batch_size and not self.args.target_shapes:
+ shape[0] = min(shape[0], self.args.max_batch_size)
+ return shape
+
# Argument generators
# Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
# Where the string descriptor is used to generate the test name and
@@ -2081,6 +2087,48 @@ class TosaTestGen:
return acc_out
+ def build_fft2d(
+ self, op, val1, val2, inverse, validator_fcns=None, error_name=None
+ ):
+ results = OutputShaper.fft2dOp(self.ser, self.rng, val1, val2, error_name)
+
+ input_names = [val1.name, val2.name]
+ pCount, cCount = op["operands"]
+ num_operands = pCount + cCount
+
+ output_names = [res.name for res in results]
+ output_shapes = [res.shape for res in results]
+ output_dtypes = [res.dtype for res in results]
+
+ input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
+ self, error_name, input_names, output_names
+ )
+
+ if not TosaErrorValidator.evValidateErrorIfs(
+ self.ser,
+ validator_fcns,
+ error_name,
+ op=op,
+ inverse=inverse,
+ input1=val1,
+ input2=val2,
+ input_shape=val1.shape,
+ input_dtype=val1.dtype,
+ output_shape=output_shapes,
+ output_dtype=output_dtypes,
+ result_tensors=results,
+ input_list=input_names,
+ output_list=output_names,
+ num_operands=num_operands,
+ ):
+ return None
+
+ attr = ts.TosaSerializerAttribute()
+ attr.FFTAttribute(inverse)
+
+ self.ser.addOperator(op["op"], input_names, output_names, attr)
+ return results
+
def build_rfft2d(self, op, val, validator_fcns=None, error_name=None):
results = OutputShaper.rfft2dOp(self.ser, self.rng, val, error_name)
@@ -2089,6 +2137,7 @@ class TosaTestGen:
num_operands = pCount + cCount
output_names = [res.name for res in results]
+ output_shapes = [res.shape for res in results]
output_dtypes = [res.dtype for res in results]
input_names, output_names = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -2102,6 +2151,7 @@ class TosaTestGen:
op=op,
input_shape=val.shape,
input_dtype=val.dtype,
+ output_shape=output_shapes,
output_dtype=output_dtypes,
result_tensors=results,
input_list=input_names,
@@ -3927,6 +3977,29 @@ class TosaTestGen:
TosaErrorValidator.evCondGraphOutputShapeNotSizeOne,
),
},
+ "fft2d": {
+ "op": Op.FFT2D,
+ "operands": (2, 0),
+ "rank": (3, 3),
+ "build_fcn": (
+ build_fft2d,
+ TosaTensorGen.tgFFT2d,
+ TosaTensorValuesGen.tvgDefault,
+ TosaArgGen.agFFT2d,
+ ),
+ "types": [DType.FP32],
+ "error_if_validators": (
+ TosaErrorValidator.evWrongInputType,
+ TosaErrorValidator.evWrongOutputType,
+ TosaErrorValidator.evWrongInputList,
+ TosaErrorValidator.evWrongOutputList,
+ TosaErrorValidator.evWrongRank,
+ TosaErrorValidator.evBatchMismatch,
+ TosaErrorValidator.evKernelNotPowerOfTwo,
+ TosaErrorValidator.evFFTInputShapeMismatch,
+ TosaErrorValidator.evFFTOutputShapeMismatch,
+ ),
+ },
"rfft2d": {
"op": Op.RFFT2D,
"operands": (1, 0),
@@ -3946,6 +4019,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evBatchMismatch,
TosaErrorValidator.evKernelNotPowerOfTwo,
+ TosaErrorValidator.evFFTOutputShapeMismatch,
),
},
}
@@ -4770,6 +4844,37 @@ class OutputShaper:
return ser.addOutput(output_shape, out_dtype)
@staticmethod
+ def fft2dOp(serializer, rng, ifm1, ifm2, error_name=None):
+ outputs = []
+
+ assert ifm1.dtype == ifm2.dtype
+ input_dtype = ifm1.dtype
+
+ if error_name != ErrorIf.FFTInputShapeMismatch:
+ assert ifm1.shape == ifm2.shape
+
+ input_shape = ifm1.shape
+ if error_name != ErrorIf.WrongRank:
+ assert len(input_shape) == 3
+
+ output_shape = input_shape.copy()
+ output_dtype = input_dtype
+
+ if error_name == ErrorIf.WrongOutputType:
+ excludes = [DType.FP32]
+ wrong_dtypes = list(usableDTypes(excludes=excludes))
+ output_dtype = rng.choice(wrong_dtypes)
+ elif error_name == ErrorIf.BatchMismatch:
+ output_shape[0] += rng.integers(1, 10)
+ elif error_name == ErrorIf.FFTOutputShapeMismatch:
+ modify_dim = rng.choice([1, 2])
+ output_shape[modify_dim] += rng.integers(1, 10)
+
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ outputs.append(serializer.addOutput(output_shape, output_dtype))
+ return outputs
+
+ @staticmethod
def rfft2dOp(serializer, rng, value, error_name=None):
outputs = []
@@ -4785,8 +4890,10 @@ class OutputShaper:
wrong_dtypes = list(usableDTypes(excludes=excludes))
output_dtype = rng.choice(wrong_dtypes)
elif error_name == ErrorIf.BatchMismatch:
- incorrect_batch = input_shape[0] + rng.integers(1, 10)
- output_shape = [incorrect_batch, *input_shape[1:]]
+ output_shape[0] += rng.integers(1, 10)
+ elif error_name == ErrorIf.FFTOutputShapeMismatch:
+ modify_dim = rng.choice([1, 2])
+ output_shape[modify_dim] += rng.integers(1, 10)
outputs.append(serializer.addOutput(output_shape, output_dtype))
outputs.append(serializer.addOutput(output_shape, output_dtype))