diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2018-02-22 16:17:20 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:49:16 +0000 |
commit | 7e4b23953e885e58d655a7d9f35a1afcc38365e4 (patch) | |
tree | 4f5a3f6535aae10a36482bd4f996d3427ac77080 /tests/validation/fixtures | |
parent | 66c656a1d10831d8311f7797b285faa2c30bcb3f (diff) | |
download | ComputeLibrary-7e4b23953e885e58d655a7d9f35a1afcc38365e4.tar.gz |
COMPMID-935 - Implementing Convolution with Winograd on OpenCL (part 2)
Implemented Winograd Filter Transform 3x3 on OpenCL
Change-Id: I8f2b2dd938c5c000ef7ce392a37fb7b8b4202a4e
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122708
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/WinogradLayerFixture.h | 84 |
1 files changed, 81 insertions, 3 deletions
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h index 95e331560d..bfe1efce3b 100644 --- a/tests/validation/fixtures/WinogradLayerFixture.h +++ b/tests/validation/fixtures/WinogradLayerFixture.h @@ -27,7 +27,6 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" #include "tests/AssetsLibrary.h" #include "tests/Globals.h" #include "tests/IAccessor.h" @@ -42,8 +41,6 @@ namespace arm_compute { -class NEWinogradLayer; - namespace test { namespace validation @@ -224,6 +221,87 @@ protected: TensorType _target{}; SimpleTensor<T> _reference{}; }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +class WinogradFilterTransformValidationFixture : public framework::Fixture +{ +public: + template <typename...> + void setup(TensorShape input_shape, bool is_nchw_format, DataType data_type) + { + TensorShape output_shape = compute_winograd_filter_transform_shape(TensorInfo(input_shape, 1, data_type)); + + _target = compute_target(input_shape, output_shape, is_nchw_format, data_type); + _reference = compute_reference(input_shape, output_shape, is_nchw_format, data_type); + } + +protected: + template <typename U> + void fill(U &&tensor, int i, float min, float max) + { + switch(tensor.data_type()) + { + case DataType::F32: + { + std::uniform_real_distribution<> distribution(min, max); + library->fill(tensor, distribution, i); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + library->fill_tensor_uniform(tensor, i); + break; + } + } + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, DataType data_type) + { + ARM_COMPUTE_UNUSED(is_nchw_format); + + // Create tensors + TensorType src = create_tensor<TensorType>(input_shape, data_type); + TensorType dst = create_tensor<TensorType>(output_shape, data_type); + + // Create and configure function + FunctionType filter_transform; + filter_transform.configure(&src, &dst); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + src.allocator()->allocate(); + dst.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(src), 0, -1.f, 1.f); + + filter_transform.run(); + + return dst; + } + + SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, DataType data_type) + { + ARM_COMPUTE_ERROR_ON(!is_nchw_format); + + // Create reference + SimpleTensor<T> src{ input_shape, data_type, 1 }; + + // Fill reference + fill(src, 0, -1.f, 1.f); + + return reference::winograd_filter_transform<T>(src, output_shape); + } + + TensorType _target{}; + SimpleTensor<T> _reference{}; +}; } // namespace validation } // namespace test } // namespace arm_compute |