ArmNN
 22.05
ConvertFp32NetworkToBf16Impl Class Reference

#include <ConvertFp32NetworkToBf16.hpp>

Public Member Functions

void Run (Graph &graph, Layer &layer) const
 

Protected Member Functions

 ConvertFp32NetworkToBf16Impl ()=default
 
 ~ConvertFp32NetworkToBf16Impl ()=default
 

Detailed Description

Definition at line 44 of file ConvertFp32NetworkToBf16.hpp.

Constructor & Destructor Documentation

◆ ConvertFp32NetworkToBf16Impl()

◆ ~ConvertFp32NetworkToBf16Impl()

Member Function Documentation

◆ Run()

void Run ( Graph graph,
Layer layer 
) const
inline

Definition at line 48 of file ConvertFp32NetworkToBf16.hpp.

References ConvertFp32NetworkToBf16Impl::ConvertFp32NetworkToBf16Impl(), armnn::Convolution2d, armnn::Float32, armnn::FullyConnected, Layer::GetDataType(), Layer::GetType(), armnn::InsertConvertFp32ToBf16LayersBefore(), and ConvertFp32NetworkToBf16Impl::~ConvertFp32NetworkToBf16Impl().

49  {
50  // Only convert Float32 To BFloat16 for the Input of Convolution2d layer and FullyConnected layer.
51  // And also convert weight data type from Float32 to Bfloat16.
52  // Do not convert bias data type.
53  if (layer.GetType() == LayerType::Convolution2d)
54  {
55  if (layer.GetDataType() == DataType::Float32)
56  {
58  ConvertWeight<Convolution2dLayer>(&layer);
59  }
60  }
61  else if (layer.GetType() == LayerType::FullyConnected)
62  {
63  if (layer.GetDataType() == DataType::Float32)
64  {
66  ConvertWeight<FullyConnectedLayer>(&layer);
67  }
68  }
69  }
std::vector< ConvertFp32ToBf16Layer * > InsertConvertFp32ToBf16LayersBefore(Graph &graph, Layer &layer, bool expectCorrectInputType)

The documentation for this class was generated from the following file: