aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-08-08 14:07:42 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-08-30 17:03:33 +0100
commitdc8ed9d75e54e914a970e137900930fa64a0782b (patch)
tree8bcaedaae81a6afbdbe3c9a4e69e45840f18cdb4
parent9c9d5b9d796d243d88bd7a7aebb2e7e6c467e3a4 (diff)
downloadarmnn-dc8ed9d75e54e914a970e137900930fa64a0782b.tar.gz
IVGCVSW-7105: BatchMatMul Optional Parameter Support
* Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667
-rw-r--r--InstallationViaAptRepository.md2
-rw-r--r--include/armnn/Descriptors.hpp71
-rw-r--r--include/armnn/Version.hpp2
-rw-r--r--include/armnnOnnxParser/Version.hpp2
-rw-r--r--include/armnnTfLiteParser/Version.hpp2
-rw-r--r--python/pyarmnn/README.md14
-rw-r--r--python/pyarmnn/examples/image_classification/README.md2
-rw-r--r--python/pyarmnn/examples/keyword_spotting/README.md2
-rw-r--r--python/pyarmnn/examples/object_detection/README.md2
-rw-r--r--python/pyarmnn/examples/speech_recognition/README.md2
-rw-r--r--python/pyarmnn/src/pyarmnn/_version.py4
-rw-r--r--python/pyarmnn/test/test_setup.py8
-rw-r--r--python/pyarmnn/test/test_version.py4
-rw-r--r--samples/ObjectDetection/Readme.md4
-rw-r--r--src/armnn/Descriptors.cpp115
-rw-r--r--src/armnn/layers/BatchMatMulLayer.cpp27
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp236
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp364
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp18
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp21
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.cpp346
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.hpp69
-rw-r--r--src/backends/reference/workloads/RefBatchMatMulWorkload.cpp3
23 files changed, 970 insertions, 350 deletions
diff --git a/InstallationViaAptRepository.md b/InstallationViaAptRepository.md
index 3fa36f62a7..037e5cc7f1 100644
--- a/InstallationViaAptRepository.md
+++ b/InstallationViaAptRepository.md
@@ -117,7 +117,7 @@ The easiest way to install all of the available packages for your systems archit
sudo apt-get install -y python3-pyarmnn armnn-latest-all
# Verify installation via python:
python3 -c "import pyarmnn as ann;print(ann.GetVersion())"
- # Returns '{ARMNN_MAJOR_VERSION}.0.0' e.g. 30.0.0
+ # Returns '{ARMNN_MAJOR_VERSION}.0.0' e.g. 31.0.0
```
This will install PyArmNN and the three backends for Neon (CpuAcc), OpenCL (GpuAcc) and our Reference Backend.
It will also install their dependencies including the arm-compute-library package along with the Tensorflow Lite Parser
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 38e3c61500..493ce65976 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -1553,55 +1553,74 @@ struct ChannelShuffleDescriptor : BaseDescriptor
/// A BatchMatMulDescriptor for the BatchMatMul operator
struct BatchMatMulDescriptor : BaseDescriptor
{
- BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(),
- Optional<DataLayout> dataLayoutY = EmptyOptional(),
- std::vector<unsigned int> transposeX = {},
- std::vector<unsigned int> transposeY = {},
- std::vector<unsigned int> adjointX = {},
- std::vector<unsigned int> adjointY = {})
- : m_DataLayoutX(dataLayoutX)
- , m_DataLayoutY(dataLayoutY)
- , m_TransposeX(transposeX)
+ BatchMatMulDescriptor(bool transposeX = false,
+ bool transposeY = false,
+ bool adjointX = false,
+ bool adjointY = false,
+ DataLayout dataLayoutX = DataLayout::NCHW,
+ DataLayout dataLayoutY = DataLayout::NCHW)
+ : m_TransposeX(transposeX)
, m_TransposeY(transposeY)
, m_AdjointX(adjointX)
, m_AdjointY(adjointY)
+ , m_DataLayoutX(dataLayoutX)
+ , m_DataLayoutY(dataLayoutY)
{}
bool operator ==(const BatchMatMulDescriptor &rhs) const
{
- return m_DataLayoutX == rhs.m_DataLayoutX &&
- m_DataLayoutY == rhs.m_DataLayoutY &&
- m_TransposeX == rhs.m_TransposeX &&
+ return m_TransposeX == rhs.m_TransposeX &&
m_TransposeY == rhs.m_TransposeY &&
m_AdjointX == rhs.m_AdjointX &&
- m_AdjointY == rhs.m_AdjointY;
+ m_AdjointY == rhs.m_AdjointY &&
+ m_DataLayoutX == rhs.m_DataLayoutX &&
+ m_DataLayoutY == rhs.m_DataLayoutY;
}
- /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout)
- Optional<DataLayout> m_DataLayoutX;
- Optional<DataLayout> m_DataLayoutY;
-
- /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing)
+ /// Transpose the slices of each input tensor
/// Transpose and Adjoint can not both be set to true for the same tensor at the same time
- std::vector<unsigned int> m_TransposeX;
- std::vector<unsigned int> m_TransposeY;
+ bool m_TransposeX;
+ bool m_TransposeY;
- /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint)
+ /// Adjoint the slices of each input tensor
/// Transpose and Adjoint can not both be set to true for the same tensor at the same time
- std::vector<unsigned int> m_AdjointX;
- std::vector<unsigned int> m_AdjointY;
+ bool m_AdjointX;
+ bool m_AdjointY;
- /// Static helper to get the two axes (for each input) for multiplication
+ /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout)
+ DataLayout m_DataLayoutX;
+ DataLayout m_DataLayoutY;
+
+ ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable "
+ "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.",
+ "23.05")
static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul(
const BatchMatMulDescriptor& desc,
const TensorShape& tensorXShape,
const TensorShape& tensorYShape);
- /// Static helper to get the axes (for each input) that will not be multiplied together
+ ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable "
+ "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.",
+ "23.05")
static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul(
const BatchMatMulDescriptor& desc,
const TensorShape& inputXShape,
const TensorShape& inputYShape);
+
+ /// Static helper to get the two axes (for each input) for multiplication
+ static std::pair<unsigned int, unsigned int> GetAxesToMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape);
+
+ /// Static helper to get the axes (for each input) that will not be multiplied together
+ static std::vector<unsigned int> GetAxesNotMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape);
+
+ /// Static helper to get the axes which will be transposed
+ static PermutationVector GetPermuteVec(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape);
};
} // namespace armnn
diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp
index 7951eacf1d..7fdb20ade5 100644
--- a/include/armnn/Version.hpp
+++ b/include/armnn/Version.hpp
@@ -10,7 +10,7 @@
#define STRINGIFY_MACRO(s) #s
// ArmNN version components
-#define ARMNN_MAJOR_VERSION 30
+#define ARMNN_MAJOR_VERSION 31
#define ARMNN_MINOR_VERSION 0
#define ARMNN_PATCH_VERSION 0
diff --git a/include/armnnOnnxParser/Version.hpp b/include/armnnOnnxParser/Version.hpp
index 33a2846263..5fbaf0ca0e 100644
--- a/include/armnnOnnxParser/Version.hpp
+++ b/include/armnnOnnxParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnOnnxParser
// OnnxParser version components
#define ONNX_PARSER_MAJOR_VERSION 24
-#define ONNX_PARSER_MINOR_VERSION 5
+#define ONNX_PARSER_MINOR_VERSION 6
#define ONNX_PARSER_PATCH_VERSION 0
/// ONNX_PARSER_VERSION: "X.Y.Z"
diff --git a/include/armnnTfLiteParser/Version.hpp b/include/armnnTfLiteParser/Version.hpp
index 5db527ec8c..43fa436454 100644
--- a/include/armnnTfLiteParser/Version.hpp
+++ b/include/armnnTfLiteParser/Version.hpp
@@ -14,7 +14,7 @@ namespace armnnTfLiteParser
// TfLiteParser version components
#define TFLITE_PARSER_MAJOR_VERSION 24
-#define TFLITE_PARSER_MINOR_VERSION 5
+#define TFLITE_PARSER_MINOR_VERSION 6
#define TFLITE_PARSER_PATCH_VERSION 0
/// TFLITE_PARSER_VERSION: "X.Y.Z"
diff --git a/python/pyarmnn/README.md b/python/pyarmnn/README.md
index 547a868316..5e8ceb4b67 100644
--- a/python/pyarmnn/README.md
+++ b/python/pyarmnn/README.md
@@ -91,14 +91,14 @@ This step will put all generated files under `./src/pyarmnn/_generated` folder a
```bash
$ python setup.py sdist
```
-As the result you will get `./dist/pyarmnn-30.0.0.tar.gz` file. As you can see it is platform independent.
+As the result you will get `./dist/pyarmnn-31.0.0.tar.gz` file. As you can see it is platform independent.
##### 5. Build the binary package
```bash
$ python setup.py bdist_wheel
```
-As the result you will get something like `./dist/pyarmnn-30.0.0-cp36-cp36m-linux_x86_64.whl` file. As you can see it
+As the result you will get something like `./dist/pyarmnn-31.0.0-cp36-cp36m-linux_x86_64.whl` file. As you can see it
is platform dependent.
# PyArmNN installation
@@ -107,8 +107,8 @@ PyArmNN can be distributed as a source package or a binary package (wheel).
Binary package is platform dependent, the name of the package will indicate the platform it was built for, e.g.:
-* Linux x86 64bit machine: pyarmnn-30.0.0-cp36-cp36m-*linux_x86_64*.whl
-* Linux Aarch 64 bit machine: pyarmnn-30.0.0-cp36-cp36m-*linux_aarch64*.whl
+* Linux x86 64bit machine: pyarmnn-31.0.0-cp36-cp36m-*linux_x86_64*.whl
+* Linux Aarch 64 bit machine: pyarmnn-31.0.0-cp36-cp36m-*linux_aarch64*.whl
The source package is platform independent but installation involves compilation of Arm NN python extension. You will need to have g++ compatible with C++ 14 standard and a python development library installed on the build machine.
@@ -126,7 +126,7 @@ $ gcc --print-search-dirs
```
Install PyArmNN from binary by pointing to the wheel file:
```bash
-$ pip install /path/to/pyarmnn-30.0.0-cp36-cp36m-linux_aarch64.whl
+$ pip install /path/to/pyarmnn-31.0.0-cp36-cp36m-linux_aarch64.whl
```
## Installing from source package
@@ -145,7 +145,7 @@ $ export ARMNN_INCLUDE=/full/path/to/armnn/include:/full/path/to/armnn/profilin
Install PyArmNN as follows:
```bash
-$ pip install /path/to/pyarmnn-30.0.0.tar.gz
+$ pip install /path/to/pyarmnn-31.0.0.tar.gz
```
If PyArmNN installation script fails to find Arm NN libraries it will raise an error like this
@@ -159,7 +159,7 @@ $ pip show pyarmnn
You can also verify it by running the following and getting output similar to below:
```bash
$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
```
# PyArmNN API overview
diff --git a/python/pyarmnn/examples/image_classification/README.md b/python/pyarmnn/examples/image_classification/README.md
index a360f01148..04718e2bf4 100644
--- a/python/pyarmnn/examples/image_classification/README.md
+++ b/python/pyarmnn/examples/image_classification/README.md
@@ -20,7 +20,7 @@ $ pip show pyarmnn
You can also verify it by running the following and getting output similar to below:
```bash
$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
```
##### Dependencies
diff --git a/python/pyarmnn/examples/keyword_spotting/README.md b/python/pyarmnn/examples/keyword_spotting/README.md
index 1c1deafb26..98158e6c03 100644
--- a/python/pyarmnn/examples/keyword_spotting/README.md
+++ b/python/pyarmnn/examples/keyword_spotting/README.md
@@ -18,7 +18,7 @@ You can also verify it by running the following and getting output similar to be
```bash
$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
```
### Dependencies
diff --git a/python/pyarmnn/examples/object_detection/README.md b/python/pyarmnn/examples/object_detection/README.md
index 215cf772a2..73bafb6038 100644
--- a/python/pyarmnn/examples/object_detection/README.md
+++ b/python/pyarmnn/examples/object_detection/README.md
@@ -54,7 +54,7 @@ $ pip show pyarmnn
You can also verify it by running the following and getting output similar to below:
```bash
$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
```
##### Dependencies
diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md
index d5fee8a010..e442aad591 100644
--- a/python/pyarmnn/examples/speech_recognition/README.md
+++ b/python/pyarmnn/examples/speech_recognition/README.md
@@ -18,7 +18,7 @@ You can also verify it by running the following and getting output similar to be
```bash
$ python -c "import pyarmnn as ann;print(ann.GetVersion())"
-'30.0.0'
+'31.0.0'
```
### Dependencies
diff --git a/python/pyarmnn/src/pyarmnn/_version.py b/python/pyarmnn/src/pyarmnn/_version.py
index d1b1ca290c..d68a893e9c 100644
--- a/python/pyarmnn/src/pyarmnn/_version.py
+++ b/python/pyarmnn/src/pyarmnn/_version.py
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: MIT
import os
-version_info = (30, 0, 0)
+version_info = (31, 0, 0)
__dev_version_env = os.getenv("PYARMNN_DEV_VER", "")
@@ -24,7 +24,7 @@ def check_armnn_version(installed_armnn_version: str, expected_armnn_version: st
"""Compares expected Arm NN version and Arm NN version used to build the package.
Args:
- installed_armnn_version (str): Arm NN version used to generate the package (e.g. 30.0.0)
+ installed_armnn_version (str): Arm NN version used to generate the package (e.g. 31.0.0)
expected_armnn_version (str): Expected Arm NN version
Returns:
diff --git a/python/pyarmnn/test/test_setup.py b/python/pyarmnn/test/test_setup.py
index 27feda2647..ada96ccad4 100644
--- a/python/pyarmnn/test/test_setup.py
+++ b/python/pyarmnn/test/test_setup.py
@@ -87,15 +87,15 @@ def test_gcc_serch_path():
def test_armnn_version():
- check_armnn_version('30.0.0', '30.0.0')
+ check_armnn_version('31.0.0', '31.0.0')
def test_incorrect_armnn_version():
with pytest.raises(AssertionError) as err:
- check_armnn_version('30.0.0', '30.1.0')
+ check_armnn_version('31.0.0', '31.1.0')
- assert 'Expected ArmNN version is 30.1.0 but installed ArmNN version is 30.0.0' in str(err.value)
+ assert 'Expected ArmNN version is 31.1.0 but installed ArmNN version is 31.0.0' in str(err.value)
def test_armnn_version_patch_does_not_matter():
- check_armnn_version('30.0.0', '30.0.1')
+ check_armnn_version('31.0.0', '31.0.1')
diff --git a/python/pyarmnn/test/test_version.py b/python/pyarmnn/test/test_version.py
index 83606ab15b..f68adff0c7 100644
--- a/python/pyarmnn/test/test_version.py
+++ b/python/pyarmnn/test/test_version.py
@@ -18,7 +18,7 @@ def test_dev_version():
importlib.reload(v)
- assert "30.0.0.dev1" == v.__version__
+ assert "31.0.0.dev1" == v.__version__
del os.environ["PYARMNN_DEV_VER"]
del v
@@ -30,7 +30,7 @@ def test_arm_version_not_affected():
importlib.reload(v)
- assert "30.0.0" == v.__arm_ml_version__
+ assert "31.0.0" == v.__arm_ml_version__
del os.environ["PYARMNN_DEV_VER"]
del v
diff --git a/samples/ObjectDetection/Readme.md b/samples/ObjectDetection/Readme.md
index bd84e26001..169546e038 100644
--- a/samples/ObjectDetection/Readme.md
+++ b/samples/ObjectDetection/Readme.md
@@ -253,8 +253,8 @@ From the build directory, copy the following to the host platform:
The full list of libs after cross-compilation to copy on your board:
```
libarmnn.so
-libarmnn.so.30
-libarmnn.so.30.0
+libarmnn.so.31
+libarmnn.so.31.0
For Arm NN public C++ API mode:
libarmnnTfLiteParser.so
libarmnnTfLiteParser.so.24.4
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index f9576271d5..226d121edc 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "armnn/Descriptors.hpp"
@@ -461,80 +461,79 @@ BatchMatMulDescriptor::GetAxesToMul(
const TensorShape& tensorXShape,
const TensorShape& tensorYShape)
{
- // May refactor to just work on one input per call - makes it less confusing and also
- // allows more flexibility (i.e. in Layer output shape inference)
-
- auto xNumDims = tensorXShape.GetNumDimensions();
- auto yNumDims = tensorYShape.GetNumDimensions();
-
- std::pair<unsigned int, unsigned int> xAxes = { xNumDims-2, xNumDims-1 };
- std::pair<unsigned int, unsigned int> yAxes = { yNumDims-2, yNumDims-1 };
-
- if(desc.m_DataLayoutX.has_value())
- {
- switch(desc.m_DataLayoutX.value())
- {
- case DataLayout::NDHWC:
- case DataLayout::NHWC:
- xAxes.first -= 1;
- xAxes.second -= 1;
- break;
- case DataLayout::NCDHW:
- case DataLayout::NCHW:
- default:
- break;
- }
- }
-
- if(desc.m_DataLayoutY.has_value())
- {
- switch(desc.m_DataLayoutY.value())
- {
- case DataLayout::NDHWC:
- case DataLayout::NHWC:
- yAxes.first -= 1;
- yAxes.second -= 1;
- break;
- case DataLayout::NCDHW:
- case DataLayout::NCHW:
- default:
- break;
- }
- }
-
- return { xAxes, yAxes};
+ return { GetAxesToMul(desc.m_DataLayoutX, tensorXShape),
+ GetAxesToMul(desc.m_DataLayoutY, tensorYShape) };
}
-
std::pair<std::vector<unsigned int>, std::vector<unsigned int>> BatchMatMulDescriptor::GetAxesNotMul(
const BatchMatMulDescriptor& desc,
const TensorShape& inputXShape,
const TensorShape& inputYShape)
{
- // May refactor to just work on one input per call - makes it less confusing and also
- // allows more flexibility (i.e. in Layer output shape inference)
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape);
+ return { GetAxesNotMul(desc.m_DataLayoutX, inputXShape),
+ GetAxesNotMul(desc.m_DataLayoutY, inputYShape) };
+}
- std::vector<unsigned int> axesXNotMul;
- std::vector<unsigned int> axesYNotMul;
+std::pair<unsigned int, unsigned int> BatchMatMulDescriptor::GetAxesToMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape)
+{
+ auto numDims = tensorShape.GetNumDimensions();
+ std::pair<unsigned int, unsigned int> axes = { numDims-2, numDims-1 };
+ switch(dataLayout)
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ axes.first -= 1;
+ axes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ return axes;
+}
- for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++)
+std::vector<unsigned int> BatchMatMulDescriptor::GetAxesNotMul(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape)
+{
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
+ std::vector<unsigned int> axesNotMul;
+ for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
{
- if(i == axesToMul.first.first || i == axesToMul.first.second)
+ if(i == axesToMul.first || i == axesToMul.second)
{
continue;
}
- axesXNotMul.push_back(i);
+ axesNotMul.push_back(i);
}
- for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++)
+ return axesNotMul;
+}
+
+PermutationVector BatchMatMulDescriptor::GetPermuteVec(
+ DataLayout dataLayout,
+ const TensorShape& tensorShape)
+{
+ std::vector<unsigned int> vec;
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape);
+ for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++)
{
- if(i == axesToMul.second.first || i == axesToMul.second.second)
+ if(i == axesToMul.first)
{
- continue;
+ vec.push_back(i+1);
+ }
+ else if(i == axesToMul.second)
+ {
+ vec.push_back(i-1);
+ }
+ else
+ {
+ vec.push_back(i);
}
- axesYNotMul.push_back(i);
}
-
- return { axesXNotMul, axesYNotMul };
+ return PermutationVector(vec.data(),
+ static_cast<unsigned int>(vec.size()));
}
}
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
index 501de2d091..acd089aef8 100644
--- a/src/armnn/layers/BatchMatMulLayer.cpp
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -5,6 +5,7 @@
#include "BatchMatMulLayer.hpp"
#include <armnn/backends/WorkloadFactory.hpp>
+#include <armnnUtils/Permute.hpp>
#include "layers/LayerCloneBase.hpp"
namespace armnn
@@ -36,12 +37,24 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
TensorShape inputXShape = inputShapes[0];
TensorShape inputYShape = inputShapes[1];
- // Note: Take into account what pre-adjoint or pre-transposing will do to the inferred output shape
+ // Adjoint will not affect the resultant shape, as you would be permuting two axes of equal size
+ if(m_Param.m_TransposeX)
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
+ inputXShape);
+ inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
+ }
+ if(m_Param.m_TransposeY)
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
+ inputYShape);
+ inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
+ }
TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
- inputXShape:inputYShape;
+ inputXShape : inputYShape;
TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
- inputYShape:inputXShape;
+ inputYShape : inputXShape;
unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
@@ -49,10 +62,10 @@ std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<T
std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Param, inputXShape, inputYShape);
- const auto& longerAxesToMul = (axesToMul.first.first >= axesToMul.second.first &&
- axesToMul.first.second >= axesToMul.second.second) ?
- axesToMul.first : axesToMul.second;
+ const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
+ m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
+ auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
+ longerInput);
for (unsigned int i = 0; i < outputNumDimensions; ++i)
{
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 9a4c60f551..f4afbd9a84 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -8,6 +8,7 @@
#include <armnn/backends/WorkloadInfo.hpp>
#include <armnnUtils/DataLayoutIndexed.hpp>
#include <armnnUtils/TensorUtils.hpp>
+#include <armnnUtils/Permute.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <armnn/Logging.hpp>
@@ -4154,9 +4155,10 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
// For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
// axes N and I must be the same size
- const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
- const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
- const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+ const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0];
+ const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1];
+ const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0];
+ // Output info has already been inferred
std::vector<DataType> supportedTypes =
{
@@ -4168,108 +4170,127 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
DataType::QSymmS16
};
- ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
- ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
- ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName);
+ ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
- if ((inputTensorXInfo.GetNumDimensions() < 2) ||
- (inputTensorYInfo.GetNumDimensions() < 2))
+ if ((inputXInfoBeforeParams.GetNumDimensions() < 2) ||
+ (inputYInfoBeforeParams.GetNumDimensions() < 2))
{
throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
}
- if(m_Parameters.m_DataLayoutX.has_value())
+ TensorInfo inputXInfoAfterParams;
+ TensorInfo inputYInfoAfterParams;
+
+ if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) ||
+ (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY))
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameters - Transpose and Adjoint "
+ "cannot both be true for a given input tensor.");
+ }
+ if(m_Parameters.m_TransposeX)
+ {
+ inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointX)
{
- switch(m_Parameters.m_DataLayoutX.value())
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoBeforeParams.GetShape());
+ if(inputXInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputXInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorXInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorXInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor X does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputXInfoAfterParams = inputXInfoBeforeParams;
+ }
+ else
+ {
+ inputXInfoAfterParams = inputXInfoBeforeParams;
}
- if(m_Parameters.m_DataLayoutY.has_value())
+ if(m_Parameters.m_TransposeY)
{
- switch(m_Parameters.m_DataLayoutY.value())
+ inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams,
+ BatchMatMulDescriptor::GetPermuteVec(
+ m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape()));
+ }
+ else if(m_Parameters.m_AdjointY)
+ {
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputYInfoBeforeParams.GetShape());
+ if(inputYInfoBeforeParams.GetShape()[axesToMul.first] !=
+ inputYInfoBeforeParams.GetShape()[axesToMul.second])
{
- case DataLayout::NCHW:
- case DataLayout::NHWC:
- if(inputTensorYInfo.GetNumDimensions() != 4)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- case DataLayout::NCDHW:
- case DataLayout::NDHWC:
- if(inputTensorYInfo.GetNumDimensions() != 5)
- {
- throw InvalidArgumentException(descriptorName +
- ": Input tensor Y does not have the correct "
- "number of dimensions for the Data Layout that it has been assigned.");
- }
- break;
- default:
- break;
+ throw InvalidArgumentException(descriptorName +
+ ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." );
}
+ // Shape remains the same as it's square
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+ else
+ {
+ inputYInfoAfterParams = inputYInfoBeforeParams;
+ }
+
+ switch(m_Parameters.m_DataLayoutX)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputXInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor X does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
+ }
+
+ switch(m_Parameters.m_DataLayoutY)
+ {
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputYInfoAfterParams.GetNumDimensions() < 3)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor Y does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ default:
+ break;
}
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY,
+ inputXInfoBeforeParams.GetShape());
- if(inputTensorXInfo.GetShape()[axesToMul.first.second]
- != inputTensorYInfo.GetShape()[axesToMul.second.first])
+ if(inputXInfoAfterParams.GetShape()[axesXToMul.second]
+ != inputYInfoAfterParams.GetShape()[axesYToMul.first])
{
throw InvalidArgumentException(descriptorName +
": The final axis of input tensor X must be the same size as "
"the second last axis of input tensor Y.");
}
- auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
- inputTensorXInfo.GetShape(),
- inputTensorYInfo.GetShape());
-
{ // Separate scope so we don't pollute the rest of the scope with our temp variables
// e.g. NHWC isnt compatible with NCHW as of now
- DataLayout xLayout;
- DataLayout yLayout;
-
- if(m_Parameters.m_DataLayoutX == EmptyOptional())
- {
- xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
- }
- else
- {
- xLayout = m_Parameters.m_DataLayoutX.value();
- }
-
- if(m_Parameters.m_DataLayoutY == EmptyOptional())
- {
- yLayout = DataLayout::NCHW;
- }
- else
- {
- yLayout = m_Parameters.m_DataLayoutY.value();
- }
+ DataLayout xLayout = m_Parameters.m_DataLayoutX;
+ DataLayout yLayout = m_Parameters.m_DataLayoutY;
if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
{
@@ -4290,8 +4311,8 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
}
// Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
- unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
- inputTensorYInfo.GetNumDimensions());
+ unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(),
+ inputYInfoAfterParams.GetNumDimensions());
if(outputTensorDimSize-2 > 0)
{
TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
@@ -4312,12 +4333,17 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
{
- ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+ ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i];
}
};
- doAxisExtension(axesNotMul.first, tiXNotMul);
- doAxisExtension(axesNotMul.second, tiYNotMul);
+ auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX,
+ inputXInfoAfterParams.GetShape());
+ auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY,
+ inputYInfoAfterParams.GetShape());
+
+ doAxisExtension(axesXNotMul, tiXNotMul);
+ doAxisExtension(axesYNotMul, tiYNotMul);
for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
{
@@ -4332,42 +4358,6 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
"input_X",
"input_Y");
}
-
- // Also check descriptor parameter validity
- // This will eventually be moved to the start of the function as explained below
- if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
- (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameters - Transpose and Adjoint "
- "vectors cannot both be true for a given input tensor.");
- }
-
- if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint X vector must be "
- "the same size as tensor input X's dimensionality.");
- }
- if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Transpose Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
- {
- throw InvalidArgumentException(descriptorName +
- ": Invalid descriptor parameter - Adjoint Y vector must be "
- "the same size as tensor input Y's dimensionality.");
- }
- // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
}
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
index 41add6e6da..6fcc35ab52 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -191,7 +191,7 @@ LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -247,9 +247,7 @@ LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW));
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Default arbitrary layout is treated the same as NCHW
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -282,7 +280,7 @@ LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
memoryManager,
@@ -338,9 +336,12 @@ LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ false,
+ false,
+ armnn::DataLayout::NHWC,
+ armnn::DataLayout::NHWC);
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -373,7 +374,7 @@ LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
19, 22,
43, 50
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
memoryManager,
@@ -471,7 +472,7 @@ LayerTestResult<T, 3> BatchMatMul3DBatchTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -566,7 +567,7 @@ LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -661,7 +662,7 @@ LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
267, 286,
323, 346
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -717,9 +718,12 @@ LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory)
{
- auto descriptor = armnn::BatchMatMulDescriptor(
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NDHWC),
- armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ false,
+ false,
+ armnn::DataLayout::NDHWC,
+ armnn::DataLayout::NHWC);
float qScale = 0.0f;
int32_t qOffset = 0;
@@ -761,7 +765,7 @@ LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
34, 1079,
46, 1167
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 5>(workloadFactory,
memoryManager,
@@ -959,7 +963,7 @@ LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
88, 100, 142, 106,
39, 61, 78, 56,
72, 52, 98, 70
- },qScale, qOffset);
+ }, qScale, qOffset);
return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
memoryManager,
@@ -1007,4 +1011,330 @@ template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
BatchMatMul3DNonSquareTest<armnn::DataType::QSymmS16>(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(true,
+ false,
+ false,
+ false);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2, 3,
+ 4, 5, 6
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 7, 8, 9,
+ 10, 11, 12
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 47, 52, 57,
+ 64, 71, 78,
+ 81, 90, 99
+ }, qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DTranspSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ false,
+ true,
+ false);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({3,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({3,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 3, 1, 1,
+ 1, 3, -1,
+ 2, 4, 1
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 1, 0, 0,
+ 0, 1, 0,
+ 0, 0, 1
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 7, 3, -4,
+ -3, 1, 4,
+ -2, -10, 8
+ }, qScale, qOffset);
+
+ switch (ArmnnType)
+ {
+ case armnn::DataType::QAsymmU8:
+ outputExpected = armnnUtils::QuantizedVector<T>({
+ 3, 3, 0,
+ 0, 1, 1,
+ 0, 0, 8
+ }, qScale, qOffset);
+ break;
+ default:
+ break;
+ }
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DAdjointSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(false,
+ true,
+ true,
+ false,
+ armnn::DataLayout::NHWC,
+ armnn::DataLayout::NHWC);
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,4,4,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,2,4,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,4,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, -3, 1, 4, 4, 9, 1, 2,
+ 2, 4, 2, 2, 10, 7, 6, -5,
+ 3, 8, 9, 9, 21, 1, 17, 7,
+ 5, 11, 11, 8, 29, 3, 23, 6
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 1, 2, 3, 4,
+ 5, 6, 7, 8,
+
+ 9, 10, 11, 12,
+ 13, 14, 15, 16
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 28, 625, 140, 585,
+ 8, 110, -8, 1662,
+ -24, 401, -120, 921,
+ 12, 131, 108, -501,
+
+ 252, 545, 364, 505,
+ -24, 3214, -40, 4766,
+ -216, 1441, -312, 1961,
+ 204, -1133, 300, -1765
+ }, qScale, qOffset);
+
+ switch (ArmnnType)
+ {
+ case armnn::DataType::QAsymmU8:
+ outputExpected = armnnUtils::QuantizedVector<T>({
+ 28, 80, 140, 80,
+ 8, 45, 0, 255,
+ 0, 18, 0, 18,
+ 12, 0, 108, 0,
+
+ 252, 80, 255, 80,
+ 0, 255, 0, 255,
+ 0, 18, 0, 18,
+ 204, 0, 255, 0
+ }, qScale, qOffset);
+ break;
+ default:
+ break;
+ }
+
+ return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNHWCParamsTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
index 9e2139667b..0b261fba37 100644
--- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
@@ -82,4 +82,22 @@ template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DTranspSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DAdjointSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNHWCParamsTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 593dc7851e..ae40333658 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1133,6 +1133,27 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSq
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>);
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleBFloat16, BatchMatMul2DTranspSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat32, BatchMatMul2DTranspSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat16, BatchMatMul2DTranspSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmS8, BatchMatMul2DTranspSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmU8, BatchMatMul2DTranspSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQASymmS16,BatchMatMul2DTranspSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleBFloat16, BatchMatMul2DAdjointSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat32, BatchMatMul2DAdjointSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat16, BatchMatMul2DAdjointSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmS8, BatchMatMul2DAdjointSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmU8, BatchMatMul2DAdjointSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQASymmS16,BatchMatMul2DAdjointSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsBFloat16, BatchMatMulNHWCParamsTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat32, BatchMatMulNHWCParamsTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat16, BatchMatMulNHWCParamsTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmS8, BatchMatMulNHWCParamsTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmU8, BatchMatMulNHWCParamsTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQASymmS16, BatchMatMulNHWCParamsTest<DataType::QSymmS16>);
+
// Batch Norm
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
index 6693f15760..c592b3b76c 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.cpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -7,46 +7,53 @@
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/Logging.hpp>
+#include <armnnUtils/Permute.hpp>
namespace armnn
{
-void BatchMatMul::BatchMatMulImpl()
+BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder)
+ : params(params),
+ inputXInfo(inputXInfo),
+ inputYInfo(inputYInfo),
+ outputInfo(outputInfo),
+ inputXDecoder(inputXDecoder),
+ inputYDecoder(inputYDecoder),
+ outputEncoder(outputEncoder)
{
- inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
- inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+ inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+ inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape());
// At this point, we don't touch the input decoders - just the resultant vectors
- // Pre-transpose and pre-adjoint if their vectors aren't empty
- // and also DataLayouts which may change with permutations/adjoints
+ ApplyParams();
- // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
-
- auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
- RecurseBMM(idx, 0);
+ ApplyBatchMatMul();
}
-void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+void BatchMatMul::ApplyBatchMatMul()
{
- // We're working off of the indexes of the output tensor (the max possible shape)
-
- if(!(curDim < outputInfo.GetNumDimensions()))
- {
- // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+ auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX,
+ inputXInfo.GetShape());
+ auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY,
+ inputYInfo.GetShape());
+ AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul);
- auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
- inputXInfo.GetShape(),
- inputYInfo.GetShape());
- AdjustAxesToMulForUnequalRanks(axesToMul);
+ unsigned int inputXColDim = axesXToMul.second;
+ unsigned int inputYRowDim = axesYToMul.first;
- unsigned int inputXColDim = axesToMul.first.second;
- unsigned int inputYRowDim = axesToMul.second.first;
-
- unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+ unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+ auto batchMatMulOperation = [&](const std::vector<unsigned int>& curIdx)
+ {
float sum = 0.0f;
- // You could also use inputXColSize
+ // InputYRowSize is synonymous with inputXColSize
for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
auto xIdx = curIdx;
xIdx[inputXColDim] = inputYRowIdx;
@@ -54,24 +61,271 @@ void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int cur
auto yIdx = curIdx;
yIdx[inputYRowDim] = inputYRowIdx;
- sum += (GetValueAt(DataSlot::InputX, xIdx)
- * GetValueAt(DataSlot::InputY, yIdx));
+ sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx));
}
SetValueAt(sum, DataSlot::Output, curIdx);
+ };
+
+ auto startIdx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+ RecurseTensor(outputInfo,
+ batchMatMulOperation,
+ startIdx,
+ 0);
+}
+
+void BatchMatMul::ApplyParams()
+{
+ if(params.m_TransposeX)
+ {
+ Transpose(DataSlot::InputX);
+ }
+ else if(params.m_AdjointX)
+ {
+ Adjoint(DataSlot::InputX);
+ }
+ if(params.m_TransposeY)
+ {
+ Transpose(DataSlot::InputY);
+ }
+ else if(params.m_AdjointY)
+ {
+ Adjoint(DataSlot::InputY);
+ }
+}
+
+void BatchMatMul::Transpose(DataSlot type)
+{
+ // AKA the permute of the tensor
+ // This modifies the tensor's info.
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX,
+ inputXInfo.GetShape());
+ inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec);
+ std::vector<float> temp(inputXData.size());
+ armnnUtils::Permute(inputXInfo.GetShape(),
+ permuteVec,
+ inputXData.data(),
+ temp.data(),
+ sizeof(float));
+ inputXData = temp;
+ break;
+ }
+ case DataSlot::InputY:
+ {
+ auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY,
+ inputYInfo.GetShape());
+ inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec);
+ std::vector<float> temp(inputYData.size());
+ armnnUtils::Permute(inputYInfo.GetShape(),
+ permuteVec,
+ inputYData.data(),
+ temp.data(),
+ sizeof(float));
+ inputYData = temp;
+ break;
+ }
+ case DataSlot::Output: // We needn't transpose the output tensor
+ default:
+ break;
+ }
+}
+
+void BatchMatMul::Adjoint(DataSlot type)
+{
+ // Finding the adjoint of a square matrix:
+ // Calculate the cofactor of each element (using Gauss elimination here)
+ // Apply a transpose to it (this also modifies the tensor's info)
+
+ TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo;
+ const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY;
+ const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape());
+
+ ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]);
+ // We grab a copy of the tensor data to prevent overwriting
+ std::vector<float> inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData;
+
+ // The sub-matrix is the resultant matrix when the row and column of the current index is removed
+ unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1;
+ std::vector<std::vector<float>> subMat(subMatAxisSize,
+ std::vector<float>(subMatAxisSize));
+
+ // Lambdas for each sub-step of the cofactor operation
+ auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f)
+ {
+ float diff = std::fabs(a-b);
+ float bound = diff * std::numeric_limits<float>::epsilon() * unitsInLastPlace;
+ return (diff <= bound) || (diff < std::numeric_limits<float>::min());
+ };
+
+ float swapMultiplier = std::numeric_limits<float>::max();
+ auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB)
+ {
+ // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run)
+ for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++)
+ {
+ float tmp = subMat[rowIdxA][colIdx];
+ subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx];
+ subMat[rowIdxB][colIdx] = tmp;
+ }
+ swapMultiplier *= -1.0f;
+ };
+
+ auto findNextValidPivotRowIdx = [&](unsigned int colIdx)
+ {
+ unsigned int result = std::numeric_limits<unsigned int>::max();
+
+ // The original diagonal has been checked and is invalid
+ for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++)
+ {
+ if(!almostEquals(subMat[rowIdx][colIdx], 0.0f))
+ {
+ result = rowIdx;
+ break;
+ }
+ }
+ return result;
+ };
+
+ auto eliminate = [&](const float& pivot, unsigned int pivotPos)
+ {
+ for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++)
+ {
+ float multiplierNumerator = subMat[rowIdx][pivotPos];
+ if(almostEquals(multiplierNumerator, 0.0f))
+ {
+ continue;
+ }
+ float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies
+ // Hence the almostEquals usage to counteract this
+ for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++)
+ {
+ // We start at col=pivotPos as we have assumed that all elements
+ // to our left have been eliminated to zero already
+
+ // We subtract based on the element directly above us in our pivot row
+ subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx];
+ }
+ }
+ };
+
+ auto cofactorOperation = [&](const std::vector<unsigned int>& curIdx)
+ {
+ auto row = curIdx[axesToAdjoint.first];
+ auto col = curIdx[axesToAdjoint.second];
+
+ float minorMultiplier = static_cast<float>(std::pow(-1, (row + 1 + col + 1)));
+
+ for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++)
+ {
+ for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++)
+ {
+ unsigned int outerRow = (subRow >= row)?subRow + 1:subRow;
+ unsigned int outerCol = (subCol >= col)?subCol + 1:subCol;
+ auto cloneIdx = curIdx;
+ cloneIdx[axesToAdjoint.first] = outerRow;
+ cloneIdx[axesToAdjoint.second] = outerCol;
+ subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone);
+ }
+ }
+
+ float determinant = 1.0f;
+
+ // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices
+ switch(subMatAxisSize)
+ {
+ case 0:
+ {
+ determinant = GetValueAt(type, curIdx, inputDataClone);
+ break;
+ }
+ case 1:
+ {
+ // If the resultant sub-matrix is just one element - that's the determinant
+ determinant = subMat[0][0];
+ break;
+ }
+ case 2:
+ {
+ // For a 2x2 sub-matrix, the determinant is just a*d-b*c
+ determinant = subMat[0][0] * subMat[1][1] -
+ subMat[0][1] * subMat[1][0];
+ break;
+ }
+ default:
+ {
+ // Gaussian elimination to find the determinant of this sub-matrix
+ swapMultiplier = 1.0f;
+ // March diagonally down the pivots and if it's invalid (a zero), swap the row with the
+ // nearest non-zero down within the column
+ for(unsigned int pivotRow = 0, pivotCol = 0;
+ pivotRow < subMatAxisSize;
+ pivotRow++, pivotCol++)
+ {
+ float& pivot = subMat[pivotRow][pivotCol];
+
+ if(almostEquals(pivot, 0.0f))
+ {
+ unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol);
+ if(nextValidPivotRowIdx == std::numeric_limits<unsigned int>::max())
+ {
+ // No valid pivot down this column, which means that this pivot remains a zero.
+ // This results in the determinant for this entire sub-matrix to just be zero.
+ determinant = 0.0f;
+ break;
+ }
+ swapRows(pivotRow, nextValidPivotRowIdx);
+ }
+ determinant *= pivot;
+ // The actual elimination bit (which will update/propagate to the pivots down the line)
+ eliminate(pivot, pivotRow); // Synonymous with pivotCol
+ }
+
+ determinant *= swapMultiplier;
+ break;
+ }
+ }
+ float cofactor = minorMultiplier * determinant;
+ SetValueAt(cofactor, type, curIdx);
+ };
+
+ auto startIdx = std::vector<unsigned int>(inputInfo.GetNumDimensions(), 0);
+ RecurseTensor(inputInfo,
+ cofactorOperation,
+ startIdx,
+ 0);
+
+ Transpose(type);
+}
+void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo,
+ const std::function<void(const std::vector<unsigned int>&)>& operation,
+ std::vector<unsigned int>& curIdx,
+ unsigned int curDim)
+{
+ if(!(curDim < tensorInfo.GetNumDimensions()))
+ {
+ // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+ operation(curIdx);
return;
}
- for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+ for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++)
{
curIdx[curDim] = i;
- RecurseBMM(curIdx, curDim+1);
+ RecurseTensor(tensorInfo,
+ operation,
+ curIdx,
+ curDim + 1);
}
}
-void BatchMatMul::AdjustAxesToMulForUnequalRanks(
- std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+ std::pair<unsigned int, unsigned int>& axesYToMul)
{
int rankDiff = static_cast<int>(inputXInfo.GetNumDimensions()) -
static_cast<int>(inputYInfo.GetNumDimensions());
@@ -82,18 +336,18 @@ void BatchMatMul::AdjustAxesToMulForUnequalRanks(
else if(rankDiff < 0)
{
// Y is the larger one
- axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
- axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesXToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesXToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
}
else if(rankDiff > 0)
{
// X is the larger one
- axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
- axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesYToMul.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesYToMul.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
}
}
-float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData)
{
// This gets the data from the input vector that we have, Not the decoder
// But for the output, it is operating on the encoder itself
@@ -101,14 +355,13 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
AdjustToSafeIdx(type, idx);
unsigned int flatIdx = CalcFlatIdx(type, idx);
float value = 0.0f;
-
switch(type)
{
case DataSlot::InputX:
- value = inputXData[flatIdx];
+ value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx];
break;
case DataSlot::InputY:
- value = inputYData[flatIdx];
+ value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx];
break;
case DataSlot::Output:
outputEncoder[flatIdx];
@@ -124,9 +377,7 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
{
AdjustToSafeIdx(type, idx);
-
unsigned int flatIdx = CalcFlatIdx(type, idx);
-
switch(type)
{
case DataSlot::InputX:
@@ -186,9 +437,7 @@ void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
{
unsigned int result = idx[idx.size()-1];
-
unsigned int dimMultiplier = 1;
-
unsigned int offset;
// -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
@@ -215,17 +464,4 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned
return result;
}
-template <typename T>
-std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
-{
- std::string res = "{ ";
- for(auto x : vec)
- {
- res += std::to_string(x);
- res += " ";
- }
- res += "}";
- return res;
-}
-
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
index 25b6c85d77..19971a4af3 100644
--- a/src/backends/reference/workloads/BatchMatMulImpl.hpp
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -15,6 +15,15 @@ namespace armnn
class BatchMatMul {
public:
+ BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder);
+
+private:
enum DataSlot
{
InputX = 0,
@@ -22,31 +31,35 @@ public:
Output = 2
};
- BatchMatMul(const BatchMatMulDescriptor& params,
- const TensorInfo& inputXInfo,
- const TensorInfo& inputYInfo,
- const TensorInfo& outputInfo,
- Decoder<float>& inputXDecoder,
- Decoder<float>& inputYDecoder,
- Encoder<float>& outputEncoder)
- : params(params),
- inputXInfo(inputXInfo),
- inputYInfo(inputYInfo),
- outputInfo(outputInfo),
- inputXDecoder(inputXDecoder),
- inputYDecoder(inputYDecoder),
- outputEncoder(outputEncoder)
- {}
+ const BatchMatMulDescriptor& params;
+ TensorInfo inputXInfo;
+ TensorInfo inputYInfo;
+ TensorInfo outputInfo;
+ Decoder<float>& inputXDecoder;
+ Decoder<float>& inputYDecoder;
+ Encoder<float>& outputEncoder;
- void BatchMatMulImpl();
+ std::vector<float> inputXData;
+ std::vector<float> inputYData;
+
+ void ApplyBatchMatMul();
+
+ void ApplyParams();
+
+ void Transpose(DataSlot type);
- void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
+ void Adjoint(DataSlot type);
+
+ void RecurseTensor(const TensorInfo& tensorInfo,
+ std::function<void(const std::vector<unsigned int>&)> const& operation,
+ std::vector<unsigned int>& curIdx,
+ unsigned int curDim);
// Adjusts it for when input tensors are of unequal rank
- void AdjustAxesToMulForUnequalRanks(
- std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
+ void AdjustAxesToMulForUnequalRanks(std::pair<unsigned int, unsigned int>& axesXToMul,
+ std::pair<unsigned int, unsigned int>& axesYToMul);
- float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
+ float GetValueAt(DataSlot type, std::vector<unsigned int> idx, const std::vector<float>& customData = {});
void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
@@ -54,22 +67,6 @@ public:
void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
-
- template <typename T>
- std::string StringifyVec(const std::vector<T>& vec);
-
-private:
- const BatchMatMulDescriptor& params;
- const TensorInfo& inputXInfo;
- const TensorInfo& inputYInfo;
- const TensorInfo& outputInfo;
- Decoder<float>& inputXDecoder;
- Decoder<float>& inputYDecoder;
- Encoder<float>& outputEncoder;
-
- std::vector<float> inputXData;
- std::vector<float> inputYData;
-
};
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
index 388190c4ef..027b93b5d9 100644
--- a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
@@ -51,9 +51,6 @@ void RefBatchMatMulWorkload::Execute(std::vector<ITensorHandle*> inputs, std::ve
*inputXDecoder,
*inputYDecoder,
*outputEncoder);
-
- bmm.BatchMatMulImpl();
-
}
} // namespace armnn \ No newline at end of file