aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Descriptors.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Descriptors.cpp')
-rw-r--r--src/armnn/Descriptors.cpp82
1 files changed, 82 insertions, 0 deletions
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index c740fd03ad..f9576271d5 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -455,4 +455,86 @@ uint32_t DepthwiseConvolution2dDescriptor::GetNumInputs() const
return armnn::GetNumInputs(m_BiasEnabled);
}
+std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>
+BatchMatMulDescriptor::GetAxesToMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& tensorXShape,
+ const TensorShape& tensorYShape)
+{
+ // May refactor to just work on one input per call - makes it less confusing and also
+ // allows more flexibility (i.e. in Layer output shape inference)
+
+ auto xNumDims = tensorXShape.GetNumDimensions();
+ auto yNumDims = tensorYShape.GetNumDimensions();
+
+ std::pair<unsigned int, unsigned int> xAxes = { xNumDims-2, xNumDims-1 };
+ std::pair<unsigned int, unsigned int> yAxes = { yNumDims-2, yNumDims-1 };
+
+ if(desc.m_DataLayoutX.has_value())
+ {
+ switch(desc.m_DataLayoutX.value())
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ xAxes.first -= 1;
+ xAxes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ }
+
+ if(desc.m_DataLayoutY.has_value())
+ {
+ switch(desc.m_DataLayoutY.value())
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ yAxes.first -= 1;
+ yAxes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ }
+
+ return { xAxes, yAxes};
+}
+
+std::pair<std::vector<unsigned int>, std::vector<unsigned int>> BatchMatMulDescriptor::GetAxesNotMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& inputXShape,
+ const TensorShape& inputYShape)
+{
+ // May refactor to just work on one input per call - makes it less confusing and also
+ // allows more flexibility (i.e. in Layer output shape inference)
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape);
+
+ std::vector<unsigned int> axesXNotMul;
+ std::vector<unsigned int> axesYNotMul;
+
+ for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++)
+ {
+ if(i == axesToMul.first.first || i == axesToMul.first.second)
+ {
+ continue;
+ }
+ axesXNotMul.push_back(i);
+ }
+ for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++)
+ {
+ if(i == axesToMul.second.first || i == axesToMul.second.second)
+ {
+ continue;
+ }
+ axesYNotMul.push_back(i);
+ }
+
+ return { axesXNotMul, axesYNotMul };
+}
+
}