ArmNN
 21.05
AddBroadcastReshapeLayer.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
12 
13 namespace armnn
14 {
15 namespace optimizations
16 {
17 
18 static const std::set<armnn::LayerType> broadcastOps{ LayerType::Addition, LayerType::Division,
21 
23 {
24 public:
25  /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.
26  void Run(Graph& graph, Layer& layer) const
27  {
28  if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
29  {
32 
33  const TensorInfo& inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
34  const TensorInfo& inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
35 
36  if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
37  {
38  return;
39  }
40 
41  unsigned int reshapeSlot = 1;
42  TensorInfo reshapeInfo = inputInfo1;
43  TensorInfo inputInfo = inputInfo0;
44 
45  if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
46  {
47  reshapeSlot = 0;
48  reshapeInfo = inputInfo0;
49  inputInfo = inputInfo1;
50  }
51 
52  uint32_t numDimensions = inputInfo.GetNumDimensions();
53 
54  std::vector<unsigned> reshapedDim;
55  for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
56  {
57  reshapedDim.push_back(reshapeInfo.GetShape()[i]);
58  }
59 
60  std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
61  std::copy_backward(reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
62 
63  reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
64 
65  // If the parent layer is a Constant layer and it is only used once we can short circuit by just
66  // changing the tensor info rather than adding a reshape layer.
67  Layer& parentLayer = layer.GetInputSlot(reshapeSlot).GetConnectedOutputSlot()->GetOwningLayer();
68  if ((parentLayer.GetType() == armnn::LayerType::Constant) &&
69  (parentLayer.GetOutputSlot(0).GetNumConnections() == 1))
70  {
71  ConstantLayer& constantLayer = static_cast<ConstantLayer&>(parentLayer);
72 
73  constantLayer.m_LayerOutput = std::make_unique<ScopedTensorHandle>(
74  ConstTensor(reshapeInfo, constantLayer.m_LayerOutput.get()->GetConstTensor<void>()));
75  constantLayer.GetOutputSlot().SetTensorInfo(reshapeInfo);
76  }
77  else
78  {
79  const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
80  const ReshapeDescriptor descriptor{ reshapeInfo.GetShape() };
81  ReshapeLayer* reshapeLayer =
82  graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot), descriptor, layerName.c_str());
83  reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
84  }
85  }
86  }
87 
88 protected:
89  AddBroadcastReshapeLayerImpl() = default;
91 };
92 
94 
95 } // namespace optimizations
96 } // namespace armnn
A layer that the constant data can be bound to.
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
A ReshapeDescriptor for the ReshapeLayer.
std::shared_ptr< ConstTensorHandle > m_LayerOutput
void Run(Graph &graph, Layer &layer) const
Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different...
Layer & GetOwningLayer() const
Definition: Layer.hpp:115
This layer represents a reshape operation.
Copyright (c) 2021 ARM Limited and Contributors.
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:189
unsigned int GetNumConnections() const override
Definition: Layer.hpp:138
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:316
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:314
const std::string & GetNameStr() const
Definition: Layer.hpp:220
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:265
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:55
void SetTensorInfo(const TensorInfo &tensorInfo) override
Definition: Layer.cpp:58
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:318
bool IsTensorInfoSet() const override
Definition: Layer.cpp:68
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:416
const TensorInfo & GetTensorInfo() const override
Definition: Layer.cpp:63
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:191