aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WinogradLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/WinogradLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WinogradLayerFixture.h16
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h
index a86f24f35e..5210cbf720 100644
--- a/tests/validation/fixtures/WinogradLayerFixture.h
+++ b/tests/validation/fixtures/WinogradLayerFixture.h
@@ -33,6 +33,7 @@
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/Helpers.h"
+#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/ConvolutionLayer.h"
#include "tests/validation/reference/Utils.h"
#include "tests/validation/reference/Winograd.h"
@@ -52,12 +53,12 @@ class WinogradConvolutionLayerValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataType data_type)
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataType data_type, ActivationLayerInfo act_info)
{
ARM_COMPUTE_UNUSED(dilation);
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type);
- _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info);
}
protected:
@@ -82,7 +83,7 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type)
+ DataType data_type, ActivationLayerInfo act_info)
{
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type, 1);
@@ -92,7 +93,7 @@ protected:
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info);
+ conv.configure(&src, &weights, &bias, &dst, info, act_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -122,7 +123,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type)
+ DataType data_type, ActivationLayerInfo act_info)
{
// Create reference
SimpleTensor<T> src{ input_shape, data_type, 1 };
@@ -134,7 +135,8 @@ protected:
fill(weights, 1, -1.f, 1.f);
fill(bias, 2, -1.f, 1.f);
- return reference::convolution_layer<T>(src, weights, bias, output_shape, info);
+ return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info), act_info) : reference::convolution_layer<T>(src, weights, bias,
+ output_shape, info);
}
TensorType _target{};