diff options
author | Gian Marco Iodice <gianmarco.iodice@arm.com> | 2018-11-27 16:38:33 +0000 |
---|---|---|
committer | Anthony Barbier <Anthony.barbier@arm.com> | 2018-11-28 16:56:19 +0000 |
commit | 618493d9823936799501334d06572c5f2d8da319 (patch) | |
tree | aa7576bf8b920559a40fdb1b5a95669a440d5027 /src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp | |
parent | 7c7b2e6c22119ad9db130c2c8bba60eb01bf10a3 (diff) | |
download | ComputeLibrary-618493d9823936799501334d06572c5f2d8da319.tar.gz |
COMPMID-1813: Fix bias == nullptr in NEDirectConvolutionLayerOutputStageKernel
Wrong check in the function
Change-Id: I38e4d5f01039c8352c0e83f0711455af85f9c3fe
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index a571d54501..d3ab5490a4 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -498,6 +498,8 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); + const bool has_bias = bias != nullptr; + // Set appropriate function if(input->info()->data_layout() == DataLayout::NCHW) { @@ -511,13 +513,27 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, false> : &output_stage_nchw<float16_t, float16_t, false, false>; + } break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nchw<float, float, true, false> : &output_stage_nchw<float, float, false, false>; + } break; } default: @@ -532,19 +548,33 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const { case DataType::S32: { - _func = (output == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>; + _func = (bias == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>; break; } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, false> : &output_stage_nhwc<float16_t, float16_t, false, false>; + } break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, false> : &output_stage_nhwc<float, float, false, false>; + } break; } default: |