diff options
Diffstat (limited to 'src/armnn/NetworkUtils.cpp')
-rw-r--r-- | src/armnn/NetworkUtils.cpp | 50 |
1 files changed, 49 insertions, 1 deletions
diff --git a/src/armnn/NetworkUtils.cpp b/src/armnn/NetworkUtils.cpp index 7597798fa4..5ff0e6c4e1 100644 --- a/src/armnn/NetworkUtils.cpp +++ b/src/armnn/NetworkUtils.cpp @@ -1,10 +1,12 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "NetworkUtils.hpp" +#include <armnnUtils/FloatingPointConverter.hpp> +#include <BFloat16.hpp> #include "SubgraphViewSelector.hpp" #include <armnn/Exceptions.hpp> @@ -272,4 +274,50 @@ std::vector<DebugLayer*> InsertDebugLayerAfter(Graph& graph, Layer& layer) return debugLayers; } +bool RevertConstantWeightsToFP32(Layer* layer) +{ + if (layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected) + { + // Revert Weights on Constant Layer to FP32 so they can be accessed by Conv2d or FullyConnected + // This prevents a conversion layer being added in during backend assignment which blocks + // the RedirectMembersToConstantInputs backward compatibility workaround/optimization. + auto constantLayerInfo = layer->GetInputSlot(1).GetConnection()->GetTensorInfo(); + + if (constantLayerInfo.IsConstant() && constantLayerInfo.GetDataType() == DataType::BFloat16) + { + std::vector<float> newValues(constantLayerInfo.GetNumElements()); + + auto weightLayer = PolymorphicDowncast<ConstantLayer*>( + &layer->GetInputSlot(1).GetConnection()->GetOwningIConnectableLayer()); + armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32( + weightLayer->m_LayerOutput->GetConstTensor<BFloat16>(), + constantLayerInfo.GetNumElements(), + newValues.data()); + + TensorInfo newInfo(constantLayerInfo.GetShape(), DataType::Float32); + newInfo.SetConstant(true); + ConstTensor newInput(newInfo, newValues); + weightLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput)); + weightLayer->GetOutputSlot(0).SetTensorInfo(newInfo); + + // Connect Conv2d/FullyConnected to InputLayer directly leaving out + // the ConversionLayer to be cleaned up later + auto& conversionLayer = layer->GetInputSlot(0).GetConnection()->GetOwningIConnectableLayer(); + auto actualInputOutputSlot = conversionLayer.GetInputSlot(0).GetConnection(); + + auto& conversionLayerOutputSlot = + layer->GetInputSlot(0).GetConnection()->GetOwningIConnectableLayer().GetOutputSlot(0); + auto& conversionLayerInputSlot = + layer->GetInputSlot(0).GetConnection()->GetOwningIConnectableLayer().GetInputSlot(0); + actualInputOutputSlot->Disconnect(conversionLayerInputSlot); + conversionLayerOutputSlot.Disconnect(layer->GetInputSlot(0)); + + actualInputOutputSlot->Connect(layer->GetInputSlot(0)); + + return true; + } + } + return false; +} + } // namespace armnn |