diff options
Diffstat (limited to 'src/armnn/layers/BatchMatMulLayer.cpp')
-rw-r--r-- | src/armnn/layers/BatchMatMulLayer.cpp | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp index acd089aef8..0f86b9dc48 100644 --- a/src/armnn/layers/BatchMatMulLayer.cpp +++ b/src/armnn/layers/BatchMatMulLayer.cpp @@ -37,14 +37,14 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T TensorShape inputXShape = inputShapes[0]; TensorShape inputYShape = inputShapes[1]; - // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size - if(m_Param.m_TransposeX) + // Adjoint is assumed to be square, but we will apply the permute anyway + if(m_Param.m_TransposeX || m_Param.m_AdjointX) { auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX, inputXShape); inputXShape = armnnUtils::Permuted(inputXShape, permuteVec); } - if(m_Param.m_TransposeY) + if(m_Param.m_TransposeY || m_Param.m_AdjointY) { auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY, inputYShape); |