aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Descriptors.cpp
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-07-06 15:36:03 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-07-27 15:58:31 +0100
commit6b47809e7d6c55d20a05d863ce2f09159f381f85 (patch)
treec33e5820f89e359c80d8773288e8adb075735039 /src/armnn/Descriptors.cpp
parent919ec71ea7f44bb2d284eb88cda511c2424358b2 (diff)
downloadarmnn-6b47809e7d6c55d20a05d863ce2f09159f381f85.tar.gz
IVGCVSW-7109: Add Batch MatMul front end support - Reference
* Descriptors added for BatchMatMul * Layer definition added * Input validation added (will likely change when opt. param support comes in) * Ref workload implementation for BatchMatMul added (will also change with opt. param support) * Ref layer tests made for BatchMatMul * CMake and other build files updated Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617
Diffstat (limited to 'src/armnn/Descriptors.cpp')
-rw-r--r--src/armnn/Descriptors.cpp82
1 files changed, 82 insertions, 0 deletions
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index c740fd03ad..f9576271d5 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -455,4 +455,86 @@ uint32_t DepthwiseConvolution2dDescriptor::GetNumInputs() const
return armnn::GetNumInputs(m_BiasEnabled);
}
+std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>
+BatchMatMulDescriptor::GetAxesToMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& tensorXShape,
+ const TensorShape& tensorYShape)
+{
+ // May refactor to just work on one input per call - makes it less confusing and also
+ // allows more flexibility (i.e. in Layer output shape inference)
+
+ auto xNumDims = tensorXShape.GetNumDimensions();
+ auto yNumDims = tensorYShape.GetNumDimensions();
+
+ std::pair<unsigned int, unsigned int> xAxes = { xNumDims-2, xNumDims-1 };
+ std::pair<unsigned int, unsigned int> yAxes = { yNumDims-2, yNumDims-1 };
+
+ if(desc.m_DataLayoutX.has_value())
+ {
+ switch(desc.m_DataLayoutX.value())
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ xAxes.first -= 1;
+ xAxes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ }
+
+ if(desc.m_DataLayoutY.has_value())
+ {
+ switch(desc.m_DataLayoutY.value())
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ yAxes.first -= 1;
+ yAxes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ }
+
+ return { xAxes, yAxes};
+}
+
+std::pair<std::vector<unsigned int>, std::vector<unsigned int>> BatchMatMulDescriptor::GetAxesNotMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& inputXShape,
+ const TensorShape& inputYShape)
+{
+ // May refactor to just work on one input per call - makes it less confusing and also
+ // allows more flexibility (i.e. in Layer output shape inference)
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape);
+
+ std::vector<unsigned int> axesXNotMul;
+ std::vector<unsigned int> axesYNotMul;
+
+ for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++)
+ {
+ if(i == axesToMul.first.first || i == axesToMul.first.second)
+ {
+ continue;
+ }
+ axesXNotMul.push_back(i);
+ }
+ for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++)
+ {
+ if(i == axesToMul.second.first || i == axesToMul.second.second)
+ {
+ continue;
+ }
+ axesYNotMul.push_back(i);
+ }
+
+ return { axesXNotMul, axesYNotMul };
+}
+
}