aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Descriptors.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Descriptors.cpp')
-rw-r--r--src/armnn/Descriptors.cpp115
1 files changed, 57 insertions, 58 deletions
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index f9576271d5..226d121edc 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "armnn/Descriptors.hpp"
@@ -461,80 +461,79 @@ BatchMatMulDescriptor::GetAxesToMul(
const TensorShape& tensorXShape,
const TensorShape& tensorYShape)
{
- // May refactor to just work on one input per call - makes it less confusing and also
- // allows more flexibility (i.e. in Layer output shape inference)
-
- auto xNumDims = tensorXShape.GetNumDimensions();
- auto yNumDims = tensorYShape.GetNumDimensions();
-
- std::pair<unsigned int, unsigned int> xAxes = { xNumDims-2, xNumDims-1 };
- std::pair<unsigned int, unsigned int> yAxes = { yNumDims-2, yNumDims-1 };
-
- if(desc.m_DataLayoutX.has_value())
- {
- switch(desc.m_DataLayoutX.value())
- {
- case DataLayout::NDHWC:
- case DataLayout::NHWC:
- xAxes.first -= 1;
- xAxes.second -= 1;
- break;
- case DataLayout::NCDHW:
- case DataLayout::NCHW:
- default:
- break;
- }
- }
-
- if(desc.m_DataLayoutY.has_value())
- {
- switch(desc.m_DataLayoutY.value())
- {
- case DataLayout::NDHWC:
- case DataLayout::NHWC:
- yAxes.first -= 1;
- yAxes.second -= 1;
- break;
- case DataLayout::NCDHW:
- case DataLayout::NCHW:
- default:
- break;
- }
- }
-
- return { xAxes, yAxes};
+ return { GetAxesToMul(desc.m_DataLayoutX, tensorXShape),
+ GetAxesToMul(desc.m_DataLayoutY, tensorYShape) };
}
-
std::pair<std::vector<unsigned int>, std::vector<unsigned int>> BatchMatMulDescriptor::GetAxesNotMul(
const BatchMatMulDescriptor& desc,
const TensorShape& inputXShape,
const TensorShape& inputYShape)
{
- // May refactor to just work on one input per call - makes it less confusing and also
- // allows more flexibility (i.e. in Layer output shape inference)
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape);
+ return { GetAxesNotMul(desc.m_DataLayoutX, inputXShape),
+ GetAxesNotMul(desc.m_DataLayoutY, inputYShape) };
+}
- std::vector<unsigned int> axesXNotMul;
- std::vector<unsigned int> axesYNotMul;
+std::pair<unsigned int, unsigned int> BatchMatMulDescriptor::GetAxesToMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape)
+{
+ auto numDims = tensorShape.GetNumDimensions();
+ std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
+ switch(dataLayout)
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ axes.first -= 1;
+ axes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ return axes;
+}
- for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++)
+std::vector<unsigned int> BatchMatMulDescriptor::GetAxesNotMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape)
+{
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
+ std::vector<unsigned int> axesNotMul;
+ for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
{
- if(i == axesToMul.first.first || i == axesToMul.first.second)
+ if(i == axesToMul.first || i == axesToMul.second)
{
continue;
}
- axesXNotMul.push_back(i);
+ axesNotMul.push_back(i);
}
- for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++)
+ return axesNotMul;
+}
+
+PermutationVector BatchMatMulDescriptor::GetPermuteVec(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape)
+{
+ std::vector<unsigned int> vec;
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
+ for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
{
- if(i == axesToMul.second.first || i == axesToMul.second.second)
+ if(i == axesToMul.first)
{
- continue;
+ vec.push_back(i+1);
+ }
+ else if(i == axesToMul.second)
+ {
+ vec.push_back(i-1);
+ }
+ else
+ {
+ vec.push_back(i);
}
- axesYNotMul.push_back(i);
}
-
- return { axesXNotMul, axesYNotMul };
+ return PermutationVector(vec.data(),
+ static_cast<unsigned int>(vec.size()));
}
}