diff options
Diffstat (limited to 'delegate/common/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/common/src/DelegateUtils.hpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp index b953699016..51c70f9ba1 100644 --- a/delegate/common/src/DelegateUtils.hpp +++ b/delegate/common/src/DelegateUtils.hpp @@ -109,4 +109,33 @@ void UpdateConstantTensorOutputs(const armnn::TensorInfo& inputInfo, armnn::Tens } } +void SetupConcatViewOrigin(const armnn::TensorInfo& inputTensorInfo, + armnn::OriginsDescriptor& concatDescriptor, + const unsigned int concatAxis, + unsigned int inputIndex, + unsigned int& mergeDimOrigin) +{ + const uint32_t inputRank = concatDescriptor.GetNumDimensions(); + + // double check dimensions of the tensors + if (inputTensorInfo.GetNumDimensions() != inputRank) + { + throw armnn::ParseException("The number of dimensions for input tensors " + "of the concatenation operator should be: " + std::to_string(inputRank)); + } + + for (unsigned int j = 0; j < concatAxis; ++j) + { + concatDescriptor.SetViewOriginCoord(inputIndex, j, 0); + } + + concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin); + mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis]; + + for (unsigned int j = concatAxis + 1; j < inputRank; ++j) + { + concatDescriptor.SetViewOriginCoord(inputIndex, j, 0); + } +} + } // namespace anonymous |