aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2018-11-27 16:38:33 +0000
committerAnthony Barbier <Anthony.barbier@arm.com>2018-11-28 16:56:19 +0000
commit618493d9823936799501334d06572c5f2d8da319 (patch)
treeaa7576bf8b920559a40fdb1b5a95669a440d5027 /src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
parent7c7b2e6c22119ad9db130c2c8bba60eb01bf10a3 (diff)
downloadComputeLibrary-618493d9823936799501334d06572c5f2d8da319.tar.gz
COMPMID-1813: Fix bias == nullptr in NEDirectConvolutionLayerOutputStageKernel
Wrong check in the function Change-Id: I38e4d5f01039c8352c0e83f0711455af85f9c3fe
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp40
1 files changed, 35 insertions, 5 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index a571d54501..d3ab5490a4 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -498,6 +498,8 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
+ const bool has_bias = bias != nullptr;
+
// Set appropriate function
if(input->info()->data_layout() == DataLayout::NCHW)
{
@@ -511,13 +513,27 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, false> : &output_stage_nchw<float16_t, float16_t, false, false>;
+ }
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nchw<float, float, true, false> : &output_stage_nchw<float, float, false, false>;
+ }
break;
}
default:
@@ -532,19 +548,33 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
{
case DataType::S32:
{
- _func = (output == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>;
+ _func = (bias == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>;
break;
}
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
{
- _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, false> : &output_stage_nhwc<float16_t, float16_t, false, false>;
+ }
break;
}
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
+ if(has_bias)
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
+ }
+ else
+ {
+ _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, false> : &output_stage_nhwc<float, float, false, false>;
+ }
break;
}
default: