aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp33
-rw-r--r--src/backends/reference/RefLayerSupport.cpp3
2 files changed, 23 insertions, 13 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
index 9658a35560..6aa618f7b4 100644
--- a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
+++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
@@ -28,13 +28,18 @@ public:
}
else if (layer.GetType() == LayerType::Output)
{
- // if the inputs of this layer are DataType::Float32
- // add a ConvertFloat16ToFloat32 layer before each of the inputs
- if (layer.GetDataType() == DataType::Float32)
+ // For DetectionPostProcess Layer output is always Float32 regardless of input type
+ Layer& connectedLayer = layer.GetInputSlots()[0].GetConnectedOutputSlot()->GetOwningLayer();
+ if (connectedLayer.GetType() != LayerType::DetectionPostProcess)
{
- // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false
- // here, otherwise it will expect the inputs to be DataType::Float16
- InsertConvertFp16ToFp32LayersBefore(graph, layer, false);
+ // if the inputs of this layer are DataType::Float32
+ // add a ConvertFloat16ToFloat32 layer before each of the inputs
+ if (layer.GetDataType() == DataType::Float32)
+ {
+ // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false
+ // here, otherwise it will expect the inputs to be DataType::Float16
+ InsertConvertFp16ToFp32LayersBefore(graph, layer, false);
+ }
}
}
else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32)
@@ -57,14 +62,18 @@ public:
}
}
- // change outputs to DataType::Float16
- for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
+ // For DetectionPostProcess Layer output is always Float32 regardless of input type
+ if (layer.GetType() != LayerType::DetectionPostProcess)
{
- TensorInfo convertInfo = output->GetTensorInfo();
- if (convertInfo.GetDataType() == DataType::Float32)
+ // change outputs to DataType::Float16
+ for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
{
- convertInfo.SetDataType(DataType::Float16);
- output->SetTensorInfo(convertInfo);
+ TensorInfo convertInfo = output->GetTensorInfo();
+ if (convertInfo.GetDataType() == DataType::Float32)
+ {
+ convertInfo.SetDataType(DataType::Float16);
+ output->SetTensorInfo(convertInfo);
+ }
}
}
}
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index f48c120203..b3feae6713 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -722,10 +722,11 @@ bool RefLayerSupport::IsDetectionPostProcessSupported(const TensorInfo& boxEncod
bool supported = true;
- std::array<DataType,5> supportedInputTypes =
+ std::array<DataType,6> supportedInputTypes =
{
DataType::BFloat16,
DataType::Float32,
+ DataType::Float16,
DataType::QAsymmS8,
DataType::QAsymmU8,
DataType::QSymmS16