aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-26 09:20:43 +0000
committerNarumol Prangnawarat <narumol.prangnawarat@arm.com>2020-03-26 16:16:55 +0000
commit57ef0088d20dd708ff92222d244ea02f1e1e5216 (patch)
treeae11f55f6bac939a51d5182eae441d322efb3e0e
parent9272f8b9050096f39796227c5d89ed7b9905146d (diff)
downloadarmnn-57ef0088d20dd708ff92222d244ea02f1e1e5216.tar.gz
IVGCVSW-4597 Modify BF16 optimizer to Convert only inputs and weights of
Convolution2d and FullyConnected layers * Add InsertConvertFp32ToBf16LayersBefore * Add ConvertWeight to ConvertFp32NetworkToBf16Impl for Conv2d and FullyConnected * Allow different input and output when input is BF16 and output is FP32 Conv2d and FullyConnected layers * Unit tests Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com> Change-Id: Ic8f92ff28edcae08a72a3114a28f50c4619f919b
-rw-r--r--src/armnn/Network.cpp3
-rw-r--r--src/armnn/NetworkUtils.cpp39
-rw-r--r--src/armnn/NetworkUtils.hpp4
-rw-r--r--src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp78
-rw-r--r--src/armnn/test/optimizations/Fp32NetworkToBf16ConverterTests.cpp148
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp33
-rw-r--r--src/backends/reference/RefLayerSupport.cpp32
7 files changed, 282 insertions, 55 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 5f7719730b..0272b3da65 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1020,10 +1020,11 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
}
// If Fp32 to Bf16 optimization is set convert Fp32 network to Bf16
+ // Convert input of Convolution2d and FullyConnected from Fp32 to Bf16
+ // Only Constant weight of Convolution2d and FullyConnected are converted from Fp32 to Bf16
if (options.m_ReduceFp32ToBf16)
{
Optimizer::Pass(optGraph, MakeOptimizations(Fp32NetworkToBf16Converter()));
- Optimizer::Pass(optGraph, MakeOptimizations(ConvertConstantsFloatToBFloat()));
}
// Initialize backend settings
diff --git a/src/armnn/NetworkUtils.cpp b/src/armnn/NetworkUtils.cpp
index 8653a08510..0549a115d4 100644
--- a/src/armnn/NetworkUtils.cpp
+++ b/src/armnn/NetworkUtils.cpp
@@ -87,6 +87,45 @@ std::vector<ConvertBf16ToFp32Layer*> InsertConvertBf16ToFp32LayersBefore(Graph&
return convertLayers;
}
+std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersBefore(Graph& graph,
+ Layer& layer,
+ bool expectCorrectInputType)
+{
+ std::vector<ConvertFp32ToBf16Layer*> convertLayers;
+ convertLayers.reserve(layer.GetNumInputSlots());
+
+ // Insert a ConvertFp32ToBf16Layer before each input slot
+ for (auto&& inputSlot = layer.BeginInputSlots(); inputSlot != layer.EndInputSlots(); ++inputSlot)
+ {
+ bool allowInsert = true;
+ if (expectCorrectInputType)
+ {
+ // Only insert ConvertFp32ToBf16Layer before FP32 input slots
+ OutputSlot* connectedOutputSlot = inputSlot->GetConnectedOutputSlot();
+ allowInsert =
+ connectedOutputSlot && connectedOutputSlot->GetTensorInfo().GetDataType() == DataType::Float32;
+ }
+
+ if (allowInsert)
+ {
+ const std::string name =
+ std::string("convert_fp32_to_bf16-" + std::to_string(inputSlot->GetSlotIndex()) + "-") +
+ layer.GetName();
+ ConvertFp32ToBf16Layer* convertLayer =
+ graph.InsertNewLayer<ConvertFp32ToBf16Layer>(*inputSlot, name.c_str());
+
+ TensorInfo convertInfo = convertLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ convertInfo.SetDataType(DataType::BFloat16);
+
+ convertLayer->GetOutputSlot().SetTensorInfo(convertInfo);
+
+ convertLayers.emplace_back(convertLayer);
+ }
+ }
+
+ return convertLayers;
+}
+
std::vector<ConvertFp16ToFp32Layer*> InsertConvertFp16ToFp32LayersBefore(Graph& graph,
Layer& layer,
bool expectCorrectInputType)
diff --git a/src/armnn/NetworkUtils.hpp b/src/armnn/NetworkUtils.hpp
index 064545aac5..a922770285 100644
--- a/src/armnn/NetworkUtils.hpp
+++ b/src/armnn/NetworkUtils.hpp
@@ -15,6 +15,10 @@ std::vector<ConvertBf16ToFp32Layer*> InsertConvertBf16ToFp32LayersBefore(Graph&
Layer& layer,
bool expectCorrectInputType = true);
+std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersBefore(Graph& graph,
+ Layer& layer,
+ bool expectCorrectInputType = true);
+
std::vector<ConvertFp32ToBf16Layer*> InsertConvertFp32ToBf16LayersAfter(Graph& graph, Layer& layer);
std::vector<ConvertFp16ToFp32Layer*> InsertConvertFp16ToFp32LayersBefore(Graph& graph,
diff --git a/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp b/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp
index d6350c3af3..222414c8c5 100644
--- a/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp
+++ b/src/armnn/optimizations/ConvertFp32NetworkToBf16.hpp
@@ -4,68 +4,62 @@
//
#pragma once
-#include "Optimization.hpp"
#include "NetworkUtils.hpp"
+#include "Optimization.hpp"
namespace armnn
{
namespace optimizations
{
+template <typename LayerT>
+inline LayerT* ConvertWeight(Layer* l)
+{
+ LayerT* layer = boost::polymorphic_downcast<LayerT*>(l);
+ if ((layer->GetType() == LayerType::Convolution2d || layer->GetType() == LayerType::FullyConnected)
+ && layer->m_Weight)
+ {
+ const TensorInfo& info = layer->m_Weight->GetTensorInfo();
+
+ if (info.GetDataType() == DataType::Float32)
+ {
+ std::vector<BFloat16> newValues(info.GetNumElements());
+
+ armnnUtils::FloatingPointConverter::ConvertFloat32ToBFloat16(layer->m_Weight->template GetTensor<float>(),
+ info.GetNumElements(),
+ newValues.data());
+
+ TensorInfo newInfo(info.GetShape(), DataType::BFloat16);
+ ConstTensor newInput(newInfo, newValues);
+ layer->m_Weight.reset(new ScopedCpuTensorHandle(newInput));
+ }
+ }
+ return layer;
+}
+
class ConvertFp32NetworkToBf16Impl
{
public:
+
void Run(Graph& graph, Layer& layer) const
{
- if(layer.GetType() == LayerType::Input)
+ // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
+ // And also convert weight data type from Float32 to Bfloat16.
+ // Do not convert bias data type.
+ if (layer.GetType() == LayerType::Convolution2d)
{
- // if the outputs of this layer are DataType::Float32
- // add a ConvertFloat32ToBFloat16 layer after each of the outputs
if (layer.GetDataType() == DataType::Float32)
{
- InsertConvertFp32ToBf16LayersAfter(graph, layer);
+ InsertConvertFp32ToBf16LayersBefore(graph,layer);
+ ConvertWeight<Convolution2dLayer>(&layer);
}
}
- else if (layer.GetType() == LayerType::Output)
+ else if (layer.GetType() == LayerType::FullyConnected)
{
- // if the inputs of this layer are DataType::Float32
- // add a ConvertBFloat16ToFloat32 layer before each of the inputs
if (layer.GetDataType() == DataType::Float32)
{
- // NOTE: We need to call InsertConvertBf16ToFp32LayersBefore with expectCorrectInputType = false
- // here, otherwise it will expect the inputs to be DataType::BFloat16
- InsertConvertBf16ToFp32LayersBefore(graph, layer, false);
- }
- }
- else if (layer.GetType() != LayerType::ConvertFp32ToBf16 && layer.GetType() != LayerType::ConvertBf16ToFp32)
- {
- // if the inputs/outputs of this layer are DataType::Float32
- // change the data type for all inputs and outputs to DataType::BFloat16
- for (auto&& input = layer.BeginInputSlots(); input != layer.EndInputSlots(); ++input)
- {
- // if it is connected to OutputSlot of the InputLayer do not change the DataType of connection
- // InputSlots of the current layer will be updated when conversion layer is inserted after InputLayer
- Layer& base = input->GetConnectedOutputSlot()->GetOwningLayer();
- if (base.GetType() != LayerType::Input)
- {
- TensorInfo convertInfo = input->GetConnection()->GetTensorInfo();
- if (convertInfo.GetDataType() == DataType::Float32)
- {
- convertInfo.SetDataType(DataType::BFloat16);
- input->GetConnection()->SetTensorInfo(convertInfo);
- }
- }
- }
-
- // change outputs to DataType::BFloat16
- for (auto&& output = layer.BeginOutputSlots(); output != layer.EndOutputSlots(); ++output)
- {
- TensorInfo convertInfo = output->GetTensorInfo();
- if (convertInfo.GetDataType() == DataType::Float32)
- {
- convertInfo.SetDataType(DataType::BFloat16);
- output->SetTensorInfo(convertInfo);
- }
+ InsertConvertFp32ToBf16LayersBefore(graph,layer);
+ ConvertWeight<FullyConnectedLayer>(&layer);
}
}
}
diff --git a/src/armnn/test/optimizations/Fp32NetworkToBf16ConverterTests.cpp b/src/armnn/test/optimizations/Fp32NetworkToBf16ConverterTests.cpp
index 90a15487ac..b35f983434 100644
--- a/src/armnn/test/optimizations/Fp32NetworkToBf16ConverterTests.cpp
+++ b/src/armnn/test/optimizations/Fp32NetworkToBf16ConverterTests.cpp
@@ -12,13 +12,13 @@
BOOST_AUTO_TEST_SUITE(Optimizer)
using namespace armnn::optimizations;
-BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationTest)
+BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationNoConversionTest)
{
armnn::Graph graph;
const armnn::TensorInfo infoFP32({ 2, 2, 1, 3 }, armnn::DataType::Float32);
- // Create the simple test network
+ // Create the simple test network without Conv2D/FullyConnected.
auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
input->GetOutputSlot().SetTensorInfo(infoFP32);
@@ -38,8 +38,148 @@ BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationTest)
armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToBf16Converter()));
BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
- &IsLayerOfType<armnn::ConvertFp32ToBf16Layer>, &IsLayerOfType<armnn::FloorLayer>,
- &IsLayerOfType<armnn::ConvertBf16ToFp32Layer>, &IsLayerOfType<armnn::OutputLayer>));
+ &IsLayerOfType<armnn::FloorLayer>,
+ &IsLayerOfType<armnn::OutputLayer>));
+}
+
+BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationConv2DTest)
+{
+ armnn::Graph graph;
+
+ const armnn::TensorInfo infoFP32({ 2, 3, 8, 1 }, armnn::DataType::Float32);
+
+ // Create const tensor fp32 data
+ unsigned int dims[] = { 4, 2, 1, 1 };
+ std::vector<float> floatWeights{ 0.0f, -1.0f,
+ 3.8f, // 0x40733333 Round down
+ 3.1055E+29f, // 0x707ADC3C Round up
+ 9.149516E-10f, // 0x307B7FFF Round down
+ -3.8f, // 0xC0733333 Round down
+ -3.1055E+29f, // 0xF07ADC3C Round up
+ -9.149516E-10f // 0xB07B7FFF Round down
+ };
+ armnn::ConstTensor weights(armnn::TensorInfo(4, dims, armnn::DataType::Float32), floatWeights);
+
+ // Create const bias fp32 data
+ unsigned int biasDims[] {4};
+ std::vector<float> floatBias{ 1.0f, 2.0f, 3.0f, 4.0f };
+ armnn::ConstTensor bias(armnn::TensorInfo(1, biasDims, armnn::DataType::Float32), floatBias);
+
+ // A network with Convolution2d layer
+ auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
+ input->GetOutputSlot().SetTensorInfo(infoFP32);
+
+ armnn::Convolution2dDescriptor descriptor;
+
+ auto conv = graph.AddLayer<armnn::Convolution2dLayer>(descriptor, "conv2d");
+ conv->m_Weight = std::make_unique<armnn::ScopedCpuTensorHandle>(weights);
+ conv->m_Bias = std::make_unique<armnn::ScopedCpuTensorHandle>(bias);
+ conv->GetOutputSlot().SetTensorInfo(infoFP32);
+
+ auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
+
+ // Connect up the layers
+ input->GetOutputSlot().Connect(conv->GetInputSlot(0));
+ conv->GetOutputSlot().Connect(output->GetInputSlot(0));
+
+ BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+ &IsLayerOfType<armnn::Convolution2dLayer>, &IsLayerOfType<armnn::OutputLayer>));
+
+ // Run the optimizer
+ armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToBf16Converter()));
+
+ BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+ &IsLayerOfType<armnn::ConvertFp32ToBf16Layer>, &IsLayerOfType<armnn::Convolution2dLayer>,
+ &IsLayerOfType<armnn::OutputLayer>));
+
+ armnn::TensorInfo inputTensor = conv->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ armnn::TensorInfo outputTensor = conv->GetOutputSlot(0).GetTensorInfo();
+ BOOST_TEST((conv->GetDataType() == armnn::DataType::BFloat16));
+ BOOST_TEST((conv->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::BFloat16));
+ BOOST_TEST((conv->m_Bias->GetTensorInfo().GetDataType() == armnn::DataType::Float32));
+ BOOST_TEST((inputTensor.GetDataType() == armnn::DataType::BFloat16));
+ BOOST_TEST((outputTensor.GetDataType() == armnn::DataType::Float32));
+
+ // Check whether data matches expected Bf16 data
+ armnn::BFloat16* data = conv->m_Weight->GetTensor<armnn::BFloat16>();
+ BOOST_CHECK(data[0] == armnn::BFloat16(0.0f));
+ BOOST_CHECK(data[1] == armnn::BFloat16(-1.0f));
+ BOOST_CHECK(data[2] == armnn::BFloat16(3.796875f)); // 0x4073
+ BOOST_CHECK(data[3] == armnn::BFloat16(3.1072295E29f)); // 0x707B
+ BOOST_CHECK(data[4] == armnn::BFloat16(9.131327E-10f)); // 0x307B
+ BOOST_CHECK(data[5] == armnn::BFloat16(-3.796875f)); // 0xC073
+ BOOST_CHECK(data[6] == armnn::BFloat16(-3.1072295E29f)); // 0xF07B
+ BOOST_CHECK(data[7] == armnn::BFloat16(-9.131327E-10f)); // 0xB07B
+}
+
+BOOST_AUTO_TEST_CASE(Fp32NetworkToBf16OptimizationFullyConnectedTest)
+{
+ armnn::Graph graph;
+
+ const armnn::TensorInfo infoFP32({ 2, 3, 8, 1 }, armnn::DataType::Float32);
+
+ // Create const tensor fp32 data
+ unsigned int dims[] = { 4, 2, 1, 1 };
+ std::vector<float> floatWeights{ 0.0f, -1.0f,
+ 3.8f, // 0x40733333 Round down
+ 3.1055E+29f, // 0x707ADC3C Round up
+ 9.149516E-10f, // 0x307B7FFF Round down
+ -3.8f, // 0xC0733333 Round down
+ -3.1055E+29f, // 0xF07ADC3C Round up
+ -9.149516E-10f // 0xB07B7FFF Round down
+ };
+ armnn::ConstTensor weights(armnn::TensorInfo(4, dims, armnn::DataType::Float32), floatWeights);
+
+ // Create const bias fp32 data
+ unsigned int biasDims[] {4};
+ std::vector<float> floatBias{ 1.0f, 2.0f, 3.0f, 4.0f };
+ armnn::ConstTensor bias(armnn::TensorInfo(1, biasDims, armnn::DataType::Float32), floatBias);
+
+ // A network with FullyConnected layer
+ auto input = graph.AddLayer<armnn::InputLayer>(0, "input");
+ input->GetOutputSlot().SetTensorInfo(infoFP32);
+
+ armnn::FullyConnectedDescriptor descriptor;
+
+ auto fc = graph.AddLayer<armnn::FullyConnectedLayer>(descriptor, "fully");
+ fc->m_Weight = std::make_unique<armnn::ScopedCpuTensorHandle>(weights);
+ fc->m_Bias = std::make_unique<armnn::ScopedCpuTensorHandle>(bias);
+ fc->GetOutputSlot().SetTensorInfo(infoFP32);
+
+ auto output = graph.AddLayer<armnn::OutputLayer>(1, "output");
+
+ // Connect up the layers
+ input->GetOutputSlot().Connect(fc->GetInputSlot(0));
+ fc->GetOutputSlot().Connect(output->GetInputSlot(0));
+
+ BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+ &IsLayerOfType<armnn::FullyConnectedLayer>, &IsLayerOfType<armnn::OutputLayer>));
+
+ // Run the optimizer
+ armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(Fp32NetworkToBf16Converter()));
+
+ BOOST_TEST(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>,
+ &IsLayerOfType<armnn::ConvertFp32ToBf16Layer>, &IsLayerOfType<armnn::FullyConnectedLayer>,
+ &IsLayerOfType<armnn::OutputLayer>));
+
+ armnn::TensorInfo inputTensor = fc->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ armnn::TensorInfo outputTensor = fc->GetOutputSlot(0).GetTensorInfo();
+ BOOST_TEST((fc->GetDataType() == armnn::DataType::BFloat16));
+ BOOST_TEST((fc->m_Weight->GetTensorInfo().GetDataType() == armnn::DataType::BFloat16));
+ BOOST_TEST((fc->m_Bias->GetTensorInfo().GetDataType() == armnn::DataType::Float32));
+ BOOST_TEST((inputTensor.GetDataType() == armnn::DataType::BFloat16));
+ BOOST_TEST((outputTensor.GetDataType() == armnn::DataType::Float32));
+
+ // Check whether data matches expected Bf16 data
+ armnn::BFloat16* data = fc->m_Weight->GetTensor<armnn::BFloat16>();
+ BOOST_CHECK(data[0] == armnn::BFloat16(0.0f));
+ BOOST_CHECK(data[1] == armnn::BFloat16(-1.0f));
+ BOOST_CHECK(data[2] == armnn::BFloat16(3.796875f)); // 0x4073
+ BOOST_CHECK(data[3] == armnn::BFloat16(3.1072295E29f)); // 0x707B
+ BOOST_CHECK(data[4] == armnn::BFloat16(9.131327E-10f)); // 0x307B
+ BOOST_CHECK(data[5] == armnn::BFloat16(-3.796875f)); // 0xC073
+ BOOST_CHECK(data[6] == armnn::BFloat16(-3.1072295E29f)); // 0xF07B
+ BOOST_CHECK(data[7] == armnn::BFloat16(-9.131327E-10f)); // 0xB07B
}
BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 85c074a500..f968ad78f7 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -26,10 +26,9 @@ DataType GetBiasDataType(DataType inputDataType)
{
switch (inputDataType)
{
- case DataType::BFloat16:
- return DataType::BFloat16;
case DataType::Float16:
return DataType::Float16;
+ case DataType::BFloat16:
case DataType::Float32:
return DataType::Float32;
case DataType::QAsymmS8:
@@ -1009,7 +1008,20 @@ void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
+ if (inputTensorInfo.GetDataType() == DataType::BFloat16)
+ {
+ if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
+ {
+ throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
+ "for BFloat16 input.");
+ }
+ }
+ else
+ {
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+ }
}
void NormalizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
@@ -1206,7 +1218,20 @@ void Convolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
+ if (inputTensorInfo.GetDataType() == DataType::BFloat16)
+ {
+ if (outputTensorInfo.GetDataType() != DataType::BFloat16 && outputTensorInfo.GetDataType() != DataType::Float32)
+ {
+ throw InvalidArgumentException(descriptorName + ": " + " Output tensor type must be BFloat16 or Float32 "
+ "for BFloat16 input.");
+ }
+ }
+ else
+ {
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+ }
}
void DepthwiseConvolution2dQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 551a7b5867..7b25a436e9 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -474,8 +474,20 @@ bool RefLayerSupport::IsConvolution2dSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
"Reference Convolution2d: output is not a supported type.");
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ // For Convolution2d, we allow to have BFloat16 input with Float32 output for optimization.
+ if (input.GetDataType() == DataType::BFloat16)
+ {
+ if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
+ {
+ reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
+ supported = false;
+ }
+ }
+ else
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
"Reference Convolution2d: input and output types mismatched.");
+ }
const DataType inputType = input.GetDataType();
if (IsQuantized8BitType(inputType))
@@ -882,12 +894,24 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
"Reference Fully Connected: output type not supported.");
- supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
- "Reference Fully Connected: input and output types mismatched.");
-
supported &= CheckSupportRule(TypeAnyOf(weights, supportedTypes), reasonIfUnsupported,
"Reference Fully Connected: weights type not supported.");
+ // For FullyConnected, we allow to have BFloat16 input with Float32 output for optimization.
+ if (input.GetDataType() == DataType::BFloat16)
+ {
+ if (output.GetDataType() != DataType::BFloat16 && output.GetDataType() != DataType::Float32)
+ {
+ reasonIfUnsupported.value() += "Output tensor type must be BFloat16 or Float32 for BFloat16 input.\n";
+ supported = false;
+ }
+ }
+ else
+ {
+ supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+ "Reference Fully Connected: input and output types mismatched.");
+ }
+
ARMNN_NO_DEPRECATE_WARN_BEGIN
std::array<DataType, 3> supportedWeightTypes =
{