From 2213d4b334567d0cb7f283090d42b5fb1b70f66b Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Fri, 27 Apr 2018 10:39:06 +0100 Subject: COMPMID-1096 - Add fast_math flag to CLConvolutionLayer COMPMID-1103 - CLWinogradConvolutionLayer mismatches Change-Id: Iceaa9482a1790ec39d2720c220261aaea8043978 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129398 Tested-by: Jenkins Reviewed-by: Giorgio Arena Reviewed-by: Georgios Pinitas --- tests/validation/CL/ConvolutionLayer.cpp | 102 ++++++++++------- tests/validation/CL/DilatedConvolutionLayer.cpp | 6 +- tests/validation/CL/Winograd.cpp | 24 +++- .../fixtures/WinogradConvolutionLayerFixture.h | 122 ++++++++++++++++++++- tests/validation/reference/GEMM.cpp | 102 ++++++++++++----- tests/validation/reference/Winograd.cpp | 7 +- tests/validation/reference/Winograd.h | 2 +- 7 files changed, 290 insertions(+), 75 deletions(-) (limited to 'tests/validation') diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp index 8685e5bbc7..a2b55a8555 100644 --- a/tests/validation/CL/ConvolutionLayer.cpp +++ b/tests/validation/CL/ConvolutionLayer.cpp @@ -73,44 +73,72 @@ const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo TEST_SUITE(CL) TEST_SUITE(ConvolutionLayer) -DATA_TEST_CASE(ValidateConvolutionMethod, framework::DatasetMode::ALL, zip(zip(zip(zip(zip( - framework::dataset::make("InputInfo", { TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0), - TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0), - TensorInfo(TensorShape(23U, 27U, 5U, 4U), 1, DataType::F32, 0), - TensorInfo(TensorShape(3U, 3U, 2U, 1U), 1, DataType::F32, 0), - TensorInfo(TensorShape(33U, 27U, 7U, 4U), 1, DataType::F32, 0) - }), - framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0), - TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0), - TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0), - TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0), - TensorInfo(TensorShape(5U, 5U, 7U, 16U), 1, DataType::F16, 0) - })), - framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0), - TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0), - TensorInfo(TensorShape(21U, 25U, 21U, 4U), 1, DataType::F32, 0), - TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32, 0), - TensorInfo(TensorShape(11U, 12U, 16U, 4U), 1, DataType::F32, 0) - })), - framework::dataset::make("ConvInfo", { PadStrideInfo(1, 2, 1, 1), - PadStrideInfo(1, 2, 1, 1), - PadStrideInfo(1, 1, 0, 0), - PadStrideInfo(2, 1, 0, 0), - PadStrideInfo(3, 2, 1, 0) - })), - framework::dataset::make("GpuTarget", { GPUTarget::BIFROST, - GPUTarget::MIDGARD, - GPUTarget::G71, - GPUTarget::MIDGARD, - GPUTarget::BIFROST - })), - - framework::dataset::make("Expected", { ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, ConvolutionMethod::GEMM })), - input_info, weights_info, output_info, conv_info, gpu_target, expected) +DATA_TEST_CASE(ValidateConvolutionMethod, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zip( + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0), + TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0), + TensorInfo(TensorShape(23U, 27U, 5U, 4U), 1, DataType::F32, 0), + TensorInfo(TensorShape(3U, 3U, 2U, 1U), 1, DataType::F32, 0), + TensorInfo(TensorShape(33U, 27U, 7U, 4U), 1, DataType::F32, 0), + TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0), + TensorInfo(TensorShape(17U, 31U, 2U), 1, DataType::F32, 0) + }), + framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0), + TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0), + TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0), + TensorInfo(TensorShape(3U, 3U, 5U, 21U), 1, DataType::F32, 0), + TensorInfo(TensorShape(5U, 5U, 7U, 16U), 1, DataType::F16, 0), + TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0), + TensorInfo(TensorShape(5U, 5U, 2U, 19U), 1, DataType::F32, 0) + })), + framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0), + TensorInfo(TensorShape(15U, 15U, 19U), 1, DataType::F32, 0), + TensorInfo(TensorShape(21U, 25U, 21U, 4U), 1, DataType::F32, 0), + TensorInfo(TensorShape(11U, 25U, 21U), 1, DataType::F32, 0), + TensorInfo(TensorShape(11U, 12U, 16U, 4U), 1, DataType::F32, 0), + TensorInfo(TensorShape(17U, 31U, 19U), 1, DataType::F32, 0), + TensorInfo(TensorShape(17U, 31U, 19U), 1, DataType::F32, 0) + })), + framework::dataset::make("ConvInfo", { PadStrideInfo(1, 2, 1, 1), + PadStrideInfo(1, 2, 1, 1), + PadStrideInfo(1, 1, 0, 0), + PadStrideInfo(2, 1, 0, 0), + PadStrideInfo(3, 2, 1, 0), + PadStrideInfo(1, 1, 2, 2), + PadStrideInfo(1, 1, 2, 2) + })), + framework::dataset::make("GpuTarget", { GPUTarget::BIFROST, + GPUTarget::MIDGARD, + GPUTarget::G71, + GPUTarget::MIDGARD, + GPUTarget::BIFROST, + GPUTarget::BIFROST, + GPUTarget::BIFROST + })), + framework::dataset::make("Dilation", { - ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(false), - &weights_info.clone()->set_is_resizable(false), - &output_info.clone()->set_is_resizable(false), conv_info, WeightsInfo(), ActivationLayerInfo(), gpu_target); + Size2D(1U, 1U), + Size2D(1U, 1U), + Size2D(1U, 1U), + Size2D(1U, 1U), + Size2D(1U, 1U), + Size2D(1U, 1U), + Size2D(2U, 1U), +})), +framework::dataset::make("EnableFastMath", { false, false, false, false, false, true, true })), +framework::dataset::make("Expected", +{ + ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, +})), +input_info, weights_info, output_info, conv_info, gpu_target, dilation, enable_fast_math, expected) +{ + ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(true), + &weights_info.clone()->set_is_resizable(true), + &output_info.clone()->set_is_resizable(true), conv_info, + WeightsInfo(), + ActivationLayerInfo(), + gpu_target, + dilation, + enable_fast_math); ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS); } TEST_SUITE_END() diff --git a/tests/validation/CL/DilatedConvolutionLayer.cpp b/tests/validation/CL/DilatedConvolutionLayer.cpp index e6a765bbe1..9ee002cc5a 100644 --- a/tests/validation/CL/DilatedConvolutionLayer.cpp +++ b/tests/validation/CL/DilatedConvolutionLayer.cpp @@ -104,9 +104,9 @@ DATA_TEST_CASE(ValidateConvolutionMethod, framework::DatasetMode::ALL, zip(zip(z framework::dataset::make("Expected", { ConvolutionMethod::GEMM, ConvolutionMethod::GEMM, ConvolutionMethod::WINOGRAD, ConvolutionMethod::GEMM, ConvolutionMethod::GEMM })), input_info, weights_info, output_info, conv_info, gpu_target, dilation, expected) { - ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(false), - &weights_info.clone()->set_is_resizable(false), - &output_info.clone()->set_is_resizable(false), conv_info, WeightsInfo(), ActivationLayerInfo(), gpu_target, dilation); + ConvolutionMethod is_valid = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(true), + &weights_info.clone()->set_is_resizable(true), + &output_info.clone()->set_is_resizable(true), conv_info, WeightsInfo(), ActivationLayerInfo(), gpu_target, dilation); ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS); } TEST_SUITE_END() diff --git a/tests/validation/CL/Winograd.cpp b/tests/validation/CL/Winograd.cpp index 30d8d751af..d892c9f77f 100644 --- a/tests/validation/CL/Winograd.cpp +++ b/tests/validation/CL/Winograd.cpp @@ -51,7 +51,8 @@ namespace validation { namespace { -constexpr AbsoluteTolerance tolerance_f32(0.001f); +constexpr AbsoluteTolerance tolerance_f32(0.0001f); +constexpr AbsoluteTolerance tolerance_fast_math_f32(0.1f); } // namespace using namespace arm_compute::misc::shape_calculator; @@ -379,6 +380,27 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFixture, framework::D // Validate output validate(CLAccessor(_target), _reference, tolerance_f32); } +TEST_SUITE(EnableFastMath) +using CLWinogradConvolutionLayerFastMathFixture = WinogradConvolutionLayerFastMathValidationFixture; +FIXTURE_DATA_TEST_CASE(RunSmall, CLWinogradConvolutionLayerFastMathFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(framework::dataset::concat(datasets::SmallWinogradConvolutionLayer3x3Dataset(), datasets::SmallWinogradConvolutionLayer5x5Dataset()), + framework::dataset::make("DataType", { DataType::F32 })), + framework::dataset::make("ActivationLayerInfo", { ActivationLayerInfo() }))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_fast_math_f32); +} + +FIXTURE_DATA_TEST_CASE(RunLarge, CLWinogradConvolutionLayerFastMathFixture, framework::DatasetMode::NIGHTLY, + combine(combine(framework::dataset::concat(datasets::LargeWinogradConvolutionLayer3x3Dataset(), datasets::LargeWinogradConvolutionLayer5x5Dataset()), + framework::dataset::make("DataType", { DataType::F32 })), + framework::dataset::make("ActivationLayerInfo", { ActivationLayerInfo() }))) +{ + // Validate output + validate(CLAccessor(_target), _reference, tolerance_fast_math_f32); +} + +TEST_SUITE_END() // EnableFastMath TEST_SUITE_END() // ConvolutionLayer TEST_SUITE_END() // Winograd diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index 249f9d5649..e15931eafb 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -35,6 +35,7 @@ #include "tests/validation/Helpers.h" #include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/ConvolutionLayer.h" +#include "tests/validation/reference/GEMM.h" #include "tests/validation/reference/Utils.h" #include "tests/validation/reference/Winograd.h" @@ -152,6 +153,123 @@ protected: SimpleTensor _reference{}; }; +template +class WinogradConvolutionLayerFastMathValidationFixture : public framework::Fixture +{ +public: + template + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataType data_type, ActivationLayerInfo act_info) + { + ARM_COMPUTE_UNUSED(dilation); + + _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info); + _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info); + } + +protected: + template + void fill(U &&tensor, int i, float min, float max) + { + switch(tensor.data_type()) + { + case DataType::F32: + { + std::uniform_real_distribution<> distribution(min, max); + library->fill(tensor, distribution, i); + break; + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + library->fill_tensor_uniform(tensor, i); + break; + } + } + } + + TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, + DataType data_type, ActivationLayerInfo act_info) + { + // Create tensors + TensorType src = create_tensor(input_shape, data_type, 1); + TensorType weights = create_tensor(weights_shape, data_type, 1); + TensorType bias = create_tensor(bias_shape, data_type, 1); + TensorType dst = create_tensor(output_shape, data_type, 1); + + // Create and configure function + FunctionType conv; + ARM_COMPUTE_EXPECT(static_cast(conv.validate(src.info(), weights.info(), bias.info(), dst.info(), info, act_info, true /* Enable fast math */)), framework::LogLevel::ERRORS); + conv.configure(&src, &weights, &bias, &dst, info, act_info, true /* Enable fast math */); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Allocate tensors + src.allocator()->allocate(); + weights.allocator()->allocate(); + dst.allocator()->allocate(); + bias.allocator()->allocate(); + + ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!weights.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!bias.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Fill tensors + fill(AccessorType(src), 0, -1.f, 1.f); + fill(AccessorType(weights), 1, -1.f, 1.f); + fill(AccessorType(bias), 2, -1.f, 1.f); + + // Compute Winograd Convolution function + conv.run(); + + return dst; + } + + SimpleTensor compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info, + DataType data_type, ActivationLayerInfo act_info) + { + // Create reference + SimpleTensor src{ input_shape, data_type, 1 }; + SimpleTensor weights{ weights_shape, data_type, 1 }; + SimpleTensor bias{ bias_shape, data_type, 1 }; + + // Fill reference + fill(src, 0, -1.f, 1.f); + fill(weights, 1, -1.f, 1.f); + fill(bias, 2, -1.f, 1.f); + + WinogradInfo winograd_info(Size2D(4U, 4U), + Size2D(weights_shape[0], weights_shape[1]), + Size2D(input_shape[0], input_shape[1]), + info, + src.data_layout()); + + // Compute tensor shapes for input, filter and output transforms + TensorShape input_transform_shape = compute_winograd_input_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); + TensorShape filter_transform_shape = compute_winograd_filter_transform_shape(TensorInfo(weights_shape, 1, data_type), winograd_info); + TensorShape batched_gemm_shape = input_transform_shape; + batched_gemm_shape[0] = filter_transform_shape[0]; + TensorShape output_transform_shape = compute_winograd_output_transform_shape(TensorInfo(batched_gemm_shape, 1, data_type), winograd_info); + + // Dummy matrix C to perform matrix multiplication + SimpleTensor dummy_c{ batched_gemm_shape, data_type, 1 }; + + // Compute Winograd-based convolution + SimpleTensor input_transform_out = reference::winograd_input_transform(src, input_transform_shape, winograd_info); + SimpleTensor filter_transform_out = reference::winograd_filter_transform(weights, filter_transform_shape, winograd_info); + SimpleTensor batched_gemm = reference::gemm(input_transform_out, filter_transform_out, dummy_c, 1.0f, 0.0f); + SimpleTensor conv_out = reference::winograd_output_transform(batched_gemm, bias, output_transform_shape, winograd_info); + + return (act_info.enabled()) ? reference::activation_layer(conv_out, act_info) : conv_out; + } + + TensorType _target{}; + SimpleTensor _reference{}; +}; + template class WinogradInputTransformValidationFixture : public framework::Fixture { @@ -373,11 +491,13 @@ protected: { // Create reference SimpleTensor src{ input_shape, data_type }; + SimpleTensor bias{ TensorShape(input_shape[0]), data_type }; // Fill reference fill(src, 0, -1.f, 1.f); + fill(bias, 1, 0.0f, 0.0f); // Fill with zeros as we validate just the output transform without bias contribution - return reference::winograd_output_transform(src, output_shape, winograd_info); + return reference::winograd_output_transform(src, bias, output_shape, winograd_info); } TensorType _target{}; diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp index 77d025ec8e..f9dcfcbdd0 100644 --- a/tests/validation/reference/GEMM.cpp +++ b/tests/validation/reference/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,23 +41,44 @@ SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const S SimpleTensor dst{ c.shape(), c.data_type(), 1, c.fixed_point_position() }; // Compute reference - const int M = dst.shape().y(); - const int N = dst.shape().x(); + const int M = a.shape().y(); + const int N = b.shape().x(); const int K = a.shape().x(); + const int D = a.shape().z(); // Number of matrices in a batch + const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) + + const int a_stride_z = K * M; + const int a_stride_w = K * M * D; + + const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions + const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions - for(int row = 0; row < M; ++row) + const int c_stride_z = N * M; + const int c_stride_w = N * M * D; + + for(int w = 0; w < W; ++w) { - for(int col = 0; col < N; ++col) + for(int depth = 0; depth < D; ++depth) { - T acc(0); + const int base_addr_a = depth * a_stride_z + w * a_stride_w; + const int base_addr_b = depth * b_stride_z + w * b_stride_w; + const int base_addr_c = depth * c_stride_z + w * c_stride_w; - for(int k = 0; k < K; ++k) + for(int row = 0; row < M; ++row) { - acc += a[row * K + k] * b[k * N + col]; + for(int col = 0; col < N; ++col) + { + T acc(0); + + for(int k = 0; k < K; ++k) + { + acc += a[base_addr_a + k + row * K] * b[base_addr_b + col + k * N]; + } + + // Finalize the result: alpha * A * B + beta * C + dst[base_addr_c + col + row * N] = alpha * acc + beta * c[base_addr_c + col + row * N]; + } } - - // Finalize the result: alpha * A * B + beta * C - dst[col + row * N] = alpha * acc + beta * c[col + row * N]; } } @@ -75,37 +96,58 @@ SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const S // Compute reference using promoted_type = fixed_point_arithmetic::traits::promote_t; - const int M = dst.shape().y(); - const int N = dst.shape().x(); - const int K = a.shape().x(); - const int fixed_point_position = a.fixed_point_position(); + const int M = dst.shape().y(); + const int N = dst.shape().x(); + const int K = a.shape().x(); + const int D = a.shape().z(); // Number of matrices in a batch + const int W = a.shape()[3]; // Number of batched-gemm (Winograd case) + + const int a_stride_z = K * M; + const int a_stride_w = K * M * D; + + const int b_stride_z = b.shape().num_dimensions() > 2 ? N * K : 0; // Do not slide the matrix B along the 3th dimension in case matrix B has less than 3 dimensions + const int b_stride_w = b.shape().num_dimensions() > 3 ? K * N * D : 0; // Do not slide the matrix B along the 4th dimension in case matrix B has less than 4 dimensions + + const int c_stride_z = N * M; + const int c_stride_w = N * M * D; + const int fixed_point_position = a.fixed_point_position(); const fixed_point alpha_q(alpha, fixed_point_position); const fixed_point beta_q(beta, fixed_point_position); - for(int row = 0; row < M; ++row) + for(int w = 0; w < W; ++w) { - for(int col = 0; col < N; ++col) + for(int depth = 0; depth < D; ++depth) { - fixed_point acc_q(0, fixed_point_position); + const int base_addr_a = depth * a_stride_z + w * a_stride_w; + const int base_addr_b = depth * b_stride_z + w * b_stride_w; + const int base_addr_c = depth * c_stride_z + w * c_stride_w; - for(int k = 0; k < K; ++k) + for(int row = 0; row < M; ++row) { - const fixed_point a0_q(a[row * K + k], fixed_point_position, true); - const fixed_point b0_q(b[k * N + col], fixed_point_position, true); + for(int col = 0; col < N; ++col) + { + fixed_point acc_q(0, fixed_point_position); - acc_q = acc_q + (a0_q * b0_q); - } + for(int k = 0; k < K; ++k) + { + const fixed_point a0_q(a[base_addr_a + row * K + k], fixed_point_position, true); + const fixed_point b0_q(b[base_addr_b + k * N + col], fixed_point_position, true); + + acc_q = acc_q + (a0_q * b0_q); + } - // Finalize the result: alpha * A * B + beta * C - const fixed_point c0_q(c[col + row * N], fixed_point_position, true); + // Finalize the result: alpha * A * B + beta * C + const fixed_point c0_q(c[base_addr_c + col + row * N], fixed_point_position, true); - fixed_point res_q(acc_q); - res_q = alpha_q * res_q; - res_q = res_q + (beta_q * c0_q); + fixed_point res_q(acc_q); + res_q = alpha_q * res_q; + res_q = res_q + (beta_q * c0_q); - // Store the result - dst[col + row * N] = res_q.raw(); + // Store the result + dst[base_addr_c + col + row * N] = res_q.raw(); + } + } } } diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp index 75b1b51d46..194a78e95f 100644 --- a/tests/validation/reference/Winograd.cpp +++ b/tests/validation/reference/Winograd.cpp @@ -331,7 +331,7 @@ SimpleTensor winograd_filter_transform(const SimpleTensor &in, const Tenso } template -SimpleTensor winograd_output_transform(const SimpleTensor &in, const TensorShape &output_shape, const WinogradInfo &winograd_info) +SimpleTensor winograd_output_transform(const SimpleTensor &in, const SimpleTensor &b, const TensorShape &output_shape, const WinogradInfo &winograd_info) { ARM_COMPUTE_ERROR_ON_MSG(winograd_info.output_data_layout != DataLayout::NCHW, "Only supported NCHW data format"); @@ -444,6 +444,9 @@ SimpleTensor winograd_output_transform(const SimpleTensor &in, const Tenso if((xo + xi < w_out) && (yo + yi < h_out)) { out[output_offset + yi * stridey_out + xi] = output_tile[xi + yi * out_tile_w]; + + // Add bias + out[output_offset + yi * stridey_out + xi] += b[zo]; } } } @@ -456,7 +459,7 @@ SimpleTensor winograd_output_transform(const SimpleTensor &in, const Tenso template SimpleTensor winograd_filter_transform(const SimpleTensor &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); template SimpleTensor winograd_input_transform(const SimpleTensor &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); -template SimpleTensor winograd_output_transform(const SimpleTensor &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); +template SimpleTensor winograd_output_transform(const SimpleTensor &in, const SimpleTensor &b, const TensorShape &output_shape, const WinogradInfo &winograd_info); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/Winograd.h b/tests/validation/reference/Winograd.h index 29181f1142..b74c2c3e29 100644 --- a/tests/validation/reference/Winograd.h +++ b/tests/validation/reference/Winograd.h @@ -51,7 +51,7 @@ template SimpleTensor winograd_filter_transform(const SimpleTensor &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); template -SimpleTensor winograd_output_transform(const SimpleTensor &in, const TensorShape &output_shape, const WinogradInfo &winograd_info); +SimpleTensor winograd_output_transform(const SimpleTensor &in, const SimpleTensor &b, const TensorShape &output_shape, const WinogradInfo &winograd_info); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1