diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2018-05-03 15:57:48 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:52:35 +0000 |
commit | a3221e6772dc371cf5de7e525bf5c22b58ad6d08 (patch) | |
tree | 14d224e07d92dbbd97966de0b6b0aa8e6a288022 /tests/validation/fixtures | |
parent | 20b4313365ea2ed31f59fd757f68f791f076e6bc (diff) | |
download | ComputeLibrary-a3221e6772dc371cf5de7e525bf5c22b58ad6d08.tar.gz |
COMPMID-1106 Add fast math support in NEWinogradConvolutionLayer
Change-Id: I5fcbbb3b6f22204f0aaebbc319dfdf03593577e8
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/130067
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/WinogradConvolutionLayerFixture.h | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index e15931eafb..ef596e0bae 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -153,7 +153,7 @@ protected: SimpleTensor<T> _reference{}; }; -template <typename TensorType, typename AccessorType, typename FunctionType, typename T> +template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool use_bias = true> class WinogradConvolutionLayerFastMathValidationFixture : public framework::Fixture { public: @@ -198,8 +198,9 @@ protected: // Create and configure function FunctionType conv; - ARM_COMPUTE_EXPECT(static_cast<bool>(conv.validate(src.info(), weights.info(), bias.info(), dst.info(), info, act_info, true /* Enable fast math */)), framework::LogLevel::ERRORS); - conv.configure(&src, &weights, &bias, &dst, info, act_info, true /* Enable fast math */); + ARM_COMPUTE_EXPECT(static_cast<bool>(conv.validate(src.info(), weights.info(), (use_bias) ? bias.info() : nullptr, dst.info(), info, act_info, true /* Enable fast math */)), + framework::LogLevel::ERRORS); + conv.configure(&src, &weights, (use_bias) ? &bias : nullptr, &dst, info, act_info, true /* Enable fast math */); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -239,7 +240,14 @@ protected: // Fill reference fill(src, 0, -1.f, 1.f); fill(weights, 1, -1.f, 1.f); - fill(bias, 2, -1.f, 1.f); + if(use_bias) + { + fill(bias, 2, -1.f, 1.f); + } + else + { + fill(bias, 2, 0.f, 0.f); + } WinogradInfo winograd_info(Size2D(4U, 4U), Size2D(weights_shape[0], weights_shape[1]), |