aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-04-23 17:41:22 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:52:54 +0000
commit3695f9af9db2c14acee9af2fd68c44c737faa6ce (patch)
tree87aa336d6263cb00b01d5277b19178e80782f57a /tests/validation/fixtures/WinogradConvolutionLayerFixture.h
parentb62280aca3148dd6762e57e5af3da0cb0a9e2db5 (diff)
downloadComputeLibrary-3695f9af9db2c14acee9af2fd68c44c737faa6ce.tar.gz
COMPMID-1048 Add NHWC data format support to Winograd output transform 4x4_3x3
https://confluence.arm.com/display/MLENG/Winograd+Output+Transform%3A+NCHW+vs+NHWC+on+OpenCL Change-Id: I6995f5cef759ba70ebd96d545b952041b6f1f36e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/128729 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/WinogradConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WinogradConvolutionLayerFixture.h16
1 files changed, 10 insertions, 6 deletions
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
index ef596e0bae..e23368add6 100644
--- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
@@ -36,6 +36,7 @@
#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/ConvolutionLayer.h"
#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/Utils.h"
#include "tests/validation/reference/Winograd.h"
@@ -440,10 +441,8 @@ public:
template <typename...>
void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type)
{
- TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
-
- _target = compute_target(input_shape, output_shape, winograd_info, data_type);
- _reference = compute_reference(input_shape, output_shape, winograd_info, data_type);
+ _target = compute_target(input_shape, winograd_info, data_type);
+ _reference = compute_reference(input_shape, winograd_info, data_type);
}
protected:
@@ -467,8 +466,10 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type)
{
+ TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
+
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type);
TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), winograd_info.output_data_layout);
@@ -495,8 +496,11 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const WinogradInfo &winograd_info, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type)
{
+ winograd_info.output_data_layout = DataLayout::NCHW;
+ TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
+
// Create reference
SimpleTensor<T> src{ input_shape, data_type };
SimpleTensor<T> bias{ TensorShape(input_shape[0]), data_type };