diff options
Diffstat (limited to 'src/armnn/layers/BatchMatMulLayer.cpp')
-rw-r--r-- | src/armnn/layers/BatchMatMulLayer.cpp | 27 |
1 files changed, 20 insertions, 7 deletions
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp index 501de2d091..acd089aef8 100644 --- a/src/armnn/layers/BatchMatMulLayer.cpp +++ b/src/armnn/layers/BatchMatMulLayer.cpp @@ -5,6 +5,7 @@ #include "BatchMatMulLayer.hpp" #include <armnn/backends/WorkloadFactory.hpp> +#include <armnnUtils/Permute.hpp> #include "layers/LayerCloneBase.hpp" namespace armnn @@ -36,12 +37,24 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T TensorShape inputXShape = inputShapes[0]; TensorShape inputYShape = inputShapes[1]; - // Note: Take into account what pre-adjoint or pre-transposing will do to the inferred output shape + // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size + if(m_Param.m_TransposeX) + { + auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX, + inputXShape); + inputXShape = armnnUtils::Permuted(inputXShape, permuteVec); + } + if(m_Param.m_TransposeY) + { + auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY, + inputYShape); + inputYShape = armnnUtils::Permuted(inputYShape, permuteVec); + } TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()? - inputXShape:inputYShape; + inputXShape : inputYShape; TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()? - inputYShape:inputXShape; + inputYShape : inputXShape; unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions(); @@ -49,10 +62,10 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0); - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Param, inputXShape, inputYShape); - const auto& longerAxesToMul = (axesToMul.first.first >= axesToMul.second.first && - axesToMul.first.second >= axesToMul.second.second) ? - axesToMul.first : axesToMul.second; + const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()? + m_Param.m_DataLayoutX : m_Param.m_DataLayoutY; + auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout, + longerInput); for (unsigned int i = 0; i < outputNumDimensions; ++i) { |