aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2018-10-02 16:41:52 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2018-12-05 11:37:14 +0000
commit0d0028ca25a47dd51260e2555b336fc9f09d1df1 (patch)
tree968e8f126a9c7d5d7d4159fbb7d906d47ad077f2 /tests/validation/fixtures/WinogradConvolutionLayerFixture.h
parent8bf622a44c70564d6a7c712473cdfac3e50ac62d (diff)
downloadComputeLibrary-0d0028ca25a47dd51260e2555b336fc9f09d1df1.tar.gz
COMPMID-1298: Fuse ReLu activation in CLWinogradOutputTransform
Change-Id: I9e6e43a5839d04c2e4b4552c05446efb0a5074cf Reviewed-on: https://review.mlplatform.org/232 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures/WinogradConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WinogradConvolutionLayerFixture.h16
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
index 9c9e634205..8f34654c3a 100644
--- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h
@@ -494,10 +494,10 @@ class WinogradOutputTransformValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type)
+ void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info = ActivationLayerInfo())
{
- _target = compute_target(input_shape, winograd_info, data_type);
- _reference = compute_reference(input_shape, winograd_info, data_type);
+ _target = compute_target(input_shape, winograd_info, data_type, act_info);
+ _reference = compute_reference(input_shape, winograd_info, data_type, act_info);
}
protected:
@@ -522,7 +522,7 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type)
+ TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type, ActivationLayerInfo act_info)
{
TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
@@ -533,7 +533,7 @@ protected:
// Create and configure function
FunctionType output_transform;
- output_transform.configure(&src, &bias, &dst, winograd_info);
+ output_transform.configure(&src, &bias, &dst, winograd_info, act_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -557,7 +557,7 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info)
{
winograd_info.output_data_layout = DataLayout::NCHW;
TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info);
@@ -570,7 +570,9 @@ protected:
fill(src, 0, -1.f, 1.f);
fill(bias, 1, -1.f, 1.f);
- return reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info);
+ const SimpleTensor<T> winograd_output = reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info);
+
+ return (act_info.enabled()) ? reference::activation_layer<T>(winograd_output, act_info) : winograd_output;
}
TensorType _target{};