ArmNN
 22.11
FuseConvertFp32ToBf16IntoConstLayers Class Reference

#include <FuseConvertFp32ToBf16IntoConstLayers.hpp>

Public Member Functions

void Run (Graph &graph, InputSlot &connection) const
 

Protected Member Functions

 FuseConvertFp32ToBf16IntoConstLayers ()=default
 
 ~FuseConvertFp32ToBf16IntoConstLayers ()=default
 

Detailed Description

Definition at line 17 of file FuseConvertFp32ToBf16IntoConstLayers.hpp.

Constructor & Destructor Documentation

◆ FuseConvertFp32ToBf16IntoConstLayers()

◆ ~FuseConvertFp32ToBf16IntoConstLayers()

Member Function Documentation

◆ Run()

void Run ( Graph graph,
InputSlot connection 
) const
inline

Definition at line 20 of file FuseConvertFp32ToBf16IntoConstLayers.hpp.

References ARMNN_ASSERT, armnn::Constant, FloatingPointConverter::ConvertFloat32ToBFloat16(), armnn::ConvertFp32ToBf16, armnn::Float32, FuseConvertFp32ToBf16IntoConstLayers::FuseConvertFp32ToBf16IntoConstLayers(), InputSlot::GetConnectedOutputSlot(), Layer::GetDataType(), armnn::GetDataTypeName(), TensorInfo::GetNumElements(), Layer::GetOutputSlot(), InputSlot::GetOwningLayer(), OutputSlot::GetOwningLayer(), OutputSlot::GetTensorInfo(), Layer::GetType(), armnn::IgnoreUnused(), TensorInfo::IsConstant(), ConstantLayer::m_LayerOutput, OutputSlot::MoveAllConnections(), TensorInfo::SetConstant(), OutputSlot::SetTensorInfo(), and FuseConvertFp32ToBf16IntoConstLayers::~FuseConvertFp32ToBf16IntoConstLayers().

21  {
22  Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
23  Layer& child = connection.GetOwningLayer();
24 
25  ARMNN_ASSERT(base.GetType() == LayerType::Constant);
26  ARMNN_ASSERT(child.GetType() == LayerType::ConvertFp32ToBf16);
27 
28  auto dataType = base.GetDataType();
29  switch (dataType)
30  {
31  case DataType::Float32:
32  ReplaceConvertFp32ToBf16Layer<DataType::BFloat16>(
33  graph,
34  PolymorphicDowncast<ConstantLayer*>(&base),
35  PolymorphicDowncast<ConvertFp32ToBf16Layer*>(&child));
36  break;
37  default:
38  throw InvalidArgumentException(GetDataTypeName(dataType) +
39  std::string(" Constant Layer cannot be fused into ") +
40  GetDataTypeName(child.GetDataType()) +
41  std::string(" conversion layer."));
42  }
43  }
constexpr const char * GetDataTypeName(DataType dataType)
Definition: TypesUtils.hpp:202
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14

The documentation for this class was generated from the following file: