aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp71
1 files changed, 45 insertions, 26 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 38e3c61500..493ce65976 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -1553,55 +1553,74 @@ struct ChannelShuffleDescriptor : BaseDescriptor
/// A BatchMatMulDescriptor for the BatchMatMul operator
struct BatchMatMulDescriptor : BaseDescriptor
{
- BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(),
- Optional<DataLayout> dataLayoutY = EmptyOptional(),
- std::vector<unsigned int> transposeX = {},
- std::vector<unsigned int> transposeY = {},
- std::vector<unsigned int> adjointX = {},
- std::vector<unsigned int> adjointY = {})
- : m_DataLayoutX(dataLayoutX)
- , m_DataLayoutY(dataLayoutY)
- , m_TransposeX(transposeX)
+ BatchMatMulDescriptor(bool transposeX = false,
+ bool transposeY = false,
+ bool adjointX = false,
+ bool adjointY = false,
+ DataLayout dataLayoutX = DataLayout::NCHW,
+ DataLayout dataLayoutY = DataLayout::NCHW)
+ : m_TransposeX(transposeX)
, m_TransposeY(transposeY)
, m_AdjointX(adjointX)
, m_AdjointY(adjointY)
+ , m_DataLayoutX(dataLayoutX)
+ , m_DataLayoutY(dataLayoutY)
{}
bool operator ==(const BatchMatMulDescriptor &rhs) const
{
- return m_DataLayoutX == rhs.m_DataLayoutX &&
- m_DataLayoutY == rhs.m_DataLayoutY &&
- m_TransposeX == rhs.m_TransposeX &&
+ return m_TransposeX == rhs.m_TransposeX &&
m_TransposeY == rhs.m_TransposeY &&
m_AdjointX == rhs.m_AdjointX &&
- m_AdjointY == rhs.m_AdjointY;
+ m_AdjointY == rhs.m_AdjointY &&
+ m_DataLayoutX == rhs.m_DataLayoutX &&
+ m_DataLayoutY == rhs.m_DataLayoutY;
}
- /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout)
- Optional<DataLayout> m_DataLayoutX;
- Optional<DataLayout> m_DataLayoutY;
-
- /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing)
+ /// Transpose the slices of each input tensor
/// Transpose and Adjoint can not both be set to true for the same tensor at the same time
- std::vector<unsigned int> m_TransposeX;
- std::vector<unsigned int> m_TransposeY;
+ bool m_TransposeX;
+ bool m_TransposeY;
- /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint)
+ /// Adjoint the slices of each input tensor
/// Transpose and Adjoint can not both be set to true for the same tensor at the same time
- std::vector<unsigned int> m_AdjointX;
- std::vector<unsigned int> m_AdjointY;
+ bool m_AdjointX;
+ bool m_AdjointY;
- /// Static helper to get the two axes (for each input) for multiplication
+ /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
+ DataLayout m_DataLayoutX;
+ DataLayout m_DataLayoutY;
+
+ ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable "
+ "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.",
+ "23.05")
static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul(
const BatchMatMulDescriptor& desc,
const TensorShape& tensorXShape,
const TensorShape& tensorYShape);
- /// Static helper to get the axes (for each input) that will not be multiplied together
+ ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable "
+ "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.",
+ "23.05")
static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul(
const BatchMatMulDescriptor& desc,
const TensorShape& inputXShape,
const TensorShape& inputYShape);
+
+ /// Static helper to get the two axes (for each input) for multiplication
+ static std::pair<unsigned int, unsigned int> GetAxesToMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape);
+
+ /// Static helper to get the axes (for each input) that will not be multiplied together
+ static std::vector<unsigned int> GetAxesNotMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape);
+
+ /// Static helper to get the axes which will be transposed
+ static PermutationVector GetPermuteVec(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape);
};
} // namespace armnn