diff options
Diffstat (limited to 'src/core/NEON/kernels/NEWeightsReshapeKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEWeightsReshapeKernel.cpp | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp index e9b76e7967..ac688e1381 100644 --- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp +++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp @@ -97,13 +97,13 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32); ARM_COMPUTE_ERROR_ON_NULLPTR(output); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != input->info()->dimension(1)); - const DataType dt = input->info()->data_type(); - const int fixed_point_position = input->info()->fixed_point_position(); - - TensorShape output_shape{ input->info()->tensor_shape() }; + const int fixed_point_position = input->info()->fixed_point_position(); + const DataType dt = input->info()->data_type(); + const TensorShape &input_shape = input->info()->tensor_shape(); + TensorShape output_shape{ input_shape }; output_shape.collapse(3); + const size_t tmp_dim = output_shape[0]; output_shape.set(0, output_shape[1]); output_shape.set(1, tmp_dim + (bias != nullptr ? 1 : 0)); |