aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/Network.cpp44
1 files changed, 43 insertions, 1 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 42d7ae33ac..db7b4c9bb3 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -709,12 +709,54 @@ OptimizationResult AttemptBackendAssignment(BackendSettings& backendSettings,
&& layer->GetType() != LayerType::ConvertFp32ToFp16
&& layer->GetType() != LayerType::ConvertFp16ToFp32)
{
+ auto ConstantLayerFromFp16ToFp32 = [](Layer& layer)
+ {
+ if (layer.GetType() == LayerType::Constant)
+ {
+ ConstantLayer* constantLayer = PolymorphicDowncast<ConstantLayer*>(&layer);
+
+ auto& info = constantLayer->m_LayerOutput->GetTensorInfo();
+
+ if (info.GetDataType() == DataType::Float16)
+ {
+ std::vector<float> newValues(info.GetNumElements());
+
+ armnnUtils::FloatingPointConverter::ConvertFloat16To32(
+ constantLayer->m_LayerOutput->GetConstTensor<Half>(),
+ info.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(info);
+ newInfo.SetDataType(DataType::Float32);
+ ConstTensor newInput(newInfo, newValues);
+ constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
+
+ layer.GetOutputSlot(0).SetTensorInfo(newInfo);
+ }
+ }
+ };
+
+ bool checkType = false;
+
+ for (auto inputSlot : layer->GetInputSlots())
+ {
+ auto connectedOutputSlot = inputSlot.GetConnectedOutputSlot();
+ if (connectedOutputSlot->GetOwningLayer().GetType() == LayerType::Constant)
+ {
+ if (connectedOutputSlot->GetNumConnections() == 1)
+ {
+ checkType = true;
+ ConstantLayerFromFp16ToFp32(connectedOutputSlot->GetOwningLayer());
+ }
+ }
+ }
+
// Insert FP16 -> FP32 conversion layer before current layer
std::vector<ConvertFp16ToFp32Layer*> convertFp16ToFp32Layers;
if (dataTypeIn == DataType::Float16)
{
convertFp16ToFp32Layers =
- InsertConvertFp16ToFp32LayersBefore(graph, *layer);
+ InsertConvertFp16ToFp32LayersBefore(graph, *layer, checkType);
}
// Insert FP32 -> FP16 conversion layer after current layer