aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp32
1 files changed, 22 insertions, 10 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 573f6a19e8..1797baf78e 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -71,8 +71,6 @@ Status OptimizedNetwork::SerializeToDot(std::ostream& stream) const
return m_Graph->SerializeToDot(stream);
}
-
-
void ReportError(const std::string& errorMessage,
Optional<std::vector<std::string>&> errorMessages)
{
@@ -166,7 +164,12 @@ OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr,
for (auto it = firstLayer; it != lastLayer; ++it)
{
auto layer = *it;
- DataType dataType = layer->GetDataType();
+
+ DataType dataTypeIn = layer->GetNumInputSlots() == 0 ? DataType::Float32 :
+ layer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo().GetDataType();
+ DataType dataTypeOut = layer->GetNumOutputSlots() == 0 ? DataType::Float32 :
+ layer->GetOutputSlot(0).GetTensorInfo().GetDataType();
+
std::string reasonIfUnsupported;
bool found = false;
if (!CheckScaleSetOnQuantizedType(layer, errMessages))
@@ -181,21 +184,29 @@ OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr,
// need to set the compute device on the layer
// before we can check if it is supported
layer->SetBackendId(backend);
- if (!IWorkloadFactory::IsLayerSupported(*layer, dataType, reasonIfUnsupported))
+ if (!IWorkloadFactory::IsLayerSupported(*layer, EmptyOptional(), reasonIfUnsupported))
{
- if (dataType == DataType::Float16)
+ if (dataTypeIn == DataType::Float16 || dataTypeOut == DataType::Float16)
{
if (IWorkloadFactory::IsLayerSupported(*layer, DataType::Float32, reasonIfUnsupported)
&& layer->GetType() != LayerType::ConvertFp32ToFp16
&& layer->GetType() != LayerType::ConvertFp16ToFp32)
{
// Insert FP16 -> FP32 conversion layer before current layer
- std::vector<ConvertFp16ToFp32Layer*> convertFp16ToFp32Layers =
- InsertConvertFp16ToFp32LayersBefore(optNetObjPtr->GetGraph(), *layer);
+ std::vector<ConvertFp16ToFp32Layer*> convertFp16ToFp32Layers;
+ if (dataTypeIn == DataType::Float16)
+ {
+ convertFp16ToFp32Layers =
+ InsertConvertFp16ToFp32LayersBefore(optNetObjPtr->GetGraph(), *layer);
+ }
// Insert FP32 -> FP16 conversion layer after current layer
- std::vector<ConvertFp32ToFp16Layer*> convertFp32ToFp16Layers =
- InsertConvertFp32ToFp16LayersAfter(optNetObjPtr->GetGraph(), *layer);
+ std::vector<ConvertFp32ToFp16Layer*> convertFp32ToFp16Layers;
+ if (dataTypeOut == DataType::Float16)
+ {
+ convertFp32ToFp16Layers =
+ InsertConvertFp32ToFp16LayersAfter(optNetObjPtr->GetGraph(), *layer);
+ }
// Assign a supported backend to the newly introduced conversion layers
auto AssignFirstSupportedBackend = [&](Layer* layer, BackendId preferredBackend)
@@ -258,7 +269,8 @@ OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr,
std::stringstream warningMsg;
warningMsg << "Layer of type " << GetLayerTypeAsCString(layer->GetType())
<< " is not supported on requested backend " << layer->GetBackendId().Get()
- << " for data type " << GetDataTypeName(dataType)
+ << " for input data type " << GetDataTypeName(dataTypeIn)
+ << " and output data type " << GetDataTypeName(dataTypeOut)
<< " (reason: " << reasonIfUnsupported
<< "), falling back to the next backend.";
ReportWarning(warningMsg.str(), errMessages);