From dc8ed9d75e54e914a970e137900930fa64a0782b Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Mon, 8 Aug 2022 14:07:42 +0100 Subject: IVGCVSW-7105: BatchMatMul Optional Parameter Support * Added transpose parameters to pre-transpose each input tensor's slices * Added adjoint parameters to pre-adjoint each input tensor's slices * Small refactoring (BatchMatMulDescriptor static helpers and BatchMatMulImpl constructor) * Updated input validation and output shape inference for parameters * Additional layer unit tests for parameters added * Versionings incremented Signed-off-by: Samuel Yap Change-Id: Ibe5242a8a5bf604c13de0dc65844fd6c421cc667 --- InstallationViaAptRepository.md | 2 +- include/armnn/Descriptors.hpp | 71 ++-- include/armnn/Version.hpp | 2 +- include/armnnOnnxParser/Version.hpp | 2 +- include/armnnTfLiteParser/Version.hpp | 2 +- python/pyarmnn/README.md | 14 +- .../examples/image_classification/README.md | 2 +- python/pyarmnn/examples/keyword_spotting/README.md | 2 +- python/pyarmnn/examples/object_detection/README.md | 2 +- .../pyarmnn/examples/speech_recognition/README.md | 2 +- python/pyarmnn/src/pyarmnn/_version.py | 4 +- python/pyarmnn/test/test_setup.py | 8 +- python/pyarmnn/test/test_version.py | 4 +- samples/ObjectDetection/Readme.md | 4 +- src/armnn/Descriptors.cpp | 115 ++++--- src/armnn/layers/BatchMatMulLayer.cpp | 27 +- src/backends/backendsCommon/WorkloadData.cpp | 236 +++++++------ .../test/layerTests/BatchMatMulTestImpl.cpp | 364 ++++++++++++++++++++- .../test/layerTests/BatchMatMulTestImpl.hpp | 18 + src/backends/reference/test/RefLayerTests.cpp | 21 ++ .../reference/workloads/BatchMatMulImpl.cpp | 346 ++++++++++++++++---- .../reference/workloads/BatchMatMulImpl.hpp | 69 ++-- .../reference/workloads/RefBatchMatMulWorkload.cpp | 3 - 23 files changed, 970 insertions(+), 350 deletions(-) diff --git a/InstallationViaAptRepository.md b/InstallationViaAptRepository.md index 3fa36f62a7..037e5cc7f1 100644 --- a/InstallationViaAptRepository.md +++ b/InstallationViaAptRepository.md @@ -117,7 +117,7 @@ The easiest way to install all of the available packages for your systems archit sudo apt-get install -y python3-pyarmnn armnn-latest-all # Verify installation via python: python3 -c "import pyarmnn as ann;print(ann.GetVersion())" - # Returns '{ARMNN_MAJOR_VERSION}.0.0' e.g. 30.0.0 + # Returns '{ARMNN_MAJOR_VERSION}.0.0' e.g. 31.0.0 ``` This will install PyArmNN and the three backends for Neon (CpuAcc), OpenCL (GpuAcc) and our Reference Backend. It will also install their dependencies including the arm-compute-library package along with the Tensorflow Lite Parser diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 38e3c61500..493ce65976 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -1553,55 +1553,74 @@ struct ChannelShuffleDescriptor : BaseDescriptor /// A BatchMatMulDescriptor for the BatchMatMul operator struct BatchMatMulDescriptor : BaseDescriptor { - BatchMatMulDescriptor(Optional dataLayoutX = EmptyOptional(), - Optional dataLayoutY = EmptyOptional(), - std::vector transposeX = {}, - std::vector transposeY = {}, - std::vector adjointX = {}, - std::vector adjointY = {}) - : m_DataLayoutX(dataLayoutX) - , m_DataLayoutY(dataLayoutY) - , m_TransposeX(transposeX) + BatchMatMulDescriptor(bool transposeX = false, + bool transposeY = false, + bool adjointX = false, + bool adjointY = false, + DataLayout dataLayoutX = DataLayout::NCHW, + DataLayout dataLayoutY = DataLayout::NCHW) + : m_TransposeX(transposeX) , m_TransposeY(transposeY) , m_AdjointX(adjointX) , m_AdjointY(adjointY) + , m_DataLayoutX(dataLayoutX) + , m_DataLayoutY(dataLayoutY) {} bool operator ==(const BatchMatMulDescriptor &rhs) const { - return m_DataLayoutX == rhs.m_DataLayoutX && - m_DataLayoutY == rhs.m_DataLayoutY && - m_TransposeX == rhs.m_TransposeX && + return m_TransposeX == rhs.m_TransposeX && m_TransposeY == rhs.m_TransposeY && m_AdjointX == rhs.m_AdjointX && - m_AdjointY == rhs.m_AdjointY; + m_AdjointY == rhs.m_AdjointY && + m_DataLayoutX == rhs.m_DataLayoutX && + m_DataLayoutY == rhs.m_DataLayoutY; } - /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout) - Optional m_DataLayoutX; - Optional m_DataLayoutY; - - /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing) + /// Transpose the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector m_TransposeX; - std::vector m_TransposeY; + bool m_TransposeX; + bool m_TransposeY; - /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint) + /// Adjoint the slices of each input tensor /// Transpose and Adjoint can not both be set to true for the same tensor at the same time - std::vector m_AdjointX; - std::vector m_AdjointY; + bool m_AdjointX; + bool m_AdjointY; - /// Static helper to get the two axes (for each input) for multiplication + /// Data layout of each input tensor, such as NHWC/NDHWC (leave as default for arbitrary layout) + DataLayout m_DataLayoutX; + DataLayout m_DataLayoutY; + + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesToMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair, std::pair> GetAxesToMul( const BatchMatMulDescriptor& desc, const TensorShape& tensorXShape, const TensorShape& tensorYShape); - /// Static helper to get the axes (for each input) that will not be multiplied together + ARMNN_DEPRECATED_MSG_REMOVAL_DATE("This method is deprecated. Use ABI Stable " + "GetAxesNotMul(DataLayout dataLayout, const TensorShape& tensorShape) instead.", + "23.05") static std::pair, std::vector> GetAxesNotMul( const BatchMatMulDescriptor& desc, const TensorShape& inputXShape, const TensorShape& inputYShape); + + /// Static helper to get the two axes (for each input) for multiplication + static std::pair GetAxesToMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes (for each input) that will not be multiplied together + static std::vector GetAxesNotMul( + DataLayout dataLayout, + const TensorShape& tensorShape); + + /// Static helper to get the axes which will be transposed + static PermutationVector GetPermuteVec( + DataLayout dataLayout, + const TensorShape& tensorShape); }; } // namespace armnn diff --git a/include/armnn/Version.hpp b/include/armnn/Version.hpp index 7951eacf1d..7fdb20ade5 100644 --- a/include/armnn/Version.hpp +++ b/include/armnn/Version.hpp @@ -10,7 +10,7 @@ #define STRINGIFY_MACRO(s) #s // ArmNN version components -#define ARMNN_MAJOR_VERSION 30 +#define ARMNN_MAJOR_VERSION 31 #define ARMNN_MINOR_VERSION 0 #define ARMNN_PATCH_VERSION 0 diff --git a/include/armnnOnnxParser/Version.hpp b/include/armnnOnnxParser/Version.hpp index 33a2846263..5fbaf0ca0e 100644 --- a/include/armnnOnnxParser/Version.hpp +++ b/include/armnnOnnxParser/Version.hpp @@ -14,7 +14,7 @@ namespace armnnOnnxParser // OnnxParser version components #define ONNX_PARSER_MAJOR_VERSION 24 -#define ONNX_PARSER_MINOR_VERSION 5 +#define ONNX_PARSER_MINOR_VERSION 6 #define ONNX_PARSER_PATCH_VERSION 0 /// ONNX_PARSER_VERSION: "X.Y.Z" diff --git a/include/armnnTfLiteParser/Version.hpp b/include/armnnTfLiteParser/Version.hpp index 5db527ec8c..43fa436454 100644 --- a/include/armnnTfLiteParser/Version.hpp +++ b/include/armnnTfLiteParser/Version.hpp @@ -14,7 +14,7 @@ namespace armnnTfLiteParser // TfLiteParser version components #define TFLITE_PARSER_MAJOR_VERSION 24 -#define TFLITE_PARSER_MINOR_VERSION 5 +#define TFLITE_PARSER_MINOR_VERSION 6 #define TFLITE_PARSER_PATCH_VERSION 0 /// TFLITE_PARSER_VERSION: "X.Y.Z" diff --git a/python/pyarmnn/README.md b/python/pyarmnn/README.md index 547a868316..5e8ceb4b67 100644 --- a/python/pyarmnn/README.md +++ b/python/pyarmnn/README.md @@ -91,14 +91,14 @@ This step will put all generated files under `./src/pyarmnn/_generated` folder a ```bash $ python setup.py sdist ``` -As the result you will get `./dist/pyarmnn-30.0.0.tar.gz` file. As you can see it is platform independent. +As the result you will get `./dist/pyarmnn-31.0.0.tar.gz` file. As you can see it is platform independent. ##### 5. Build the binary package ```bash $ python setup.py bdist_wheel ``` -As the result you will get something like `./dist/pyarmnn-30.0.0-cp36-cp36m-linux_x86_64.whl` file. As you can see it +As the result you will get something like `./dist/pyarmnn-31.0.0-cp36-cp36m-linux_x86_64.whl` file. As you can see it is platform dependent. # PyArmNN installation @@ -107,8 +107,8 @@ PyArmNN can be distributed as a source package or a binary package (wheel). Binary package is platform dependent, the name of the package will indicate the platform it was built for, e.g.: -* Linux x86 64bit machine: pyarmnn-30.0.0-cp36-cp36m-*linux_x86_64*.whl -* Linux Aarch 64 bit machine: pyarmnn-30.0.0-cp36-cp36m-*linux_aarch64*.whl +* Linux x86 64bit machine: pyarmnn-31.0.0-cp36-cp36m-*linux_x86_64*.whl +* Linux Aarch 64 bit machine: pyarmnn-31.0.0-cp36-cp36m-*linux_aarch64*.whl The source package is platform independent but installation involves compilation of Arm NN python extension. You will need to have g++ compatible with C++ 14 standard and a python development library installed on the build machine. @@ -126,7 +126,7 @@ $ gcc --print-search-dirs ``` Install PyArmNN from binary by pointing to the wheel file: ```bash -$ pip install /path/to/pyarmnn-30.0.0-cp36-cp36m-linux_aarch64.whl +$ pip install /path/to/pyarmnn-31.0.0-cp36-cp36m-linux_aarch64.whl ``` ## Installing from source package @@ -145,7 +145,7 @@ $ export ARMNN_INCLUDE=/full/path/to/armnn/include:/full/path/to/armnn/profilin Install PyArmNN as follows: ```bash -$ pip install /path/to/pyarmnn-30.0.0.tar.gz +$ pip install /path/to/pyarmnn-31.0.0.tar.gz ``` If PyArmNN installation script fails to find Arm NN libraries it will raise an error like this @@ -159,7 +159,7 @@ $ pip show pyarmnn You can also verify it by running the following and getting output similar to below: ```bash $ python -c "import pyarmnn as ann;print(ann.GetVersion())" -'30.0.0' +'31.0.0' ``` # PyArmNN API overview diff --git a/python/pyarmnn/examples/image_classification/README.md b/python/pyarmnn/examples/image_classification/README.md index a360f01148..04718e2bf4 100644 --- a/python/pyarmnn/examples/image_classification/README.md +++ b/python/pyarmnn/examples/image_classification/README.md @@ -20,7 +20,7 @@ $ pip show pyarmnn You can also verify it by running the following and getting output similar to below: ```bash $ python -c "import pyarmnn as ann;print(ann.GetVersion())" -'30.0.0' +'31.0.0' ``` ##### Dependencies diff --git a/python/pyarmnn/examples/keyword_spotting/README.md b/python/pyarmnn/examples/keyword_spotting/README.md index 1c1deafb26..98158e6c03 100644 --- a/python/pyarmnn/examples/keyword_spotting/README.md +++ b/python/pyarmnn/examples/keyword_spotting/README.md @@ -18,7 +18,7 @@ You can also verify it by running the following and getting output similar to be ```bash $ python -c "import pyarmnn as ann;print(ann.GetVersion())" -'30.0.0' +'31.0.0' ``` ### Dependencies diff --git a/python/pyarmnn/examples/object_detection/README.md b/python/pyarmnn/examples/object_detection/README.md index 215cf772a2..73bafb6038 100644 --- a/python/pyarmnn/examples/object_detection/README.md +++ b/python/pyarmnn/examples/object_detection/README.md @@ -54,7 +54,7 @@ $ pip show pyarmnn You can also verify it by running the following and getting output similar to below: ```bash $ python -c "import pyarmnn as ann;print(ann.GetVersion())" -'30.0.0' +'31.0.0' ``` ##### Dependencies diff --git a/python/pyarmnn/examples/speech_recognition/README.md b/python/pyarmnn/examples/speech_recognition/README.md index d5fee8a010..e442aad591 100644 --- a/python/pyarmnn/examples/speech_recognition/README.md +++ b/python/pyarmnn/examples/speech_recognition/README.md @@ -18,7 +18,7 @@ You can also verify it by running the following and getting output similar to be ```bash $ python -c "import pyarmnn as ann;print(ann.GetVersion())" -'30.0.0' +'31.0.0' ``` ### Dependencies diff --git a/python/pyarmnn/src/pyarmnn/_version.py b/python/pyarmnn/src/pyarmnn/_version.py index d1b1ca290c..d68a893e9c 100644 --- a/python/pyarmnn/src/pyarmnn/_version.py +++ b/python/pyarmnn/src/pyarmnn/_version.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT import os -version_info = (30, 0, 0) +version_info = (31, 0, 0) __dev_version_env = os.getenv("PYARMNN_DEV_VER", "") @@ -24,7 +24,7 @@ def check_armnn_version(installed_armnn_version: str, expected_armnn_version: st """Compares expected Arm NN version and Arm NN version used to build the package. Args: - installed_armnn_version (str): Arm NN version used to generate the package (e.g. 30.0.0) + installed_armnn_version (str): Arm NN version used to generate the package (e.g. 31.0.0) expected_armnn_version (str): Expected Arm NN version Returns: diff --git a/python/pyarmnn/test/test_setup.py b/python/pyarmnn/test/test_setup.py index 27feda2647..ada96ccad4 100644 --- a/python/pyarmnn/test/test_setup.py +++ b/python/pyarmnn/test/test_setup.py @@ -87,15 +87,15 @@ def test_gcc_serch_path(): def test_armnn_version(): - check_armnn_version('30.0.0', '30.0.0') + check_armnn_version('31.0.0', '31.0.0') def test_incorrect_armnn_version(): with pytest.raises(AssertionError) as err: - check_armnn_version('30.0.0', '30.1.0') + check_armnn_version('31.0.0', '31.1.0') - assert 'Expected ArmNN version is 30.1.0 but installed ArmNN version is 30.0.0' in str(err.value) + assert 'Expected ArmNN version is 31.1.0 but installed ArmNN version is 31.0.0' in str(err.value) def test_armnn_version_patch_does_not_matter(): - check_armnn_version('30.0.0', '30.0.1') + check_armnn_version('31.0.0', '31.0.1') diff --git a/python/pyarmnn/test/test_version.py b/python/pyarmnn/test/test_version.py index 83606ab15b..f68adff0c7 100644 --- a/python/pyarmnn/test/test_version.py +++ b/python/pyarmnn/test/test_version.py @@ -18,7 +18,7 @@ def test_dev_version(): importlib.reload(v) - assert "30.0.0.dev1" == v.__version__ + assert "31.0.0.dev1" == v.__version__ del os.environ["PYARMNN_DEV_VER"] del v @@ -30,7 +30,7 @@ def test_arm_version_not_affected(): importlib.reload(v) - assert "30.0.0" == v.__arm_ml_version__ + assert "31.0.0" == v.__arm_ml_version__ del os.environ["PYARMNN_DEV_VER"] del v diff --git a/samples/ObjectDetection/Readme.md b/samples/ObjectDetection/Readme.md index bd84e26001..169546e038 100644 --- a/samples/ObjectDetection/Readme.md +++ b/samples/ObjectDetection/Readme.md @@ -253,8 +253,8 @@ From the build directory, copy the following to the host platform: The full list of libs after cross-compilation to copy on your board: ``` libarmnn.so -libarmnn.so.30 -libarmnn.so.30.0 +libarmnn.so.31 +libarmnn.so.31.0 For Arm NN public C++ API mode: libarmnnTfLiteParser.so libarmnnTfLiteParser.so.24.4 diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp index f9576271d5..226d121edc 100644 --- a/src/armnn/Descriptors.cpp +++ b/src/armnn/Descriptors.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "armnn/Descriptors.hpp" @@ -461,80 +461,79 @@ BatchMatMulDescriptor::GetAxesToMul( const TensorShape& tensorXShape, const TensorShape& tensorYShape) { - // May refactor to just work on one input per call - makes it less confusing and also - // allows more flexibility (i.e. in Layer output shape inference) - - auto xNumDims = tensorXShape.GetNumDimensions(); - auto yNumDims = tensorYShape.GetNumDimensions(); - - std::pair xAxes = { xNumDims-2, xNumDims-1 }; - std::pair yAxes = { yNumDims-2, yNumDims-1 }; - - if(desc.m_DataLayoutX.has_value()) - { - switch(desc.m_DataLayoutX.value()) - { - case DataLayout::NDHWC: - case DataLayout::NHWC: - xAxes.first -= 1; - xAxes.second -= 1; - break; - case DataLayout::NCDHW: - case DataLayout::NCHW: - default: - break; - } - } - - if(desc.m_DataLayoutY.has_value()) - { - switch(desc.m_DataLayoutY.value()) - { - case DataLayout::NDHWC: - case DataLayout::NHWC: - yAxes.first -= 1; - yAxes.second -= 1; - break; - case DataLayout::NCDHW: - case DataLayout::NCHW: - default: - break; - } - } - - return { xAxes, yAxes}; + return { GetAxesToMul(desc.m_DataLayoutX, tensorXShape), + GetAxesToMul(desc.m_DataLayoutY, tensorYShape) }; } - std::pair, std::vector> BatchMatMulDescriptor::GetAxesNotMul( const BatchMatMulDescriptor& desc, const TensorShape& inputXShape, const TensorShape& inputYShape) { - // May refactor to just work on one input per call - makes it less confusing and also - // allows more flexibility (i.e. in Layer output shape inference) - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape); + return { GetAxesNotMul(desc.m_DataLayoutX, inputXShape), + GetAxesNotMul(desc.m_DataLayoutY, inputYShape) }; +} - std::vector axesXNotMul; - std::vector axesYNotMul; +std::pair BatchMatMulDescriptor::GetAxesToMul( + DataLayout dataLayout, + const TensorShape& tensorShape) +{ + auto numDims = tensorShape.GetNumDimensions(); + std::pair axes = { numDims-2, numDims-1 }; + switch(dataLayout) + { + case DataLayout::NDHWC: + case DataLayout::NHWC: + axes.first -= 1; + axes.second -= 1; + break; + case DataLayout::NCDHW: + case DataLayout::NCHW: + default: + break; + } + return axes; +} - for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++) +std::vector BatchMatMulDescriptor::GetAxesNotMul( + DataLayout dataLayout, + const TensorShape& tensorShape) +{ + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape); + std::vector axesNotMul; + for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++) { - if(i == axesToMul.first.first || i == axesToMul.first.second) + if(i == axesToMul.first || i == axesToMul.second) { continue; } - axesXNotMul.push_back(i); + axesNotMul.push_back(i); } - for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++) + return axesNotMul; +} + +PermutationVector BatchMatMulDescriptor::GetPermuteVec( + DataLayout dataLayout, + const TensorShape& tensorShape) +{ + std::vector vec; + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(dataLayout, tensorShape); + for(unsigned int i = 0; i < tensorShape.GetNumDimensions(); i++) { - if(i == axesToMul.second.first || i == axesToMul.second.second) + if(i == axesToMul.first) { - continue; + vec.push_back(i+1); + } + else if(i == axesToMul.second) + { + vec.push_back(i-1); + } + else + { + vec.push_back(i); } - axesYNotMul.push_back(i); } - - return { axesXNotMul, axesYNotMul }; + return PermutationVector(vec.data(), + static_cast(vec.size())); } } diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp index 501de2d091..acd089aef8 100644 --- a/src/armnn/layers/BatchMatMulLayer.cpp +++ b/src/armnn/layers/BatchMatMulLayer.cpp @@ -5,6 +5,7 @@ #include "BatchMatMulLayer.hpp" #include +#include #include "layers/LayerCloneBase.hpp" namespace armnn @@ -36,12 +37,24 @@ std::vector BatchMatMulLayer::InferOutputShapes(const std::vector= inputYShape.GetNumDimensions()? - inputXShape:inputYShape; + inputXShape : inputYShape; TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()? - inputYShape:inputXShape; + inputYShape : inputXShape; unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions(); @@ -49,10 +62,10 @@ std::vector BatchMatMulLayer::InferOutputShapes(const std::vector tensorDimensions(outputNumDimensions, 0); - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Param, inputXShape, inputYShape); - const auto& longerAxesToMul = (axesToMul.first.first >= axesToMul.second.first && - axesToMul.first.second >= axesToMul.second.second) ? - axesToMul.first : axesToMul.second; + const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()? + m_Param.m_DataLayoutX : m_Param.m_DataLayoutY; + auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout, + longerInput); for (unsigned int i = 0; i < outputNumDimensions; ++i) { diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 9a4c60f551..f4afbd9a84 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -4154,9 +4155,10 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively, // axes N and I must be the same size - const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0]; - const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1]; - const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + const auto& inputXInfoBeforeParams = workloadInfo.m_InputTensorInfos[0]; + const auto& inputYInfoBeforeParams = workloadInfo.m_InputTensorInfos[1]; + const auto& outputInfo = workloadInfo.m_OutputTensorInfos[0]; + // Output info has already been inferred std::vector supportedTypes = { @@ -4168,108 +4170,127 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons DataType::QSymmS16 }; - ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName); - ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName); - ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); + ValidateDataTypes(inputXInfoBeforeParams, supportedTypes, descriptorName); + ValidateDataTypes(inputYInfoBeforeParams, supportedTypes, descriptorName); + ValidateDataTypes(outputInfo, supportedTypes, descriptorName); - if ((inputTensorXInfo.GetNumDimensions() < 2) || - (inputTensorYInfo.GetNumDimensions() < 2)) + if ((inputXInfoBeforeParams.GetNumDimensions() < 2) || + (inputYInfoBeforeParams.GetNumDimensions() < 2)) { throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater."); } - if(m_Parameters.m_DataLayoutX.has_value()) + TensorInfo inputXInfoAfterParams; + TensorInfo inputYInfoAfterParams; + + if((m_Parameters.m_TransposeX && m_Parameters.m_AdjointX) || + (m_Parameters.m_TransposeY && m_Parameters.m_AdjointY)) + { + throw InvalidArgumentException(descriptorName + + ": Invalid descriptor parameters - Transpose and Adjoint " + "cannot both be true for a given input tensor."); + } + if(m_Parameters.m_TransposeX) + { + inputXInfoAfterParams = armnnUtils::Permuted(inputXInfoBeforeParams, + BatchMatMulDescriptor::GetPermuteVec( + m_Parameters.m_DataLayoutX, + inputXInfoBeforeParams.GetShape())); + } + else if(m_Parameters.m_AdjointX) { - switch(m_Parameters.m_DataLayoutX.value()) + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX, + inputXInfoBeforeParams.GetShape()); + if(inputXInfoBeforeParams.GetShape()[axesToMul.first] != + inputXInfoBeforeParams.GetShape()[axesToMul.second]) { - case DataLayout::NCHW: - case DataLayout::NHWC: - if(inputTensorXInfo.GetNumDimensions() != 4) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor X does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - case DataLayout::NCDHW: - case DataLayout::NDHWC: - if(inputTensorXInfo.GetNumDimensions() != 5) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor X does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - default: - break; + throw InvalidArgumentException(descriptorName + + ": Adjoint is set to true for input tensor X, but the axes to be adjointed are not square." ); } + // Shape remains the same as it's square + inputXInfoAfterParams = inputXInfoBeforeParams; + } + else + { + inputXInfoAfterParams = inputXInfoBeforeParams; } - if(m_Parameters.m_DataLayoutY.has_value()) + if(m_Parameters.m_TransposeY) { - switch(m_Parameters.m_DataLayoutY.value()) + inputYInfoAfterParams = armnnUtils::Permuted(inputYInfoBeforeParams, + BatchMatMulDescriptor::GetPermuteVec( + m_Parameters.m_DataLayoutY, + inputYInfoBeforeParams.GetShape())); + } + else if(m_Parameters.m_AdjointY) + { + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY, + inputYInfoBeforeParams.GetShape()); + if(inputYInfoBeforeParams.GetShape()[axesToMul.first] != + inputYInfoBeforeParams.GetShape()[axesToMul.second]) { - case DataLayout::NCHW: - case DataLayout::NHWC: - if(inputTensorYInfo.GetNumDimensions() != 4) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor Y does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - case DataLayout::NCDHW: - case DataLayout::NDHWC: - if(inputTensorYInfo.GetNumDimensions() != 5) - { - throw InvalidArgumentException(descriptorName + - ": Input tensor Y does not have the correct " - "number of dimensions for the Data Layout that it has been assigned."); - } - break; - default: - break; + throw InvalidArgumentException(descriptorName + + ": Adjoint is set to true for input tensor Y, but the axes to be adjointed are not square." ); } + // Shape remains the same as it's square + inputYInfoAfterParams = inputYInfoBeforeParams; + } + else + { + inputYInfoAfterParams = inputYInfoBeforeParams; + } + + switch(m_Parameters.m_DataLayoutX) + { + case DataLayout::NCDHW: + case DataLayout::NDHWC: + if(inputXInfoAfterParams.GetNumDimensions() < 3) + { + throw InvalidArgumentException(descriptorName + + ": Input tensor X does not have the correct " + "number of dimensions for the Data Layout that it has been assigned."); + } + break; + case DataLayout::NCHW: + case DataLayout::NHWC: + default: + break; + } + + switch(m_Parameters.m_DataLayoutY) + { + case DataLayout::NCDHW: + case DataLayout::NDHWC: + if(inputYInfoAfterParams.GetNumDimensions() < 3) + { + throw InvalidArgumentException(descriptorName + + ": Input tensor Y does not have the correct " + "number of dimensions for the Data Layout that it has been assigned."); + } + break; + case DataLayout::NCHW: + case DataLayout::NHWC: + default: + break; } - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters, - inputTensorXInfo.GetShape(), - inputTensorYInfo.GetShape()); + auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutX, + inputXInfoAfterParams.GetShape()); + auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters.m_DataLayoutY, + inputXInfoBeforeParams.GetShape()); - if(inputTensorXInfo.GetShape()[axesToMul.first.second] - != inputTensorYInfo.GetShape()[axesToMul.second.first]) + if(inputXInfoAfterParams.GetShape()[axesXToMul.second] + != inputYInfoAfterParams.GetShape()[axesYToMul.first]) { throw InvalidArgumentException(descriptorName + ": The final axis of input tensor X must be the same size as " "the second last axis of input tensor Y."); } - auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters, - inputTensorXInfo.GetShape(), - inputTensorYInfo.GetShape()); - { // Separate scope so we don't pollute the rest of the scope with our temp variables // e.g. NHWC isnt compatible with NCHW as of now - DataLayout xLayout; - DataLayout yLayout; - - if(m_Parameters.m_DataLayoutX == EmptyOptional()) - { - xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes - } - else - { - xLayout = m_Parameters.m_DataLayoutX.value(); - } - - if(m_Parameters.m_DataLayoutY == EmptyOptional()) - { - yLayout = DataLayout::NCHW; - } - else - { - yLayout = m_Parameters.m_DataLayoutY.value(); - } + DataLayout xLayout = m_Parameters.m_DataLayoutX; + DataLayout yLayout = m_Parameters.m_DataLayoutY; if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW) { @@ -4290,8 +4311,8 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons } // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one - unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(), - inputTensorYInfo.GetNumDimensions()); + unsigned int outputTensorDimSize = std::max(inputXInfoAfterParams.GetNumDimensions(), + inputYInfoAfterParams.GetNumDimensions()); if(outputTensorDimSize-2 > 0) { TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2), @@ -4312,12 +4333,17 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons for(unsigned int i = 0; i < ti.GetNumDimensions(); i++) { - ti.GetShape()[i] = inputTensorXInfo.GetShape()[i]; + ti.GetShape()[i] = inputXInfoAfterParams.GetShape()[i]; } }; - doAxisExtension(axesNotMul.first, tiXNotMul); - doAxisExtension(axesNotMul.second, tiYNotMul); + auto axesXNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutX, + inputXInfoAfterParams.GetShape()); + auto axesYNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters.m_DataLayoutY, + inputYInfoAfterParams.GetShape()); + + doAxisExtension(axesXNotMul, tiXNotMul); + doAxisExtension(axesYNotMul, tiYNotMul); for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++) { @@ -4332,42 +4358,6 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons "input_X", "input_Y"); } - - // Also check descriptor parameter validity - // This will eventually be moved to the start of the function as explained below - if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) || - (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty())) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameters - Transpose and Adjoint " - "vectors cannot both be true for a given input tensor."); - } - - if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Transpose X vector must be " - "the same size as tensor input X's dimensionality."); - } - if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Adjoint X vector must be " - "the same size as tensor input X's dimensionality."); - } - if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Transpose Y vector must be " - "the same size as tensor input Y's dimensionality."); - } - if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions()) - { - throw InvalidArgumentException(descriptorName + - ": Invalid descriptor parameter - Adjoint Y vector must be " - "the same size as tensor input Y's dimensionality."); - } - // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation. } diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp index 41add6e6da..6fcc35ab52 100644 --- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp +++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp @@ -191,7 +191,7 @@ LayerTestResult BatchMatMul3DSimpleTest( std::vector outputExpected = armnnUtils::QuantizedVector({ 19, 22, 43, 50 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -247,9 +247,7 @@ LayerTestResult BatchMatMulNCHWSimpleTest( const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::ITensorHandleFactory& tensorHandleFactory) { - auto descriptor = armnn::BatchMatMulDescriptor( - armnn::Optional(armnn::DataLayout::NCHW), - armnn::Optional(armnn::DataLayout::NCHW)); + auto descriptor = armnn::BatchMatMulDescriptor(); // Default arbitrary layout is treated the same as NCHW float qScale = 0.0f; int32_t qOffset = 0; @@ -282,7 +280,7 @@ LayerTestResult BatchMatMulNCHWSimpleTest( std::vector outputExpected = armnnUtils::QuantizedVector({ 19, 22, 43, 50 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -338,9 +336,12 @@ LayerTestResult BatchMatMulNHWCSimpleTest( const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::ITensorHandleFactory& tensorHandleFactory) { - auto descriptor = armnn::BatchMatMulDescriptor( - armnn::Optional(armnn::DataLayout::NHWC), - armnn::Optional(armnn::DataLayout::NHWC)); + auto descriptor = armnn::BatchMatMulDescriptor(false, + false, + false, + false, + armnn::DataLayout::NHWC, + armnn::DataLayout::NHWC); float qScale = 0.0f; int32_t qOffset = 0; @@ -373,7 +374,7 @@ LayerTestResult BatchMatMulNHWCSimpleTest( std::vector outputExpected = armnnUtils::QuantizedVector({ 19, 22, 43, 50 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -471,7 +472,7 @@ LayerTestResult BatchMatMul3DBatchTest( 267, 286, 323, 346 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -566,7 +567,7 @@ LayerTestResult BatchMatMul3DBroadcastTest( 267, 286, 323, 346 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -661,7 +662,7 @@ LayerTestResult BatchMatMul3D2DBroadcastTest( 267, 286, 323, 346 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -717,9 +718,12 @@ LayerTestResult BatchMatMulNDHWCNHWCTest( const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::ITensorHandleFactory& tensorHandleFactory) { - auto descriptor = armnn::BatchMatMulDescriptor( - armnn::Optional(armnn::DataLayout::NDHWC), - armnn::Optional(armnn::DataLayout::NHWC)); + auto descriptor = armnn::BatchMatMulDescriptor(false, + false, + false, + false, + armnn::DataLayout::NDHWC, + armnn::DataLayout::NHWC); float qScale = 0.0f; int32_t qOffset = 0; @@ -761,7 +765,7 @@ LayerTestResult BatchMatMulNDHWCNHWCTest( 34, 1079, 46, 1167 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -959,7 +963,7 @@ LayerTestResult BatchMatMul3DNonSquareTest( 88, 100, 142, 106, 39, 61, 78, 56, 72, 52, 98, 70 - },qScale, qOffset); + }, qScale, qOffset); return BatchMatMulTestImpl(workloadFactory, memoryManager, @@ -1005,6 +1009,332 @@ BatchMatMul3DNonSquareTest( template LayerTestResult, 3> BatchMatMul3DNonSquareTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template +LayerTestResult BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + auto descriptor = armnn::BatchMatMulDescriptor(true, + false, + false, + false); + + float qScale = 0.0f; + int32_t qOffset = 0; + + switch(ArmnnType) + { + case armnn::DataType::QAsymmS8: + case armnn::DataType::QAsymmU8: + case armnn::DataType::QSymmS16: + qScale = 1.0f; + break; + default: + break; + } + + armnn::TensorInfo inputXInfo({2,3}, ArmnnType, qScale, qOffset); + armnn::TensorInfo inputYInfo({2,3}, ArmnnType, qScale, qOffset); + armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset); + + std::vector inputX = armnnUtils::QuantizedVector({ + 1, 2, 3, + 4, 5, 6 + }, qScale, qOffset); + + std::vector inputY = armnnUtils::QuantizedVector({ + 7, 8, 9, + 10, 11, 12 + }, qScale, qOffset); + + std::vector outputExpected = armnnUtils::QuantizedVector({ + 47, 52, 57, + 64, 71, 78, + 81, 90, 99 + }, qScale, qOffset); + + return BatchMatMulTestImpl(workloadFactory, + memoryManager, + tensorHandleFactory, + descriptor, + inputX, + inputY, + outputExpected, + inputXInfo, + inputYInfo, + outputInfo); +} + +template LayerTestResult, 2> +BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template +LayerTestResult BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + auto descriptor = armnn::BatchMatMulDescriptor(false, + false, + true, + false); + + float qScale = 0.0f; + int32_t qOffset = 0; + + switch(ArmnnType) + { + case armnn::DataType::QAsymmS8: + case armnn::DataType::QAsymmU8: + case armnn::DataType::QSymmS16: + qScale = 1.0f; + break; + default: + break; + } + + armnn::TensorInfo inputXInfo({3,3}, ArmnnType, qScale, qOffset); + armnn::TensorInfo inputYInfo({3,3}, ArmnnType, qScale, qOffset); + armnn::TensorInfo outputInfo({3,3}, ArmnnType, qScale, qOffset); + + std::vector inputX = armnnUtils::QuantizedVector({ + 3, 1, 1, + 1, 3, -1, + 2, 4, 1 + }, qScale, qOffset); + + std::vector inputY = armnnUtils::QuantizedVector({ + 1, 0, 0, + 0, 1, 0, + 0, 0, 1 + }, qScale, qOffset); + + std::vector outputExpected = armnnUtils::QuantizedVector({ + 7, 3, -4, + -3, 1, 4, + -2, -10, 8 + }, qScale, qOffset); + + switch (ArmnnType) + { + case armnn::DataType::QAsymmU8: + outputExpected = armnnUtils::QuantizedVector({ + 3, 3, 0, + 0, 1, 1, + 0, 0, 8 + }, qScale, qOffset); + break; + default: + break; + } + + return BatchMatMulTestImpl(workloadFactory, + memoryManager, + tensorHandleFactory, + descriptor, + inputX, + inputY, + outputExpected, + inputXInfo, + inputYInfo, + outputInfo); +} + +template LayerTestResult, 2> +BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 2> +BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template +LayerTestResult BatchMatMulNHWCParamsTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory) +{ + auto descriptor = armnn::BatchMatMulDescriptor(false, + true, + true, + false, + armnn::DataLayout::NHWC, + armnn::DataLayout::NHWC); + + float qScale = 0.0f; + int32_t qOffset = 0; + + switch(ArmnnType) + { + case armnn::DataType::QAsymmS8: + case armnn::DataType::QAsymmU8: + case armnn::DataType::QSymmS16: + qScale = 1.0f; + break; + default: + break; + } + + armnn::TensorInfo inputXInfo({1,4,4,2}, ArmnnType, qScale, qOffset); + armnn::TensorInfo inputYInfo({2,2,4,1}, ArmnnType, qScale, qOffset); + armnn::TensorInfo outputInfo({2,4,2,2}, ArmnnType, qScale, qOffset); + + std::vector inputX = armnnUtils::QuantizedVector({ + 1, -3, 1, 4, 4, 9, 1, 2, + 2, 4, 2, 2, 10, 7, 6, -5, + 3, 8, 9, 9, 21, 1, 17, 7, + 5, 11, 11, 8, 29, 3, 23, 6 + }, qScale, qOffset); + + std::vector inputY = armnnUtils::QuantizedVector({ + 1, 2, 3, 4, + 5, 6, 7, 8, + + 9, 10, 11, 12, + 13, 14, 15, 16 + }, qScale, qOffset); + + std::vector outputExpected = armnnUtils::QuantizedVector({ + 28, 625, 140, 585, + 8, 110, -8, 1662, + -24, 401, -120, 921, + 12, 131, 108, -501, + + 252, 545, 364, 505, + -24, 3214, -40, 4766, + -216, 1441, -312, 1961, + 204, -1133, 300, -1765 + }, qScale, qOffset); + + switch (ArmnnType) + { + case armnn::DataType::QAsymmU8: + outputExpected = armnnUtils::QuantizedVector({ + 28, 80, 140, 80, + 8, 45, 0, 255, + 0, 18, 0, 18, + 12, 0, 108, 0, + + 252, 80, 255, 80, + 0, 255, 0, 255, + 0, 18, 0, 18, + 204, 0, 255, 0 + }, qScale, qOffset); + break; + default: + break; + } + + return BatchMatMulTestImpl(workloadFactory, + memoryManager, + tensorHandleFactory, + descriptor, + inputX, + inputY, + outputExpected, + inputXInfo, + inputYInfo, + outputInfo); +} + +template LayerTestResult, 4> +BatchMatMulNHWCParamsTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 4> +BatchMatMulNHWCParamsTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 4> +BatchMatMulNHWCParamsTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 4> +BatchMatMulNHWCParamsTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 4> +BatchMatMulNHWCParamsTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template LayerTestResult, 4> +BatchMatMulNHWCParamsTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp index 9e2139667b..0b261fba37 100644 --- a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp +++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp @@ -80,6 +80,24 @@ LayerTestResult BatchMatMul2DTinyTest( template> LayerTestResult BatchMatMul3DNonSquareTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template> +LayerTestResult BatchMatMul2DTranspSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template> +LayerTestResult BatchMatMul2DAdjointSimpleTest( + armnn::IWorkloadFactory& workloadFactory, + const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, + const armnn::ITensorHandleFactory& tensorHandleFactory); + +template> +LayerTestResult BatchMatMulNHWCParamsTest( armnn::IWorkloadFactory& workloadFactory, const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 593dc7851e..ae40333658 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1133,6 +1133,27 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSq ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest); ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleBFloat16, BatchMatMul2DTranspSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat32, BatchMatMul2DTranspSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleFloat16, BatchMatMul2DTranspSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmS8, BatchMatMul2DTranspSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQAsymmU8, BatchMatMul2DTranspSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTranspSimpleQASymmS16,BatchMatMul2DTranspSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleBFloat16, BatchMatMul2DAdjointSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat32, BatchMatMul2DAdjointSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleFloat16, BatchMatMul2DAdjointSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmS8, BatchMatMul2DAdjointSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQAsymmU8, BatchMatMul2DAdjointSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DAdjointSimpleQASymmS16,BatchMatMul2DAdjointSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsBFloat16, BatchMatMulNHWCParamsTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat32, BatchMatMulNHWCParamsTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsFloat16, BatchMatMulNHWCParamsTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmS8, BatchMatMulNHWCParamsTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQAsymmU8, BatchMatMulNHWCParamsTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCParamsQASymmS16, BatchMatMulNHWCParamsTest); + // Batch Norm ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest) diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp index 6693f15760..c592b3b76c 100644 --- a/src/backends/reference/workloads/BatchMatMulImpl.cpp +++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp @@ -7,46 +7,53 @@ #include #include +#include namespace armnn { -void BatchMatMul::BatchMatMulImpl() +BatchMatMul::BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder& inputXDecoder, + Decoder& inputYDecoder, + Encoder& outputEncoder) + : params(params), + inputXInfo(inputXInfo), + inputYInfo(inputYInfo), + outputInfo(outputInfo), + inputXDecoder(inputXDecoder), + inputYDecoder(inputYDecoder), + outputEncoder(outputEncoder) { - inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape()); - inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape()); + inputXData = this->inputXDecoder.DecodeTensor(inputXInfo.GetShape()); + inputYData = this->inputYDecoder.DecodeTensor(inputYInfo.GetShape()); // At this point, we don't touch the input decoders - just the resultant vectors - // Pre-transpose and pre-adjoint if their vectors aren't empty - // and also DataLayouts which may change with permutations/adjoints + ApplyParams(); - // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes? - - auto idx = std::vector(outputInfo.GetNumDimensions(), 0); - RecurseBMM(idx, 0); + ApplyBatchMatMul(); } -void BatchMatMul::RecurseBMM(std::vector& curIdx, unsigned int curDim) +void BatchMatMul::ApplyBatchMatMul() { - // We're working off of the indexes of the output tensor (the max possible shape) - - if(!(curDim < outputInfo.GetNumDimensions())) - { - // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) + auto axesXToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutX, + inputXInfo.GetShape()); + auto axesYToMul = BatchMatMulDescriptor::GetAxesToMul(params.m_DataLayoutY, + inputYInfo.GetShape()); + AdjustAxesToMulForUnequalRanks(axesXToMul, axesYToMul); - auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params, - inputXInfo.GetShape(), - inputYInfo.GetShape()); - AdjustAxesToMulForUnequalRanks(axesToMul); + unsigned int inputXColDim = axesXToMul.second; + unsigned int inputYRowDim = axesYToMul.first; - unsigned int inputXColDim = axesToMul.first.second; - unsigned int inputYRowDim = axesToMul.second.first; - - unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; + unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; + auto batchMatMulOperation = [&](const std::vector& curIdx) + { float sum = 0.0f; - // You could also use inputXColSize + // InputYRowSize is synonymous with inputXColSize for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) { auto xIdx = curIdx; xIdx[inputXColDim] = inputYRowIdx; @@ -54,24 +61,271 @@ void BatchMatMul::RecurseBMM(std::vector& curIdx, unsigned int cur auto yIdx = curIdx; yIdx[inputYRowDim] = inputYRowIdx; - sum += (GetValueAt(DataSlot::InputX, xIdx) - * GetValueAt(DataSlot::InputY, yIdx)); + sum += (GetValueAt(DataSlot::InputX, xIdx) * GetValueAt(DataSlot::InputY, yIdx)); } SetValueAt(sum, DataSlot::Output, curIdx); + }; + + auto startIdx = std::vector(outputInfo.GetNumDimensions(), 0); + RecurseTensor(outputInfo, + batchMatMulOperation, + startIdx, + 0); +} + +void BatchMatMul::ApplyParams() +{ + if(params.m_TransposeX) + { + Transpose(DataSlot::InputX); + } + else if(params.m_AdjointX) + { + Adjoint(DataSlot::InputX); + } + if(params.m_TransposeY) + { + Transpose(DataSlot::InputY); + } + else if(params.m_AdjointY) + { + Adjoint(DataSlot::InputY); + } +} + +void BatchMatMul::Transpose(DataSlot type) +{ + // AKA the permute of the tensor + // This modifies the tensor's info. + + switch(type) + { + case DataSlot::InputX: + { + auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutX, + inputXInfo.GetShape()); + inputXInfo = armnnUtils::Permuted(inputXInfo, permuteVec); + std::vector temp(inputXData.size()); + armnnUtils::Permute(inputXInfo.GetShape(), + permuteVec, + inputXData.data(), + temp.data(), + sizeof(float)); + inputXData = temp; + break; + } + case DataSlot::InputY: + { + auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(params.m_DataLayoutY, + inputYInfo.GetShape()); + inputYInfo = armnnUtils::Permuted(inputYInfo, permuteVec); + std::vector temp(inputYData.size()); + armnnUtils::Permute(inputYInfo.GetShape(), + permuteVec, + inputYData.data(), + temp.data(), + sizeof(float)); + inputYData = temp; + break; + } + case DataSlot::Output: // We needn't transpose the output tensor + default: + break; + } +} + +void BatchMatMul::Adjoint(DataSlot type) +{ + // Finding the adjoint of a square matrix: + // Calculate the cofactor of each element (using Gauss elimination here) + // Apply a transpose to it (this also modifies the tensor's info) + + TensorInfo& inputInfo = (type == DataSlot::InputX) ? inputXInfo : inputYInfo; + const auto& dataLayout = (type == DataSlot::InputX) ? params.m_DataLayoutX : params.m_DataLayoutY; + const auto axesToAdjoint = BatchMatMulDescriptor::GetAxesToMul(dataLayout,inputInfo.GetShape()); + + ARMNN_ASSERT(inputInfo.GetShape()[axesToAdjoint.first] == inputInfo.GetShape()[axesToAdjoint.second]); + // We grab a copy of the tensor data to prevent overwriting + std::vector inputDataClone = (type == DataSlot::InputX) ? inputXData : inputYData; + + // The sub-matrix is the resultant matrix when the row and column of the current index is removed + unsigned int subMatAxisSize = inputInfo.GetShape()[axesToAdjoint.first] - 1; + std::vector> subMat(subMatAxisSize, + std::vector(subMatAxisSize)); + + // Lambdas for each sub-step of the cofactor operation + auto almostEquals = [&](const float& a, const float& b, float unitsInLastPlace = 2.0f) + { + float diff = std::fabs(a-b); + float bound = diff * std::numeric_limits::epsilon() * unitsInLastPlace; + return (diff <= bound) || (diff < std::numeric_limits::min()); + }; + + float swapMultiplier = std::numeric_limits::max(); + auto swapRows = [&](unsigned int rowIdxA, unsigned int rowIdxB) + { + // Every row swap flips this around by the negative (set to 1 at the beginning of each cofactor op run) + for(unsigned int colIdx = 0; colIdx < subMatAxisSize; colIdx++) + { + float tmp = subMat[rowIdxA][colIdx]; + subMat[rowIdxA][colIdx] = subMat[rowIdxB][colIdx]; + subMat[rowIdxB][colIdx] = tmp; + } + swapMultiplier *= -1.0f; + }; + + auto findNextValidPivotRowIdx = [&](unsigned int colIdx) + { + unsigned int result = std::numeric_limits::max(); + + // The original diagonal has been checked and is invalid + for(unsigned int rowIdx = colIdx+1; rowIdx < subMatAxisSize; rowIdx++) + { + if(!almostEquals(subMat[rowIdx][colIdx], 0.0f)) + { + result = rowIdx; + break; + } + } + return result; + }; + + auto eliminate = [&](const float& pivot, unsigned int pivotPos) + { + for(unsigned int rowIdx = pivotPos+1; rowIdx < subMatAxisSize; rowIdx++) + { + float multiplierNumerator = subMat[rowIdx][pivotPos]; + if(almostEquals(multiplierNumerator, 0.0f)) + { + continue; + } + float multiplier = multiplierNumerator / pivot; // Susceptible to floating point inaccuracies + // Hence the almostEquals usage to counteract this + for(unsigned int colIdx = pivotPos; colIdx < subMatAxisSize; colIdx++) + { + // We start at col=pivotPos as we have assumed that all elements + // to our left have been eliminated to zero already + + // We subtract based on the element directly above us in our pivot row + subMat[rowIdx][colIdx] -= multiplier * subMat[pivotPos][colIdx]; + } + } + }; + + auto cofactorOperation = [&](const std::vector& curIdx) + { + auto row = curIdx[axesToAdjoint.first]; + auto col = curIdx[axesToAdjoint.second]; + + float minorMultiplier = static_cast(std::pow(-1, (row + 1 + col + 1))); + + for(unsigned int subRow = 0; subRow < subMatAxisSize; subRow++) + { + for(unsigned int subCol = 0; subCol < subMatAxisSize; subCol++) + { + unsigned int outerRow = (subRow >= row)?subRow + 1:subRow; + unsigned int outerCol = (subCol >= col)?subCol + 1:subCol; + auto cloneIdx = curIdx; + cloneIdx[axesToAdjoint.first] = outerRow; + cloneIdx[axesToAdjoint.second] = outerCol; + subMat[subRow][subCol] = GetValueAt(type,cloneIdx,inputDataClone); + } + } + + float determinant = 1.0f; + + // Cover the edge cases and simple base cases before resorting to Gauss elimination for larger matrices + switch(subMatAxisSize) + { + case 0: + { + determinant = GetValueAt(type, curIdx, inputDataClone); + break; + } + case 1: + { + // If the resultant sub-matrix is just one element - that's the determinant + determinant = subMat[0][0]; + break; + } + case 2: + { + // For a 2x2 sub-matrix, the determinant is just a*d-b*c + determinant = subMat[0][0] * subMat[1][1] - + subMat[0][1] * subMat[1][0]; + break; + } + default: + { + // Gaussian elimination to find the determinant of this sub-matrix + swapMultiplier = 1.0f; + // March diagonally down the pivots and if it's invalid (a zero), swap the row with the + // nearest non-zero down within the column + for(unsigned int pivotRow = 0, pivotCol = 0; + pivotRow < subMatAxisSize; + pivotRow++, pivotCol++) + { + float& pivot = subMat[pivotRow][pivotCol]; + + if(almostEquals(pivot, 0.0f)) + { + unsigned int nextValidPivotRowIdx = findNextValidPivotRowIdx(pivotCol); + if(nextValidPivotRowIdx == std::numeric_limits::max()) + { + // No valid pivot down this column, which means that this pivot remains a zero. + // This results in the determinant for this entire sub-matrix to just be zero. + determinant = 0.0f; + break; + } + swapRows(pivotRow, nextValidPivotRowIdx); + } + determinant *= pivot; + // The actual elimination bit (which will update/propagate to the pivots down the line) + eliminate(pivot, pivotRow); // Synonymous with pivotCol + } + + determinant *= swapMultiplier; + break; + } + } + float cofactor = minorMultiplier * determinant; + SetValueAt(cofactor, type, curIdx); + }; + + auto startIdx = std::vector(inputInfo.GetNumDimensions(), 0); + RecurseTensor(inputInfo, + cofactorOperation, + startIdx, + 0); + + Transpose(type); +} +void BatchMatMul::RecurseTensor(const TensorInfo& tensorInfo, + const std::function&)>& operation, + std::vector& curIdx, + unsigned int curDim) +{ + if(!(curDim < tensorInfo.GetNumDimensions())) + { + // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) + operation(curIdx); return; } - for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++) + for(unsigned int i = 0; i < tensorInfo.GetShape()[curDim]; i++) { curIdx[curDim] = i; - RecurseBMM(curIdx, curDim+1); + RecurseTensor(tensorInfo, + operation, + curIdx, + curDim + 1); } } -void BatchMatMul::AdjustAxesToMulForUnequalRanks( - std::pair, std::pair>& axesToMul) +void BatchMatMul::AdjustAxesToMulForUnequalRanks(std::pair& axesXToMul, + std::pair& axesYToMul) { int rankDiff = static_cast(inputXInfo.GetNumDimensions()) - static_cast(inputYInfo.GetNumDimensions()); @@ -82,18 +336,18 @@ void BatchMatMul::AdjustAxesToMulForUnequalRanks( else if(rankDiff < 0) { // Y is the larger one - axesToMul.first.first += static_cast::type>(std::abs(rankDiff)); - axesToMul.first.second += static_cast::type>(std::abs(rankDiff)); + axesXToMul.first += static_cast::type>(std::abs(rankDiff)); + axesXToMul.second += static_cast::type>(std::abs(rankDiff)); } else if(rankDiff > 0) { // X is the larger one - axesToMul.second.first += static_cast::type>(std::abs(rankDiff)); - axesToMul.second.second += static_cast::type>(std::abs(rankDiff)); + axesYToMul.first += static_cast::type>(std::abs(rankDiff)); + axesYToMul.second += static_cast::type>(std::abs(rankDiff)); } } -float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) +float BatchMatMul::GetValueAt(DataSlot type, std::vector idx, const std::vector& customData) { // This gets the data from the input vector that we have, Not the decoder // But for the output, it is operating on the encoder itself @@ -101,14 +355,13 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) AdjustToSafeIdx(type, idx); unsigned int flatIdx = CalcFlatIdx(type, idx); float value = 0.0f; - switch(type) { case DataSlot::InputX: - value = inputXData[flatIdx]; + value = customData.empty() ? inputXData[flatIdx] : customData[flatIdx]; break; case DataSlot::InputY: - value = inputYData[flatIdx]; + value = customData.empty() ? inputYData[flatIdx] : customData[flatIdx]; break; case DataSlot::Output: outputEncoder[flatIdx]; @@ -124,9 +377,7 @@ float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector idx) { AdjustToSafeIdx(type, idx); - unsigned int flatIdx = CalcFlatIdx(type, idx); - switch(type) { case DataSlot::InputX: @@ -186,9 +437,7 @@ void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector& idx) unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector& idx) { unsigned int result = idx[idx.size()-1]; - unsigned int dimMultiplier = 1; - unsigned int offset; // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x) @@ -215,17 +464,4 @@ unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector -std::string BatchMatMul::StringifyVec(const std::vector& vec) -{ - std::string res = "{ "; - for(auto x : vec) - { - res += std::to_string(x); - res += " "; - } - res += "}"; - return res; -} - } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp index 25b6c85d77..19971a4af3 100644 --- a/src/backends/reference/workloads/BatchMatMulImpl.hpp +++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp @@ -15,6 +15,15 @@ namespace armnn class BatchMatMul { public: + BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder& inputXDecoder, + Decoder& inputYDecoder, + Encoder& outputEncoder); + +private: enum DataSlot { InputX = 0, @@ -22,31 +31,35 @@ public: Output = 2 }; - BatchMatMul(const BatchMatMulDescriptor& params, - const TensorInfo& inputXInfo, - const TensorInfo& inputYInfo, - const TensorInfo& outputInfo, - Decoder& inputXDecoder, - Decoder& inputYDecoder, - Encoder& outputEncoder) - : params(params), - inputXInfo(inputXInfo), - inputYInfo(inputYInfo), - outputInfo(outputInfo), - inputXDecoder(inputXDecoder), - inputYDecoder(inputYDecoder), - outputEncoder(outputEncoder) - {} + const BatchMatMulDescriptor& params; + TensorInfo inputXInfo; + TensorInfo inputYInfo; + TensorInfo outputInfo; + Decoder& inputXDecoder; + Decoder& inputYDecoder; + Encoder& outputEncoder; - void BatchMatMulImpl(); + std::vector inputXData; + std::vector inputYData; + + void ApplyBatchMatMul(); + + void ApplyParams(); + + void Transpose(DataSlot type); - void RecurseBMM(std::vector& curIdx, unsigned int curDim); + void Adjoint(DataSlot type); + + void RecurseTensor(const TensorInfo& tensorInfo, + std::function&)> const& operation, + std::vector& curIdx, + unsigned int curDim); // Adjusts it for when input tensors are of unequal rank - void AdjustAxesToMulForUnequalRanks( - std::pair, std::pair>& axesToMul); + void AdjustAxesToMulForUnequalRanks(std::pair& axesXToMul, + std::pair& axesYToMul); - float GetValueAt(DataSlot type, std::vector idx); + float GetValueAt(DataSlot type, std::vector idx, const std::vector& customData = {}); void SetValueAt(float value, DataSlot type, std::vector idx); @@ -54,22 +67,6 @@ public: void AdjustToSafeIdx(DataSlot type, std::vector& idx); unsigned int CalcFlatIdx(DataSlot type, const std::vector& idx); - - template - std::string StringifyVec(const std::vector& vec); - -private: - const BatchMatMulDescriptor& params; - const TensorInfo& inputXInfo; - const TensorInfo& inputYInfo; - const TensorInfo& outputInfo; - Decoder& inputXDecoder; - Decoder& inputYDecoder; - Encoder& outputEncoder; - - std::vector inputXData; - std::vector inputYData; - }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp index 388190c4ef..027b93b5d9 100644 --- a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp +++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp @@ -51,9 +51,6 @@ void RefBatchMatMulWorkload::Execute(std::vector inputs, std::ve *inputXDecoder, *inputYDecoder, *outputEncoder); - - bmm.BatchMatMulImpl(); - } } // namespace armnn \ No newline at end of file -- cgit v1.2.1