aboutsummaryrefslogtreecommitdiff
path: root/src/core
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2017-06-26 17:20:16 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:15:39 +0100
commit13edbff0820c3b41e7dd766db5a9d6ff65fcda2a (patch)
treeb17fbea676fe0a77153b1610f88ebc6faa30e023 /src/core
parent238cfc06ed377045df9b76c2047081d27ab9ff66 (diff)
downloadComputeLibrary-13edbff0820c3b41e7dd766db5a9d6ff65fcda2a.tar.gz
COMPMID-432 - Extended Convolution Layer to support rectangular kernels
Change-Id: I99be1efede4de6dd63ce103fb11196c413757621 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79252 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Moritz Pflanzer <moritz.pflanzer@arm.com>
Diffstat (limited to 'src/core')
-rw-r--r--src/core/CL/cl_kernels/convolution_layer.cl50
-rw-r--r--src/core/CL/kernels/CLCol2ImKernel.cpp6
-rw-r--r--src/core/CL/kernels/CLIm2ColKernel.cpp66
-rw-r--r--src/core/CL/kernels/CLWeightsReshapeKernel.cpp17
-rw-r--r--src/core/NEON/kernels/NEIm2ColKernel.cpp31
-rw-r--r--src/core/NEON/kernels/NEWeightsReshapeKernel.cpp10
6 files changed, 78 insertions, 102 deletions
diff --git a/src/core/CL/cl_kernels/convolution_layer.cl b/src/core/CL/cl_kernels/convolution_layer.cl
index bd5dfaff68..837fdd70fe 100644
--- a/src/core/CL/cl_kernels/convolution_layer.cl
+++ b/src/core/CL/cl_kernels/convolution_layer.cl
@@ -27,7 +27,7 @@
*
* @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=short
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16, F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -35,13 +35,13 @@
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Same as input
+ * @param[out] dst_ptr Pointer to the destination tensor. Same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- * @param[in] bias_ptr Pointer to the bias tensor. Same as input
+ * @param[in] bias_ptr Pointer to the bias tensor. Same as @p src_ptr
* @param[in] bias_stride_x Stride of the bias tensor in X dimension (in bytes)
* @param[in] bias_step_x bias_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] bias_offset_first_element_in_bytes The offset of the first element in the source tensor
@@ -93,12 +93,13 @@ __kernel void reshape_to_columns(
}
}
+#if(defined CONVOLVED_WIDTH && defined STRIDE_X && defined STRIDE_Y && defined PAD_X && defined PAD_Y && defined KERNEL_WIDTH && defined KERNEL_HEIGHT && defined KERNEL_DEPTH && SRC_WIDTH && SRC_HEIGHT)
/** This kernel performs a reshaping of the input tensor to a tensor used to perform convolution using GEMM.
*
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
* @note In case biases will be added to the convolution -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16, F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -106,48 +107,36 @@ __kernel void reshape_to_columns(
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: F16, F32
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
* @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
- * @param[in] kernel_size The convolution kernel size
- * @param[in] kernel_depth The kernel depth
- * @param[in] width The output tensor width
- * @param[in] input_dims The input tensor dimensions
- * @param[in] strides The strides of the im2col operation
- * @param[in] paddings The input tensor paddings
*/
__kernel void im2col_generic(
TENSOR3D_DECLARATION(src),
- IMAGE_DECLARATION(dst),
- int kernel_size,
- int kernel_depth,
- int width,
- int2 input_dims,
- int2 strides,
- int2 paddings)
+ IMAGE_DECLARATION(dst))
{
Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Image dst = CONVERT_TO_IMAGE_STRUCT_NO_STEP(dst);
// Determine output index
- uint idx = (get_global_id(1) * width + get_global_id(0)) * dst.stride_y;
+ uint idx = (get_global_id(1) * CONVOLVED_WIDTH + get_global_id(0)) * dst.stride_y;
__global uchar *output_ptr = dst.ptr + idx;
// Determine current input index
- const int top_left_x = get_global_id(0) * strides.x - paddings.x;
- const int top_left_y = get_global_id(1) * strides.y - paddings.y;
+ const int top_left_x = get_global_id(0) * STRIDE_X - PAD_X;
+ const int top_left_y = get_global_id(1) * STRIDE_Y - PAD_Y;
// Linearize convolution elements
- for(int d = 0; d < kernel_depth; ++d)
+ for(int d = 0; d < KERNEL_DEPTH; ++d)
{
- for(int y = top_left_y, y_e = top_left_y + kernel_size; y < y_e; ++y)
+ for(int y = top_left_y, y_e = top_left_y + KERNEL_HEIGHT; y < y_e; ++y)
{
- for(int x = top_left_x, x_e = top_left_x + kernel_size; x < x_e; ++x, output_ptr += dst.stride_x)
+ for(int x = top_left_x, x_e = top_left_x + KERNEL_WIDTH; x < x_e; ++x, output_ptr += dst.stride_x)
{
- if(x < 0 || x >= input_dims.x || y < 0 || y >= input_dims.y)
+ if(x < 0 || x >= SRC_WIDTH || y < 0 || y >= SRC_HEIGHT)
{
*((__global DATA_TYPE *)output_ptr) = 0;
}
@@ -160,21 +149,22 @@ __kernel void im2col_generic(
}
#if defined HAS_BIAS
- *((__global DATA_TYPE *)output_ptr) = 1;
+ *((__global DATA_TYPE *)output_ptr) = (DATA_TYPE)1;
#endif
}
+#endif //(defined CONVOLVED_WIDTH && defined STRIDE_X && defined STRIDE_Y && defined PAD_X && defined PAD_Y && defined KERNEL_WIDTH && defined KERNEL_HEIGHT && defined KERNEL_DEPTH && SRC_WIDTH && SRC_HEIGHT)
/** This kernel performs a reshaping of the output of the convolution layer.
*
* @note The data type must be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=float
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16, F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
* @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: F16, F32
+ * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
@@ -202,7 +192,7 @@ __kernel void col2im(
* @note Datatype should be given as a preprocessor argument using -DDATA_TYPE=type. e.g. -DDATA_TYPE=float
* @note In case biases will be added in late stage, -DHAS_BIAS has to be passed to append the final matrix with 1 in each row.
*
- * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16, F32
+ * @param[in] src_ptr Pointer to the source tensor. Supported data types: F16/F32
* @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
* @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
@@ -210,7 +200,7 @@ __kernel void col2im(
* @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
* @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
* @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
- * @param[out] dst_ptr Pointer to the destination tensor. Same as input.
+ * @param[out] dst_ptr Pointer to the destination tensor. Same as @p src_ptr
* @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
diff --git a/src/core/CL/kernels/CLCol2ImKernel.cpp b/src/core/CL/kernels/CLCol2ImKernel.cpp
index ad66c39483..679943ba3e 100644
--- a/src/core/CL/kernels/CLCol2ImKernel.cpp
+++ b/src/core/CL/kernels/CLCol2ImKernel.cpp
@@ -61,8 +61,12 @@ void CLCol2ImKernel::configure(const ICLTensor *input, ICLTensor *output, std::p
// Configure window
Window win = calculate_max_window(*input->info(), Steps());
+
// The CLCol2ImKernel doesn't need padding so update_window_and_padding() can be skipped
- output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
+ Coordinates coord;
+ coord.set_num_dimensions(output->info()->num_dimensions());
+ output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
+
ICLKernel::configure(win);
}
diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp
index 8c0fe26666..092f495f92 100644
--- a/src/core/CL/kernels/CLIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLIm2ColKernel.cpp
@@ -29,8 +29,10 @@
#include "arm_compute/core/CL/OpenCL.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "support/ToolchainSupport.h"
#include <cmath>
#include <tuple>
@@ -38,14 +40,14 @@
using namespace arm_compute;
CLIm2ColKernel::CLIm2ColKernel()
- : _input(nullptr), _output(nullptr), _convolved_dims(), _conv_info(), _kernel_size(0), _num_elems_processed_per_iteration(1), _run_func(nullptr)
+ : _input(nullptr), _output(nullptr), _convolved_dims(), _num_elems_processed_per_iteration(1), _run_func(nullptr)
{
}
-void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, std::pair<unsigned int, unsigned int> convolved_dims, const PadStrideInfo &conv_info, bool has_bias)
+void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
_input = input;
_output = output;
@@ -70,44 +72,23 @@ void CLIm2ColKernel::configure(const ICLTensor *input, ICLTensor *output, std::p
if(!run_img2col_reduced)
{
- _convolved_dims = convolved_dims;
- _conv_info = conv_info;
- _kernel_size = std::sqrt((output->info()->dimension(0) - (has_bias ? 1 : 0)) / input->info()->dimension(2));
+ _convolved_dims = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1),
+ kernel_dims.width, kernel_dims.height,
+ conv_info);
_num_elems_processed_per_iteration = output->info()->dimension(0);
- _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("im2col_generic", build_opts));
+ build_opts.emplace("-DKERNEL_WIDTH=" + support::cpp11::to_string(kernel_dims.width));
+ build_opts.emplace("-DKERNEL_HEIGHT=" + support::cpp11::to_string(kernel_dims.height));
+ build_opts.emplace("-DKERNEL_DEPTH=" + support::cpp11::to_string(input->info()->dimension(2)));
+ build_opts.emplace("-DCONVOLVED_WIDTH=" + support::cpp11::to_string(_convolved_dims.first));
+ build_opts.emplace("-DSTRIDE_X=" + support::cpp11::to_string(conv_info.stride().first));
+ build_opts.emplace("-DSTRIDE_Y=" + support::cpp11::to_string(conv_info.stride().second));
+ build_opts.emplace("-DPAD_X=" + support::cpp11::to_string(conv_info.pad().first));
+ build_opts.emplace("-DPAD_Y=" + support::cpp11::to_string(conv_info.pad().second));
+ build_opts.emplace("-DSRC_WIDTH=" + support::cpp11::to_string(input->info()->dimension(0)));
+ build_opts.emplace("-DSRC_HEIGHT=" + support::cpp11::to_string(input->info()->dimension(1)));
- // Create static kernel arguments
- const cl_int2 input_dims =
- {
- {
- static_cast<cl_int>(input->info()->dimension(0)),
- static_cast<cl_int>(input->info()->dimension(1)),
- }
- };
- const cl_int2 strides =
- {
- {
- stride_x,
- stride_y,
- }
- };
- const cl_int2 paddings =
- {
- {
- pad_x,
- pad_y,
- }
- };
-
- // Set static kernel arguments
- unsigned int idx = num_arguments_per_2D_tensor() + num_arguments_per_3D_tensor();
- _kernel.setArg<cl_int>(idx++, _kernel_size);
- _kernel.setArg<cl_int>(idx++, input->info()->dimension(2) /* depth */);
- _kernel.setArg<cl_int>(idx++, _convolved_dims.first /* output width */);
- _kernel.setArg<cl_int2>(idx++, input_dims);
- _kernel.setArg<cl_int2>(idx++, strides);
- _kernel.setArg<cl_int2>(idx++, paddings);
+ _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("im2col_generic", build_opts));
_run_func = &CLIm2ColKernel::run_generic;
}
@@ -136,13 +117,6 @@ void CLIm2ColKernel::run_generic(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
- int pad_x = 0;
- int pad_y = 0;
- int stride_x = 0;
- int stride_y = 0;
- std::tie(pad_x, pad_y) = _conv_info.pad();
- std::tie(stride_x, stride_y) = _conv_info.stride();
-
// Get initial windows
Window slice = window.first_slice_window_3D();
Window slice_in = window.first_slice_window_3D();
diff --git a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
index 845bd3799d..82634164de 100644
--- a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
+++ b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
@@ -42,14 +42,7 @@ CLWeightsReshapeKernel::CLWeightsReshapeKernel()
void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor *biases, ICLTensor *output)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
- ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 4) && (biases->info()->num_dimensions() != 1));
- ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 5) && (biases->info()->num_dimensions() != 2));
- ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 4) && (biases->info()->dimension(0) != input->info()->tensor_shape()[3]));
- ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 5) && (biases->info()->dimension(0) != input->info()->tensor_shape()[3] || biases->info()->dimension(1) != input->info()->tensor_shape()[4]));
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != input->info()->dimension(1));
const DataType dt = input->info()->data_type();
const int fixed_point_position = input->info()->fixed_point_position();
@@ -67,6 +60,16 @@ void CLWeightsReshapeKernel::configure(const ICLTensor *input, const ICLTensor *
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, output);
+ if(biases != nullptr)
+ {
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
+ ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
+ ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 4) && (biases->info()->num_dimensions() != 1));
+ ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 5) && (biases->info()->num_dimensions() != 2));
+ ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 4) && (biases->info()->dimension(0) != input->info()->tensor_shape()[3]));
+ ARM_COMPUTE_ERROR_ON((input->info()->num_dimensions() == 5) && (biases->info()->dimension(0) != input->info()->tensor_shape()[3] || biases->info()->dimension(1) != input->info()->tensor_shape()[4]));
+ }
+
_biases = biases;
_output = output;
_input = input;
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 875c08ed42..99daa2e5e7 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -27,6 +27,7 @@
#include "arm_compute/core/FixedPoint.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
@@ -47,7 +48,8 @@ inline void linearize_volume(const uint8_t *const in_ptr,
bool has_bias,
int top_left_x,
int top_left_y,
- int kernel_size,
+ int kernel_width,
+ int kernel_height,
int kernel_depth,
int input_w,
int input_h,
@@ -56,9 +58,9 @@ inline void linearize_volume(const uint8_t *const in_ptr,
int input_stride_z,
int fixed_point_position)
{
- const int kernel_size2 = kernel_size * kernel_size;
- const int x_e = top_left_x + kernel_size;
- const int y_e = top_left_y + kernel_size;
+ const int kernel_size2 = kernel_width * kernel_height;
+ const int x_e = top_left_x + kernel_width;
+ const int y_e = top_left_y + kernel_height;
// Linearize volume
int d = 0;
@@ -109,8 +111,8 @@ inline void linearize_volume(const uint8_t *const in_ptr,
if((y < 0 || y >= input_h) && has_pads)
{
// All the values will be zeros
- memset(out_ptr, 0, kernel_size * sizeof(T));
- out_ptr += kernel_size;
+ memset(out_ptr, 0, kernel_width * sizeof(T));
+ out_ptr += kernel_width;
}
else
{
@@ -199,7 +201,8 @@ void NEIm2ColKernel::run_generic(const Window &window)
_has_bias,
top_left_x,
top_left_y,
- static_cast<int>(_kernel_size),
+ static_cast<int>(_kernel_width),
+ static_cast<int>(_kernel_height),
kernel_depth,
input_w,
input_h,
@@ -260,22 +263,24 @@ void NEIm2ColKernel::run_reduced(const Window &window)
}
NEIm2ColKernel::NEIm2ColKernel()
- : _func(), _input(nullptr), _output(nullptr), _convolved_dims(), _conv_info(), _kernel_size(0), _has_bias(false)
+ : _func(), _input(nullptr), _output(nullptr), _convolved_dims(), _conv_info(), _kernel_width(0), _kernel_height(0), _has_bias(false)
{
}
-void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, std::pair<unsigned int, unsigned int> convolved_dims, const PadStrideInfo &conv_info, bool has_bias)
+void NEIm2ColKernel::configure(const ITensor *input, ITensor *output, const Size2D &kernel_dims, const PadStrideInfo &conv_info, bool has_bias)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QS8);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::F16, DataType::F32, DataType::QS8);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
_input = input;
_output = output;
- _convolved_dims = convolved_dims;
_conv_info = conv_info;
- _kernel_size = std::sqrt((output->info()->dimension(0) - (has_bias ? 1 : 0)) / input->info()->dimension(2));
- _has_bias = has_bias;
+ _kernel_width = kernel_dims.width;
+ _kernel_height = kernel_dims.height,
+ _convolved_dims = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1),
+ _kernel_width, _kernel_height,
+ _conv_info);
+ _has_bias = has_bias;
unsigned int pad_x, pad_y, stride_x, stride_y = 0;
std::tie(pad_x, pad_y) = conv_info.pad();
diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
index e9b76e7967..ac688e1381 100644
--- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
+++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
@@ -97,13 +97,13 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
- ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != input->info()->dimension(1));
- const DataType dt = input->info()->data_type();
- const int fixed_point_position = input->info()->fixed_point_position();
-
- TensorShape output_shape{ input->info()->tensor_shape() };
+ const int fixed_point_position = input->info()->fixed_point_position();
+ const DataType dt = input->info()->data_type();
+ const TensorShape &input_shape = input->info()->tensor_shape();
+ TensorShape output_shape{ input_shape };
output_shape.collapse(3);
+
const size_t tmp_dim = output_shape[0];
output_shape.set(0, output_shape[1]);
output_shape.set(1, tmp_dim + (bias != nullptr ? 1 : 0));