aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-04-04 17:44:26 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:50:48 +0000
commit7657224de2b697a8a92cccf26d98e53ccd7c1a03 (patch)
tree1dcfa4541dbaf753854a628c93991652158d373e /tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
parente74b201ca1abca040ca9f30837fdf19aa610e7c4 (diff)
downloadComputeLibrary-7657224de2b697a8a92cccf26d98e53ccd7c1a03.tar.gz
COMPMID-926 Add depth multiplier support to NEON/CL/GLES depthwise convolution
Change-Id: I03f32c62350e5ea43e77bb15fc5a832d83719e3b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/126657 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h37
1 files changed, 25 insertions, 12 deletions
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
index bb756f806e..b7bca8dbf3 100644
--- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
#include "tests/IAccessor.h"
@@ -44,6 +45,8 @@ namespace test
{
namespace validation
{
+using namespace arm_compute::misc::shape_calculator;
+
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class DepthwiseConvolutionLayerValidationGenericFixture : public framework::Fixture
{
@@ -52,12 +55,20 @@ public:
public:
template <typename...>
- void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
+ void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, unsigned int depth_multiplier, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
{
- _quantization_info = quantization_info;
- _data_type = data_type;
+ _quantization_info = quantization_info;
+ _data_type = data_type;
+ const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
+
+ TensorShape weights_shape(kernel_size.width, kernel_size.height);
+
+ const TensorInfo in_info(in_shape, 1, data_type);
+ const TensorInfo we_info(weights_shape, 1, data_type);
+ TensorShape out_shape = compute_depthwise_convolution_shape(in_info, we_info, pad_stride_info, depth_multiplier);
+
+ weights_shape.set(2, out_shape.z());
const TensorShape biases_shape(weights_shape[2]);
- const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
if(data_layout == DataLayout::NHWC)
{
@@ -66,8 +77,8 @@ public:
permute(out_shape, PermutationVector(2U, 0U, 1U));
}
- _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout);
- _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout);
+ _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, depth_multiplier, data_type, bias_data_type, quantization_info, data_layout);
+ _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, depth_multiplier, data_type, bias_data_type, quantization_info, data_layout);
}
protected:
@@ -101,6 +112,7 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info,
+ unsigned int depth_multiplier,
const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout)
{
// Create tensors
@@ -111,7 +123,7 @@ protected:
// Create Depthwise Convolution configure function
FunctionType dwc;
- dwc.configure(&src, &weights, &biases, &dst, pad_stride_info);
+ dwc.configure(&src, &weights, &biases, &dst, pad_stride_info, depth_multiplier);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -141,6 +153,7 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info,
+ unsigned int depth_multiplier,
const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout)
{
SimpleTensor<T> src{ in_shape, data_type, 1, 0, quantization_info, data_layout };
@@ -151,7 +164,7 @@ protected:
fill(weights, 1);
fill(biases, 2);
- return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info);
+ return reference::depthwise_convolution(src, weights, biases, out_shape, pad_stride_info, depth_multiplier);
}
TensorType _target{};
@@ -165,9 +178,9 @@ class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLa
{
public:
template <typename...>
- void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, DataLayout data_layout)
+ void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, unsigned int depth_multiplier, DataType data_type, DataLayout data_layout)
{
- DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, out_shape, pad_stride_info,
+ DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, kernel_size, pad_stride_info, depth_multiplier,
data_type, QuantizationInfo(), data_layout);
}
};
@@ -177,9 +190,9 @@ class DepthwiseConvolutionLayerValidationQuantizedFixture : public DepthwiseConv
{
public:
template <typename...>
- void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
+ void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, unsigned int depth_multiplier, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
{
- DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, out_shape, pad_stride_info,
+ DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, kernel_size, pad_stride_info, depth_multiplier,
data_type, quantization_info, data_layout);
}
};