diff options
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp | 40 |
1 files changed, 35 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp index a571d54501..d3ab5490a4 100644 --- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp +++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp @@ -498,6 +498,8 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ARM_COMPUTE_ERROR_THROW_ON(win_config.first); INEKernel::configure(win_config.second); + const bool has_bias = bias != nullptr; + // Set appropriate function if(input->info()->data_layout() == DataLayout::NCHW) { @@ -511,13 +513,27 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, false> : &output_stage_nchw<float16_t, float16_t, false, false>; + } break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nchw<float, float, true, false> : &output_stage_nchw<float, float, false, false>; + } break; } default: @@ -532,19 +548,33 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const { case DataType::S32: { - _func = (output == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>; + _func = (bias == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>; break; } #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: { - _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, false> : &output_stage_nhwc<float16_t, float16_t, false, false>; + } break; } #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ case DataType::F32: { - _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>; + if(has_bias) + { + _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>; + } + else + { + _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, false> : &output_stage_nhwc<float, float, false, false>; + } break; } default: |