aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NetworkUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/NetworkUtils.cpp')
-rw-r--r--src/armnn/NetworkUtils.cpp50
1 files changed, 49 insertions, 1 deletions
diff --git a/src/armnn/NetworkUtils.cpp b/src/armnn/NetworkUtils.cpp
index 7597798fa4..5ff0e6c4e1 100644
--- a/src/armnn/NetworkUtils.cpp
+++ b/src/armnn/NetworkUtils.cpp
@@ -1,10 +1,12 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "NetworkUtils.hpp"
+#include <armnnUtils/FloatingPointConverter.hpp>
+#include <BFloat16.hpp>
#include "SubgraphViewSelector.hpp"
#include <armnn/Exceptions.hpp>
@@ -272,4 +274,50 @@ std::vector<DebugLayer*> InsertDebugLayerAfter(Graph& graph, Layer& layer)
return debugLayers;
}
+bool RevertConstantWeightsToFP32(Layer* layer)
+{
+ if (layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
+ {
+ // Revert Weights on Constant Layer to FP32 so they can be accessed by Conv2d or FullyConnected
+ // This prevents a conversion layer being added in during backend assignment which blocks
+ // the RedirectMembersToConstantInputs backward compatibility workaround/optimization.
+ auto constantLayerInfo = layer->GetInputSlot(1).GetConnection()->GetTensorInfo();
+
+ if (constantLayerInfo.IsConstant() && constantLayerInfo.GetDataType() == DataType::BFloat16)
+ {
+ std::vector<float> newValues(constantLayerInfo.GetNumElements());
+
+ auto weightLayer = PolymorphicDowncast<ConstantLayer*>(
+ &layer->GetInputSlot(1).GetConnection()->GetOwningIConnectableLayer());
+ armnnUtils::FloatingPointConverter::ConvertBFloat16ToFloat32(
+ weightLayer->m_LayerOutput->GetConstTensor<BFloat16>(),
+ constantLayerInfo.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(constantLayerInfo.GetShape(), DataType::Float32);
+ newInfo.SetConstant(true);
+ ConstTensor newInput(newInfo, newValues);
+ weightLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
+ weightLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
+
+ // Connect Conv2d/FullyConnected to InputLayer directly leaving out
+ // the ConversionLayer to be cleaned up later
+ auto& conversionLayer = layer->GetInputSlot(0).GetConnection()->GetOwningIConnectableLayer();
+ auto actualInputOutputSlot = conversionLayer.GetInputSlot(0).GetConnection();
+
+ auto& conversionLayerOutputSlot =
+ layer->GetInputSlot(0).GetConnection()->GetOwningIConnectableLayer().GetOutputSlot(0);
+ auto& conversionLayerInputSlot =
+ layer->GetInputSlot(0).GetConnection()->GetOwningIConnectableLayer().GetInputSlot(0);
+ actualInputOutputSlot->Disconnect(conversionLayerInputSlot);
+ conversionLayerOutputSlot.Disconnect(layer->GetInputSlot(0));
+
+ actualInputOutputSlot->Connect(layer->GetInputSlot(0));
+
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace armnn