aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-02-21 14:47:56 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-03-12 11:21:58 +0000
commitdef2a851abd0d0e1cd748e53b7cb438be15d8f2b (patch)
tree80208efd120d3dbb1dc456ac816b2a6fd53adabe /tests
parent8d04fa08b1c2a71466876b832fc6c6dfaa978a40 (diff)
downloadComputeLibrary-def2a851abd0d0e1cd748e53b7cb438be15d8f2b.tar.gz
COMPMID-1960: Implement DFT reference
Change-Id: I08d47ce2cf7fc833df94420bedd69cedf080fd34 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/822 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/SimpleTensorAccessor.h174
-rw-r--r--tests/validation/CPP/DFT.cpp183
-rw-r--r--tests/validation/reference/DFT.cpp420
-rw-r--r--tests/validation/reference/DFT.h118
-rw-r--r--tests/validation/reference/Permute.cpp10
5 files changed, 900 insertions, 5 deletions
diff --git a/tests/SimpleTensorAccessor.h b/tests/SimpleTensorAccessor.h
new file mode 100644
index 0000000000..a7ed6f664c
--- /dev/null
+++ b/tests/SimpleTensorAccessor.h
@@ -0,0 +1,174 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_SIMPLE_TENSOR_ACCESSOR_H__
+#define __ARM_COMPUTE_TEST_SIMPLE_TENSOR_ACCESSOR_H__
+
+#include "SimpleTensor.h"
+#include "tests/IAccessor.h"
+
+namespace arm_compute
+{
+namespace test
+{
+/** Accessor implementation for @ref SimpleTensor objects. */
+template <typename T>
+class SimpleTensorAccessor : public IAccessor
+{
+public:
+ /** Create an accessor for the given @p tensor.
+ *
+ * @param[in, out] tensor To be accessed tensor.
+ */
+ SimpleTensorAccessor(SimpleTensor<T> &tensor);
+
+ /** Prevent instances of this class from being copy constructed */
+ SimpleTensorAccessor(const SimpleTensorAccessor &) = delete;
+ /** Prevent instances of this class from being copied */
+ SimpleTensorAccessor &operator=(const SimpleTensorAccessor &) = delete;
+ /** Allow instances of this class to be move constructed */
+ SimpleTensorAccessor(SimpleTensorAccessor &&) = default;
+ /** Allow instances of this class to be moved */
+ SimpleTensorAccessor &operator=(SimpleTensorAccessor &&) = default;
+
+ /** Get the tensor data.
+ *
+ * @return a constant pointer to the tensor data.
+ */
+ const void *data() const;
+ /** Get the tensor data.
+ *
+ * @return a pointer to the tensor data.
+ */
+ void *data();
+
+ // Inherited methods overridden:
+ TensorShape shape() const override;
+ size_t element_size() const override;
+ size_t size() const override;
+ Format format() const override;
+ DataLayout data_layout() const override;
+ DataType data_type() const override;
+ int num_channels() const override;
+ int num_elements() const override;
+ PaddingSize padding() const override;
+ QuantizationInfo quantization_info() const override;
+ const void *operator()(const Coordinates &coord) const override;
+ void *operator()(const Coordinates &coord) override;
+
+private:
+ SimpleTensor<T> &_tensor;
+};
+
+template <typename T>
+inline SimpleTensorAccessor<T>::SimpleTensorAccessor(SimpleTensor<T> &tensor)
+ : _tensor{ tensor }
+{
+}
+
+template <typename T>
+inline TensorShape SimpleTensorAccessor<T>::shape() const
+{
+ return _tensor.shape();
+}
+
+template <typename T>
+inline size_t SimpleTensorAccessor<T>::element_size() const
+{
+ return _tensor.element_size();
+}
+
+template <typename T>
+inline size_t SimpleTensorAccessor<T>::size() const
+{
+ return _tensor.num_elements() * _tensor.element_size();
+}
+
+template <typename T>
+inline Format SimpleTensorAccessor<T>::format() const
+{
+ return _tensor.format();
+}
+
+template <typename T>
+inline DataLayout SimpleTensorAccessor<T>::data_layout() const
+{
+ return _tensor.data_layout();
+}
+
+template <typename T>
+inline DataType SimpleTensorAccessor<T>::data_type() const
+{
+ return _tensor.data_type();
+}
+
+template <typename T>
+inline int SimpleTensorAccessor<T>::num_channels() const
+{
+ return _tensor.num_channels();
+}
+
+template <typename T>
+inline int SimpleTensorAccessor<T>::num_elements() const
+{
+ return _tensor.num_elements();
+}
+
+template <typename T>
+inline PaddingSize SimpleTensorAccessor<T>::padding() const
+{
+ return _tensor.padding();
+}
+
+template <typename T>
+inline QuantizationInfo SimpleTensorAccessor<T>::quantization_info() const
+{
+ return _tensor.quantization_info();
+}
+
+template <typename T>
+inline const void *SimpleTensorAccessor<T>::data() const
+{
+ return _tensor.data();
+}
+
+template <typename T>
+inline void *SimpleTensorAccessor<T>::data()
+{
+ return _tensor.data();
+}
+
+template <typename T>
+inline const void *SimpleTensorAccessor<T>::operator()(const Coordinates &coord) const
+{
+ return _tensor(coord);
+}
+
+template <typename T>
+inline void *SimpleTensorAccessor<T>::operator()(const Coordinates &coord)
+{
+ return _tensor(coord);
+}
+} // namespace test
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_TEST_SIMPLE_TENSOR_ACCESSOR_H__ */
diff --git a/tests/validation/CPP/DFT.cpp b/tests/validation/CPP/DFT.cpp
new file mode 100644
index 0000000000..8f1b82371d
--- /dev/null
+++ b/tests/validation/CPP/DFT.cpp
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/TensorShape.h"
+#include "arm_compute/core/Types.h"
+#include "tests/AssetsLibrary.h"
+#include "tests/Globals.h"
+#include "tests/SimpleTensor.h"
+#include "tests/SimpleTensorAccessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+
+#include "tests/validation/Validation.h"
+#include "tests/validation/reference/ConvolutionLayer.h"
+#include "tests/validation/reference/DFT.h"
+
+#include <random>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace
+{
+auto shapes_1d_dft = framework::dataset::make("TensorShape", { TensorShape(33U),
+ TensorShape(8U),
+ TensorShape(23U, 7U),
+ TensorShape(16U, 8U, 4U)
+ });
+
+auto shapes_2d_dft = framework::dataset::make("TensorShape", { TensorShape(33U, 14U),
+ TensorShape(8U, 9U),
+ TensorShape(23U, 7U, 3U),
+ TensorShape(16U, 8U, 4U)
+ });
+
+auto conv_dataset_dft = framework::dataset::zip(framework::dataset::zip(framework::dataset::make("InputShape", { TensorShape(8U, 7U, 3U, 2U),
+ TensorShape(18U, 22U, 4U),
+ TensorShape(32U, 48U, 8U)
+ }),
+ framework::dataset::make("WeightShape", { TensorShape(3U, 3U, 3U, 6U),
+ TensorShape(5U, 5U, 4U, 3U),
+ TensorShape(9U, 9U, 8U, 3U)
+ })),
+ framework::dataset::make("ConvInfo", { PadStrideInfo(1, 1, 1, 1),
+ PadStrideInfo(1, 1, 2, 2),
+ PadStrideInfo(1, 1, 4, 4)
+ }));
+} // namespace
+TEST_SUITE(CPP)
+TEST_SUITE(DFT)
+
+TEST_SUITE(DFT1D)
+DATA_TEST_CASE(Real, framework::DatasetMode::ALL, shapes_1d_dft,
+ shape)
+{
+ SimpleTensor<float> src{ shape, DataType::F32, 1 };
+ std::uniform_real_distribution<float> distribution(-5.f, 5.f);
+ library->fill(src, distribution, 0);
+
+ const bool is_odd = shape.x() % 2;
+
+ // Forward pass
+ auto forward = reference::rdft_1d(src);
+ // Backward pass
+ auto backward = reference::ridft_1d(forward, is_odd);
+
+ // Validate with input
+ validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f));
+}
+
+DATA_TEST_CASE(Complex, framework::DatasetMode::ALL, shapes_1d_dft,
+ shape)
+{
+ SimpleTensor<float> src{ shape, DataType::F32, 2 };
+ std::uniform_real_distribution<float> distribution(-5.f, 5.f);
+ library->fill(src, distribution, 0);
+
+ // Forward pass
+ auto forward = reference::dft_1d(src, reference::FFTDirection::Forward);
+ // Backward pass
+ auto backward = reference::dft_1d(forward, reference::FFTDirection::Inverse);
+
+ // Validate with input
+ validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f));
+}
+TEST_SUITE_END() // DFT1D
+
+TEST_SUITE(DFT2D)
+DATA_TEST_CASE(Real, framework::DatasetMode::ALL, shapes_2d_dft,
+ shape)
+{
+ SimpleTensor<float> src{ shape, DataType::F32, 1 };
+ std::uniform_real_distribution<float> distribution(-5.f, 5.f);
+ library->fill(src, distribution, 0);
+
+ const bool is_odd = shape.x() % 2;
+
+ // Forward pass
+ auto forward = reference::rdft_2d(src);
+ // Backward pass
+ auto backward = reference::ridft_2d(forward, is_odd);
+
+ // Validate with input
+ validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f));
+}
+
+DATA_TEST_CASE(Complex, framework::DatasetMode::ALL, shapes_2d_dft,
+ shape)
+{
+ SimpleTensor<float> src{ shape, DataType::F32, 2 };
+ std::uniform_real_distribution<float> distribution(-5.f, 5.f);
+ library->fill(src, distribution, 0);
+
+ // Forward pass
+ auto forward = reference::dft_2d(src, reference::FFTDirection::Forward);
+ // Backward pass
+ auto backward = reference::dft_2d(forward, reference::FFTDirection::Inverse);
+
+ // Validate with input
+ validate(SimpleTensorAccessor<float>(src), backward, RelativeTolerance<float>(0.1f));
+}
+TEST_SUITE_END() // DFT2D
+
+TEST_SUITE(Conv)
+DATA_TEST_CASE(Real2Real, framework::DatasetMode::ALL, conv_dataset_dft,
+ shape_in, shape_w, conv_info)
+{
+ std::uniform_real_distribution<float> distribution(-1.f, 1.f);
+ std::uniform_real_distribution<float> distribution_b(0.f, 0.f);
+
+ SimpleTensor<float> src{ shape_in, DataType::F32, 1 };
+ SimpleTensor<float> w{ shape_w, DataType::F32, 1 };
+ SimpleTensor<float> b{ TensorShape(shape_w[3]), DataType::F32, 1 };
+
+ library->fill(src, distribution, 0);
+ library->fill(w, distribution, 1);
+ library->fill(b, distribution_b, 2);
+
+ const auto output_wh = arm_compute::scaled_dimensions(shape_in.x(), shape_in.y(), shape_w.x(), shape_w.y(), conv_info);
+ TensorShape dst_shape = shape_in;
+ dst_shape.set(0, output_wh.first);
+ dst_shape.set(1, output_wh.second);
+ dst_shape.set(2, shape_w[3]);
+
+ // FFT based convolution
+ auto dst = reference::conv2d_dft(src, w, conv_info);
+ // Reference convolution
+ auto dst_ref = reference::convolution_layer(src, w, b, dst_shape, conv_info);
+
+ // Validate with input
+ validate(SimpleTensorAccessor<float>(dst), dst_ref, RelativeTolerance<float>(0.1f), 0.f, AbsoluteTolerance<float>(0.001f));
+}
+TEST_SUITE_END() // Conv
+
+TEST_SUITE_END() // DFT
+TEST_SUITE_END() // CPP
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/reference/DFT.cpp b/tests/validation/reference/DFT.cpp
new file mode 100644
index 0000000000..6ad1b9e150
--- /dev/null
+++ b/tests/validation/reference/DFT.cpp
@@ -0,0 +1,420 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "DFT.h"
+
+#include "PadLayer.h"
+#include "Permute.h"
+#include "Reverse.h"
+#include "SliceOperations.h"
+
+#include <cmath>
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace reference
+{
+namespace
+{
+/** Performs an one dimensional DFT on a given real sequence.
+ *
+ * @param[in] src_ptr Pointer to the real input sequence.
+ * @param[in] N Size of input sequence.
+ * @param[out] dst_ptr Pointer to the complex output sequence.
+ * @param[out] K Size of the output sequence
+ */
+template <typename T>
+void rdft_1d_step(const T *src_ptr, size_t N, T *dst_ptr, size_t K)
+{
+ for(unsigned int k = 0; k < K; ++k)
+ {
+ float Xr = 0;
+ float Xi = 0;
+ for(unsigned int n = 0; n < N; ++n)
+ {
+ const float alpha = (2 * M_PI * k * n) / N;
+ const float val_r = src_ptr[n];
+ // Assuming DFT from the R domain thus skipping imaginary calculations
+ Xr += val_r * cos(alpha);
+ Xi -= val_r * sin(alpha);
+ }
+
+ dst_ptr[k * 2] = Xr;
+ dst_ptr[k * 2 + 1] = Xi;
+ }
+}
+
+/** Performs an one dimensional DFT on a given complex sequence.
+ *
+ * @param[in] src_ptr Pointer to the complex input sequence.
+ * @param[out] dst_ptr Pointer to the complex output sequence.
+ * @param[in] N Size of the sequences
+ */
+template <typename T>
+void dft_1d_step(const T *src_ptr, T *dst_ptr, size_t N)
+{
+ for(unsigned int k = 0; k < N; ++k)
+ {
+ float Xr = 0;
+ float Xi = 0;
+ for(unsigned int n = 0; n < N; ++n)
+ {
+ const float alpha = (2 * M_PI * k * n) / N;
+ const float val_r = src_ptr[2 * n];
+ const float val_i = src_ptr[2 * n + 1];
+ const float cos_alpha = cos(alpha);
+ const float sin_alpha = sin(alpha);
+
+ Xr += val_r * cos_alpha + val_i * sin_alpha;
+ Xi += val_i * cos_alpha - val_r * sin_alpha;
+ }
+
+ dst_ptr[k * 2] = Xr;
+ dst_ptr[k * 2 + 1] = Xi;
+ }
+}
+
+/** Performs an one dimensional inverse DFT on a given real sequence.
+ *
+ * @param[in] src_ptr Pointer to the real input sequence.
+ * @param[in] K Size of input sequence.
+ * @param[out] dst_ptr Pointer to the complex output sequence.
+ * @param[out] N Size of the output sequence
+ */
+template <typename T>
+void irdft_1d_step(const T *src_ptr, size_t K, T *dst_ptr, size_t N)
+{
+ const bool is_odd = N % 2;
+ const unsigned int Nleft = N - K;
+ const int tail_start = is_odd ? K - 1 : K - 2;
+
+ for(unsigned int n = 0; n < N; ++n)
+ {
+ float xr = 0;
+ for(unsigned int k = 0; k < K; ++k)
+ {
+ const float alpha = (2 * M_PI * k * n) / N;
+ xr += src_ptr[2 * k] * cos(alpha) - src_ptr[2 * k + 1] * sin(alpha);
+ }
+
+ unsigned int j = tail_start;
+ for(unsigned int k = 0; k < Nleft; ++k)
+ {
+ const float alpha = (2 * M_PI * (k + K) * n) / N;
+ xr += src_ptr[2 * j] * cos(alpha) + src_ptr[2 * j + 1] * sin(alpha);
+ --j;
+ }
+
+ dst_ptr[n] = xr;
+ }
+}
+
+/** Performs an one dimensional inverse DFT on a given complex sequence.
+ *
+ * @param[in] src_ptr Pointer to the complex input sequence.
+ * @param[out] dst_ptr Pointer to the complex output sequence.
+ * @param[in] N Size of the sequences
+ */
+template <typename T>
+void idft_1d_step(const T *src_ptr, T *dst_ptr, size_t N)
+{
+ for(unsigned int n = 0; n < N; ++n)
+ {
+ float xr = 0;
+ float xi = 0;
+ for(unsigned int k = 0; k < N; ++k)
+ {
+ const float alpha = (2 * M_PI * k * n) / N;
+ const float cos_alpha = cos(alpha);
+ const float sin_alpha = sin(alpha);
+ const float val_r = src_ptr[2 * k];
+ const float val_i = src_ptr[2 * k + 1];
+
+ xr += val_r * cos_alpha - val_i * sin_alpha;
+ xi += val_i * cos_alpha + val_r * sin_alpha;
+ }
+
+ dst_ptr[2 * n] = xr;
+ dst_ptr[2 * n + 1] = xi;
+ }
+}
+
+template <typename T>
+SimpleTensor<T> rdft_1d_core(const SimpleTensor<T> &src, FFTDirection direction, bool is_odd)
+{
+ // Performs only rdft
+ ARM_COMPUTE_ERROR_ON(direction == FFTDirection::Forward && src.num_channels() != 1);
+ ARM_COMPUTE_ERROR_ON(direction == FFTDirection::Inverse && src.num_channels() != 2);
+
+ const unsigned int inverse_tail = is_odd ? 1 : 0;
+ const unsigned int N = src.shape()[0];
+ const unsigned int K = direction == FFTDirection::Forward ? N / 2 + 1 : (N - 1) * 2 + inverse_tail;
+ const unsigned int num_channels = direction == FFTDirection::Forward ? 2 : 1;
+
+ TensorShape dst_shape = src.shape();
+ dst_shape.set(0, K);
+
+ SimpleTensor<T> dst(dst_shape, src.data_type(), num_channels);
+
+ const unsigned int upper_dims = src.shape().total_size_upper(1);
+ for(unsigned int du = 0; du < upper_dims; ++du)
+ {
+ const T *src_row_ptr = src.data() + du * N * src.num_channels();
+ T *dst_row_ptr = dst.data() + du * K * dst.num_channels();
+ direction == FFTDirection::Forward ? rdft_1d_step(src_row_ptr, N, dst_row_ptr, K) : irdft_1d_step(src_row_ptr, N, dst_row_ptr, K);
+ }
+
+ return dst;
+}
+
+template <typename T>
+SimpleTensor<T> dft_1d_core(const SimpleTensor<T> &src, FFTDirection direction)
+{
+ ARM_COMPUTE_ERROR_ON(src.num_channels() != 2);
+
+ const unsigned int N = src.shape()[0];
+
+ SimpleTensor<T> dst(src.shape(), src.data_type(), src.num_channels());
+
+ const unsigned int upper_dims = src.shape().total_size_upper(1);
+ for(unsigned int du = 0; du < upper_dims; ++du)
+ {
+ const T *src_row_ptr = src.data() + du * N * src.num_channels();
+ T *dst_row_ptr = dst.data() + du * N * dst.num_channels();
+ direction == FFTDirection::Forward ? dft_1d_step(src_row_ptr, dst_row_ptr, N) : idft_1d_step(src_row_ptr, dst_row_ptr, N);
+ }
+
+ return dst;
+}
+
+/** Scale a tensor by a given scaling factor.
+ *
+ * @param[in,out] tensor Tensor to scale.
+ * @param[in] scaling_factor Scaling to scale the tensor data with.
+ */
+template <typename T>
+void scale(SimpleTensor<T> &tensor, T scaling_factor)
+{
+ const int total_elements = tensor.num_elements() * tensor.num_channels();
+ T *data_ptr = tensor.data();
+ for(int i = 0; i < total_elements; ++i)
+ {
+ data_ptr[i] /= scaling_factor;
+ }
+}
+
+/** Performs a complex element-wise multiplication with reduction across the channels axis.
+ *
+ * @param[in] input Input tensor.
+ * @param[in] weights Weights tensor.
+ *
+ * @return Output tensor.
+ */
+template <typename T>
+SimpleTensor<T> complex_mul_and_reduce(const SimpleTensor<T> &input, const SimpleTensor<T> &weights)
+{
+ const int W = input.shape().x();
+ const int H = input.shape().y();
+ const int Ci = input.shape().z();
+ const int Co = weights.shape()[3];
+ const int N = input.shape().total_size() / (W * H * Ci);
+
+ TensorShape output_shape = input.shape();
+ output_shape.set(2, Co);
+ SimpleTensor<T> dst(output_shape, input.data_type(), input.num_channels());
+
+ // MemSet dst memory to zero
+ std::memset(dst.data(), 0, dst.size());
+
+ for(int b = 0; b < N; ++b)
+ {
+ for(int co = 0; co < Co; ++co)
+ {
+ for(int ci = 0; ci < Ci; ++ci)
+ {
+ for(int h = 0; h < H; ++h)
+ {
+ for(int w = 0; w < W; ++w)
+ {
+ size_t i_index = w + h * W + ci * H * W + b * H * W * Ci;
+ size_t w_index = w + h * W + ci * H * W + co * H * W * Ci;
+ size_t o_index = w + h * W + co * H * W + b * H * W * Co;
+ const Coordinates i_coords = index2coords(input.shape(), i_index);
+ const Coordinates w_coords = index2coords(weights.shape(), w_index);
+ const Coordinates o_coords = index2coords(dst.shape(), o_index);
+
+ auto i_ptr = static_cast<const T *>(input(i_coords));
+ auto w_ptr = static_cast<const T *>(weights(w_coords));
+ auto o_ptr = static_cast<T *>(dst(o_coords));
+
+ const T Rin = i_ptr[0];
+ const T Iin = i_ptr[1];
+ const T Rw = w_ptr[0];
+ const T Iw = w_ptr[1];
+
+ o_ptr[0] += Rin * Rw - Iin * Iw;
+ o_ptr[1] += Rin * Iw + Rw * Iin;
+ }
+ }
+ }
+ }
+ }
+ return dst;
+}
+} // namespace
+
+template <typename T>
+SimpleTensor<T> rdft_1d(const SimpleTensor<T> &src)
+{
+ return rdft_1d_core(src, FFTDirection::Forward, false);
+}
+
+template <typename T>
+SimpleTensor<T> ridft_1d(const SimpleTensor<T> &src, bool is_odd)
+{
+ auto dst = rdft_1d_core(src, FFTDirection::Inverse, is_odd);
+
+ const T scaling_factor = dst.shape()[0];
+ scale(dst, scaling_factor);
+
+ return dst;
+}
+
+template <typename T>
+SimpleTensor<T> dft_1d(const SimpleTensor<T> &src, FFTDirection direction)
+{
+ auto dst = dft_1d_core(src, direction);
+ if(direction == FFTDirection::Inverse)
+ {
+ const T scaling_factor = dst.shape()[0];
+ scale(dst, scaling_factor);
+ }
+ return dst;
+}
+
+template <typename T>
+SimpleTensor<T> rdft_2d(const SimpleTensor<T> &src)
+{
+ ARM_COMPUTE_ERROR_ON(src.num_channels() != 1);
+ constexpr FFTDirection direction = FFTDirection::Forward;
+
+ auto first_pass = rdft_1d_core(src, direction, false);
+ auto transposed = permute(first_pass, PermutationVector(1U, 0U));
+ auto second_pass = dft_1d_core(transposed, direction);
+ return permute(second_pass, PermutationVector(1U, 0U));
+}
+
+template <typename T>
+SimpleTensor<T> ridft_2d(const SimpleTensor<T> &src, bool is_odd)
+{
+ ARM_COMPUTE_ERROR_ON(src.num_channels() != 2);
+ constexpr FFTDirection direction = FFTDirection::Inverse;
+
+ auto transposed = permute(src, PermutationVector(1U, 0U));
+ auto first_pass = dft_1d_core(transposed, direction);
+ auto transposed_2 = permute(first_pass, PermutationVector(1U, 0U));
+ auto dst = rdft_1d_core(transposed_2, direction, is_odd);
+
+ const T scaling_factor = dst.shape()[0] * dst.shape()[1];
+ scale(dst, scaling_factor);
+ return dst;
+}
+
+template <typename T>
+SimpleTensor<T> dft_2d(const SimpleTensor<T> &src, FFTDirection direction)
+{
+ ARM_COMPUTE_ERROR_ON(src.num_channels() != 2);
+
+ if(direction == FFTDirection::Forward)
+ {
+ auto first_pass = dft_1d_core(src, direction);
+ auto transposed = permute(first_pass, PermutationVector(1U, 0U));
+ auto second_pass = dft_1d_core(transposed, direction);
+ return permute(second_pass, PermutationVector(1U, 0U));
+ }
+ else
+ {
+ auto transposed = permute(src, PermutationVector(1U, 0U));
+ auto first_pass = dft_1d_core(transposed, direction);
+ auto transposed_2 = permute(first_pass, PermutationVector(1U, 0U));
+ auto dst = dft_1d_core(transposed_2, direction);
+
+ const T scaling_factor = dst.shape()[0] * dst.shape()[1];
+ scale(dst, scaling_factor);
+
+ return dst;
+ }
+}
+
+template <typename T>
+SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w, const PadStrideInfo &conv_info)
+{
+ // Pad input to full padding
+ const PaddingList padding_in = { { 0, w.shape()[0] - 1 }, { 0, w.shape()[1] - 1 } };
+ auto padded_src = pad_layer(src, padding_in);
+
+ // Flip weights
+ std::vector<uint32_t> axis_v = { 0, 1 };
+ SimpleTensor<uint32_t> axis{ TensorShape(2U), DataType::U32 };
+ std::copy(axis_v.begin(), axis_v.begin() + axis.shape().x(), axis.data());
+ auto flipped_w = reverse(w, axis);
+
+ // Pad weights to have the same size as input
+ const PaddingList paddings_w = { { 0, src.shape()[0] - 1 }, { 0, src.shape()[1] - 1 } };
+ auto padded_w = pad_layer(flipped_w, paddings_w);
+
+ // Transform input and weights to frequency domain
+ auto Fsrc = rdft_2d(padded_src);
+ auto Fw = rdft_2d(padded_w);
+
+ // Perform dot product
+ auto Fdst = complex_mul_and_reduce(Fsrc, Fw);
+
+ // Transform output back to frequency domain
+ auto conv_res = ridft_2d(Fdst);
+
+ // Slice output
+ const int start_left = w.shape().x() - conv_info.pad_left() - 1;
+ const int start_top = w.shape().y() - conv_info.pad_top() - 1;
+ const int end_right = conv_res.shape().x() - (w.shape().x() - conv_info.pad_right() - 1);
+ const int end_botton = conv_res.shape().y() - (w.shape().y() - conv_info.pad_bottom() - 1);
+ return slice(conv_res, Coordinates(start_left, start_top), Coordinates(end_right, end_botton));
+}
+
+template SimpleTensor<float> rdft_1d(const SimpleTensor<float> &src);
+template SimpleTensor<float> ridft_1d(const SimpleTensor<float> &src, bool is_odd);
+template SimpleTensor<float> dft_1d(const SimpleTensor<float> &src, FFTDirection direction);
+
+template SimpleTensor<float> rdft_2d(const SimpleTensor<float> &src);
+template SimpleTensor<float> ridft_2d(const SimpleTensor<float> &src, bool is_odd);
+template SimpleTensor<float> dft_2d(const SimpleTensor<float> &src, FFTDirection direction);
+
+template SimpleTensor<float> conv2d_dft(const SimpleTensor<float> &src, const SimpleTensor<float> &w, const PadStrideInfo &conv_info);
+} // namespace reference
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
diff --git a/tests/validation/reference/DFT.h b/tests/validation/reference/DFT.h
new file mode 100644
index 0000000000..a3a10abd7f
--- /dev/null
+++ b/tests/validation/reference/DFT.h
@@ -0,0 +1,118 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_TEST_FFT_H__
+#define __ARM_COMPUTE_TEST_FFT_H__
+
+#include "tests/SimpleTensor.h"
+#include "tests/validation/Helpers.h"
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+namespace reference
+{
+enum class FFTDirection
+{
+ Forward,
+ Inverse
+};
+
+/** Performs an one dimensional DFT on a real input.
+ *
+ * @param[in] src Source tensor.
+ *
+ * @return Complex output of length n/2 + 1 due to symmetry.
+ */
+template <typename T>
+SimpleTensor<T> rdft_1d(const SimpleTensor<T> &src);
+
+/** Performs an one dimensional inverse DFT on a real input.
+ *
+ * @param[in] src Source tensor.
+ * @param[in] is_odd (Optional) Specifies if the output has odd dimensions.
+ * Is used by the inverse variant to reconstruct odd sequences.
+ *
+ * @return Complex output of length n/2 + 1 due to symmetry.
+ */
+template <typename T>
+SimpleTensor<T> ridft_1d(const SimpleTensor<T> &src, bool is_odd = false);
+
+/** Performs an one dimensional DFT on a complex input.
+ *
+ * @param[in] src Source tensor.
+ * @param[in] direction Direction of the DFT.
+ *
+ * @return Complex output of same length as input.
+ */
+template <typename T>
+SimpleTensor<T> dft_1d(const SimpleTensor<T> &src, FFTDirection direction);
+
+/** Performs a two dimensional DFT on a real input.
+ *
+ * @param[in] src Source tensor.
+ *
+ * @return Complex output of length n/2 + 1 across width due to symmetry and height of same size as the input.
+ */
+template <typename T>
+SimpleTensor<T> rdft_2d(const SimpleTensor<T> &src);
+
+/** Performs a two dimensional inverse DFT on a real input.
+ *
+ * @param[in] src Source tensor.
+ * @param[in] is_odd (Optional) Specifies if the output has odd dimensions across width.
+ * Is used by the inverse variant to reconstruct odd sequences.
+ *
+ * @return Complex output of length n/2 + 1 across width due to symmetry and height of same size as the input.
+ */
+template <typename T>
+SimpleTensor<T> ridft_2d(const SimpleTensor<T> &src, bool is_odd = false);
+
+/** Performs a two dimensional DFT on a complex input.
+ *
+ * @param[in] src Source tensor.
+ * @param[in] direction Direction of the DFT.
+ *
+ * @return Complex output of same length as input.
+ */
+template <typename T>
+SimpleTensor<T> dft_2d(const SimpleTensor<T> &src, FFTDirection direction);
+
+/** Performs and DFT based convolution on a real input.
+ *
+ * @param[in] src Source tensor.
+ * @param[in] w Weights tensor.
+ * @param[in] conv_info Convolution related metadata.
+ *
+ * @return The output tensor.
+ */
+template <typename T>
+SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w, const PadStrideInfo &conv_info);
+} // namespace reference
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_TEST_FFT_H__ */
diff --git a/tests/validation/reference/Permute.cpp b/tests/validation/reference/Permute.cpp
index 29c3c5cda8..619a787a05 100644
--- a/tests/validation/reference/Permute.cpp
+++ b/tests/validation/reference/Permute.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,11 +47,11 @@ SimpleTensor<T> permute(const SimpleTensor<T> &src, PermutationVector perm)
// Compute reference
for(int i = 0; i < src.num_elements(); ++i)
{
- Coordinates coord = index2coord(src.shape(), i);
- permute(coord, perm);
- const size_t dst_index = coord2index(dst.shape(), coord);
+ const Coordinates src_coords = index2coord(src.shape(), i);
+ Coordinates dst_coords = src_coords;
+ permute(dst_coords, perm);
- dst[dst_index] = src[i];
+ std::copy_n(static_cast<const T *>(src(src_coords)), src.num_channels(), static_cast<T *>(dst(dst_coords)));
}
return dst;