diff options
Diffstat (limited to 'src/backends/tosaCommon/operatorMappings/SplitOperator.cpp')
-rw-r--r-- | src/backends/tosaCommon/operatorMappings/SplitOperator.cpp | 43 |
1 files changed, 19 insertions, 24 deletions
diff --git a/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp b/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp index f8c60b1b6d..53f4f052bb 100644 --- a/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp +++ b/src/backends/tosaCommon/operatorMappings/SplitOperator.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // // Copyright © 2020 The TensorFlow Authors. All Rights Reserved. @@ -7,6 +7,7 @@ // #include "SplitOperator.hpp" +#include <backendsCommon/WorkloadUtils.hpp> // This function is paraphrased from: // tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc from function convertSplitOp @@ -26,7 +27,7 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer, throw armnn::Exception("ConvertSplitToTosaOperator: Dynamic input dimensions are unsupported."); } - std::string inputName = std::string("input0_"); + std::string inputName = std::string("input_"); std::vector<std::string> outputNames; std::string blockName = std::string("Op_SPLIT_block_") + GetUniqueTosaMappingID(); @@ -35,9 +36,7 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer, // using the previous and following layers so the graph is connected correctly. For validation this doesn't matter. if(layer != nullptr) { - // Get the layers connected to the input slots and determine unique tensor names. - Layer& connectedLayer = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer(); - inputName = GenerateUniqueName(connectedLayer, 0); + inputName = GenerateUniqueInputName(layer->GetInputSlot(0)); for (unsigned int i=0; i < numSplit; ++i) { @@ -56,26 +55,19 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer, } } - // Each slice op has a different beginning point. - // The size is the same for each slice op. - std::vector<int32_t> beginVals; - beginVals.reserve(inputs[0]->GetNumDimensions()); - std::vector<int32_t> sizeVals; - sizeVals.reserve(inputs[0]->GetNumDimensions()); - for (unsigned int j = 0; j < inputs[0]->GetNumDimensions(); ++j) + // Configure input and output tensors + std::set<unsigned int> splitAxis = ComputeSplitAxis(*splitDescriptor, inputs[0]->GetShape()); + if (splitAxis.size() != 1) { - beginVals.emplace_back(0); - uint32_t dim = inputs[0]->GetShape()[j]; - sizeVals.emplace_back(dim); + throw InvalidArgumentException("Cannot derive split axis from SplitterDescriptor"); } - - uint32_t axis = static_cast<uint32_t>(splitDescriptor->GetAxis()); - sizeVals[axis] = sizeVals[axis] / static_cast<int32_t>(numSplit); + uint32_t axis = *splitAxis.begin(); std::vector<TosaSerializationOperator*> ops; - for (unsigned int i=0; i < numSplit; ++i) + std::vector<int32_t> beginVals(inputs[0]->GetNumDimensions(), 0); + for (unsigned int i = 0; i < numSplit; ++i) { - beginVals[axis] = static_cast<int>(i) * sizeVals[axis]; + std::vector<int32_t> sizeVals = GetTosaTensorShape(outputs[i]->GetShape()); TosaSliceAttribute attribute(beginVals, sizeVals); auto* op = new TosaSerializationOperator(Op_SLICE, Attribute_SliceAttribute, @@ -84,13 +76,16 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer, {outputNames[i]}); ops.push_back(op); + + // Update the axis begin value for the next split operation, to be the correct size axis value. + beginVals[axis] += sizeVals[axis]; } std::vector<TosaSerializationTensor*> tensors; // Only add input tensors if connected layer is an input layer. // As intermediate or constant tensors will be created separately. // There also can't be duplicate tensor. - if(inputName.find("input0_") != std::string::npos) + if(inputName.find("input_") != std::string::npos) { std::vector<int32_t> inputShape = GetTosaTensorShape(inputs[0]->GetShape()); DType inputDType = ArmNNToDType(inputs[0]->GetDataType()); @@ -98,13 +93,13 @@ TosaSerializationBasicBlock* ConvertSplitToTosaOperator(const Layer* layer, tensors.push_back(new TosaSerializationTensor(inputName, inputShape, inputDType, {})); } - std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[0]->GetShape()); DType outputDType = ArmNNToDType(outputs[0]->GetDataType()); - - for (unsigned int i=0; i < numSplit; ++i) + for (unsigned int i = 0; i < numSplit; ++i) { + std::vector<int32_t> outputShape = GetTosaTensorShape(outputs[i]->GetShape()); tensors.push_back(new TosaSerializationTensor(outputNames[i], outputShape, outputDType, {})); } + // operatorInputNames/operatorOutputNames ends up being the same as // blockInputNames/blockOutputNames for one-to-one ArmNN to TOSA mappings return new TosaSerializationBasicBlock(blockName, // name |