aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2019-07-31 17:25:43 +0100
committerMike Kelly <mike.kelly@arm.com>2019-07-31 17:25:43 +0100
commitb8805204fb0ea64079735921bcc0cc2b1aedfcf6 (patch)
treea3cf2901ca6c2b582c9a078b631cccddc9c95453 /ConversionUtils.hpp
parentd74c50550f797e9a5df0e379b5b49d9bd3b29bbd (diff)
downloadandroid-nn-driver-b8805204fb0ea64079735921bcc0cc2b1aedfcf6.tar.gz
IVGCVSW-3601 Fix skipped VTS Concatenate Tests
* Fixed Skipped VTS Concatenate Tests. Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I29e7dcdedefc0e9c54f86fa5de23aa714c469585
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp235
1 files changed, 235 insertions, 0 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 0349999d..2fa8a072 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1412,6 +1412,241 @@ bool ConvertPooling2d(const HalOperation& operation,
}
template<typename HalPolicy,
+ typename Operation = typename HalPolicy::Operation,
+ typename Model = typename HalPolicy::Model>
+bool ConvertConcatenation(const Operation& operation, const Model& model, ConversionData& data)
+{
+ using HalOperand = typename HalPolicy::Operand;
+ using HalOperandType = typename HalPolicy::OperandType;
+
+ // The first N (0..N-1) inputs are tensors. The Nth input is the concatenation axis.
+ if (operation.inputs.size() <= 1)
+ {
+ return Fail("%s: Operation has insufficient arguments", __func__);
+ }
+
+ // Get inputs and outputs
+ const std::size_t numInputTensors = operation.inputs.size() - 1;
+
+ int32_t concatDim;
+ if (!GetInputScalar<HalPolicy>(operation, numInputTensors, HalOperandType::INT32, concatDim, model, data))
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, 0, model);
+ if (!outputOperand)
+ {
+ return Fail("%s: Operation has no outputs", __func__);
+ }
+
+
+ armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*outputOperand);
+ armnn::TensorShape outputShape = outputInfo.GetShape();
+
+ //
+ // handle negative concat dims along the lines of tensorflow as described here:
+ // https://www.tensorflow.org/api_docs/python/tf/concat
+ // "negative axis refers to axis + rank(values)-th dimension"
+ //
+ if (concatDim < 0)
+ {
+ concatDim += outputShape.GetNumDimensions();
+ }
+
+ if (concatDim >= static_cast<int32_t>(outputShape.GetNumDimensions()) || concatDim < 0)
+ {
+ return Fail("%s: Operation has invalid concat axis: %d", __func__, concatDim);
+ }
+
+ std::vector<LayerInputHandle> inputHandles;
+ std::vector<armnn::TensorShape> inputShapes;
+
+ inputHandles.reserve(numInputTensors);
+ inputShapes.reserve(numInputTensors);
+
+ bool inputsHaveBeenReshaped = false;
+ unsigned int tensorDimensionsAdded = 0;
+
+ for (uint32_t i = 0; i < numInputTensors; ++i)
+ {
+ const HalOperand* operand = GetInputOperand<HalPolicy>(operation, i, model);
+ if (!operand)
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+
+ armnn::TensorShape operandShape = GetTensorShapeForOperand(*operand);
+ LayerInputHandle operandInputHandle =
+ ConvertToLayerInputHandle<HalPolicy>(operation, i, model, data);
+
+ if (operandShape.GetNumDimensions() == 0)
+ {
+ return Fail("%s: Operands with rank 0 are not supported", __func__);
+ }
+
+ if (RequiresReshape(operandShape))
+ {
+ inputsHaveBeenReshaped = true;
+
+ armnn::TensorInfo reshapeInfo = operandInputHandle.GetTensorInfo();
+
+ // Expand the tensor to three dimensions
+ if (operandShape.GetNumDimensions() == 2)
+ {
+ reshapeInfo.SetShape(armnn::TensorShape({1, operandShape[0], operandShape[1]}));
+ tensorDimensionsAdded = 1;
+ }
+ else
+ {
+ reshapeInfo.SetShape(armnn::TensorShape({1, 1, operandShape[0]}));
+ tensorDimensionsAdded = 2;
+ }
+
+ armnn::IConnectableLayer& newReshape = AddReshapeLayer(
+ *data.m_Network,
+ operandInputHandle,
+ reshapeInfo
+ );
+
+ // Point to the reshape operation rather then the input operation
+ operandShape = reshapeInfo.GetShape();
+ operandInputHandle = LayerInputHandle(true, &newReshape.GetOutputSlot(0), reshapeInfo);
+ }
+
+ inputShapes.emplace_back(operandShape);
+ inputHandles.emplace_back(operandInputHandle);
+
+ if (!inputHandles.back().IsValid())
+ {
+ return Fail("%s: Operation has invalid inputs", __func__);
+ }
+ }
+
+ BOOST_ASSERT(inputShapes.size() == inputHandles.size());
+
+ if (inputsHaveBeenReshaped)
+ {
+ // Adjust the concatenation dimension by the amount of dimensions added (if any)
+ concatDim += tensorDimensionsAdded;
+
+ // Add extra dimensions to the output shape to reflect the addition of the reshape layers
+ if (tensorDimensionsAdded == 1)
+ {
+ outputShape = armnn::TensorShape({1, outputShape[0], outputShape[1]});
+ }
+ else if (tensorDimensionsAdded == 2)
+ {
+ outputShape = armnn::TensorShape({1, 1, outputShape[0]});
+ }
+ }
+
+ // Check if permutations is required and get the pair of permutations required for the concatenation.
+ // Permutation is required when the concat dimension is 2 for a 4D tensor or 1 for a 3D tensor.
+ std::pair<armnn::PermutationVector, armnn::PermutationVector> permutationPair =
+ std::make_pair(IdentityPermutation4D, IdentityPermutation4D);
+
+ bool needPermute =
+ CreateConcatPermutationParameters(inputShapes[0].GetNumDimensions(), concatDim, permutationPair);
+
+ if (needPermute)
+ {
+ outputShape = armnnUtils::Permuted(outputShape, permutationPair.first);
+ }
+
+ outputInfo.SetShape(outputShape);
+
+ // this is no-op for identity swizzles, otherwise it replaces both
+ // the handles and shapes with the swizzled layer output handles and shapes
+ SwizzleInputs(*data.m_Network, inputHandles, inputShapes, permutationPair.first);
+
+ // Create an armnn concat layer descriptor - this will also perform validation on the input shapes
+ armnn::OriginsDescriptor concatDescriptor;
+
+ try
+ {
+ // The concat descriptor is always created across the only supported concat dimension
+ // which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
+ concatDescriptor =
+ armnn::CreateDescriptorForConcatenation(inputShapes.begin(), inputShapes.end(), concatDim);
+ }
+ catch (const armnn::Exception& error)
+ {
+ return Fail("%s: Error preparing concat descriptor. %s", __func__, error.what());
+ }
+
+ // Validate the output shape is correct given the input shapes based on the
+ // only valid concat dimension which is 0, 1 or 3 for a 4-D tensor, or 0 or 2 for a 3-D tensor.
+ if (!ValidateConcatOutputShape(inputShapes, outputShape, concatDim))
+ {
+ return Fail("%s: Error validating the output shape for concat", __func__);
+ }
+
+ std::vector<const armnn::TensorInfo*> inputTensorInfos;
+ std::transform(inputHandles.begin(), inputHandles.end(), std::back_inserter(inputTensorInfos),
+ [](const LayerInputHandle& h) -> const armnn::TensorInfo*{ return &h.GetTensorInfo(); });
+
+ bool isSupported = false;
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsConcatSupported,
+ data.m_Backends,
+ isSupported,
+ inputTensorInfos,
+ outputInfo,
+ concatDescriptor);
+ if (!isSupported)
+ {
+ return false;
+ }
+
+ armnn::IConnectableLayer* layer = data.m_Network->AddConcatLayer(concatDescriptor);
+ assert(layer != nullptr);
+ layer->GetOutputSlot(0).SetTensorInfo(outputInfo);
+
+ // Connect inputs to the layer
+ const int numInputSlots = layer->GetNumInputSlots();
+ assert(static_cast<std::size_t>(numInputSlots) == inputHandles.size());
+ for (int i = 0; i < numInputSlots; ++i)
+ {
+ // connect the input directly to the merge (concat) layer
+ inputHandles[static_cast<unsigned int>(i)].Connect(layer->GetInputSlot(i));
+ }
+
+ if (needPermute)
+ {
+ // Add permutation layer and connect the output to it, the permutation becomes the output layer
+ armnn::IConnectableLayer& deswizzleLayer = AddPermuteLayer(*data.m_Network,
+ layer->GetOutputSlot(0),
+ permutationPair.second);
+ layer = &deswizzleLayer;
+ }
+
+ if (inputsHaveBeenReshaped)
+ {
+ armnn::TensorInfo afterConcatInfo = layer->GetOutputSlot(0).GetTensorInfo();
+
+ // Undo the reshape knowing the amount of dimensions added
+ if (tensorDimensionsAdded == 1)
+ {
+ afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[1],
+ afterConcatInfo.GetShape()[2] }));
+ }
+ else if (tensorDimensionsAdded == 2)
+ {
+ afterConcatInfo.SetShape(armnn::TensorShape({ afterConcatInfo.GetShape()[2] }));
+ }
+
+ layer = &AddReshapeLayer(
+ *data.m_Network,
+ layer->GetOutputSlot(0),
+ afterConcatInfo
+ );
+ }
+
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data);
+}
+
+template<typename HalPolicy,
typename HalOperation = typename HalPolicy::Operation,
typename HalModel = typename HalPolicy::Model>
bool ConvertConv2d(const HalOperation& operation, const HalModel& model, ConversionData& data)