aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEConvolutionLayer.cpp7
1 files changed, 2 insertions, 5 deletions
diff --git a/src/runtime/NEON/functions/NEConvolutionLayer.cpp b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
index b38d6617d5..dc8652747f 100644
--- a/src/runtime/NEON/functions/NEConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEConvolutionLayer.cpp
@@ -24,6 +24,7 @@
#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h"
#include "arm_compute/core/PixelValue.h"
+#include "arm_compute/core/Size2D.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
@@ -41,7 +42,6 @@ NEConvolutionLayerReshapeWeights::NEConvolutionLayerReshapeWeights()
void NEConvolutionLayerReshapeWeights::configure(const ITensor *weights, const ITensor *biases, ITensor *output, bool transpose1xW)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(weights, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(weights, output);
ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
@@ -97,8 +97,6 @@ NEConvolutionLayer::NEConvolutionLayer()
void NEConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, weights, output);
ARM_COMPUTE_ERROR_ON(!weights_info.are_reshaped() && weights->info()->dimension(2) != input->info()->dimension(2));
@@ -106,7 +104,6 @@ void NEConvolutionLayer::configure(const ITensor *input, const ITensor *weights,
if(biases != nullptr)
{
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::QS8, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input, biases);
ARM_COMPUTE_ERROR_ON(!weights_info.are_reshaped() && biases->info()->dimension(0) != weights->info()->dimension(3));
@@ -197,7 +194,7 @@ void NEConvolutionLayer::configure(const ITensor *input, const ITensor *weights,
_gemm_output.allocator()->init(TensorInfo(shape_gemm, 1, dt, fixed_point_position));
// Configure kernels
- _input_im2col_kernel.configure(input, &_input_im2col_reshaped, std::make_pair(conv_w, conv_h), conv_info, _has_bias);
+ _input_im2col_kernel.configure(input, &_input_im2col_reshaped, Size2D(kernel_width, kernel_height), conv_info, _has_bias);
if(_is_fully_connected_convolution)
{
_mm_kernel.configure(&_input_im2col_reshaped, weights, &_gemm_output, 1.0f);