From 0d0028ca25a47dd51260e2555b336fc9f09d1df1 Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Tue, 2 Oct 2018 16:41:52 +0100 Subject: COMPMID-1298: Fuse ReLu activation in CLWinogradOutputTransform Change-Id: I9e6e43a5839d04c2e4b4552c05446efb0a5074cf Reviewed-on: https://review.mlplatform.org/232 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- .../fixtures/WinogradConvolutionLayerFixture.h | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'tests/validation/fixtures/WinogradConvolutionLayerFixture.h') diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index 9c9e634205..8f34654c3a 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -494,10 +494,10 @@ class WinogradOutputTransformValidationFixture : public framework::Fixture { public: template - void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type) + void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info = ActivationLayerInfo()) { - _target = compute_target(input_shape, winograd_info, data_type); - _reference = compute_reference(input_shape, winograd_info, data_type); + _target = compute_target(input_shape, winograd_info, data_type, act_info); + _reference = compute_reference(input_shape, winograd_info, data_type, act_info); } protected: @@ -522,7 +522,7 @@ protected: } } - TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type, ActivationLayerInfo act_info) { TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); @@ -533,7 +533,7 @@ protected: // Create and configure function FunctionType output_transform; - output_transform.configure(&src, &bias, &dst, winograd_info); + output_transform.configure(&src, &bias, &dst, winograd_info, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -557,7 +557,7 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type) + SimpleTensor compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info) { winograd_info.output_data_layout = DataLayout::NCHW; TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); @@ -570,7 +570,9 @@ protected: fill(src, 0, -1.f, 1.f); fill(bias, 1, -1.f, 1.f); - return reference::winograd_output_transform(src, bias, output_shape, winograd_info); + const SimpleTensor winograd_output = reference::winograd_output_transform(src, bias, output_shape, winograd_info); + + return (act_info.enabled()) ? reference::activation_layer(winograd_output, act_info) : winograd_output; } TensorType _target{}; -- cgit v1.2.1