From 13edbff0820c3b41e7dd766db5a9d6ff65fcda2a Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Mon, 26 Jun 2017 17:20:16 +0100 Subject: COMPMID-432 - Extended Convolution Layer to support rectangular kernels Change-Id: I99be1efede4de6dd63ce103fb11196c413757621 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/79252 Tested-by: Kaizen Reviewed-by: Moritz Pflanzer --- src/core/NEON/kernels/NEWeightsReshapeKernel.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'src/core/NEON/kernels/NEWeightsReshapeKernel.cpp') diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp index e9b76e7967..ac688e1381 100644 --- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp +++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp @@ -97,13 +97,13 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != input->info()->dimension(1)); - const DataType dt = input->info()->data_type(); - const int fixed_point_position = input->info()->fixed_point_position(); - - TensorShape output_shape{ input->info()->tensor_shape() }; + const int fixed_point_position = input->info()->fixed_point_position(); + const DataType dt = input->info()->data_type(); + const TensorShape &input_shape = input->info()->tensor_shape(); + TensorShape output_shape{ input_shape }; output_shape.collapse(3); + const size_t tmp_dim = output_shape[0]; output_shape.set(0, output_shape[1]); output_shape.set(1, tmp_dim + (bias != nullptr ? 1 : 0)); -- cgit v1.2.1