39 template<
typename FactoryType>
41 const FactoryType& factory,
49 if (factory.SupportsSubTensors())
55 && ((concatAxis == numberOfDimensions - 1) || (concatAxis == numberOfDimensions - 2));
59 std::queue<ConcatLayer*> m_ConcatLayers;
61 m_ConcatLayers.push(
this);
62 while (!m_ConcatLayers.empty())
72 bool canUseSubTensorOnXorY =
true;
73 bool isTensorHandleFactory = std::is_same<armnn::ITensorHandleFactory, FactoryType>::value;
74 if (isTensorHandleFactory)
76 for (
unsigned int i = 0; i < numInputSlots; ++i)
80 std::vector<Capability> capabilities =
86 canUseSubTensorOnXorY =
false;
87 if (capabilities.empty())
89 canUseSubTensorOnXorY =
true;
96 && (PolymorphicDowncast<const Layer*>(currentLayer))->GetType() ==
LayerType::Concat)
98 canUseSubTensorOnXorY =
false;
101 if (!canUseSubTensorOnXorY)
109 std::vector<std::unique_ptr<ITensorHandle>> subTensors(0);
110 subTensors.reserve(numInputSlots);
111 for (
unsigned int i = 0; i < numInputSlots; ++i)
114 const TensorInfo&
info = slot->GetTensorInfo();
116 auto CreateSubTensor = [&]()
126 factoryId == slot->GetTensorHandleFactoryId() &&
129 slot->GetNumConnections() == 1 &&
130 canUseSubTensorOnXorY)
133 return factory.CreateSubTensorHandle(*parentTensor,
138 return std::unique_ptr<ITensorHandle>();
141 auto subTensor = CreateSubTensor();
148 subTensors.push_back(std::move(subTensor));
153 if (subTensors.size() < numInputSlots)
160 for (
auto& subTensor : subTensors)
165 ARMNN_ASSERT_MSG(subTensor,
"ConcatLayer: Expected a valid sub-tensor for substitution.");
166 outputHandler.SetData(std::move(subTensor));
168 Layer& inputLayer = slot->GetOwningLayer();
172 m_ConcatLayers.push(PolymorphicDowncast<ConcatLayer*>(&inputLayer));
182 const bool isMemoryManaged)
189 CreateTensors(registry, workloadFactory, isMemoryManaged);
195 CreateTensors(registry, *handleFactory, isMemoryManaged);
209 for (
unsigned int i=0; i< inputShapes.size(); i++)
211 auto& inputShape = inputShapes[i];
213 ConditionalThrowIfNotEqual<LayerValidationException>(
214 "ConcatLayer: Num Dimensions must match all inputs.",
216 inputShape.GetNumDimensions());
220 std::vector<unsigned int> extentMin(numDims);
221 std::vector<unsigned int> extentMax(numDims);
222 for (
unsigned int i = 0; i < inputShapes.size(); i++)
226 for (
unsigned int d = 0; d < numDims; d++)
228 extentMin[d] = std::min(extentMin[d], origin[d]);
229 extentMax[d] = std::max(extentMax[d], origin[d] + shape[d]);
234 if (!std::all_of(extentMin.begin(), extentMin.end(), [](
unsigned int s) { return s == 0; }))
242 for (
unsigned int a = 0; a < inputShapes.size(); a++)
246 for (
unsigned int b = 0; b < a; b++)
251 bool allAxesOverlap =
true;
252 for (
unsigned int d = 0; d < numDims && allAxesOverlap; d++)
254 unsigned int a1 = aOrigin[d];
255 unsigned int a2 = aOrigin[d] + aShape[d];
257 unsigned int b1 = bOrigin[d];
258 unsigned int b2 = bOrigin[d] + bShape[d];
260 if (a2 <= b1 || b2 <= a1)
262 allAxesOverlap =
false;
275 unsigned int totalViewsVolume = 0;
276 for (
unsigned int i = 0; i < inputShapes.size(); i++)
278 totalViewsVolume += inputShapes[i].GetNumElements();
280 unsigned int outputVolume = 1;
281 for (
unsigned int d = 0; d < numDims; d++)
283 outputVolume *= (extentMax[d] - extentMin[d]);
286 ConditionalThrowIfNotEqual<LayerValidationException>(
287 "ConcatLayer: there are some gaps between views",
291 return std::vector<TensorShape>({
TensorShape({numDims, extentMax.data()}) });
297 ConditionalThrowIfNotEqual<LayerValidationException>(
298 "ConcatLayer: Num Inputs must match num views.",
308 std::vector<TensorShape> inputShapes;