aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKeith Davis <keith.davis@arm.com>2020-09-03 13:17:21 +0100
committerKeith Davis <keith.davis@arm.com>2020-09-08 18:43:21 +0100
commit6e4081f67e35427212bd3505180c9abb1ac52b23 (patch)
tree87d64c3472c5f6786393c0c601461ef8860cce94
parent34db1872566a1737fd94305d0b3f3e7741d99b60 (diff)
downloadandroid-nn-driver-6e4081f67e35427212bd3505180c9abb1ac52b23.tar.gz
IVGCVSW-5270 Update ConvertConcatenation function to use ShapeInferenceMethod
Signed-off-by: Keith Davis <keith.davis@arm.com> Change-Id: I13e16d271ba55217b98a439aa82931f809fdeeb8
-rw-r--r--ConversionUtils.hpp185
1 files changed, 121 insertions, 64 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index fe8e026e..fa67f791 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1968,7 +1968,7 @@ template<typename HalPolicy,
typename HalModel = typename HalPolicy::Model>
bool ConvertConcatenation(const HalOperation& operation, const HalModel& model, ConversionData& data)
{
- using HalOperand = typename HalPolicy::Operand;
+ using HalOperand = typename HalPolicy::Operand;
using HalOperandType = typename HalPolicy::OperandType;
// The first N (0..N-1) inputs are tensors. The Nth input is the concatenation axis.
@@ -1992,9 +1992,9 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
return Fail("%s: Operation has no outputs", __func__);
}
- armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*outputOperand);
- armnn::TensorShape outputShape = outputInfo.GetShape();
-
+ armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*outputOperand);
+ armnn::TensorShape outputShape = outputInfo.GetShape();
+ const bool isDynamicTensor = IsDynamicTensor(outputInfo);
//
// handle negative concat dims along the lines of tensorflow as described here:
// https://www.tensorflow.org/api_docs/python/tf/concat
@@ -2016,9 +2016,8 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
inputHandles.reserve(numInputTensors);
inputShapes.reserve(numInputTensors);
- bool inputsHaveBeenReshaped = false;
- unsigned int tensorDimensionsAdded = 0;
-
+ bool inputsHaveBeenReshaped = false;
+ unsigned int tensorDimensionsAdded = 0;
for (uint32_t i = 0; i < numInputTensors; ++i)
{
const HalOperand* operand = GetInputOperand<HalPolicy>(operation, i, model);
@@ -2033,7 +2032,7 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
return Fail("%s: Operation has invalid inputs", __func__);
}
- armnn::TensorShape operandShape = GetTensorShapeForOperand(*operand);
+ armnn::TensorShape operandShape = GetTensorShapeForOperand(*operand);
if (operandShape.GetNumDimensions() == 0)
{
return Fail("%s: Operands with rank 0 are not supported", __func__);
@@ -2068,19 +2067,15 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
operandInputHandle.GetTensorInfo(),
reshapeInfo,
reshapeDescriptor);
+
if (!isSupported)
{
return false;
}
-
- armnn::IConnectableLayer& newReshape = AddReshapeLayer(
- *data.m_Network,
- operandInputHandle,
- reshapeInfo
- );
+ armnn::IConnectableLayer& newReshape = AddReshapeLayer(*data.m_Network, operandInputHandle, reshapeInfo);
// Point to the reshape operation rather then the input operation
- operandShape = reshapeInfo.GetShape();
+ operandShape = reshapeInfo.GetShape();
operandInputHandle = LayerInputHandle(true, &newReshape.GetOutputSlot(0), reshapeInfo);
}
@@ -2103,29 +2098,47 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
// Add extra dimensions to the output shape to reflect the addition of the reshape layers
if (tensorDimensionsAdded == 1)
{
- outputShape = armnn::TensorShape({1, outputShape[0], outputShape[1]});
+ if (IsDynamicTensor(outputInfo))
+ {
+ outputShape = armnn::TensorShape({1, 0, 0}, {true, false, false});
+ }
+ else
+ {
+ outputShape = armnn::TensorShape({1, outputShape[0], outputShape[1]});
+ }
}
else if (tensorDimensionsAdded == 2)
{
- outputShape = armnn::TensorShape({1, 1, outputShape[0]});
+ if (IsDynamicTensor(outputInfo))
+ {
+ outputShape = armnn::TensorShape({1, 1, 0}, {true, true, false});
+ }
+ else
+ {
+ outputShape = armnn::TensorShape({1, 1, outputShape[0]});
+ }
}
}
// Check if permutations is required and get the pair of permutations required for the concatenation.
// Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor.
std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair =
- std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
+ std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
- bool needPermute =
- CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+ bool needPermute = CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(),
+ concatDim,
+ permutationPair);
- if (needPermute)
+ // Only relevant to static tensors as dynamic output tensors will be transposed as a result of inferring from input
+ if (!isDynamicTensor)
{
- outputShape = armnnUtils::TransposeTensorShape(outputShape, permutationPair.first);
- }
-
- outputInfo.SetShape(outputShape);
+ if (needPermute)
+ {
+ outputShape = armnnUtils::TransposeTensorShape(outputShape, permutationPair.first);
+ }
+ outputInfo.SetShape(outputShape);
+ }
// this is no-op for identity swizzles, otherwise it replaces both
// the handles and shapes with the swizzled layer output handles and shapes
if (!TransposeInputTensors(data, inputHandles, inputShapes, permutationPair.first))
@@ -2140,33 +2153,43 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
{
// The concat descriptor is always created across the only supported concat dimension
// which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
- concatDescriptor =
- armnn::CreateDescriptorForConcatenation(inputShapes.begin(), inputShapes.end(), concatDim);
- }
- catch (std::exception& error)
+ concatDescriptor = armnn::CreateDescriptorForConcatenation(inputShapes.begin(),
+ inputShapes.end(),
+ concatDim);
+ } catch (std::exception& error)
{
return Fail("%s: Error preparing concat descriptor. %s", __func__, error.what());
}
// Validate the output shape is correct given the input shapes based on the
// only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
- if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
+ if (!isDynamicTensor)
{
- return Fail("%s: Error validating the output shape for concat", __func__);
+ if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
+ {
+ return Fail("%s: Error validating the output shape for concat", __func__);
+ }
}
std::vector<const armnn::TensorInfo*> inputTensorInfos;
std::transform(inputHandles.begin(), inputHandles.end(), std::back_inserter(inputTensorInfos),
- [](const LayerInputHandle& h) -> const armnn::TensorInfo*{ return &h.GetTensorInfo(); });
+ [](const LayerInputHandle& h)->const armnn::TensorInfo*{ return &h.GetTensorInfo(); });
+
+ bool isSupported = false;
+ auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported){
+ FORWARD_LAYER_SUPPORT_FUNC(__func__, IsConcatSupported, data.m_Backends, isSupported, inputTensorInfos,
+ outputInfo, concatDescriptor);
+ };
+
+ if (!isDynamicTensor)
+ {
+ validateFunc(outputInfo, isSupported);
+ }
+ else
+ {
+ isSupported = AreDynamicTensorsSupported();
+ }
- bool isSupported = false;
- FORWARD_LAYER_SUPPORT_FUNC(__func__,
- IsConcatSupported,
- data.m_Backends,
- isSupported,
- inputTensorInfos,
- outputInfo,
- concatDescriptor);
if (!isSupported)
{
return false;
@@ -2175,7 +2198,6 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
armnn::IConnectableLayer* layer = data.m_Network->AddConcatLayer(concatDescriptor);
assert(layer != nullptr);
layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
-
// Connect inputs to the layer
const int numInputSlots = layer->GetNumInputSlots();
assert(static_cast<std::size_t>(numInputSlots) == inputHandles.size());
@@ -2185,15 +2207,14 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i));
}
- if (needPermute)
- {
+ // Transpose the output shape
+ auto transposeOutputShape = [&](){
armnn::TransposeDescriptor transposeDesc;
transposeDesc.m_DimMappings = permutationPair.second;
armnn::TensorInfo inputTransposeInfo = layer->GetOutputSlot(0).GetTensorInfo();
armnn::TensorInfo outputTransposeInfo = armnnUtils::TransposeTensorShape(inputTransposeInfo,
permutationPair.second);
-
- bool isSupported = false;
+ isSupported = false;
FORWARD_LAYER_SUPPORT_FUNC(__func__,
IsTransposeSupported,
data.m_Backends,
@@ -2201,56 +2222,92 @@ bool ConvertConcatenation(const HalOperation& operation, const HalModel& model,
inputTransposeInfo,
outputTransposeInfo,
transposeDesc);
+
if (!isSupported)
{
return false;
}
// Add permutation layer and connect the output to it, the permutation becomes the output layer
- armnn::IConnectableLayer& deswizzleLayer = AddTransposeLayer(*data.m_Network,
- layer->GetOutputSlot(0),
+ armnn::IConnectableLayer& deswizzleLayer = AddTransposeLayer(*data.m_Network, layer->GetOutputSlot(0),
permutationPair.second);
layer = &deswizzleLayer;
+
+ return true;
+ };
+
+ if (needPermute && !isDynamicTensor)
+ {
+ transposeOutputShape();
}
if (inputsHaveBeenReshaped)
{
+ if (isDynamicTensor)
+ {
+ // Infer the output shapes of concat if outputs are type 1 dynamic
+ layer->GetOutputSlot(0).IsTensorInfoSet();
+ if (!ValidateConcatOutputShape(inputShapes,
+ layer->GetOutputSlot(0).GetTensorInfo().GetShape(),
+ concatDim))
+ {
+ return Fail("%s: Error validating the output shape for concat", __func__);
+ }
+ transposeOutputShape();
+ }
+
armnn::TensorInfo afterConcatInfo = layer->GetOutputSlot(0).GetTensorInfo();
// Undo the reshape knowing the amount of dimensions added
if (tensorDimensionsAdded == 1)
{
- afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[1],
- afterConcatInfo.GetShape()[2] }));
+ afterConcatInfo.SetShape(
+ armnn::TensorShape({afterConcatInfo.GetShape()[1], afterConcatInfo.GetShape()[2]}));
}
else if (tensorDimensionsAdded == 2)
{
- afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] }));
+ afterConcatInfo.SetShape(armnn::TensorShape({afterConcatInfo.GetShape()[2]}));
}
armnn::ReshapeDescriptor reshapeDescriptor;
reshapeDescriptor.m_TargetShape = afterConcatInfo.GetShape();
+ armnn::TensorInfo concatInfo = layer->GetOutputSlot(0).GetTensorInfo();
+
+ isSupported = false;
+ auto validateReshapeFunc = [&](const armnn::TensorInfo& afterConcatInfo, bool& isSupported){
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsReshapeSupported,
+ data.m_Backends,
+ isSupported,
+ concatInfo,
+ afterConcatInfo,
+ reshapeDescriptor);
+ };
+
+ if (!IsDynamicTensor(afterConcatInfo))
+ {
+ validateReshapeFunc(afterConcatInfo, isSupported);
+ }
+ else
+ {
+ isSupported = AreDynamicTensorsSupported();
+ }
- bool isSupported = false;
- FORWARD_LAYER_SUPPORT_FUNC(__func__,
- IsReshapeSupported,
- data.m_Backends,
- isSupported,
- layer->GetOutputSlot(0).GetTensorInfo(),
- afterConcatInfo,
- reshapeDescriptor);
if (!isSupported)
{
return false;
}
- layer = &AddReshapeLayer(
- *data.m_Network,
- layer->GetOutputSlot(0),
- afterConcatInfo
- );
+ layer = &AddReshapeLayer(*data.m_Network, layer->GetOutputSlot(0), afterConcatInfo);
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation,
+ 0,
+ *layer,
+ model,
+ data,
+ nullptr,
+ validateReshapeFunc);
}
- return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data, nullptr, validateFunc);
}
template<typename HalPolicy,