ArmNN
 21.02
AddBroadcastReshapeLayerImpl Class Reference

#include <AddBroadcastReshapeLayer.hpp>

Public Member Functions

void Run (Graph &graph, Layer &layer) const
 Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different. More...
 

Protected Member Functions

 AddBroadcastReshapeLayerImpl ()=default
 
 ~AddBroadcastReshapeLayerImpl ()=default
 

Detailed Description

Definition at line 26 of file AddBroadcastReshapeLayer.hpp.

Constructor & Destructor Documentation

◆ AddBroadcastReshapeLayerImpl()

◆ ~AddBroadcastReshapeLayerImpl()

Member Function Documentation

◆ Run()

void Run ( Graph graph,
Layer layer 
) const
inline

Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.

Definition at line 30 of file AddBroadcastReshapeLayer.hpp.

References AddBroadcastReshapeLayerImpl::AddBroadcastReshapeLayerImpl(), InputSlot::GetConnectedOutputSlot(), Layer::GetInputSlot(), Layer::GetNameStr(), TensorInfo::GetNumDimensions(), Layer::GetOutputSlot(), TensorInfo::GetShape(), OutputSlot::GetTensorInfo(), Layer::GetType(), Graph::InsertNewLayer(), OutputSlot::IsTensorInfoSet(), TensorInfo::SetShape(), OutputSlot::SetTensorInfo(), and AddBroadcastReshapeLayerImpl::~AddBroadcastReshapeLayerImpl().

31  {
32  if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
33  {
34  layer.GetInputSlot(0).GetConnectedOutputSlot()->IsTensorInfoSet();
35  layer.GetInputSlot(1).GetConnectedOutputSlot()->IsTensorInfoSet();
36 
37  const TensorInfo &inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
38  const TensorInfo &inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
39 
40  if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
41  {
42  return;
43  }
44 
45  unsigned int reshapeSlot = 1;
46  TensorInfo reshapeInfo = inputInfo1;
47  TensorInfo inputInfo = inputInfo0;
48 
49  if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
50  {
51  reshapeSlot = 0;
52  reshapeInfo = inputInfo0;
53  inputInfo = inputInfo1;
54  }
55 
56  uint32_t numDimensions = inputInfo.GetNumDimensions();
57 
58  std::vector<unsigned> reshapedDim;
59  for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
60  {
61  reshapedDim.push_back(reshapeInfo.GetShape()[i]);
62  }
63 
64  std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
65  std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
66 
67  reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
68  const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
69  const ReshapeDescriptor descriptor{reshapeInfo.GetShape()};
70  ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot),
71  descriptor,
72  layerName.c_str());
73  reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
74  }
75  }

The documentation for this class was generated from the following file: