diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-13 15:16:28 +0000 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-11-15 16:54:47 +0000 |
commit | 87972be8d838f6fde6f6e98dd81c422e85457a5e (patch) | |
tree | 78e8a9abfefc6db67f9a71f6c1fddb0444daac5f /src/armnn/Network.cpp | |
parent | 5716de25c6981d004e32b81dc65b4869eda25f7c (diff) | |
download | armnn-87972be8d838f6fde6f6e98dd81c422e85457a5e.tar.gz |
IVGCVSW-4119 Fix FP16 to FP32 fallback mechanism in optimizer to work with Dequantize
* Check for output data type as well as input data type when determining
whether we should attempt to fall back to FP32 if FP16 is not supported
* Override output type for Dequantize in IsLayerSupported() instead of
input type
* Updated original input type from FP16 to FP32 in InsertConvertFp32ToFp16LayersAfter()
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ic6477fd17cea5a91bd8bf9ae0cf836520897d5b7
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r-- | src/armnn/Network.cpp | 32 |
1 files changed, 22 insertions, 10 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp index 573f6a19e8..1797baf78e 100644 --- a/src/armnn/Network.cpp +++ b/src/armnn/Network.cpp @@ -71,8 +71,6 @@ Status OptimizedNetwork::SerializeToDot(std::ostream& stream) const return m_Graph->SerializeToDot(stream); } - - void ReportError(const std::string& errorMessage, Optional<std::vector<std::string>&> errorMessages) { @@ -166,7 +164,12 @@ OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr, for (auto it = firstLayer; it != lastLayer; ++it) { auto layer = *it; - DataType dataType = layer->GetDataType(); + + DataType dataTypeIn = layer->GetNumInputSlots() == 0 ? DataType::Float32 : + layer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo().GetDataType(); + DataType dataTypeOut = layer->GetNumOutputSlots() == 0 ? DataType::Float32 : + layer->GetOutputSlot(0).GetTensorInfo().GetDataType(); + std::string reasonIfUnsupported; bool found = false; if (!CheckScaleSetOnQuantizedType(layer, errMessages)) @@ -181,21 +184,29 @@ OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr, // need to set the compute device on the layer // before we can check if it is supported layer->SetBackendId(backend); - if (!IWorkloadFactory::IsLayerSupported(*layer, dataType, reasonIfUnsupported)) + if (!IWorkloadFactory::IsLayerSupported(*layer, EmptyOptional(), reasonIfUnsupported)) { - if (dataType == DataType::Float16) + if (dataTypeIn == DataType::Float16 || dataTypeOut == DataType::Float16) { if (IWorkloadFactory::IsLayerSupported(*layer, DataType::Float32, reasonIfUnsupported) && layer->GetType() != LayerType::ConvertFp32ToFp16 && layer->GetType() != LayerType::ConvertFp16ToFp32) { // Insert FP16 -> FP32 conversion layer before current layer - std::vector<ConvertFp16ToFp32Layer*> convertFp16ToFp32Layers = - InsertConvertFp16ToFp32LayersBefore(optNetObjPtr->GetGraph(), *layer); + std::vector<ConvertFp16ToFp32Layer*> convertFp16ToFp32Layers; + if (dataTypeIn == DataType::Float16) + { + convertFp16ToFp32Layers = + InsertConvertFp16ToFp32LayersBefore(optNetObjPtr->GetGraph(), *layer); + } // Insert FP32 -> FP16 conversion layer after current layer - std::vector<ConvertFp32ToFp16Layer*> convertFp32ToFp16Layers = - InsertConvertFp32ToFp16LayersAfter(optNetObjPtr->GetGraph(), *layer); + std::vector<ConvertFp32ToFp16Layer*> convertFp32ToFp16Layers; + if (dataTypeOut == DataType::Float16) + { + convertFp32ToFp16Layers = + InsertConvertFp32ToFp16LayersAfter(optNetObjPtr->GetGraph(), *layer); + } // Assign a supported backend to the newly introduced conversion layers auto AssignFirstSupportedBackend = [&](Layer* layer, BackendId preferredBackend) @@ -258,7 +269,8 @@ OptimizationResult AssignBackends(OptimizedNetwork* optNetObjPtr, std::stringstream warningMsg; warningMsg << "Layer of type " << GetLayerTypeAsCString(layer->GetType()) << " is not supported on requested backend " << layer->GetBackendId().Get() - << " for data type " << GetDataTypeName(dataType) + << " for input data type " << GetDataTypeName(dataTypeIn) + << " and output data type " << GetDataTypeName(dataTypeOut) << " (reason: " << reasonIfUnsupported << "), falling back to the next backend."; ReportWarning(warningMsg.str(), errMessages); |