aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp')
-rw-r--r--src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp5
1 files changed, 3 insertions, 2 deletions
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
index 729b76ad6b..9658a35560 100644
--- a/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
+++ b/src/armnn/optimizations/ConvertFp32NetworkToFp16.hpp
@@ -15,7 +15,6 @@ namespace optimizations
class ConvertFp32NetworkToFp16Impl
{
public:
-
void Run(Graph& graph, Layer& layer) const
{
if(layer.GetType() == LayerType::Input)
@@ -33,7 +32,9 @@ public:
// add a ConvertFloat16ToFloat32 layer before each of the inputs
if (layer.GetDataType() == DataType::Float32)
{
- InsertConvertFp16ToFp32LayersBefore(graph, layer);
+ // NOTE: We need to call InsertConvertFp16ToFp32LayersBefore with expectCorrectInputType = false
+ // here, otherwise it will expect the inputs to be DataType::Float16
+ InsertConvertFp16ToFp32LayersBefore(graph, layer, false);
}
}
else if (layer.GetType() != LayerType::ConvertFp32ToFp16 && layer.GetType() != LayerType::ConvertFp16ToFp32)