ArmNN
 23.05
Concatenate.cpp
Go to the documentation of this file.
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Concatenate.hpp"
7 #include "RefWorkloadUtils.hpp"
8 #include "Decoders.hpp"
9 #include "Encoders.hpp"
10 
11 namespace armnn
12 {
13 
15  std::vector<ITensorHandle*> inputs,
16  std::vector<ITensorHandle*> outputs)
17 {
18  const TensorInfo& outputInfo0 = GetTensorInfo(outputs[0]);
19 
20  std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo0, outputs[0]->Map());
21  Encoder<float>& encoder = *encoderPtr;
22 
23  for (unsigned int index = 0 ; index < outputInfo0.GetNumElements(); ++index)
24  {
25  unsigned int indices[MaxNumOfTensorDimensions] = { 0 };
26 
27  unsigned int indexRemainder = index;
28  unsigned int dimensionStride = outputInfo0.GetNumElements();
29 
30  for (unsigned int i = 0; i < outputInfo0.GetNumDimensions(); i++)
31  {
32  dimensionStride /= outputInfo0.GetShape()[i];
33  indices[i] = indexRemainder / dimensionStride; // Use integer division to round down.
34  indexRemainder -= indices[i] * dimensionStride;
35  }
36 
37  for (unsigned int viewIdx = 0; viewIdx < data.m_ViewOrigins.size(); ++viewIdx)
38  {
39  ConcatQueueDescriptor::ViewOrigin const& view = data.m_ViewOrigins[viewIdx];
40 
41  //Split view extents are defined by the size of (the corresponding) input tensor.
42  const TensorInfo& inputInfo = GetTensorInfo(inputs[viewIdx]);
43  ARMNN_ASSERT(inputInfo.GetNumDimensions() == outputInfo0.GetNumDimensions());
44 
45  // Check all dimensions to see if this element is inside the given input view.
46  bool insideView = true;
47  for (unsigned int i = 0; i < inputInfo.GetNumDimensions(); i++)
48  {
49  if (indices[i] < view.m_Origin[i])
50  {
51  insideView = false;
52  }
53  if (indices[i] >= view.m_Origin[i] + inputInfo.GetShape()[i])
54  {
55  insideView = false;
56  }
57  }
58 
59  if (insideView)
60  {
61  std::unique_ptr<Decoder<float>> decoderPtr =
62  MakeDecoder<float>(inputInfo,inputs[viewIdx]->Map());
63  Decoder<float>& decoder = *decoderPtr;
64  unsigned int inIndex = 0;
65  unsigned int dimensionStride = 1;
66 
67  for (unsigned int i = inputInfo.GetNumDimensions(); i-- > 0;)
68  {
69  inIndex += dimensionStride * (indices[i] - view.m_Origin[i]);
70  dimensionStride *= inputInfo.GetShape()[i];
71  }
72  decoder += inIndex;
73  encoder.Set(decoder.Get());
74 
75  //What should we do if input views overlap on the output tensor?
76  //We could error, take the average, or shm else...
77  //For now just stop after finding first view (input) that matches.
78  break;
79  }
80  }
81  ++encoder;
82  }
83 }
84 
85 } //namespace armnn
armnn::GetTensorInfo
const TensorInfo & GetTensorInfo(const ITensorHandle *tensorHandle)
float32 helpers
Definition: RefWorkloadUtils.hpp:27
armnn::ConcatQueueDescriptor::ViewOrigin
Definition: WorkloadData.hpp:132
armnn::Concatenate
void Concatenate(const ConcatQueueDescriptor &data, std::vector< ITensorHandle * > inputs, std::vector< ITensorHandle * > outputs)
Definition: Concatenate.cpp:14
armnn::LayerType::Map
@ Map
armnn::ConcatQueueDescriptor::ViewOrigin::m_Origin
std::vector< unsigned int > m_Origin
Definition: WorkloadData.hpp:138
armnn::Encoder< float >
armnn
Copyright (c) 2021 ARM Limited and Contributors.
Definition: 01_00_quick_start.dox:6
armnn::TensorInfo::GetNumDimensions
unsigned int GetNumDimensions() const
Definition: Tensor.hpp:195
RefWorkloadUtils.hpp
armnn::TensorInfo::GetNumElements
unsigned int GetNumElements() const
Definition: Tensor.hpp:196
Encoders.hpp
armnn::Decoder< float >
armnn::ConcatQueueDescriptor
Definition: WorkloadData.hpp:130
armnn::TensorInfo
Definition: Tensor.hpp:152
armnn::TensorInfo::GetShape
const TensorShape & GetShape() const
Definition: Tensor.hpp:191
Concatenate.hpp
armnn::Encoder::Set
virtual void Set(IType right)=0
ARMNN_ASSERT
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
Decoders.hpp
armnn::Decoder::Get
virtual IType Get() const =0
armnn::MaxNumOfTensorDimensions
constexpr unsigned int MaxNumOfTensorDimensions
Definition: Types.hpp:31
armnn::ConcatQueueDescriptor::m_ViewOrigins
std::vector< ViewOrigin > m_ViewOrigins
Definition: WorkloadData.hpp:143