aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ConcatLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ConcatLayer.cpp')
-rw-r--r--src/armnn/layers/ConcatLayer.cpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/armnn/layers/ConcatLayer.cpp b/src/armnn/layers/ConcatLayer.cpp
index 69660dd04f..7a1b689b2c 100644
--- a/src/armnn/layers/ConcatLayer.cpp
+++ b/src/armnn/layers/ConcatLayer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "ConcatLayer.hpp"
@@ -104,14 +104,13 @@ void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
}
}
}
-
// First go through all the input slots and verify that we can sub-tensor all the inputs.
std::vector<std::unique_ptr<ITensorHandle>> subTensors(0);
subTensors.reserve(numInputSlots);
for (unsigned int i = 0; i < numInputSlots; ++i)
{
OutputSlot* slot = currentLayer->GetInputSlot(i).GetConnectedOutputSlot();
- const TensorInfo& info = slot->GetTensorInfo();
+ const TensorInfo& info = currentLayer->GetInputSlot(i).GetTensorInfo();
auto CreateSubTensor = [&]()
{
@@ -121,13 +120,15 @@ void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry,
// 3) the input does not come from a Constant layer or input layer
// 4) the input is only read by this concat layer
// 5) if concat along x or y (2 innermost dimensions) and the previous layers do not require padding
+ // 6) the input does not have an Overridden TensorInfo
if (slot &&
parentInfo.IsTypeSpaceMatch(info) && //(1)
factoryId == slot->GetTensorHandleFactoryId() && //(2)
slot->GetOwningLayer().GetType() != LayerType::Constant && //(3)
slot->GetOwningLayer().GetType() != LayerType::Input && //(3)
slot->GetNumConnections() == 1 &&
- canUseSubTensorOnXorY) //(5)
+ canUseSubTensorOnXorY && //(5)
+ !currentLayer->GetInputSlot(i).IsTensorInfoOverridden()) //(6)
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
return factory.CreateSubTensorHandle(*parentTensor,
@@ -308,7 +309,7 @@ void ConcatLayer::ValidateTensorShapesFromInputs()
std::vector<TensorShape> inputShapes;
for (unsigned int i = 0; i < GetNumInputSlots(); ++i)
{
- inputShapes.push_back(GetInputSlot(i).GetConnection()->GetTensorInfo().GetShape());
+ inputShapes.push_back(GetInputSlot(i).GetTensorInfo().GetShape());
}
auto inferredShapes = InferOutputShapes(inputShapes);