aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchMatMulLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/BatchMatMulLayer.cpp')
-rw-r--r--src/armnn/layers/BatchMatMulLayer.cpp27
1 files changed, 20 insertions, 7 deletions
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index 501de2d091..acd089aef8 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -5,6 +5,7 @@
#include "BatchMatMulLayer.hpp"
#include <armnn/backends/WorkloadFactory.hpp>
+#include <armnnUtils/Permute.hpp>
#include "layers/LayerCloneBase.hpp"
namespace armnn
@@ -36,12 +37,24 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
TensorShape inputXShape = inputShapes[0];
TensorShape inputYShape = inputShapes[1];
- // Note: Take into account what pre-adjoint or pre-transposing will do to the inferred output shape
+ // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
+ if(m_Param.m_TransposeX)
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
+ inputXShape);
+ inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
+ }
+ if(m_Param.m_TransposeY)
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
+ inputYShape);
+ inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
+ }
TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
- inputXShape:inputYShape;
+ inputXShape : inputYShape;
TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
- inputYShape:inputXShape;
+ inputYShape : inputXShape;
unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
@@ -49,10 +62,10 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Param, inputXShape, inputYShape);
- const auto& longerAxesToMul = (axesToMul.first.first >= axesToMul.second.first &&
- axesToMul.first.second >= axesToMul.second.second) ?
- axesToMul.first : axesToMul.second;
+ const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
+ m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
+ auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
+ longerInput);
for (unsigned int i = 0; i < outputNumDimensions; ++i)
{