diff options
Diffstat (limited to 'src/armnn/NetworkUtils.cpp')
-rw-r--r-- | src/armnn/NetworkUtils.cpp | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/src/armnn/NetworkUtils.cpp b/src/armnn/NetworkUtils.cpp index 8653a08510..0549a115d4 100644 --- a/src/armnn/NetworkUtils.cpp +++ b/src/armnn/NetworkUtils.cpp @@ -87,6 +87,45 @@ std::vector<ConvertBf16ToFp32Layer*> InsertConvertBf16ToFp32LayersBefore(Graph& return convertLayers; } +std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersBefore(Graph& graph, + Layer& layer, + bool expectCorrectInputType) +{ + std::vector<ConvertFp32ToBf16Layer*> convertLayers; + convertLayers.reserve(layer.GetNumInputSlots()); + + // Insert a ConvertFp32ToBf16Layer before each input slot + for (auto&& inputSlot = layer.BeginInputSlots(); inputSlot != layer.EndInputSlots(); ++inputSlot) + { + bool allowInsert = true; + if (expectCorrectInputType) + { + // Only insert ConvertFp32ToBf16Layer before FP32 input slots + OutputSlot* connectedOutputSlot = inputSlot->GetConnectedOutputSlot(); + allowInsert = + connectedOutputSlot && connectedOutputSlot->GetTensorInfo().GetDataType() == DataType::Float32; + } + + if (allowInsert) + { + const std::string name = + std::string("convert_fp32_to_bf16-" + std::to_string(inputSlot->GetSlotIndex()) + "-") + + layer.GetName(); + ConvertFp32ToBf16Layer* convertLayer = + graph.InsertNewLayer<ConvertFp32ToBf16Layer>(*inputSlot, name.c_str()); + + TensorInfo convertInfo = convertLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(); + convertInfo.SetDataType(DataType::BFloat16); + + convertLayer->GetOutputSlot().SetTensorInfo(convertInfo); + + convertLayers.emplace_back(convertLayer); + } + } + + return convertLayers; +} + std::vector<ConvertFp16ToFp32Layer*> InsertConvertFp16ToFp32LayersBefore(Graph& graph, Layer& layer, bool expectCorrectInputType) |