aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WinogradLayerFixture.h
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2018-02-12 14:59:19 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commit3f217ec4ff11e20fe686beb9a28d0bbd80a56cd6 (patch)
tree81db8baab925af5b416b66d0328be2eb49543824 /tests/validation/fixtures/WinogradLayerFixture.h
parentd9eb27597eabe5b7c17520f4f9b3f8a282d72573 (diff)
downloadComputeLibrary-3f217ec4ff11e20fe686beb9a28d0bbd80a56cd6.tar.gz
COMPMID-908 - Merge Activation layer with Convolution Layer (NEON. CL, GLES)
Change-Id: Iab06d0768ecf805b841e601185608aae88cf9166 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/120874 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/WinogradLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WinogradLayerFixture.h16
1 files changed, 9 insertions, 7 deletions
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h
index a86f24f35e..5210cbf720 100644
--- a/tests/validation/fixtures/WinogradLayerFixture.h
+++ b/tests/validation/fixtures/WinogradLayerFixture.h
@@ -33,6 +33,7 @@
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/Helpers.h"
+#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/ConvolutionLayer.h"
#include "tests/validation/reference/Utils.h"
#include "tests/validation/reference/Winograd.h"
@@ -52,12 +53,12 @@ class WinogradConvolutionLayerValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataType data_type)
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataType data_type, ActivationLayerInfo act_info)
{
ARM_COMPUTE_UNUSED(dilation);
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type);
- _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, act_info);
}
protected:
@@ -82,7 +83,7 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type)
+ DataType data_type, ActivationLayerInfo act_info)
{
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type, 1);
@@ -92,7 +93,7 @@ protected:
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info);
+ conv.configure(&src, &weights, &bias, &dst, info, act_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -122,7 +123,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type)
+ DataType data_type, ActivationLayerInfo act_info)
{
// Create reference
SimpleTensor<T> src{ input_shape, data_type, 1 };
@@ -134,7 +135,8 @@ protected:
fill(weights, 1, -1.f, 1.f);
fill(bias, 2, -1.f, 1.f);
- return reference::convolution_layer<T>(src, weights, bias, output_shape, info);
+ return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info), act_info) : reference::convolution_layer<T>(src, weights, bias,
+ output_shape, info);
}
TensorType _target{};