aboutsummaryrefslogtreecommitdiff
path: root/delegate/common/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/common/src/DelegateUtils.hpp')
-rw-r--r--delegate/common/src/DelegateUtils.hpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp
index b953699016..51c70f9ba1 100644
--- a/delegate/common/src/DelegateUtils.hpp
+++ b/delegate/common/src/DelegateUtils.hpp
@@ -109,4 +109,33 @@ void UpdateConstantTensorOutputs(const armnn::TensorInfo& inputInfo, armnn::Tens
}
}
+void SetupConcatViewOrigin(const armnn::TensorInfo& inputTensorInfo,
+ armnn::OriginsDescriptor& concatDescriptor,
+ const unsigned int concatAxis,
+ unsigned int inputIndex,
+ unsigned int& mergeDimOrigin)
+{
+ const uint32_t inputRank = concatDescriptor.GetNumDimensions();
+
+ // double check dimensions of the tensors
+ if (inputTensorInfo.GetNumDimensions() != inputRank)
+ {
+ throw armnn::ParseException("The number of dimensions for input tensors "
+ "of the concatenation operator should be: " + std::to_string(inputRank));
+ }
+
+ for (unsigned int j = 0; j < concatAxis; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+
+ concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
+ mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
+
+ for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
+ {
+ concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
+ }
+}
+
} // namespace anonymous