From 618493d9823936799501334d06572c5f2d8da319 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 27 Nov 2018 16:38:33 +0000 Subject: COMPMID-1813: Fix bias == nullptr in NEDirectConvolutionLayerOutputStageKernel Wrong check in the function Change-Id: I38e4d5f01039c8352c0e83f0711455af85f9c3fe --- .../NEDirectConvolutionLayerOutputStageKernel.cpp | 40 +++++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp') diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index a571d54501..d3ab5490a4 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -498,6 +498,8 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); + const bool has_bias = bias != nullptr; + // Set appropriate function if(input->info()->data_layout() == DataLayout::NCHW) { @@ -511,13 +513,27 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (output == nullptr) ? &output_stage_nchw : &output_stage_nchw; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nchw : &output_stage_nchw; + } + else + { + _func = (output == nullptr) ? &output_stage_nchw : &output_stage_nchw; + } break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (output == nullptr) ? &output_stage_nchw : &output_stage_nchw; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nchw : &output_stage_nchw; + } + else + { + _func = (output == nullptr) ? &output_stage_nchw : &output_stage_nchw; + } break; } default: @@ -532,19 +548,33 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const { case DataType::S32: { - _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + _func = (bias == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; break; } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + } + else + { + _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + } break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + } + else + { + _func = (output == nullptr) ? &output_stage_nhwc : &output_stage_nhwc; + } break; } default: -- cgit v1.2.1