aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-13 15:16:28 +0000
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-11-15 16:54:47 +0000
commit87972be8d838f6fde6f6e98dd81c422e85457a5e (patch)
tree78e8a9abfefc6db67f9a71f6c1fddb0444daac5f /src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
parent5716de25c6981d004e32b81dc65b4869eda25f7c (diff)
downloadarmnn-87972be8d838f6fde6f6e98dd81c422e85457a5e.tar.gz
IVGCVSW-4119 Fix FP16 to FP32 fallback mechanism in optimizer to work with Dequantize
* Check for output data type as well as input data type when determining whether we should attempt to fall back to FP32 if FP16 is not supported * Override output type for Dequantize in IsLayerSupported() instead of input type * Updated original input type from FP16 to FP32 in InsertConvertFp32ToFp16LayersAfter() Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Ic6477fd17cea5a91bd8bf9ae0cf836520897d5b7
Diffstat (limited to 'src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp')
-rw-r--r--src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
index 729b76ad6b..9658a35560 100644
--- a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
+++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
@@ -15,7 +15,6 @@ namespace optimizations
class ConvertFp32NetworkToFp16Impl
{
public:
-
void Run(Graph& graph, Layer& layer) const
{
if(layer.GetType() == LayerType::Input)
@@ -33,7 +32,9 @@ public:
// add a ConvertFloat16ToFloat32 layer before each of the inputs
if (layer.GetDataType() == DataType::Float32)
{
- InsertConvertFp16ToFp32LayersBefore(graph, layer);
+ // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false
+ // here, otherwise it will expect the inputs to be DataType::Float16
+ InsertConvertFp16ToFp32LayersBefore(graph, layer, false);
}
}
else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32)