ArmNN
 21.02
AddBroadcastReshapeLayer.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
11 
12 namespace armnn
13 {
14 namespace optimizations
15 {
16 
17 static const std::set<armnn::LayerType> broadcastOps {
24 };
25 
27 {
28 public:
29  /// Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different.
30  void Run(Graph& graph, Layer& layer) const
31  {
32  if (std::find(broadcastOps.begin(), broadcastOps.end(), layer.GetType()) != broadcastOps.end())
33  {
36 
37  const TensorInfo &inputInfo0 = layer.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
38  const TensorInfo &inputInfo1 = layer.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
39 
40  if (inputInfo0.GetNumDimensions() == inputInfo1.GetNumDimensions())
41  {
42  return;
43  }
44 
45  unsigned int reshapeSlot = 1;
46  TensorInfo reshapeInfo = inputInfo1;
47  TensorInfo inputInfo = inputInfo0;
48 
49  if (inputInfo0.GetNumDimensions() < inputInfo1.GetNumDimensions())
50  {
51  reshapeSlot = 0;
52  reshapeInfo = inputInfo0;
53  inputInfo = inputInfo1;
54  }
55 
56  uint32_t numDimensions = inputInfo.GetNumDimensions();
57 
58  std::vector<unsigned> reshapedDim;
59  for (unsigned int i = 0; i < reshapeInfo.GetNumDimensions(); ++i)
60  {
61  reshapedDim.push_back(reshapeInfo.GetShape()[i]);
62  }
63 
64  std::vector<unsigned int> reshapedDimensions(numDimensions, 1);
65  std::copy_backward (reshapedDim.begin(), reshapedDim.end(), reshapedDimensions.end());
66 
67  reshapeInfo.SetShape(armnn::TensorShape{ numDimensions, reshapedDimensions.data() });
68  const std::string layerName = "Reshape_for:" + layer.GetNameStr() + "-" + std::to_string(reshapeSlot);
69  const ReshapeDescriptor descriptor{reshapeInfo.GetShape()};
70  ReshapeLayer *reshapeLayer = graph.InsertNewLayer<ReshapeLayer>(layer.GetInputSlot(reshapeSlot),
71  descriptor,
72  layerName.c_str());
73  reshapeLayer->GetOutputSlot().SetTensorInfo(reshapeInfo);
74  }
75  }
76 
77 protected:
78  AddBroadcastReshapeLayerImpl() = default;
80 };
81 
83 
84 } // namespace optimizations
85 } // namespace armnn
const TensorShape & GetShape() const
Definition: Tensor.hpp:187
A ReshapeDescriptor for the ReshapeLayer.
void Run(Graph &graph, Layer &layer) const
Run for every ElementwiseBaseLayer. Add Broadcast reshape layer if the inputs shape are different...
This layer represents a reshape operation.
Copyright (c) 2021 ARM Limited and Contributors.
void SetShape(const TensorShape &newShape)
Definition: Tensor.hpp:189
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:316
const std::string & GetNameStr() const
Definition: Layer.hpp:220
LayerType GetType() const override
Returns the armnn::LayerType of this layer.
Definition: Layer.hpp:265
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:55
void SetTensorInfo(const TensorInfo &tensorInfo) override
Definition: Layer.cpp:58
const OutputSlot & GetOutputSlot(unsigned int index=0) const override
Get the const output slot handle by slot index.
Definition: Layer.hpp:318
bool IsTensorInfoSet() const override
Definition: Layer.cpp:68
LayerT * InsertNewLayer(InputSlot &insertBefore, Args &&... args)
Inserts a new layer between the output slot currently connected to insertBefore and insertBefore itse...
Definition: Graph.hpp:416
const TensorInfo & GetTensorInfo() const override
Definition: Layer.cpp:63
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:191