aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp40
1 files changed, 35 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index a571d54501..d3ab5490a4 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -498,6 +498,8 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
+ const bool has_bias = bias != nullptr;
+
// Set appropriate function
if(input->info()->data_layout() == DataLayout::NCHW)
{
@@ -511,13 +513,27 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, false> : &output_stage_nchw<float16_t, float16_t, false, false>;
+ }
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float, float, true, false> : &output_stage_nchw<float, float, false, false>;
+ }
break;
}
default:
@@ -532,19 +548,33 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
{
case DataType::S32:
{
- _func = (output == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>;
+ _func = (bias == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>;
break;
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, false> : &output_stage_nhwc<float16_t, float16_t, false, false>;
+ }
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, false> : &output_stage_nhwc<float, float, false, false>;
+ }
break;
}
default: