aboutsummaryrefslogtreecommitdiff
path: root/src/armnn
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-08-24 17:04:34 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-09-05 10:47:21 +0100
commitfd3ba5a2f3630dc34094912b1a2c057f790f3092 (patch)
tree546406ab4199637a4374d859d0a9ad328a63c97d /src/armnn
parenta04f4a15575ddd778d3a330dbce629412e1ffc0c (diff)
downloadarmnn-fd3ba5a2f3630dc34094912b1a2c057f790f3092.tar.gz
IVGCVSW-6497: BatchMatMul TfLite Parser
* Added armnnTfLiteParser for BatchMatMul * Added unit testing for parser * Updated CMakeLists Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: If6842aaf7cf08f688093b714e2ecea6e8cd87161
Diffstat (limited to 'src/armnn')
-rw-r--r--src/armnn/layers/BatchMatMulLayer.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index acd089aef8..0f86b9dc48 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -37,14 +37,14 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
TensorShape inputXShape = inputShapes[0];
TensorShape inputYShape = inputShapes[1];
- // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
- if(m_Param.m_TransposeX)
+ // Adjoint is assumed to be square, but we will apply the permute anyway
+ if(m_Param.m_TransposeX || m_Param.m_AdjointX)
{
auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
inputXShape);
inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
}
- if(m_Param.m_TransposeY)
+ if(m_Param.m_TransposeY || m_Param.m_AdjointY)
{
auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
inputYShape);