diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/Network.cpp | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 42d7ae33ac..db7b4c9bb3 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -709,12 +709,54 @@ OptimizationResult AttemptBackendAssignment(BackendSettings& backendSettings, && layer->GetType() != LayerType::ConvertFp32ToFp16 && layer->GetType() != LayerType::ConvertFp16ToFp32) { + auto ConstantLayerFromFp16ToFp32 = [](Layer& layer) + { + if (layer.GetType() == LayerType::Constant) + { + ConstantLayer* constantLayer = PolymorphicDowncast<ConstantLayer*>(&layer); + + auto& info = constantLayer->m_LayerOutput->GetTensorInfo(); + + if (info.GetDataType() == DataType::Float16) + { + std::vector<float> newValues(info.GetNumElements()); + + armnnUtils::FloatingPointConverter::ConvertFloat16To32( + constantLayer->m_LayerOutput->GetConstTensor<Half>(), + info.GetNumElements(), + newValues.data()); + + TensorInfo newInfo(info); + newInfo.SetDataType(DataType::Float32); + ConstTensor newInput(newInfo, newValues); + constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput)); + + layer.GetOutputSlot(0).SetTensorInfo(newInfo); + } + } + }; + + bool checkType = false; + + for (auto inputSlot : layer->GetInputSlots()) + { + auto connectedOutputSlot = inputSlot.GetConnectedOutputSlot(); + if (connectedOutputSlot->GetOwningLayer().GetType() == LayerType::Constant) + { + if (connectedOutputSlot->GetNumConnections() == 1) + { + checkType = true; + ConstantLayerFromFp16ToFp32(connectedOutputSlot->GetOwningLayer()); + } + } + } + // Insert FP16 -> FP32 conversion layer before current layer std::vector<ConvertFp16ToFp32Layer*> convertFp16ToFp32Layers; if (dataTypeIn == DataType::Float16) { convertFp16ToFp32Layers = - InsertConvertFp16ToFp32LayersBefore(graph, *layer); + InsertConvertFp16ToFp32LayersBefore(graph, *layer, checkType); } // Insert FP32 -> FP16 conversion layer after current layer |