ArmNN
 24.02
SplitOperator.cpp File Reference
#include "SplitOperator.hpp"
Include dependency graph for SplitOperator.cpp:

Go to the source code of this file.

Functions

TosaSerializationBasicBlock * ConvertSplitToTosaOperator (const Layer *layer, const std::vector< const TensorInfo * > &inputs, const std::vector< const TensorInfo * > &outputs, const SplitterDescriptor *splitDescriptor)
 

Function Documentation

◆ ConvertSplitToTosaOperator()

TosaSerializationBasicBlock* ConvertSplitToTosaOperator ( const Layer layer,
const std::vector< const TensorInfo * > &  inputs,
const std::vector< const TensorInfo * > &  outputs,
const SplitterDescriptor splitDescriptor 
)

Definition at line 13 of file SplitOperator.cpp.

17 {
18  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE( inputs.size() == 1,
19  "ConvertSplitToTosaOperator: Split must have only one input" );
20 
21  ARMNN_THROW_INVALIDARG_MSG_IF_FALSE( outputs.size() >= 1,
22  "ConvertSplitToTosaOperator: Split must have at least one output" );
23 
24  if (!inputs[0]->GetShape().AreAllDimensionsSpecified())
25  {
26  throw armnn::Exception("ConvertSplitToTosaOperator: Dynamic input dimensions are unsupported.");
27  }
28 
29  std::string inputName = std::string("input0_");
30  std::vector<std::string> outputNames;
31  std::string blockName = std::string("Op_SPLIT_block_") + GetUniqueTosaMappingID();
32 
33  unsigned int numSplit = splitDescriptor->GetNumViews();
34  // If a layer is present then the block will be used for execution, so input and output names need to be determined
35  // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
36  if(layer != nullptr)
37  {
38  // Get the layers connected to the input slots and determine unique tensor names.
39  Layer& connectedLayer = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
40  inputName = GenerateUniqueName(connectedLayer, 0);
41 
42  for (unsigned int i=0; i < numSplit; ++i)
43  {
44  // Determine unique output(s) tensor name.
45  std::string outputName = GenerateUniqueOutputName(*layer, i);
46  outputNames.push_back(outputName);
47  }
48  }
49  else
50  {
51  for (unsigned int i=0; i < numSplit; ++i)
52  {
53  // Determine unique output(s) tensor name.
54  std::string outputName = "output" + std::to_string(i) + "_";
55  outputNames.push_back(outputName);
56  }
57  }
58 
59  // Each slice op has a different beginning point.
60  // The size is the same for each slice op.
61  std::vector<int32_t> beginVals;
62  beginVals.reserve(inputs[0]->GetNumDimensions());
63  std::vector<int32_t> sizeVals;
64  sizeVals.reserve(inputs[0]->GetNumDimensions());
65  for (unsigned int j = 0; j < inputs[0]->GetNumDimensions(); ++j)
66  {
67  beginVals.emplace_back(0);
68  uint32_t dim = inputs[0]->GetShape()[j];
69  sizeVals.emplace_back(dim);
70  }
71 
72  uint32_t axis = static_cast<uint32_t>(splitDescriptor->GetAxis());
73  sizeVals[axis] = sizeVals[axis] / static_cast<int32_t>(numSplit);
74 
75  std::vector<TosaSerializationOperator*> ops;
76  for (unsigned int i=0; i < numSplit; ++i)
77  {
78  beginVals[axis] = static_cast<int>(i) * sizeVals[axis];
79  TosaSliceAttribute attribute(beginVals, sizeVals);
80  auto* op = new TosaSerializationOperator(Op_SLICE,
81  Attribute_SliceAttribute,
82  &attribute,
83  {inputName},
84  {outputNames[i]});
85 
86  ops.push_back(op);
87  }
88 
89  std::vector<TosaSerializationTensor*> tensors;
90  // Only add input tensors if connected layer is an input layer.
91  // As intermediate or constant tensors will be created separately.
92  // There also can't be duplicate tensor.
93  if(inputName.find("input0_") != std::string::npos)
94  {
95  std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[0]->GetShape());
96  DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
97 
98  tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
99  }
100 
101  std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[0]->GetShape());
102  DType outputDType = ArmNNToDType(outputs[0]->GetDataType());
103 
104  for (unsigned int i=0; i < numSplit; ++i)
105  {
106  tensors.push_back(new TosaSerializationTensor(outputNames[i], outputShape, outputDType, {}));
107  }
108  // operatorInputNames/operatorOutputNames ends up being the same as
109  // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
110  return new TosaSerializationBasicBlock(blockName, // name
111  mainName, // region name
112  ops, // operators
113  tensors, // tensors
114  {inputName}, // inputs
115  outputNames); // outputs
116 }

References ARMNN_THROW_INVALIDARG_MSG_IF_FALSE.

Referenced by GetTosaMapping().

armnn::Layer::GetInputSlot
const InputSlot & GetInputSlot(unsigned int index) const override
Get a const input slot handle by slot index.
Definition: Layer.hpp:337
armnn::Layer
Definition: Layer.hpp:230
mainName
const std::string mainName
Definition: TosaOperatorUtils.hpp:19
armnn::OutputSlot::GetOwningLayer
Layer & GetOwningLayer() const
Definition: Layer.hpp:132
ArmNNToDType
DType ArmNNToDType(const DataType &type)
Definition: TosaOperatorUtils.hpp:22
armnn::ViewsDescriptor::GetAxis
int32_t GetAxis() const
Get the axis value.
Definition: Descriptors.cpp:381
GenerateUniqueOutputName
std::string GenerateUniqueOutputName(const Layer &layer, uint32_t layerSlot)
Definition: TosaOperatorUtils.hpp:82
armnn::Exception
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46
GenerateUniqueName
std::string GenerateUniqueName(const Layer &layer, uint32_t layerSlot)
Definition: TosaOperatorUtils.hpp:63
GetTosaTensorShape
std::vector< int32_t > GetTosaTensorShape(const TensorShape &shape)
Definition: TosaOperatorUtils.hpp:52
armnn::InputSlot::GetConnectedOutputSlot
const OutputSlot * GetConnectedOutputSlot() const
Definition: Layer.hpp:56
armnn::ViewsDescriptor::GetNumViews
uint32_t GetNumViews() const
Get the number of views.
Definition: Descriptors.cpp:301
GetUniqueTosaMappingID
std::string GetUniqueTosaMappingID()
Definition: TosaOperatorUtils.hpp:100
ARMNN_THROW_INVALIDARG_MSG_IF_FALSE
#define ARMNN_THROW_INVALIDARG_MSG_IF_FALSE(_cond, _str)
Definition: Exceptions.hpp:210