diff options
Diffstat (limited to 'tests/validation/fixtures/WinogradConvolutionLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/WinogradConvolutionLayerFixture.h | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h index 9c9e634205..8f34654c3a 100644 --- a/tests/validation/fixtures/WinogradConvolutionLayerFixture.h +++ b/tests/validation/fixtures/WinogradConvolutionLayerFixture.h @@ -494,10 +494,10 @@ class WinogradOutputTransformValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type) + void setup(TensorShape input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info = ActivationLayerInfo()) { - _target = compute_target(input_shape, winograd_info, data_type); - _reference = compute_reference(input_shape, winograd_info, data_type); + _target = compute_target(input_shape, winograd_info, data_type, act_info); + _reference = compute_reference(input_shape, winograd_info, data_type, act_info); } protected: @@ -522,7 +522,7 @@ protected: } } - TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type) + TensorType compute_target(const TensorShape &input_shape, const WinogradInfo &winograd_info, DataType data_type, ActivationLayerInfo act_info) { TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); @@ -533,7 +533,7 @@ protected: // Create and configure function FunctionType output_transform; - output_transform.configure(&src, &bias, &dst, winograd_info); + output_transform.configure(&src, &bias, &dst, winograd_info, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(bias.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -557,7 +557,7 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type) + SimpleTensor<T> compute_reference(const TensorShape &input_shape, WinogradInfo winograd_info, DataType data_type, ActivationLayerInfo act_info) { winograd_info.output_data_layout = DataLayout::NCHW; TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), winograd_info); @@ -570,7 +570,9 @@ protected: fill(src, 0, -1.f, 1.f); fill(bias, 1, -1.f, 1.f); - return reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info); + const SimpleTensor<T> winograd_output = reference::winograd_output_transform<T>(src, bias, output_shape, winograd_info); + + return (act_info.enabled()) ? reference::activation_layer<T>(winograd_output, act_info) : winograd_output; } TensorType _target{}; |