aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-01-22 16:29:17 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:00 +0000
commitd05dce46a14a7b67f322328ecd95bf96bdd30bae (patch)
tree6e001f539969a1a669241a72e78ff5a62998a984
parent5d9d019b2c7ca3dc59bfbb44b3169ee5cd71dc79 (diff)
downloadComputeLibrary-d05dce46a14a7b67f322328ecd95bf96bdd30bae.tar.gz
COMPMID-791: Generic Depthwise Convolution Layer NEON QASYMM8
Change-Id: I33cf54e68f6c097ac58b6f16c3f9a720978f09cd Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/117289 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/NEDepthwiseIm2ColKernel.h26
-rw-r--r--arm_compute/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.h20
-rw-r--r--arm_compute/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.h12
-rw-r--r--arm_compute/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.h29
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h5
-rw-r--r--arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h23
-rw-r--r--src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp109
-rw-r--r--src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp89
-rw-r--r--src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp110
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp191
-rw-r--r--src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp70
-rw-r--r--tests/datasets/DepthwiseConvolutionLayerDataset.h4
-rw-r--r--tests/validation/NEON/DepthwiseConvolutionLayer.cpp10
-rw-r--r--tests/validation/reference/DepthwiseConvolutionLayer.cpp25
14 files changed, 528 insertions, 195 deletions
diff --git a/arm_compute/core/NEON/kernels/NEDepthwiseIm2ColKernel.h b/arm_compute/core/NEON/kernels/NEDepthwiseIm2ColKernel.h
index 8d59ba3248..ca10bfaab2 100644
--- a/arm_compute/core/NEON/kernels/NEDepthwiseIm2ColKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthwiseIm2ColKernel.h
@@ -55,7 +55,7 @@ public:
/** Set the input and output of the kernel.
*
* @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM],
- * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: F32
+ * while every optional dimension from 4 and above represent a batch of inputs. Data types supported: QASYMM8, F32
* @param[out] output The output tensor. First 3 lower dimensions represent a transform of each 3D input,
* while every dimension above 3 represents a batch. Data types supported: Same as @p input
* @param[in] kernel_dims The kernel dimensions (width and height).
@@ -68,11 +68,25 @@ public:
void run(const Window &window, const ThreadInfo &info) override;
private:
- const ITensor *_input;
- ITensor *_output;
- Size2D _kernel_dims;
- PadStrideInfo _conv_info;
- bool _has_bias;
+ /** Template function to run the im2col used for the depthwise convolution layer case
+ *
+ * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
+ */
+ template <typename T>
+ void run_generic(const Window &window);
+ /** Common signature for all the specialised depthwise im2col functions
+ *
+ * @param[in] window Region on which to execute the kernel.
+ */
+ using DepthwiseIm2ColFunctionPtr = void (NEDepthwiseIm2ColKernel::*)(const Window &window);
+
+private:
+ DepthwiseIm2ColFunctionPtr _func;
+ const ITensor *_input;
+ ITensor *_output;
+ Size2D _kernel_dims;
+ PadStrideInfo _conv_info;
+ bool _has_bias;
};
} // arm_compute
#endif /*__ARM_COMPUTE_NEDEPTHWISEIM2COLKERNEL_H__ */
diff --git a/arm_compute/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.h b/arm_compute/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.h
index 19000905b0..458cbd7812 100644
--- a/arm_compute/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.h
@@ -56,7 +56,7 @@ public:
NEDepthwiseVectorToTensorKernel &operator=(NEDepthwiseVectorToTensorKernel &&) = default;
/** Set the input and output of the kernel.
*
- * @param[in] input The input vector to convert. Data type supported: F32.
+ * @param[in] input The input vector to convert. Data type supported: QASYMM8/S32/F32.
* @param[out] output The output tensor. 3 lower dimensions represent a single input [width, height, IFM]. Data type supported: same as @p input.
* @param[in] conv_w The converted tensor's width.
* @param[in] conv_h The converted tensor's height.
@@ -67,8 +67,22 @@ public:
void run(const Window &window, const ThreadInfo &info) override;
private:
- const ITensor *_input;
- ITensor *_output;
+ /** Template function to run the vector to tensor reshape used for the depthwise convolution layer case
+ *
+ * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
+ */
+ template <typename T>
+ void vector_to_tensor(const Window &window);
+ /** Common signature for all the specialised depthwise vector to tensor functions
+ *
+ * @param[in] window Region on which to execute the kernel.
+ */
+ using DepthwiseVectorToTensorFunctionPtr = void (NEDepthwiseVectorToTensorKernel::*)(const Window &window);
+
+private:
+ DepthwiseVectorToTensorFunctionPtr _func;
+ const ITensor *_input;
+ ITensor *_output;
std::pair<size_t, size_t> _conv_dims;
};
} // arm_compute
diff --git a/arm_compute/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.h b/arm_compute/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.h
index 4d23b8bd65..d00e8a46ed 100644
--- a/arm_compute/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.h
@@ -53,7 +53,7 @@ public:
NEDepthwiseWeightsReshapeKernel &operator=(NEDepthwiseWeightsReshapeKernel &&) = default;
/** Set the input and output of the kernel.
*
- * @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM]. Data type supported: F32.
+ * @param[in] input The input tensor to convert. 3 lower dimensions represent a single input [width, height, IFM]. Data type supported: QASYMM8, F32.
* @param[out] output The output tensor. Data type supported: same as @p input.
* @param[in] biases (Optional) The input biases to add. Shape [IFM]. Data type supported: same as @p input.
*/
@@ -63,9 +63,13 @@ public:
void run(const Window &window, const ThreadInfo &info) override;
private:
- const ITensor *_input;
- ITensor *_output;
- const ITensor *_biases;
+ using DepthwiseWeightsReshapeFunction = void(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window);
+
+private:
+ DepthwiseWeightsReshapeFunction *_func;
+ const ITensor *_input;
+ ITensor *_output;
+ const ITensor *_biases;
};
} // arm_compute
#endif /*__ARM_COMPUTE_NEDEPTHWISEWEIGHTSRESHAPEKERNEL_H__ */
diff --git a/arm_compute/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.h b/arm_compute/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.h
index 5ea83901f4..95fe916a3c 100644
--- a/arm_compute/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.h
+++ b/arm_compute/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.h
@@ -49,7 +49,7 @@ public:
NEGEMMMatrixVectorMultiplyKernel &operator=(NEGEMMMatrixVectorMultiplyKernel &&) = default;
/** Initialise the kernel's input and output.
*
- * @param[in] input0 First Input tensor. Data types supported: F16/F32
+ * @param[in] input0 First Input tensor. Data types supported: QASYMM8/F32
* @param[in] input1 Second Input tensor. Data types supported: same as @p input.
* @param[out] output Output tensor which stores the interleaved matrix. Data type supported: same as @p input.
*/
@@ -57,11 +57,32 @@ public:
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
+ BorderSize border_size() const override;
private:
- const ITensor *_input0;
- const ITensor *_input1;
- ITensor *_output;
+ /** Template function to run the matrix vector multiplication
+ *
+ * @tparam I0 Input 0 type
+ * @tparam I1 Input 1 type
+ * @tparam O Output type
+ *
+ * @param[in] window_in Input region. (Must be a valid region of the window returned by window()).
+ * @param[in] window_w Weights region. (Must be a valid region of the window returned by window()).
+ * @param[in] window_out Output region.(Must be a valid region of the window returned by window()).
+ */
+ template <typename I0, typename I1, typename O>
+ void matrix_vector_multiply(const Window &window_in, const Window &window_w, const Window &window_out);
+ /** Common signature for all the specialised matrix vector multiplication functions */
+ using GEMMMatrixVectorMultiplyFunctionPtr = void (NEGEMMMatrixVectorMultiplyKernel::*)(const Window &window_in,
+ const Window &window_w,
+ const Window &window_out);
+
+private:
+ GEMMMatrixVectorMultiplyFunctionPtr _func;
+ const ITensor *_input0;
+ const ITensor *_input1;
+ ITensor *_output;
+ BorderSize _border_size;
};
} // namespace arm_compute
#endif /*__ARM_COMPUTE_NEGEMMMATRIXVECTORMULTIPLYKERNEL_H_*/
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 6ecfdf0323..26384651f1 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -116,8 +116,9 @@ inline TensorShape compute_depthwise_convolution_shape(const ITensorInfo &input,
unsigned int output_width = 0;
unsigned int output_height = 0;
- std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(),
- weights_shape.y(), conv_info);
+ std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(),
+ weights_shape.x(), weights_shape.y(),
+ conv_info);
TensorShape output_shape{ input_shape };
output_shape.set(0, output_width);
diff --git a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h
index 2100828f0d..e89ef88562 100644
--- a/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h
@@ -54,7 +54,7 @@ public:
NEDepthwiseConvolutionLayer3x3();
/** Initialize the function's source, destination, kernels and border_size.
*
- * @param[in, out] input Source tensor. Data type supported: QASYMM8, F32. (Written to only for border filling).
+ * @param[in, out] input Source tensor. Data type supported: QASYMM8/F32. (Written to only for border filling).
* @param[in] weights Weights tensor. These are 3D tensors with shape [3, 3, IFM]. Data type supported: Same as @p input.
* @param[in] biases (Optional) Biases tensor. A 1D tensor with shape [IFM]. Must be nullptr if not needed.
* Data type supported: Same as @p input.
@@ -90,7 +90,7 @@ public:
NEDepthwiseConvolutionLayer();
/** Initialize the function's source, destination, weights and convolution information.
*
- * @param[in, out] input Source tensor. Data type supported: F32. (Written to only for border filling).
+ * @param[in, out] input Source tensor. Data type supported: QASYMM8/F32. (Written to only for border filling).
* @param[out] output Destination tensor. Data type supported: same as @p input.
* @param[in] weights Weights tensor. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. Data type supported: Same as @p input.
* @param[in] biases (Optional) Biases tensor. A 1D tensor with shape [IFM]. Must be nullptr if not needed.
@@ -103,13 +103,18 @@ public:
void run() override;
private:
- NEDepthwiseIm2ColKernel _im2col_kernel;
- NEDepthwiseWeightsReshapeKernel _weights_reshape_kernel;
- NEGEMMMatrixVectorMultiplyKernel _v2mm_kernel;
- NEDepthwiseVectorToTensorKernel _vector_to_tensor_kernel;
- Tensor _input_reshaped;
- Tensor _weights_reshaped;
- Tensor _v2mm_output;
+ NEDepthwiseIm2ColKernel _im2col_kernel;
+ NEDepthwiseWeightsReshapeKernel _weights_reshape_kernel;
+ NEGEMMMatrixVectorMultiplyKernel _v2mm_kernel;
+ NEDepthwiseVectorToTensorKernel _vector_to_tensor_kernel;
+ NEDirectConvolutionLayerOutputStageKernel _output_stage_kernel;
+ NEFillBorderKernel _v2mm_input_fill_border;
+ NEFillBorderKernel _v2mm_weights_fill_border;
+ Tensor _input_reshaped;
+ Tensor _weights_reshaped;
+ Tensor _v2mm_output;
+ Tensor _output_reshaped;
+ bool _is_quantized;
};
}
#endif /* __ARM_COMPUTE_NEDEPTHWISECONVOLUTION_H__ */ \ No newline at end of file
diff --git a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
index 2ceb39d217..b924d9f8bd 100644
--- a/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseIm2ColKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,40 +37,9 @@
using namespace arm_compute;
-NEDepthwiseIm2ColKernel::NEDepthwiseIm2ColKernel()
- : _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias()
-{
-}
-
-void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
-{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2));
- ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
-
- _input = input;
- _output = output;
- _kernel_dims = kernel_dims;
- _conv_info = conv_info;
- _has_bias = has_bias;
-
- // Configure kernel window
- Window win = calculate_max_window(*input->info(), Steps());
-
- // The NEDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
- output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
-
- INEKernel::configure(win);
-}
-
-void NEDepthwiseIm2ColKernel::run(const Window &window, const ThreadInfo &info)
+template <typename T>
+void NEDepthwiseIm2ColKernel::run_generic(const Window &window)
{
- ARM_COMPUTE_UNUSED(info);
- ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
-
- //const int kernel_depth = _input->info()->dimension(2);
const int input_w = _input->info()->dimension(0);
const int input_h = _input->info()->dimension(1);
const int input_stride_x = _input->info()->strides_in_bytes().x();
@@ -101,6 +70,13 @@ void NEDepthwiseIm2ColKernel::run(const Window &window, const ThreadInfo &info)
const int full_length = input_w + pad_left + pad_right;
const int max_initial_x = stride_x * (((full_length - _kernel_dims.width) / stride_x) + 1);
+ // Define pad value
+ auto zero = static_cast<T>(0);
+ if(std::is_same<T, uint8_t>::value)
+ {
+ zero = _input->info()->quantization_info().offset;
+ }
+
execute_window_loop(window_out, [&](const Coordinates & id)
{
const int src_pixel_linear = id.y() * stride_x;
@@ -110,7 +86,7 @@ void NEDepthwiseIm2ColKernel::run(const Window &window, const ThreadInfo &info)
// Get pointers
const uint8_t *const input_ptr = in.ptr() + id.z() * input_stride_z;
- auto output_ptr = reinterpret_cast<float *>(out.ptr());
+ auto output_ptr = reinterpret_cast<T *>(out.ptr());
const int height = src_y + _kernel_dims.height;
const int width = src_x + _kernel_dims.width;
@@ -120,19 +96,76 @@ void NEDepthwiseIm2ColKernel::run(const Window &window, const ThreadInfo &info)
{
if(x < 0 || x >= input_w || y < 0 || y >= input_h)
{
- *output_ptr = 0;
+ *output_ptr = zero;
}
else
{
- *output_ptr = *(reinterpret_cast<const float *>(input_ptr + x * input_stride_x + y * input_stride_y));
+ *output_ptr = *(reinterpret_cast<const T *>(input_ptr + x * input_stride_x + y * input_stride_y));
}
}
}
if(_has_bias)
{
- *output_ptr = static_cast<float>(1);
+ *output_ptr = static_cast<T>(1);
}
},
in, out);
}
+
+NEDepthwiseIm2ColKernel::NEDepthwiseIm2ColKernel()
+ : _func(nullptr), _input(nullptr), _output(nullptr), _kernel_dims(), _conv_info(), _has_bias()
+{
+}
+
+void NEDepthwiseIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+{
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && has_bias);
+ ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(2));
+ ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (kernel_dims.width * kernel_dims.height + ((has_bias) ? 1 : 0)));
+
+ _input = input;
+ _output = output;
+ _kernel_dims = kernel_dims;
+ _conv_info = conv_info;
+ _has_bias = has_bias;
+
+ // Configure kernel window
+ Window win = calculate_max_window(*input->info(), Steps());
+
+ // Set appropriate function to run
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ _func = &NEDepthwiseIm2ColKernel::run_generic<uint8_t>;
+ break;
+ case DataType::F16:
+ _func = &NEDepthwiseIm2ColKernel::run_generic<half>;
+ break;
+ case DataType::F32:
+ _func = &NEDepthwiseIm2ColKernel::run_generic<float>;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
+
+ // The NEDepthwiseIm2ColKernel doesn't need padding so update_window_and_padding() can be skipped
+ output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
+
+ INEKernel::configure(win);
+}
+
+void NEDepthwiseIm2ColKernel::run(const Window &window, const ThreadInfo &info)
+{
+ ARM_COMPUTE_UNUSED(info);
+ ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
+
+ if(_func != nullptr)
+ {
+ (this->*_func)(window);
+ }
+}
diff --git a/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp b/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp
index 9b36df3c39..8960d8a8af 100644
--- a/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseVectorToTensorKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,14 +37,46 @@
using namespace arm_compute;
+template <typename T>
+void NEDepthwiseVectorToTensorKernel::vector_to_tensor(const Window &window)
+{
+ // const int input_w = _input->info()->dimension(0);
+ const int output_stride_x = _output->info()->strides_in_bytes().x();
+ const int output_stride_y = _output->info()->strides_in_bytes().y();
+ const int output_stride_z = _output->info()->strides_in_bytes().z();
+
+ // Setup output window
+ Window window_out(window);
+ window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ Iterator in(_input, window);
+ Iterator out(_output, window_out);
+
+ const int patch_size = _conv_dims.first * _conv_dims.second;
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const int z = id.x() / patch_size;
+ const int index2D = id.x() - z * patch_size;
+
+ auto input_ptr = reinterpret_cast<T *>(in.ptr());
+ auto output_ptr = reinterpret_cast<T *>(out.ptr() + index2D % _conv_dims.first * output_stride_x + index2D / _conv_dims.first * output_stride_y + z * output_stride_z);
+
+ *output_ptr = *input_ptr;
+ },
+ in, out);
+}
+
NEDepthwiseVectorToTensorKernel::NEDepthwiseVectorToTensorKernel()
- : _input(nullptr), _output(nullptr), _conv_dims()
+ : _func(nullptr), _input(nullptr), _output(nullptr), _conv_dims()
{
}
void NEDepthwiseVectorToTensorKernel::configure(const ITensor *input, ITensor *output, size_t conv_w, size_t conv_h)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
TensorShape output_shape = input->info()->tensor_shape();
@@ -53,7 +85,7 @@ void NEDepthwiseVectorToTensorKernel::configure(const ITensor *input, ITensor *o
output_shape.set(2, input->info()->tensor_shape()[0] / (conv_w * conv_h));
// Output auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), output_shape, 1, input->info()->data_type(), input->info()->fixed_point_position());
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
@@ -63,6 +95,25 @@ void NEDepthwiseVectorToTensorKernel::configure(const ITensor *input, ITensor *o
_output = output;
_conv_dims = std::pair<size_t, size_t>(conv_w, conv_h);
+ // Set appropriate function to run
+ switch(input->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ _func = &NEDepthwiseVectorToTensorKernel::vector_to_tensor<uint8_t>;
+ break;
+ case DataType::S32:
+ _func = &NEDepthwiseVectorToTensorKernel::vector_to_tensor<int32_t>;
+ break;
+ case DataType::F16:
+ _func = &NEDepthwiseVectorToTensorKernel::vector_to_tensor<half>;
+ break;
+ case DataType::F32:
+ _func = &NEDepthwiseVectorToTensorKernel::vector_to_tensor<float>;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
+
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps());
// The NEDepthwisevectorToTensorKernel doesn't need padding so update_window_and_padding() can be skipped
@@ -75,32 +126,10 @@ void NEDepthwiseVectorToTensorKernel::run(const Window &window, const ThreadInfo
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- // const int input_w = _input->info()->dimension(0);
- const int output_stride_x = _output->info()->strides_in_bytes().x();
- const int output_stride_y = _output->info()->strides_in_bytes().y();
- const int output_stride_z = _output->info()->strides_in_bytes().z();
-
- // Setup output window
- Window window_out(window);
- window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
- window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
- window_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
-
- Iterator in(_input, window);
- Iterator out(_output, window_out);
-
- const int patch_size = _conv_dims.first * _conv_dims.second;
-
- execute_window_loop(window, [&](const Coordinates & id)
+ if(_func != nullptr)
{
- const int z = id.x() / patch_size;
- const int index2D = id.x() - z * patch_size;
-
- auto input_ptr = reinterpret_cast<float *>(in.ptr());
- auto output_ptr = reinterpret_cast<float *>(out.ptr() + index2D % _conv_dims.first * output_stride_x + index2D / _conv_dims.first * output_stride_y + z * output_stride_z);
-
- *output_ptr = *input_ptr;
- },
- in, out);
+ (this->*_func)(window);
+ }
}
diff --git a/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp
index 6585fdb8b8..36b17bfc4c 100644
--- a/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthwiseWeightsReshapeKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,16 +37,59 @@
using namespace arm_compute;
+namespace
+{
+template <typename T>
+void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window)
+{
+ const int input_w = input->info()->dimension(0);
+ const int output_stride_x = output->info()->strides_in_bytes().x();
+ const int output_stride_y = output->info()->strides_in_bytes().y();
+
+ Window window_in(window);
+ // The first three dimensions of the input are increased by the inner loops
+ window_in.set(Window::DimX, Window::Dimension(0, input->info()->dimension(0), input->info()->dimension(0)));
+ window_in.set(Window::DimY, Window::Dimension(0, input->info()->dimension(1), 1));
+ window_in.set(Window::DimZ, Window::Dimension(0, input->info()->dimension(2), 1));
+
+ // Setup output window
+ Window window_out;
+ window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
+
+ Iterator in(input, window_in);
+ Iterator out(output, window_out);
+
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ auto input_ptr = reinterpret_cast<T *>(in.ptr());
+ auto output_ptr = reinterpret_cast<T *>(out.ptr() + id.y() * input_w * output_stride_x + id.z() * output_stride_y);
+
+ for(int i = 0; i < input_w; ++i, ++input_ptr)
+ {
+ *(output_ptr + i) = *input_ptr;
+ }
+
+ if(bias != nullptr)
+ {
+ *(output_ptr + input_w) = *(reinterpret_cast<T *>(bias->ptr_to_element(Coordinates(id.z()))));
+ }
+ },
+ in, out);
+}
+} // namespace
+
NEDepthwiseWeightsReshapeKernel::NEDepthwiseWeightsReshapeKernel()
- : _input(nullptr), _output(nullptr), _biases(nullptr)
+ : _func(nullptr), _input(nullptr), _output(nullptr), _biases(nullptr)
{
}
void NEDepthwiseWeightsReshapeKernel::configure(const ITensor *input, ITensor *output, const ITensor *biases)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input->info()->data_type()) && (biases != nullptr));
ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != output->info()->dimension(1));
ARM_COMPUTE_ERROR_ON(output->info()->dimension(0) != (input->info()->dimension(0) * input->info()->dimension(1) + ((biases != nullptr) ? 1 : 0)));
@@ -62,6 +105,30 @@ void NEDepthwiseWeightsReshapeKernel::configure(const ITensor *input, ITensor *o
_output = output;
_biases = biases;
+ switch(_input->info()->element_size())
+ {
+ case 4:
+ {
+ _func = &weights_reshape<uint32_t>;
+ break;
+ }
+ case 2:
+ {
+ _func = &weights_reshape<uint16_t>;
+ break;
+ }
+ case 1:
+ {
+ _func = &weights_reshape<uint8_t>;
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR_ON("Element size not supported");
+ break;
+ }
+ }
+
// Configure kernel window
Window win = calculate_max_window(*input->info(), Steps());
// The NEDepthwiseWeightsReshapeKernel doesn't need padding so update_window_and_padding() can be skipped
@@ -74,39 +141,10 @@ void NEDepthwiseWeightsReshapeKernel::run(const Window &window, const ThreadInfo
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- const int input_w = _input->info()->dimension(0);
- const int output_stride_x = _output->info()->strides_in_bytes().x();
- const int output_stride_y = _output->info()->strides_in_bytes().y();
-
- Window window_in(window);
- // The first three dimensions of the input are increased by the inner loops
- window_in.set(Window::DimX, Window::Dimension(0, _input->info()->dimension(0), _input->info()->dimension(0)));
- window_in.set(Window::DimY, Window::Dimension(0, _input->info()->dimension(1), 1));
- window_in.set(Window::DimZ, Window::Dimension(0, _input->info()->dimension(2), 1));
-
- // Setup output window
- Window window_out;
- window_out.set(Window::DimX, Window::Dimension(0, 0, 0));
- window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
-
- Iterator in(_input, window_in);
- Iterator out(_output, window_out);
-
- execute_window_loop(window_in, [&](const Coordinates & id)
+ if(_func != nullptr)
{
- auto input_ptr = reinterpret_cast<float *>(in.ptr());
- auto output_ptr = reinterpret_cast<float *>(out.ptr() + id.y() * input_w * output_stride_x + id.z() * output_stride_y);
-
- for(int i = 0; i < input_w; ++i, ++input_ptr)
- {
- *(output_ptr + i) = *input_ptr;
- }
-
- if(_biases != nullptr)
- {
- *(output_ptr + input_w) = *(reinterpret_cast<float *>(_biases->ptr_to_element(Coordinates(id.z()))));
- }
- },
- in, out);
+ (*_func)(_input, _biases, _output, window);
+ }
}
diff --git a/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
index fe79df2528..c1e975e77e 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixVectorMultiplyKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -39,24 +39,170 @@
using namespace arm_compute;
+template <typename I0, typename I1, typename O>
+void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply(const Window &window_in, const Window &window_w, const Window &window_out)
+{
+ ARM_COMPUTE_ERROR("Unsupported data types");
+ ARM_COMPUTE_UNUSED(window_in);
+ ARM_COMPUTE_UNUSED(window_w);
+ ARM_COMPUTE_UNUSED(window_out);
+}
+
+namespace arm_compute
+{
+template <>
+void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<float, float, float>(const Window &window_in,
+ const Window &window_w,
+ const Window &window_out)
+{
+ Iterator in(_input0, window_in);
+ Iterator in2(_input1, window_w);
+ Iterator out(_output, window_out);
+
+ const int input_w = _input0->info()->dimension(0);
+ const int input_h = _input0->info()->dimension(1);
+ const int input_stride_x = _input0->info()->strides_in_bytes().x();
+ const int weights_stride_x = _input1->info()->strides_in_bytes().x();
+ const int weights_stride_y = _input1->info()->strides_in_bytes().y();
+ const int output_stride_x = _output->info()->strides_in_bytes().x();
+
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Get pointers
+ const uint8_t *const input_ptr = in.ptr();
+ const uint8_t *const weights_ptr = in2.ptr() + id.z() * weights_stride_y;
+ auto output_ptr = reinterpret_cast<float *>(out.ptr() + (id.y() + id.z() * input_h) * output_stride_x);
+
+ float32x4_t row_dot = vdupq_n_f32(0.f);
+ for(int i = 0; i < input_w; i += 4)
+ {
+ const auto input = vld1q_f32(reinterpret_cast<const float *>(input_ptr + i * input_stride_x));
+ const auto weights = vld1q_f32(reinterpret_cast<const float *>(weights_ptr + i * weights_stride_x));
+ row_dot = vaddq_f32(row_dot, vmulq_f32(input, weights));
+ }
+
+ auto temp = vadd_f32(vget_high_f32(row_dot), vget_low_f32(row_dot));
+ temp = vpadd_f32(temp, temp);
+
+ *output_ptr = vget_lane_f32(temp, 0);
+ },
+ in, in2, out);
+}
+
+template <>
+void NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<uint8_t, uint8_t, int32_t>(const Window &window_in,
+ const Window &window_w,
+ const Window &window_out)
+{
+ Iterator in(_input0, window_in);
+ Iterator in2(_input1, window_w);
+ Iterator out(_output, window_out);
+
+ const int input_offset = -_input0->info()->quantization_info().offset;
+ const int weights_offset = -_input1->info()->quantization_info().offset;
+
+ const int input_w = _input0->info()->dimension(0);
+ const int input_h = _input0->info()->dimension(1);
+ const int input_stride_x = _input0->info()->strides_in_bytes().x();
+ const int weights_stride_x = _input1->info()->strides_in_bytes().x();
+ const int weights_stride_y = _input1->info()->strides_in_bytes().y();
+ const int output_stride_x = _output->info()->strides_in_bytes().x();
+ const int read_step = 16 / _input0->info()->element_size();
+
+ const int32x4_t v_input_offset = vdupq_n_s32(input_offset);
+ const int32x4_t v_weights_offset = vdupq_n_s32(weights_offset);
+
+ execute_window_loop(window_in, [&](const Coordinates & id)
+ {
+ // Get pointers
+ const uint8_t *const input_ptr = in.ptr();
+ const uint8_t *const weights_ptr = in2.ptr() + id.z() * weights_stride_y;
+ auto output_ptr = reinterpret_cast<int32_t *>(out.ptr() + (id.y() + id.z() * input_h) * output_stride_x);
+
+ int32x4_t row_dot = vdupq_n_s32(0);
+ for(int i = 0; i < input_w; i += read_step)
+ {
+ // Read values
+ const auto input = vld1q_u8(reinterpret_cast<const uint8_t *>(input_ptr + i * input_stride_x));
+ const auto weights = vld1q_u8(reinterpret_cast<const uint8_t *>(weights_ptr + i * weights_stride_x));
+
+ // Add offsets
+ const int32x4x4_t input_s32 =
+ {
+ {
+ vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vget_low_u8(input))))),
+ vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vget_low_u8(input))))),
+ vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vget_high_u8(input))))),
+ vaddw_s16(v_input_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vget_high_u8(input)))))
+ }
+ };
+ const int32x4x4_t weights_s32 =
+ {
+ {
+ vaddw_s16(v_weights_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vget_low_u8(weights))))),
+ vaddw_s16(v_weights_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vget_low_u8(weights))))),
+ vaddw_s16(v_weights_offset, vreinterpret_s16_u16(vget_low_u16(vmovl_u8(vget_high_u8(weights))))),
+ vaddw_s16(v_weights_offset, vreinterpret_s16_u16(vget_high_u16(vmovl_u8(vget_high_u8(weights)))))
+ }
+ };
+
+ // Dot
+ row_dot = vaddq_s32(row_dot, vmulq_s32(input_s32.val[0], weights_s32.val[0]));
+ row_dot = vaddq_s32(row_dot, vmulq_s32(input_s32.val[1], weights_s32.val[1]));
+ row_dot = vaddq_s32(row_dot, vmulq_s32(input_s32.val[2], weights_s32.val[2]));
+ row_dot = vaddq_s32(row_dot, vmulq_s32(input_s32.val[3], weights_s32.val[3]));
+ }
+
+ // Reduction
+ auto temp = vadd_s32(vget_high_s32(row_dot), vget_low_s32(row_dot));
+ temp = vpadd_s32(temp, temp);
+
+ *output_ptr = vget_lane_s32(temp, 0);
+ },
+ in, in2, out);
+}
+} //namespace arm_compute
+
NEGEMMMatrixVectorMultiplyKernel::NEGEMMMatrixVectorMultiplyKernel()
- : _input0(nullptr), _input1(nullptr), _output(nullptr)
+ : _func(nullptr), _input0(nullptr), _input1(nullptr), _output(nullptr), _border_size(0)
+{
+}
+
+BorderSize NEGEMMMatrixVectorMultiplyKernel::border_size() const
{
+ return _border_size;
}
void NEGEMMMatrixVectorMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
+ ARM_COMPUTE_ERROR_ON(is_data_type_quantized_asymmetric(input0->info()->data_type()) && (output->info()->data_type() != DataType::S32));
ARM_COMPUTE_ERROR_ON(input0->info()->dimension(2) != input1->info()->dimension(1));
_input0 = input0;
_input1 = input1;
_output = output;
+ // Set appropriate function to run
+ switch(input0->info()->data_type())
+ {
+ case DataType::QASYMM8:
+ _func = &NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<uint8_t, uint8_t, int32_t>;
+ break;
+ case DataType::F32:
+ _func = &NEGEMMMatrixVectorMultiplyKernel::matrix_vector_multiply<float, float, float>;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
+
// Configure kernel window
- const unsigned int num_elems_read_per_iteration = 4;
+ const unsigned int num_elems_read_per_iteration = 16 / _input0->info()->element_size();
+
+ const unsigned int border_x = ceil_to_multiple(input0->info()->dimension(0), num_elems_read_per_iteration) - input0->info()->dimension(0);
+ _border_size = BorderSize(0, border_x);
Window win = calculate_max_window(*input0->info(), Steps(num_elems_read_per_iteration));
@@ -75,6 +221,7 @@ void NEGEMMMatrixVectorMultiplyKernel::run(const Window &window, const ThreadInf
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
+ ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
Window window_slice = window.first_slice_window_3D();
@@ -96,36 +243,8 @@ void NEGEMMMatrixVectorMultiplyKernel::run(const Window &window, const ThreadInf
window_out.set(Window::DimY, Window::Dimension(0, 0, 0));
window_out.set(Window::DimZ, Window::Dimension(0, 0, 0));
- Iterator in(_input0, window_in);
- Iterator in2(_input1, window_weights);
- Iterator out(_output, window_out);
-
- const int input_w = _input0->info()->dimension(0);
- const int input_h = _input0->info()->dimension(1);
- const int input_stride_x = _input0->info()->strides_in_bytes().x();
- const int weights_stride_x = _input1->info()->strides_in_bytes().x();
- const int weights_stride_y = _input1->info()->strides_in_bytes().y();
- const int output_stride_x = _output->info()->strides_in_bytes().x();
-
- execute_window_loop(window_in, [&](const Coordinates & id)
+ if(_func != nullptr)
{
- // Get pointers
- const uint8_t *const input_ptr = in.ptr();
- const uint8_t *const weights_ptr = in2.ptr() + id.z() * weights_stride_y;
- auto output_ptr = reinterpret_cast<float *>(out.ptr() + (id.y() + id.z() * input_h) * output_stride_x);
-
- float32x4_t row_dot = vdupq_n_f32(0.f);
- for(int i = 0; i < input_w; i += 4)
- {
- const auto input = vld1q_f32(reinterpret_cast<const float *>(input_ptr + i * input_stride_x));
- const auto weights = vld1q_f32(reinterpret_cast<const float *>(weights_ptr + i * weights_stride_x));
- row_dot = vaddq_f32(row_dot, vmulq_f32(input, weights));
- }
-
- auto temp = vadd_f32(vget_high_f32(row_dot), vget_low_f32(row_dot));
- temp = vpadd_f32(temp, temp);
-
- *output_ptr = vget_lane_f32(temp, 0);
- },
- in, in2, out);
+ (this->*_func)(window_in, window_weights, window_out);
+ }
}
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index 2d08b45210..1af0b18933 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -26,11 +26,13 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/PixelValue.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "support/ToolchainSupport.h"
using namespace arm_compute;
+using namespace arm_compute::misc;
NEDepthwiseConvolutionLayer3x3::NEDepthwiseConvolutionLayer3x3()
: _kernel(), _output_stage_kernel(), _border_handler(), _accumulator(), _has_bias(false), _is_quantized(false)
@@ -90,13 +92,14 @@ void NEDepthwiseConvolutionLayer3x3::run()
}
NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer()
- : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _input_reshaped(), _weights_reshaped(), _v2mm_output()
+ : _im2col_kernel(), _weights_reshape_kernel(), _v2mm_kernel(), _vector_to_tensor_kernel(), _output_stage_kernel(), _v2mm_input_fill_border(), _v2mm_weights_fill_border(), _input_reshaped(),
+ _weights_reshaped(), _v2mm_output(), _output_reshaped(), _is_quantized(false)
{
}
void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
ARM_COMPUTE_ERROR_ON(input->info()->dimension(2) != weights->info()->dimension(2));
@@ -104,14 +107,20 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh
const size_t weights_h = weights->info()->dimension(1);
const size_t weights_z = weights->info()->dimension(2);
- bool has_bias = (biases != nullptr);
+ _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
+
+ // Should bias be appended ?
+ bool append_bias = (biases != nullptr) && !_is_quantized;
- unsigned int conv_w = 0;
- unsigned int conv_h = 0;
- std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), weights_w, weights_h, conv_info);
+ // Calculate output shape
+ TensorShape dwc_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info);
+
+ // Output width and height
+ const unsigned int conv_w = dwc_output_shape.x();
+ const unsigned int conv_h = dwc_output_shape.y();
// Set up intermediate tensors
- const size_t patch_size = weights_w * weights_h + ((has_bias) ? 1 : 0);
+ const size_t patch_size = weights_w * weights_h + (append_bias ? 1 : 0);
const size_t conv_size = conv_w * conv_h;
// Im2Col configuration
@@ -119,25 +128,48 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh
shape_im2col.set(0, patch_size);
shape_im2col.set(1, conv_size);
shape_im2col.set(2, weights_z);
- const TensorInfo info_im2col(shape_im2col, 1, input->info()->data_type(), input->info()->fixed_point_position());
- _input_reshaped.allocator()->init(info_im2col);
- _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, has_bias);
+ _input_reshaped.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_im2col));
+ _im2col_kernel.configure(input, &_input_reshaped, Size2D(weights_w, weights_h), conv_info, append_bias);
// Weights reshape configuration
const TensorShape shape_weights_reshape(patch_size, weights_z);
- const TensorInfo info_weights_reshape(shape_weights_reshape, 1, weights->info()->data_type(), weights->info()->fixed_point_position());
- _weights_reshaped.allocator()->init(info_weights_reshape);
- _weights_reshape_kernel.configure(weights, &_weights_reshaped, biases);
+ _weights_reshaped.allocator()->init(weights->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(shape_weights_reshape));
+ _weights_reshape_kernel.configure(weights, &_weights_reshaped, append_bias ? biases : nullptr);
// GEMV configuration
+ DataType v2mm_dt = (input->info()->data_type() == DataType::QASYMM8) ? DataType::S32 : input->info()->data_type();
TensorShape shape_v2mm_out = input->info()->tensor_shape();
shape_v2mm_out.set(0, conv_size * weights_z);
shape_v2mm_out.set(1, 1);
shape_v2mm_out.set(2, 1);
- const TensorInfo info_v2mm_out(shape_v2mm_out, 1, input->info()->data_type(), input->info()->fixed_point_position());
- _v2mm_output.allocator()->init(info_v2mm_out);
+ _v2mm_output.allocator()->init(input->info()->clone()->set_is_resizable(true).reset_padding().set_data_type(v2mm_dt).set_tensor_shape(shape_v2mm_out));
_v2mm_kernel.configure(&_input_reshaped, &_weights_reshaped, &_v2mm_output);
- _vector_to_tensor_kernel.configure(&_v2mm_output, output, conv_w, conv_h);
+ _output_reshaped.allocator()->init(_v2mm_output.info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(dwc_output_shape));
+ _vector_to_tensor_kernel.configure(&_v2mm_output, (_is_quantized) ? &_output_reshaped : output, conv_w, conv_h);
+
+ // Output staged configuration
+ if(_is_quantized)
+ {
+ float multiplier = input->info()->quantization_info().scale * weights->info()->quantization_info().scale / output->info()->quantization_info().scale;
+ int output_multiplier, output_shift;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ _output_stage_kernel.configure(&_output_reshaped, biases, output, output_multiplier, output_shift, output->info()->quantization_info().offset);
+ _output_reshaped.allocator()->allocate();
+ }
+
+ // Fill borders on inputs
+ PixelValue zero_in(0);
+ PixelValue zero_w(0);
+ if(_is_quantized)
+ {
+ zero_in = PixelValue(static_cast<int32_t>(input->info()->quantization_info().offset));
+ zero_w = PixelValue(static_cast<int32_t>(weights->info()->quantization_info().offset));
+ }
+ BorderSize border_size = _v2mm_kernel.border_size();
+ _v2mm_input_fill_border.configure(&_input_reshaped, border_size, BorderMode::CONSTANT, zero_in);
+
+ border_size.bottom = 0;
+ _v2mm_weights_fill_border.configure(&_weights_reshaped, border_size, BorderMode::CONSTANT, zero_w);
// Allocate intermediate tensors
_input_reshaped.allocator()->allocate();
@@ -149,6 +181,12 @@ void NEDepthwiseConvolutionLayer::run()
{
NEScheduler::get().schedule(&_im2col_kernel, Window::DimX);
NEScheduler::get().schedule(&_weights_reshape_kernel, Window::DimX);
+ NEScheduler::get().schedule(&_v2mm_input_fill_border, Window::DimX);
+ NEScheduler::get().schedule(&_v2mm_weights_fill_border, Window::DimX);
NEScheduler::get().schedule(&_v2mm_kernel, Window::DimX);
NEScheduler::get().schedule(&_vector_to_tensor_kernel, Window::DimX);
+ if(_is_quantized)
+ {
+ NEScheduler::get().schedule(&_output_stage_kernel, Window::DimX);
+ }
} \ No newline at end of file
diff --git a/tests/datasets/DepthwiseConvolutionLayerDataset.h b/tests/datasets/DepthwiseConvolutionLayerDataset.h
index c5a9f96a2e..d3eb2c5d9e 100644
--- a/tests/datasets/DepthwiseConvolutionLayerDataset.h
+++ b/tests/datasets/DepthwiseConvolutionLayerDataset.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -118,7 +118,7 @@ class SmallDepthwiseConvolutionLayerDataset final : public DepthwiseConvolutionL
public:
SmallDepthwiseConvolutionLayerDataset()
{
- add_config(TensorShape(7U, 7U, 3U), TensorShape(3U, 3U, 3U), TensorShape(5U, 5U, 3U), PadStrideInfo(1, 1, 0, 0));
+ add_config(TensorShape(7U, 7U, 1U), TensorShape(3U, 3U, 1U), TensorShape(5U, 5U, 1U), PadStrideInfo(1, 1, 0, 0));
add_config(TensorShape(23U, 27U, 5U), TensorShape(3U, 5U, 5U), TensorShape(11U, 23U, 5U), PadStrideInfo(2, 1, 0, 0));
add_config(TensorShape(33U, 27U, 7U), TensorShape(7U, 3U, 7U), TensorShape(10U, 13U, 7U), PadStrideInfo(3, 2, 1, 0));
add_config(TensorShape(33U, 27U, 11U), TensorShape(3U, 3U, 11U), TensorShape(31U, 14U, 11U), PadStrideInfo(1, 2, 0, 1));
diff --git a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
index e8c771595e..f8c04dab3e 100644
--- a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
@@ -128,9 +128,19 @@ TEST_SUITE_END()
template <typename T>
using NEDepthwiseConvolutionLayerQuantizedFixture3x3 = DepthwiseConvolutionLayerValidationQuantizedFixture<Tensor, Accessor, NEDepthwiseConvolutionLayer3x3, T>;
+template <typename T>
+using NEDepthwiseConvolutionLayerQuantizedFixture = DepthwiseConvolutionLayerValidationQuantizedFixture<Tensor, Accessor, NEDepthwiseConvolutionLayer, T>;
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
+TEST_SUITE(Generic)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset(),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.5f, 10) })))
+{
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END()
TEST_SUITE(W3x3)
FIXTURE_DATA_TEST_CASE(RunSmall, NEDepthwiseConvolutionLayerQuantizedFixture3x3<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallDepthwiseConvolutionLayerDataset3x3(),
framework::dataset::make("DataType", DataType::QASYMM8)),
diff --git a/tests/validation/reference/DepthwiseConvolutionLayer.cpp b/tests/validation/reference/DepthwiseConvolutionLayer.cpp
index 6ca347f1d4..66e3a4b783 100644
--- a/tests/validation/reference/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/reference/DepthwiseConvolutionLayer.cpp
@@ -140,11 +140,18 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co
const int input_depth = src.shape().z();
const int num_batches = src.shape().total_size() / (input_width * input_height * input_depth);
- const int filter_half_size = filter_width / 2;
- const int pad_x = std::min(filter_half_size, static_cast<int>(conv_info.pad().first));
- const int pad_y = std::min(filter_half_size, static_cast<int>(conv_info.pad().second));
- const int minimum_x = -pad_x + filter_half_size;
- const int minimum_y = -pad_y + filter_half_size;
+ const int filter_half_width = filter_width / 2;
+ const int filter_half_height = filter_height / 2;
+
+ const int pad_left = std::min(static_cast<int>(conv_info.pad_left()), filter_half_width);
+ const int pad_top = std::min(static_cast<int>(conv_info.pad_top()), filter_half_height);
+ const int pad_right = std::min(static_cast<int>(conv_info.pad_right()), filter_half_width);
+ const int pad_bottom = std::min(static_cast<int>(conv_info.pad_bottom()), filter_half_height);
+
+ const int minimum_x = -pad_left + filter_half_width;
+ const int minimum_y = -pad_top + filter_half_height;
+ const int maximum_x = input_width + pad_left - filter_half_width + pad_right - filter_half_width;
+ const int maximum_y = input_height + pad_top - filter_half_height + pad_bottom - filter_half_height;
int out_pos = 0;
for(int r = 0; r < num_batches; ++r)
@@ -152,17 +159,17 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co
for(int z = 0; z < input_depth; ++z)
{
int32_t bias_val = *static_cast<const int32_t *>(biases(Coordinates(z)));
- for(int y = minimum_y; y < input_height + pad_y - filter_half_size; y += conv_info.stride().second)
+ for(int y = minimum_y; y < minimum_y + maximum_y; y += conv_info.stride().second)
{
- for(int x = minimum_x; x < input_width + pad_x - filter_half_size; x += conv_info.stride().first)
+ for(int x = minimum_x; x < minimum_x + maximum_x; x += conv_info.stride().first)
{
Coordinates coords(x, y, z, r);
int filter_offset = filter_plane * z;
int32_t val = 0;
- for(int j = y - filter_half_size; j <= (y + filter_half_size); ++j)
+ for(int j = y - filter_half_height; j <= (y + filter_half_height); ++j)
{
- for(int i = x - filter_half_size; i <= (x + filter_half_size); ++i)
+ for(int i = x - filter_half_width; i <= (x + filter_half_width); ++i)
{
coords.set(0, i);
coords.set(1, j);