diff options
author | Manuel Bottini <manuel.bottini@arm.com> | 2018-10-02 16:41:52 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-12-05 11:37:14 +0000 |
commit | 0d0028ca25a47dd51260e2555b336fc9f09d1df1 (patch) | |
tree | 968e8f126a9c7d5d7d4159fbb7d906d47ad077f2 /tests/validation/fixtures | |
parent | 8bf622a44c70564d6a7c712473cdfac3e50ac62d (diff) | |
download | ComputeLibrary-0d0028ca25a47dd51260e2555b336fc9f09d1df1.tar.gz |
COMPMID-1298: Fuse ReLu activation in CLWinogradOutputTransform
Change-Id: I9e6e43a5839d04c2e4b4552c05446efb0a5074cf
Reviewed-on: https://review.mlplatform.org/232
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/WinogradConvolutionLayerFixture.h | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index 9c9e634205..8f34654c3a 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -494,10 +494,10 @@ class WinogradOutputTransformValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type) + void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info = ActivationLayerInfo()) { - _target = compute_target(input_shape, winograd_info, data_type); - _reference = compute_reference(input_shape, winograd_info, data_type); + _target = compute_target(input_shape, winograd_info, data_type, act_info); + _reference = compute_reference(input_shape, winograd_info, data_type, act_info); } protected: @@ -522,7 +522,7 @@ protected: } } - TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type, ActivationLayerInfo act_info) { TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); @@ -533,7 +533,7 @@ protected: // Create and configure function FunctionType output_transform; - output_transform.configure(&src, &bias, &dst, winograd_info); + output_transform.configure(&src, &bias, &dst, winograd_info, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -557,7 +557,7 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info) { winograd_info.output_data_layout = DataLayout::NCHW; TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); @@ -570,7 +570,9 @@ protected: fill(src, 0, -1.f, 1.f); fill(bias, 1, -1.f, 1.f); - return reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info); + const SimpleTensor<T> winograd_output = reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info); + + return (act_info.enabled()) ? reference::activation_layer<T>(winograd_output, act_info) : winograd_output; } TensorType _target{}; |