diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp | 33 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 3 |
2 files changed, 23 insertions, 13 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp index 9658a35560..6aa618f7b4 100644 --- a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp +++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp @@ -28,13 +28,18 @@ public: } else if (layer.GetType() == LayerType::Output) { - // if the inputs of this layer are DataType::Float32 - // add a ConvertFloat16ToFloat32 layer before each of the inputs - if (layer.GetDataType() == DataType::Float32) + // For DetectionPostProcess Layer output is always Float32 regardless of input type + Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer(); + if (connectedLayer.GetType() != LayerType::DetectionPostProcess) { - // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false - // here, otherwise it will expect the inputs to be DataType::Float16 - InsertConvertFp16ToFp32LayersBefore(graph, layer, false); + // if the inputs of this layer are DataType::Float32 + // add a ConvertFloat16ToFloat32 layer before each of the inputs + if (layer.GetDataType() == DataType::Float32) + { + // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false + // here, otherwise it will expect the inputs to be DataType::Float16 + InsertConvertFp16ToFp32LayersBefore(graph, layer, false); + } } } else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32) @@ -57,14 +62,18 @@ public: } } - // change outputs to DataType::Float16 - for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) + // For DetectionPostProcess Layer output is always Float32 regardless of input type + if (layer.GetType() != LayerType::DetectionPostProcess) { - TensorInfo convertInfo = output->GetTensorInfo(); - if (convertInfo.GetDataType() == DataType::Float32) + // change outputs to DataType::Float16 + for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output) { - convertInfo.SetDataType(DataType::Float16); - output->SetTensorInfo(convertInfo); + TensorInfo convertInfo = output->GetTensorInfo(); + if (convertInfo.GetDataType() == DataType::Float32) + { + convertInfo.SetDataType(DataType::Float16); + output->SetTensorInfo(convertInfo); + } } } } diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index f48c120203..b3feae6713 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -722,10 +722,11 @@ bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncod bool supported = true; - std::array<DataType,5> supportedInputTypes = + std::array<DataType,6> supportedInputTypes = { DataType::BFloat16, DataType::Float32, + DataType::Float16, DataType::QAsymmS8, DataType::QAsymmU8, DataType::QSymmS16 |