aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchMatMulLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/BatchMatMulLayer.cpp')
-rw-r--r--src/armnn/layers/BatchMatMulLayer.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index acd089aef8..0f86b9dc48 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -37,14 +37,14 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
TensorShape inputXShape = inputShapes[0];
TensorShape inputYShape = inputShapes[1];
- // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
- if(m_Param.m_TransposeX)
+ // Adjoint is assumed to be square, but we will apply the permute anyway
+ if(m_Param.m_TransposeX || m_Param.m_AdjointX)
{
auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
inputXShape);
inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
}
- if(m_Param.m_TransposeY)
+ if(m_Param.m_TransposeY || m_Param.m_AdjointY)
{
auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
inputYShape);