diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-03-26 17:23:28 +0000 |
---|---|---|
committer | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2019-04-11 09:34:26 +0000 |
commit | 8be9148814b88e5b0cabd5a4d2b1f4ff470a8c1c (patch) | |
tree | 760658b8c7b8917379467bd3fc119a5502faa850 /tests/validation/fixtures/FFTFixture.h | |
parent | a50e702289af66944e860eafc7f3b32f6c5f30be (diff) | |
download | ComputeLibrary-8be9148814b88e5b0cabd5a4d2b1f4ff470a8c1c.tar.gz |
COMPMID-1959: Implements 2D FFT on OpenCL
Change-Id: I73cf3984a5463acc854c8a59dc2bd9a5234cd99c
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-on: https://review.mlplatform.org/c/936
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/fixtures/FFTFixture.h')
-rw-r--r-- | tests/validation/fixtures/FFTFixture.h | 138 |
1 files changed, 133 insertions, 5 deletions
diff --git a/tests/validation/fixtures/FFTFixture.h b/tests/validation/fixtures/FFTFixture.h index 8e3c01eaff..1aaa5965b2 100644 --- a/tests/validation/fixtures/FFTFixture.h +++ b/tests/validation/fixtures/FFTFixture.h @@ -31,6 +31,8 @@ #include "tests/IAccessor.h" #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" +#include "tests/validation/reference/ActivationLayer.h" +#include "tests/validation/reference/ConvolutionLayer.h" #include "tests/validation/reference/DFT.h" #include <random> @@ -41,7 +43,7 @@ namespace test { namespace validation { -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename InfoType, typename T> class FFTValidationFixture : public framework::Fixture { public: @@ -68,8 +70,8 @@ protected: TensorType dst = create_tensor<TensorType>(shape, data_type, 2); // Create and configure function - FunctionType fft1d; - fft1d.configure(&src, &dst, FFT1DInfo()); + FunctionType fft; + fft.configure(&src, &dst, InfoType()); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -85,7 +87,7 @@ protected: fill(AccessorType(src)); // Compute function - fft1d.run(); + fft.run(); return dst; } @@ -97,12 +99,138 @@ protected: // Fill reference fill(src); + if(std::is_same<InfoType, FFT1DInfo>::value) + { + return reference::dft_1d(src, reference::FFTDirection::Forward); + } + else + { + return reference::dft_2d(src, reference::FFTDirection::Forward); + } + } + + TensorType _target{}; + SimpleTensor<T> _reference{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class FFTConvolutionValidationGenericFixture : public framework::Fixture +{ +public: + template <typename...> + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, + DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info) + { + _data_type = data_type; + _data_layout = data_layout; + + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info); + } + +protected: + template <typename U> + void fill(U &&tensor, int i) + { + switch(tensor.data_type()) + { + case DataType::F32: + { + std::uniform_real_distribution<> distribution(-1.0f, 1.0f); + library->fill(tensor, distribution, i); + break; + } + default: + library->fill_tensor_uniform(tensor, i); + } + } + + TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &info, + const Size2D &dilation, const ActivationLayerInfo act_info) + { + ARM_COMPUTE_UNUSED(dilation); + ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0); + + if(_data_layout == DataLayout::NHWC) + { + permute(input_shape, PermutationVector(2U, 0U, 1U)); + permute(weights_shape, PermutationVector(2U, 0U, 1U)); + permute(output_shape, PermutationVector(2U, 0U, 1U)); + } + + // Create tensors + TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, QuantizationInfo(), _data_layout); + TensorType weights = create_tensor<TensorType>(weights_shape, _data_type, 1, QuantizationInfo(), _data_layout); + TensorType bias = create_tensor<TensorType>(bias_shape, _data_type, 1, QuantizationInfo(), _data_layout); + TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, QuantizationInfo(), _data_layout); + + // Create and configure function + FunctionType conv; + conv.configure(&src, &weights, &bias, &dst, info, act_info); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + src.allocator()->allocate(); + weights.allocator()->allocate(); + bias.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(src), 0); + fill(AccessorType(weights), 1); + fill(AccessorType(bias), 2); + + // Compute convolution function + conv.run(); + + return dst; + } - return reference::dft_1d(src, reference::FFTDirection::Forward); + SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, + const Size2D &dilation, const ActivationLayerInfo act_info) + { + ARM_COMPUTE_ERROR_ON((input_shape[2] % weights_shape[2]) != 0); + + // Create reference + SimpleTensor<T> src{ input_shape, _data_type, 1 }; + SimpleTensor<T> weights{ weights_shape, _data_type, 1 }; + SimpleTensor<T> bias{ bias_shape, _data_type, 1 }; + + // Fill reference + fill(src, 0); + fill(weights, 1); + fill(bias, 2); + + return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation), act_info) : reference::convolution_layer<T>(src, + weights, bias, output_shape, info, dilation); } TensorType _target{}; SimpleTensor<T> _reference{}; + DataType _data_type{}; + DataLayout _data_layout{}; +}; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class FFTConvolutionValidationFixture : public FFTConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T> +{ +public: + template <typename...> + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, + DataType data_type, DataLayout data_layout, ActivationLayerInfo act_info) + { + FFTConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, + data_type, data_layout, act_info); + } }; } // namespace validation } // namespace test |