aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEIm2ColKernel.cpp
diff options
context:
space:
mode:
authorIoan-Cristian Szabo <ioan-cristian.szabo@arm.com>2017-11-30 17:17:17 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:47:40 +0000
commitb4e3e1c371d8091e86ee1c6e704057559bbe1554 (patch)
treed072c9f9d7471e4df9ef5aa6b50cb09c35b0c361 /src/core/NEON/kernels/NEIm2ColKernel.cpp
parentc1b6e37233e0ebd21cb44bf8863a09c0ba5feeb1 (diff)
downloadComputeLibrary-b4e3e1c371d8091e86ee1c6e704057559bbe1554.tar.gz
COMPMID-617: Add validate support for NEON FullyConnectedLayer
Change-Id: I08987022c8d4cc335c00b8af27bd3edb8fe64d3b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111596 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Alexander Gilday <alexander.gilday@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEIm2ColKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEIm2ColKernel.cpp40
1 files changed, 33 insertions, 7 deletions
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 633f78de4b..4fa329bf44 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -32,6 +32,8 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
#include <arm_neon.h>
#include <cstddef>
#include <cstdint>
@@ -42,14 +44,34 @@ using namespace arm_compute;
namespace
{
-Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+ bool has_bias, bool is_fully_connected, bool is_flatten)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::QASYMM8 && has_bias);
- ARM_COMPUTE_UNUSED(kernel_dims);
- ARM_COMPUTE_UNUSED(conv_info);
+
+ if(is_flatten) /* Called by FlattenLayer */
+ {
+ size_t flatten_shape = input->tensor_shape().x() * input->tensor_shape().y() * input->tensor_shape().z();
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != flatten_shape);
+ }
+ else if(!is_fully_connected) /* Called by ConvolutionLayer */
+ {
+ std::pair<unsigned int, unsigned int> out_dims = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_dims.width, kernel_dims.height, conv_info);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != (input->dimension(2) * kernel_dims.area() + (has_bias ? 1 : 0)));
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != (out_dims.first * out_dims.second));
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(2) != 1);
+ }
+ else /* Called by FullyConnectedLayer */
+ {
+ const int num_batch_dimensions = std::max(0, static_cast<int>(output->tensor_shape().num_dimensions()) - 1);
+ const int num_input_dimensions = input->tensor_shape().num_dimensions() - num_batch_dimensions;
+
+ TensorInfo expected_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_im2col_shape(input, num_input_dimensions));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(&expected_output, output);
+ }
return Status{};
}
@@ -291,12 +313,15 @@ NEIm2ColKernel::NEIm2ColKernel()
{
}
-void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+ bool has_bias, bool is_fully_connected, bool is_flatten)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias));
+ ARM_COMPUTE_UNUSED(is_fully_connected);
+ ARM_COMPUTE_UNUSED(is_flatten);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten));
_input = input;
_output = output;
@@ -382,9 +407,10 @@ void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size
IKernel::configure(window);
}
-Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
+Status NEIm2ColKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info,
+ bool has_bias, bool is_fully_connected, bool is_flatten)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, kernel_dims, conv_info, has_bias, is_fully_connected, is_flatten));
return Status{};
}