aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h32
1 files changed, 19 insertions, 13 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 3d073e3f79..1bcffed526 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -33,6 +33,7 @@
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/Helpers.h"
+#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/ConvolutionLayer.h"
#include "tests/validation/reference/Utils.h"
@@ -55,7 +56,7 @@ public:
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights,
- DataType data_type, int fractional_bits, QuantizationInfo quantization_info)
+ DataType data_type, int fractional_bits, QuantizationInfo quantization_info, ActivationLayerInfo act_info)
{
_data_type = data_type;
_is_quantized = is_data_type_quantized_asymmetric(data_type);
@@ -63,8 +64,8 @@ public:
_fractional_bits = fractional_bits;
_quantization_info = quantization_info;
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation);
- _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, dilation, act_info);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation, act_info);
}
protected:
@@ -98,7 +99,7 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- bool reshape_weights, const Size2D &dilation)
+ bool reshape_weights, const Size2D &dilation, const ActivationLayerInfo act_info)
{
const bool is_optimised = std::is_same<FunctionType, NEConvolutionLayer>::value && _data_type == DataType::F32;
@@ -140,7 +141,7 @@ protected:
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation);
+ conv.configure(&src, &weights, &bias, &dst, info, weights_info, dilation, act_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -210,7 +211,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- const Size2D &dilation)
+ const Size2D &dilation, const ActivationLayerInfo act_info)
{
// Create reference
SimpleTensor<T> src{ input_shape, _data_type, 1, _fractional_bits, _quantization_info };
@@ -222,7 +223,9 @@ protected:
fill(weights, 1);
fill(bias, 2);
- return reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation);
+ return (act_info.enabled()) ? reference::activation_layer<T>(reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation),
+ act_info) :
+ reference::convolution_layer<T>(src, weights, bias, output_shape, info, dilation);
}
TensorType _target{};
@@ -283,10 +286,12 @@ class ConvolutionValidationFixture : public ConvolutionValidationGenericFixture<
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type)
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
+ ActivationLayerInfo act_info)
{
ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, 0,
- QuantizationInfo());
+ QuantizationInfo(),
+ act_info);
}
};
@@ -296,10 +301,11 @@ class ConvolutionValidationFixedPointFixture : public ConvolutionValidationGener
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
- int fractional_bits)
+ int fractional_bits,
+ ActivationLayerInfo act_info)
{
ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, fractional_bits,
- QuantizationInfo());
+ QuantizationInfo(), act_info);
}
};
@@ -309,10 +315,10 @@ class ConvolutionValidationQuantizedFixture : public ConvolutionValidationGeneri
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, bool reshape_weights, DataType data_type,
- QuantizationInfo quantization_info)
+ QuantizationInfo quantization_info, ActivationLayerInfo act_info)
{
ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type, 0,
- quantization_info);
+ quantization_info, act_info);
}
};
} // namespace validation