aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/SplitOperator.cpp')
-rw-r--r--src/backends/tosaCommon/operatorMappings/SplitOperator.cpp43
1 files changed, 19 insertions, 24 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp b/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
index f8c60b1b6d..53f4f052bb 100644
--- a/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
+++ b/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
// Copyright © 2020 The TensorFlow Authors. All Rights Reserved.
@@ -7,6 +7,7 @@
//
#include "SplitOperator.hpp"
+#include <backendsCommon/WorkloadUtils.hpp>
// This function is paraphrased from:
// tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc from function convertSplitOp
@@ -26,7 +27,7 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer,
throw armnn::Exception("ConvertSplitToTosaOperator: Dynamic input dimensions are unsupported.");
}
- std::string inputName = std::string("input0_");
+ std::string inputName = std::string("input_");
std::vector<std::string> outputNames;
std::string blockName = std::string("Op_SPLIT_block_") + GetUniqueTosaMappingID();
@@ -35,9 +36,7 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer,
// using the previous and following layers so the graph is connected correctly. For validation this doesn't matter.
if(layer != nullptr)
{
- // Get the layers connected to the input slots and determine unique tensor names.
- Layer& connectedLayer = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer();
- inputName = GenerateUniqueName(connectedLayer, 0);
+ inputName = GenerateUniqueInputName(layer->GetInputSlot(0));
for (unsigned int i=0; i < numSplit; ++i)
{
@@ -56,26 +55,19 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer,
}
}
- // Each slice op has a different beginning point.
- // The size is the same for each slice op.
- std::vector<int32_t> beginVals;
- beginVals.reserve(inputs[0]->GetNumDimensions());
- std::vector<int32_t> sizeVals;
- sizeVals.reserve(inputs[0]->GetNumDimensions());
- for (unsigned int j = 0; j < inputs[0]->GetNumDimensions(); ++j)
+ // Configure input and output tensors
+ std::set<unsigned int> splitAxis = ComputeSplitAxis(*splitDescriptor, inputs[0]->GetShape());
+ if (splitAxis.size() != 1)
{
- beginVals.emplace_back(0);
- uint32_t dim = inputs[0]->GetShape()[j];
- sizeVals.emplace_back(dim);
+ throw InvalidArgumentException("Cannot derive split axis from SplitterDescriptor");
}
-
- uint32_t axis = static_cast<uint32_t>(splitDescriptor->GetAxis());
- sizeVals[axis] = sizeVals[axis] / static_cast<int32_t>(numSplit);
+ uint32_t axis = *splitAxis.begin();
std::vector<TosaSerializationOperator*> ops;
- for (unsigned int i=0; i < numSplit; ++i)
+ std::vector<int32_t> beginVals(inputs[0]->GetNumDimensions(), 0);
+ for (unsigned int i = 0; i < numSplit; ++i)
{
- beginVals[axis] = static_cast<int>(i) * sizeVals[axis];
+ std::vector<int32_t> sizeVals = GetTosaTensorShape(outputs[i]->GetShape());
TosaSliceAttribute attribute(beginVals, sizeVals);
auto* op = new TosaSerializationOperator(Op_SLICE,
Attribute_SliceAttribute,
@@ -84,13 +76,16 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer,
{outputNames[i]});
ops.push_back(op);
+
+ // Update the axis begin value for the next split operation, to be the correct size axis value.
+ beginVals[axis] += sizeVals[axis];
}
std::vector<TosaSerializationTensor*> tensors;
// Only add input tensors if connected layer is an input layer.
// As intermediate or constant tensors will be created separately.
// There also can't be duplicate tensor.
- if(inputName.find("input0_") != std::string::npos)
+ if(inputName.find("input_") != std::string::npos)
{
std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[0]->GetShape());
DType inputDType = ArmNNToDType(inputs[0]->GetDataType());
@@ -98,13 +93,13 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer,
tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {}));
}
- std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[0]->GetShape());
DType outputDType = ArmNNToDType(outputs[0]->GetDataType());
-
- for (unsigned int i=0; i < numSplit; ++i)
+ for (unsigned int i = 0; i < numSplit; ++i)
{
+ std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[i]->GetShape());
tensors.push_back(new TosaSerializationTensor(outputNames[i], outputShape, outputDType, {}));
}
+
// operatorInputNames/operatorOutputNames ends up being the same as
// blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings
return new TosaSerializationBasicBlock(blockName, // name