diff options
Diffstat (limited to 'src/armnn/layers')
-rw-r--r-- | src/armnn/layers/ConcatLayer.cpp | 3 | ||||
-rw-r--r-- | src/armnn/layers/ReverseV2Layer.cpp | 2 | ||||
-rw-r--r-- | src/armnn/layers/SplitterLayer.cpp | 5 |
3 files changed, 6 insertions, 4 deletions
diff --git a/src/armnn/layers/ConcatLayer.cpp b/src/armnn/layers/ConcatLayer.cpp index 7a1b689b2c..4629bf245e 100644 --- a/src/armnn/layers/ConcatLayer.cpp +++ b/src/armnn/layers/ConcatLayer.cpp @@ -120,7 +120,7 @@ void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, // 3) the input does not come from a Constant layer or input layer // 4) the input is only read by this concat layer // 5) if concat along x or y (2 innermost dimensions) and the previous layers do not require padding - // 6) the input does not have an Overridden TensorInfo + // 6) neither the inputs nor the output have an Overridden TensorInfo if (slot && parentInfo.IsTypeSpaceMatch(info) && //(1) factoryId == slot->GetTensorHandleFactoryId() && //(2) @@ -128,6 +128,7 @@ void ConcatLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, slot->GetOwningLayer().GetType() != LayerType::Input && //(3) slot->GetNumConnections() == 1 && canUseSubTensorOnXorY && //(5) + !GetOutputSlot(0).GetConnection(0)->IsTensorInfoOverridden() && //(6) !currentLayer->GetInputSlot(i).IsTensorInfoOverridden()) //(6) { ARMNN_NO_DEPRECATE_WARN_BEGIN diff --git a/src/armnn/layers/ReverseV2Layer.cpp b/src/armnn/layers/ReverseV2Layer.cpp index 29f8b1b781..201e19819b 100644 --- a/src/armnn/layers/ReverseV2Layer.cpp +++ b/src/armnn/layers/ReverseV2Layer.cpp @@ -40,7 +40,7 @@ void ReverseV2Layer::ValidateTensorShapesFromInputs() VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod); auto inferredShapes = InferOutputShapes({ - GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape() }); + GetInputSlot(0).GetTensorInfo().GetShape() }); ARMNN_ASSERT(inferredShapes.size() == 1); diff --git a/src/armnn/layers/SplitterLayer.cpp b/src/armnn/layers/SplitterLayer.cpp index 86a42305ff..dc8864a736 100644 --- a/src/armnn/layers/SplitterLayer.cpp +++ b/src/armnn/layers/SplitterLayer.cpp @@ -131,13 +131,14 @@ void SplitterLayer::CreateTensors(const TensorHandleFactoryRegistry& registry, // 2) the same TensorHandleFactory is used for input and split layer output // 3) the output does not go to a Constant layer or input layer // 4) if split along x or y (2 innermost dimensions) and the next layers do not require padding - // 5) none of the outputs have an Overridden TensorInfo + // 5) neither the input nor the outputs have an Overridden TensorInfo if (parentInfo.IsTypeSpaceMatch(info) && //(1) factoryId == slot->GetTensorHandleFactoryId() && //(2) GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Constant && //(3) GetOutputSlot(i).GetConnection(0)->GetOwningLayer().GetType() != LayerType::Input && //(3) canUseSubTensorOnXorY && //(4) - !GetOutputSlot(i).GetConnection(0)->IsTensorInfoOverridden()) //(5) + !GetOutputSlot(i).GetConnection(0)->IsTensorInfoOverridden() && //(5) + !GetInputSlot(0).IsTensorInfoOverridden()) //(5) { ARMNN_NO_DEPRECATE_WARN_BEGIN return factory.CreateSubTensorHandle(*inputData, |