ArmNN
 21.02
ConvertFp32NetworkToBf16Impl Class Reference

#include <ConvertFp32NetworkToBf16.hpp>

Public Member Functions

void Run (Graph &graph, Layer &layer) const
 

Protected Member Functions

 ConvertFp32NetworkToBf16Impl ()=default
 
 ~ConvertFp32NetworkToBf16Impl ()=default
 

Detailed Description

Definition at line 43 of file ConvertFp32NetworkToBf16.hpp.

Constructor & Destructor Documentation

◆ ConvertFp32NetworkToBf16Impl()

◆ ~ConvertFp32NetworkToBf16Impl()

Member Function Documentation

◆ Run()

void Run ( Graph graph,
Layer layer 
) const
inline

Definition at line 47 of file ConvertFp32NetworkToBf16.hpp.

References ConvertFp32NetworkToBf16Impl::ConvertFp32NetworkToBf16Impl(), armnn::Convolution2d, armnn::Float32, armnn::FullyConnected, Layer::GetDataType(), Layer::GetType(), armnn::InsertConvertFp32ToBf16LayersBefore(), and ConvertFp32NetworkToBf16Impl::~ConvertFp32NetworkToBf16Impl().

48  {
49  // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
50  // And also convert weight data type from Float32 to Bfloat16.
51  // Do not convert bias data type.
52  if (layer.GetType() == LayerType::Convolution2d)
53  {
54  if (layer.GetDataType() == DataType::Float32)
55  {
57  ConvertWeight<Convolution2dLayer>(&layer);
58  }
59  }
60  else if (layer.GetType() == LayerType::FullyConnected)
61  {
62  if (layer.GetDataType() == DataType::Float32)
63  {
65  ConvertWeight<FullyConnectedLayer>(&layer);
66  }
67  }
68  }
std::vector< ConvertFp32ToBf16Layer * > InsertConvertFp32ToBf16LayersBefore(Graph &graph, Layer &layer, bool expectCorrectInputType)

The documentation for this class was generated from the following file: