diff options
Diffstat (limited to 'src/armnn/NetworkUtils.cpp')
-rw-r--r-- | src/armnn/NetworkUtils.cpp | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/src/armnn/NetworkUtils.cpp b/src/armnn/NetworkUtils.cpp index 666ce3d069..7597798fa4 100644 --- a/src/armnn/NetworkUtils.cpp +++ b/src/armnn/NetworkUtils.cpp @@ -98,6 +98,15 @@ std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersBefore(Graph& for (auto&& inputSlot = layer.BeginInputSlots(); inputSlot != layer.EndInputSlots(); ++inputSlot) { bool allowInsert = true; + + if ((layer.GetType() == LayerType::Convolution2d || + layer.GetType() == LayerType::FullyConnected || + layer.GetType() == LayerType::DepthwiseConvolution2d) + && inputSlot->GetSlotIndex() == 2) + { + // Refrain from reducing bias to Bf16 + continue; + } if (expectCorrectInputType) { // Only insert ConvertFp32ToBf16Layer before FP32 input slots |