aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-02-22 16:17:20 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commit7e4b23953e885e58d655a7d9f35a1afcc38365e4 (patch)
tree4f5a3f6535aae10a36482bd4f996d3427ac77080 /tests/validation/fixtures
parent66c656a1d10831d8311f7797b285faa2c30bcb3f (diff)
downloadComputeLibrary-7e4b23953e885e58d655a7d9f35a1afcc38365e4.tar.gz
COMPMID-935 - Implementing Convolution with Winograd on OpenCL (part 2)
Implemented Winograd Filter Transform 3x3 on OpenCL Change-Id: I8f2b2dd938c5c000ef7ce392a37fb7b8b4202a4e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122708 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/WinogradLayerFixture.h84
1 files changed, 81 insertions, 3 deletions
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h
index 95e331560d..bfe1efce3b 100644
--- a/tests/validation/fixtures/WinogradLayerFixture.h
+++ b/tests/validation/fixtures/WinogradLayerFixture.h
@@ -27,7 +27,6 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
#include "tests/IAccessor.h"
@@ -42,8 +41,6 @@
namespace arm_compute
{
-class NEWinogradLayer;
-
namespace test
{
namespace validation
@@ -224,6 +221,87 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class WinogradFilterTransformValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, bool is_nchw_format, DataType data_type)
+ {
+ TensorShape output_shape = compute_winograd_filter_transform_shape(TensorInfo(input_shape, 1, data_type));
+
+ _target = compute_target(input_shape, output_shape, is_nchw_format, data_type);
+ _reference = compute_reference(input_shape, output_shape, is_nchw_format, data_type);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i, float min, float max)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ library->fill_tensor_uniform(tensor, i);
+ break;
+ }
+ }
+ }
+
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, DataType data_type)
+ {
+ ARM_COMPUTE_UNUSED(is_nchw_format);
+
+ // Create tensors
+ TensorType src = create_tensor<TensorType>(input_shape, data_type);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type);
+
+ // Create and configure function
+ FunctionType filter_transform;
+ filter_transform.configure(&src, &dst);
+
+ ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Allocate tensors
+ src.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Fill tensors
+ fill(AccessorType(src), 0, -1.f, 1.f);
+
+ filter_transform.run();
+
+ return dst;
+ }
+
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, DataType data_type)
+ {
+ ARM_COMPUTE_ERROR_ON(!is_nchw_format);
+
+ // Create reference
+ SimpleTensor<T> src{ input_shape, data_type, 1 };
+
+ // Fill reference
+ fill(src, 0, -1.f, 1.f);
+
+ return reference::winograd_filter_transform<T>(src, output_shape);
+ }
+
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+};
} // namespace validation
} // namespace test
} // namespace arm_compute