aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSamuel Yap <samuel.yap@arm.com>2022-07-06 15:36:03 +0100
committerSamuel Yap <samuel.yap@arm.com>2022-07-22 16:52:38 +0100
commit4b7a34dd92eb3f736e05ac6623fd147ecd8636b1 (patch)
treec33e5820f89e359c80d8773288e8adb075735039
parent16929a2b432232f7a34fcbd1f1b0fe1212500206 (diff)
downloadarmnn-4b7a34dd92eb3f736e05ac6623fd147ecd8636b1.tar.gz
IVGCVSW-7109: Add Batch MatMul front end support - Reference
* Descriptors added for BatchMatMul * Layer definition added * Input validation added (will likely change when opt. param support comes in) * Ref workload implementation for BatchMatMul added (will also change with opt. param support) * Ref layer tests made for BatchMatMul * CMake and other build files updated Signed-off-by: Samuel Yap <samuel.yap@arm.com> Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617
-rw-r--r--Android.mk1
-rw-r--r--CMakeLists.txt2
-rw-r--r--docs/02_operator_list.dox42
-rw-r--r--include/armnn/BackendHelper.hpp6
-rw-r--r--include/armnn/Descriptors.hpp54
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/INetwork.hpp6
-rw-r--r--include/armnn/Types.hpp3
-rw-r--r--include/armnn/backends/WorkloadData.hpp5
-rw-r--r--src/armnn/BackendHelper.cpp16
-rw-r--r--src/armnn/Descriptors.cpp82
-rw-r--r--src/armnn/ILayerSupport.cpp2
-rw-r--r--src/armnn/LayersFwd.hpp4
-rw-r--r--src/armnn/Network.cpp11
-rw-r--r--src/armnn/Network.hpp3
-rw-r--r--src/armnn/layers/BatchMatMulLayer.cpp97
-rw-r--r--src/armnn/layers/BatchMatMulLayer.hpp46
-rw-r--r--src/backends/backendsCommon/LayerSupportRules.hpp8
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp227
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp16
-rw-r--r--src/backends/backendsCommon/common.mk1
-rw-r--r--src/backends/backendsCommon/test/CMakeLists.txt2
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp2
-rw-r--r--src/backends/backendsCommon/test/LayerTests.hpp1
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp1010
-rw-r--r--src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp85
-rw-r--r--src/backends/reference/RefLayerSupport.cpp52
-rw-r--r--src/backends/reference/RefLayerSupport.hpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp5
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp71
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.cpp230
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.hpp75
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefBatchMatMulWorkload.cpp59
-rw-r--r--src/backends/reference/workloads/RefBatchMatMulWorkload.hpp30
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
37 files changed, 2265 insertions, 3 deletions
diff --git a/Android.mk b/Android.mk
index 2d291a44b8..74a6deeb8f 100644
--- a/Android.mk
+++ b/Android.mk
@@ -208,6 +208,7 @@ LOCAL_SRC_FILES := \
src/armnn/layers/ActivationLayer.cpp \
src/armnn/layers/AdditionLayer.cpp \
src/armnn/layers/ArgMinMaxLayer.cpp \
+ src/armnn/layers/BatchMatMulLayer.cpp \
src/armnn/layers/BatchNormalizationLayer.cpp \
src/armnn/layers/BatchToSpaceNdLayer.cpp \
src/armnn/layers/CastLayer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f0eb81cc6c..1d8ebe2952 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -187,6 +187,8 @@ list(APPEND armnn_sources
src/armnn/layers/AdditionLayer.cpp
src/armnn/layers/ArgMinMaxLayer.hpp
src/armnn/layers/ArgMinMaxLayer.cpp
+ src/armnn/layers/BatchMatMulLayer.hpp
+ src/armnn/layers/BatchMatMulLayer.cpp
src/armnn/layers/BatchNormalizationLayer.hpp
src/armnn/layers/BatchNormalizationLayer.cpp
src/armnn/layers/BatchToSpaceNdLayer.hpp
diff --git a/docs/02_operator_list.dox b/docs/02_operator_list.dox
index 960428999e..658aa07d1d 100644
--- a/docs/02_operator_list.dox
+++ b/docs/02_operator_list.dox
@@ -268,6 +268,48 @@ where N = batches, C = channels, H = height, W = width
<tr><td>FLOAT32
</table>
<tr>
+ <td rowspan="3">BatchMatMulLayer
+ <td rowspan="3" style="width:200px;"> Layer to perform batch matrix multiplication.
+ <td rowspan="3">
+ <ul>
+ <li>N/A
+ </ul>
+ <td>CpuRef
+ <td>
+ <ul>
+ <li>All
+ </ul>
+ <td>
+ <table>
+ <tr><th>
+ <tr><td>BFLOAT16
+ <tr><td>FLOAT16
+ <tr><td>FLOAT32
+ <tr><td>QASYMMS8
+ <tr><td>QASYMMU8
+ <tr><td>QSYMMS16
+ </table>
+<tr>
+ <td>CpuAcc
+ <td>
+ <ul>
+ <li>N/A
+ </ul>
+ <td>
+ <ul>
+ <li>N/A
+ </ul>
+<tr>
+ <td>GpuAcc
+ <td>
+ <ul>
+ <li>N/A
+ </ul>
+ <td>
+ <ul>
+ <li>N/A
+ </ul>
+<tr>
<td rowspan="3">BatchNormalizationLayer
<td rowspan="3" style="width:200px;"> Layer to perform batch normalization.
<td rowspan="3">
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index 09c7385d5c..f78b4f80b9 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -43,6 +43,12 @@ public:
const ArgMinMaxDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+ bool IsBatchMatMulSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
bool IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 628d045529..38e3c61500 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1550,4 +1550,58 @@ struct ChannelShuffleDescriptor : BaseDescriptor
uint32_t m_Axis;
};
+/// A BatchMatMulDescriptor for the BatchMatMul operator
+struct BatchMatMulDescriptor : BaseDescriptor
+{
+ BatchMatMulDescriptor(Optional<DataLayout> dataLayoutX = EmptyOptional(),
+ Optional<DataLayout> dataLayoutY = EmptyOptional(),
+ std::vector<unsigned int> transposeX = {},
+ std::vector<unsigned int> transposeY = {},
+ std::vector<unsigned int> adjointX = {},
+ std::vector<unsigned int> adjointY = {})
+ : m_DataLayoutX(dataLayoutX)
+ , m_DataLayoutY(dataLayoutY)
+ , m_TransposeX(transposeX)
+ , m_TransposeY(transposeY)
+ , m_AdjointX(adjointX)
+ , m_AdjointY(adjointY)
+ {}
+
+ bool operator ==(const BatchMatMulDescriptor &rhs) const
+ {
+ return m_DataLayoutX == rhs.m_DataLayoutX &&
+ m_DataLayoutY == rhs.m_DataLayoutY &&
+ m_TransposeX == rhs.m_TransposeX &&
+ m_TransposeY == rhs.m_TransposeY &&
+ m_AdjointX == rhs.m_AdjointX &&
+ m_AdjointY == rhs.m_AdjointY;
+ }
+
+ /// Data layout of each input tensor, such as NHWC/NDHWC (or leave as EmptyOptional for arbitrary layout)
+ Optional<DataLayout> m_DataLayoutX;
+ Optional<DataLayout> m_DataLayoutY;
+
+ /// Transpose vector for each input tensor (leave as empty vector for no pre-transposing)
+ /// Transpose and Adjoint can not both be set to true for the same tensor at the same time
+ std::vector<unsigned int> m_TransposeX;
+ std::vector<unsigned int> m_TransposeY;
+
+ /// Adjoint vector for each input tensor (leave as empty vector for no pre-adjoint)
+ /// Transpose and Adjoint can not both be set to true for the same tensor at the same time
+ std::vector<unsigned int> m_AdjointX;
+ std::vector<unsigned int> m_AdjointY;
+
+ /// Static helper to get the two axes (for each input) for multiplication
+ static std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>> GetAxesToMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& tensorXShape,
+ const TensorShape& tensorYShape);
+
+ /// Static helper to get the axes (for each input) that will not be multiplied together
+ static std::pair<std::vector<unsigned int>, std::vector<unsigned int>> GetAxesNotMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& inputXShape,
+ const TensorShape& inputYShape);
+};
+
} // namespace armnn
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index ab6c7d235a..c0c1cc238d 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -11,6 +11,7 @@ struct BaseDescriptor;
struct ActivationDescriptor;
struct ArgMinMaxDescriptor;
+struct BatchMatMulDescriptor;
struct BatchNormalizationDescriptor;
struct BatchToSpaceNdDescriptor;
struct ChannelShuffleDescriptor;
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 3d4be1a7fa..349c7e87b5 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -752,6 +752,12 @@ public:
IConnectableLayer* AddChannelShuffleLayer(const ChannelShuffleDescriptor& descriptor,
const char* name = nullptr);
+ /// Add a BatchMatMul layer to the network
+ /// @param descriptor - Parameters for the BatchMatMul operation
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer
+ IConnectableLayer* AddBatchMatMulLayer(const BatchMatMulDescriptor& descriptor,
+ const char* name = nullptr);
void ExecuteStrategy(IStrategy& strategy) const;
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index af75513638..98229df07f 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -458,7 +458,8 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(ChannelShuffle) \
X(Convolution3d) \
X(Pooling3d) \
- X(GatherNd)\
+ X(GatherNd) \
+ X(BatchMatMul) \
// New layers should be added at last to minimize instability.
diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp
index 1a2f34e21f..00962ed52c 100644
--- a/include/armnn/backends/WorkloadData.hpp
+++ b/include/armnn/backends/WorkloadData.hpp
@@ -785,4 +785,9 @@ struct ChannelShuffleQueueDescriptor : QueueDescriptorWithParameters<ChannelShuf
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct BatchMatMulQueueDescriptor : QueueDescriptorWithParameters<BatchMatMulDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} // namespace armnn
diff --git a/src/armnn/BackendHelper.cpp b/src/armnn/BackendHelper.cpp
index 5b5bece783..6638709d6f 100644
--- a/src/armnn/BackendHelper.cpp
+++ b/src/armnn/BackendHelper.cpp
@@ -179,6 +179,22 @@ bool LayerSupportHandle::IsArgMinMaxSupported(const TensorInfo& input,
reasonIfUnsupported);
}
+bool LayerSupportHandle::IsBatchMatMulSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported)
+{
+ TensorInfos infos{input0, input1, output};
+
+ return m_LayerSupport->IsLayerSupported(LayerType::BatchMatMul,
+ infos,
+ descriptor,
+ EmptyOptional(),
+ EmptyOptional(),
+ reasonIfUnsupported);
+}
+
bool LayerSupportHandle::IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/armnn/Descriptors.cpp b/src/armnn/Descriptors.cpp
index c740fd03ad..f9576271d5 100644
--- a/src/armnn/Descriptors.cpp
+++ b/src/armnn/Descriptors.cpp
@@ -455,4 +455,86 @@ uint32_t DepthwiseConvolution2dDescriptor::GetNumInputs() const
return armnn::GetNumInputs(m_BiasEnabled);
}
+std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>
+BatchMatMulDescriptor::GetAxesToMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& tensorXShape,
+ const TensorShape& tensorYShape)
+{
+ // May refactor to just work on one input per call - makes it less confusing and also
+ // allows more flexibility (i.e. in Layer output shape inference)
+
+ auto xNumDims = tensorXShape.GetNumDimensions();
+ auto yNumDims = tensorYShape.GetNumDimensions();
+
+ std::pair<unsigned int, unsigned int> xAxes = { xNumDims-2, xNumDims-1 };
+ std::pair<unsigned int, unsigned int> yAxes = { yNumDims-2, yNumDims-1 };
+
+ if(desc.m_DataLayoutX.has_value())
+ {
+ switch(desc.m_DataLayoutX.value())
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ xAxes.first -= 1;
+ xAxes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ }
+
+ if(desc.m_DataLayoutY.has_value())
+ {
+ switch(desc.m_DataLayoutY.value())
+ {
+ case DataLayout::NDHWC:
+ case DataLayout::NHWC:
+ yAxes.first -= 1;
+ yAxes.second -= 1;
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NCHW:
+ default:
+ break;
+ }
+ }
+
+ return { xAxes, yAxes};
+}
+
+std::pair<std::vector<unsigned int>, std::vector<unsigned int>> BatchMatMulDescriptor::GetAxesNotMul(
+ const BatchMatMulDescriptor& desc,
+ const TensorShape& inputXShape,
+ const TensorShape& inputYShape)
+{
+ // May refactor to just work on one input per call - makes it less confusing and also
+ // allows more flexibility (i.e. in Layer output shape inference)
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(desc, inputXShape, inputYShape);
+
+ std::vector<unsigned int> axesXNotMul;
+ std::vector<unsigned int> axesYNotMul;
+
+ for(unsigned int i = 0; i < inputXShape.GetNumDimensions(); i++)
+ {
+ if(i == axesToMul.first.first || i == axesToMul.first.second)
+ {
+ continue;
+ }
+ axesXNotMul.push_back(i);
+ }
+ for(unsigned int i = 0; i < inputYShape.GetNumDimensions(); i++)
+ {
+ if(i == axesToMul.second.first || i == axesToMul.second.second)
+ {
+ continue;
+ }
+ axesYNotMul.push_back(i);
+ }
+
+ return { axesXNotMul, axesYNotMul };
+}
+
}
diff --git a/src/armnn/ILayerSupport.cpp b/src/armnn/ILayerSupport.cpp
index 5366b13088..8099782750 100644
--- a/src/armnn/ILayerSupport.cpp
+++ b/src/armnn/ILayerSupport.cpp
@@ -13,7 +13,7 @@ namespace armnn
{
ARMNN_NO_DEPRECATE_WARN_BEGIN
-// IsLayerSupport() forwards to the deprecated virtual methods depending on input LayerType.
+// IsLayerSupported() forwards to the deprecated virtual methods depending on input LayerType.
// Allows backends continue to behave as before maintaining backward compatibility.
bool ILayerSupport::IsLayerSupported(const LayerType& type,
const std::vector<TensorInfo>& infos,
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index dcfb91b65a..acac1f9988 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -9,6 +9,7 @@
#include "layers/ActivationLayer.hpp"
#include "layers/AdditionLayer.hpp"
#include "layers/ArgMinMaxLayer.hpp"
+#include "layers/BatchMatMulLayer.hpp"
#include "layers/BatchNormalizationLayer.hpp"
#include "layers/BatchToSpaceNdLayer.hpp"
#include "layers/CastLayer.hpp"
@@ -110,6 +111,7 @@ constexpr LayerType LayerEnumOf(const T* = nullptr);
DECLARE_LAYER(Activation)
DECLARE_LAYER(Addition)
DECLARE_LAYER(ArgMinMax)
+DECLARE_LAYER(BatchMatMul)
DECLARE_LAYER(BatchNormalization)
DECLARE_LAYER(BatchToSpaceNd)
DECLARE_LAYER(Cast)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 5d443068ce..ef9f4e7522 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -456,6 +456,12 @@ IConnectableLayer* INetwork::AddChannelShuffleLayer(const ChannelShuffleDescript
return pNetworkImpl->AddChannelShuffleLayer(descriptor, name);
}
+IConnectableLayer* INetwork::AddBatchMatMulLayer(const BatchMatMulDescriptor &descriptor,
+ const char* name)
+{
+ return pNetworkImpl->AddBatchMatMulLayer(descriptor, name);
+}
+
void INetwork::ExecuteStrategy(IStrategy& strategy) const
{
return pNetworkImpl->ExecuteStrategy(strategy);
@@ -2876,6 +2882,11 @@ IConnectableLayer* NetworkImpl::AddUnidirectionalSequenceLstmLayer(
return layer;
}
+IConnectableLayer* NetworkImpl::AddBatchMatMulLayer(const BatchMatMulDescriptor& desc, const char* name)
+{
+ return m_Graph->AddLayer<BatchMatMulLayer>(desc, name);
+}
+
IConnectableLayer* NetworkImpl::AddPrecompiledLayer(const PreCompiledDescriptor& preCompiledDescriptor,
CompiledBlobPtr compiledBlobPtr,
const Optional<BackendId>& backend,
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index a4387e65c0..19a0286e95 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -49,6 +49,9 @@ public:
IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
const char* name = nullptr);
+ IConnectableLayer* AddBatchMatMulLayer(const BatchMatMulDescriptor& desc,
+ const char* name = nullptr);
+
IConnectableLayer* AddBatchNormalizationLayer(const BatchNormalizationDescriptor& desc,
const ConstTensor& mean,
const ConstTensor& variance,
diff --git a/src/armnn/layers/BatchMatMulLayer.cpp b/src/armnn/layers/BatchMatMulLayer.cpp
new file mode 100644
index 0000000000..501de2d091
--- /dev/null
+++ b/src/armnn/layers/BatchMatMulLayer.cpp
@@ -0,0 +1,97 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#include "BatchMatMulLayer.hpp"
+
+#include <armnn/backends/WorkloadFactory.hpp>
+#include "layers/LayerCloneBase.hpp"
+
+namespace armnn
+{
+
+BatchMatMulLayer::BatchMatMulLayer(const BatchMatMulDescriptor& param, const char* name)
+ : LayerWithParameters(2, 1, LayerType::BatchMatMul, param, name)
+{}
+
+std::unique_ptr<IWorkload> BatchMatMulLayer::CreateWorkload(const IWorkloadFactory& factory) const
+{
+ BatchMatMulQueueDescriptor descriptor;
+ SetAdditionalInfo(descriptor);
+
+ return factory.CreateWorkload(LayerType::BatchMatMul, descriptor, PrepInfoAndDesc(descriptor));
+}
+
+BatchMatMulLayer* BatchMatMulLayer::Clone(Graph& graph) const
+{
+ auto layer = CloneBase<BatchMatMulLayer>(graph, m_Param, GetName());
+
+ return std::move(layer);
+}
+
+std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ ARMNN_ASSERT(inputShapes.size() == 2);
+
+ TensorShape inputXShape = inputShapes[0];
+ TensorShape inputYShape = inputShapes[1];
+
+ // Note: Take into account what pre-adjoint or pre-transposing will do to the inferred output shape
+
+ TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
+ inputXShape:inputYShape;
+ TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
+ inputYShape:inputXShape;
+
+ unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();
+
+ unsigned int outputNumDimensions = longerInput.GetNumDimensions();
+
+ std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);
+
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Param, inputXShape, inputYShape);
+ const auto& longerAxesToMul = (axesToMul.first.first >= axesToMul.second.first &&
+ axesToMul.first.second >= axesToMul.second.second) ?
+ axesToMul.first : axesToMul.second;
+
+ for (unsigned int i = 0; i < outputNumDimensions; ++i)
+ {
+ if (i == longerAxesToMul.first)
+ {
+ tensorDimensions[i] = &shorterInput == &inputXShape ? inputXShape[i - inputNumDimsOffset] : inputXShape[i];
+ }
+ else if(i == longerAxesToMul.second)
+ {
+ tensorDimensions[i] = &shorterInput == &inputYShape ? inputYShape[i - inputNumDimsOffset] : inputYShape[i];
+ }
+ else // The other dimensions not to be multiplied (but may be broadcasted)
+ {
+ // Does NOT validate whether it's a valid broadcast - that's done in the validate func in WorkloadData.cpp
+ tensorDimensions[i] = static_cast<int>(i) - static_cast<int>(inputNumDimsOffset) < 0 ?
+ longerInput[i] :
+ std::max(longerInput[i], shorterInput[i - inputNumDimsOffset]);
+ }
+ }
+
+ auto outputShape = TensorShape(outputNumDimensions, tensorDimensions.data());
+ return std::vector<TensorShape>({ outputShape });
+}
+
+void BatchMatMulLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
+
+ VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
+
+ auto inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
+
+ ARMNN_ASSERT(inferredShapes.size() == 1);
+
+ ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "BatchMatMulLayer");
+}
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/armnn/layers/BatchMatMulLayer.hpp b/src/armnn/layers/BatchMatMulLayer.hpp
new file mode 100644
index 0000000000..8dc79d33c4
--- /dev/null
+++ b/src/armnn/layers/BatchMatMulLayer.hpp
@@ -0,0 +1,46 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "LayerWithParameters.hpp"
+
+namespace armnn
+{
+
+class BatchMatMulLayer : public LayerWithParameters<BatchMatMulDescriptor>
+{
+public:
+ /// Makes a workload for the BatchMatMul type.
+ /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] factory The workload factory which will create the workload.
+ /// @return A pointer to the created workload, or nullptr if not created.
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory &factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer.
+ /// @param [in] graph The graph into which this layer is being cloned.
+ BatchMatMulLayer* Clone(Graph &graph) const override;
+
+ /// Infers the output shape from the given input shapes.
+ /// @param [in] inputShapes The vector of input shapes for BatchMatMul.
+ /// @return A vector of inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+ /// Check if the input tensor shapes
+ /// will lead to a valid configuration of @ref BatchMatMulLayer.
+ /// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validated.
+ void ValidateTensorShapesFromInputs() override;
+
+protected:
+ /// Constructor to create a BatchMatMulLayer.
+ /// @param [in] param BatchMatMulDescriptor to configure optional parameters for batch matrix multiplication
+ /// @param [in] name Optional name for the layer
+ BatchMatMulLayer(const BatchMatMulDescriptor& param, const char* name);
+
+ /// Default destructor
+ ~BatchMatMulLayer() = default;
+};
+
+} \ No newline at end of file
diff --git a/src/backends/backendsCommon/LayerSupportRules.hpp b/src/backends/backendsCommon/LayerSupportRules.hpp
index e616ecf022..a83fd62867 100644
--- a/src/backends/backendsCommon/LayerSupportRules.hpp
+++ b/src/backends/backendsCommon/LayerSupportRules.hpp
@@ -186,4 +186,12 @@ struct TensorNumDimensionsAreCorrect : public Rule
}
};
+struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
+{
+ TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
+ {
+ m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
+ }
+};
+
} //namespace armnn \ No newline at end of file
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 606821b5e5..9a4c60f551 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -4143,5 +4143,232 @@ void UnidirectionalSequenceLstmQueueDescriptor::Validate(const WorkloadInfo& wor
}
}
+void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"BatchMatMulDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 2);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ // Inputs must be: both 2D+
+ // For inputs X and Y whose dimensions to be multiplied are (M,N) and (I,J) respectively,
+ // axes N and I must be the same size
+
+ const auto& inputTensorXInfo = workloadInfo.m_InputTensorInfos[0];
+ const auto& inputTensorYInfo = workloadInfo.m_InputTensorInfos[1];
+ const auto& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::BFloat16,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ ValidateDataTypes(inputTensorXInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(inputTensorYInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+
+ if ((inputTensorXInfo.GetNumDimensions() < 2) ||
+ (inputTensorYInfo.GetNumDimensions() < 2))
+ {
+ throw InvalidArgumentException(descriptorName + ": Input tensors are not 2D or greater.");
+ }
+
+ if(m_Parameters.m_DataLayoutX.has_value())
+ {
+ switch(m_Parameters.m_DataLayoutX.value())
+ {
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ if(inputTensorXInfo.GetNumDimensions() != 4)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor X does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputTensorXInfo.GetNumDimensions() != 5)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor X does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ if(m_Parameters.m_DataLayoutY.has_value())
+ {
+ switch(m_Parameters.m_DataLayoutY.value())
+ {
+ case DataLayout::NCHW:
+ case DataLayout::NHWC:
+ if(inputTensorYInfo.GetNumDimensions() != 4)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor Y does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ case DataLayout::NCDHW:
+ case DataLayout::NDHWC:
+ if(inputTensorYInfo.GetNumDimensions() != 5)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Input tensor Y does not have the correct "
+ "number of dimensions for the Data Layout that it has been assigned.");
+ }
+ break;
+ default:
+ break;
+ }
+ }
+
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(m_Parameters,
+ inputTensorXInfo.GetShape(),
+ inputTensorYInfo.GetShape());
+
+ if(inputTensorXInfo.GetShape()[axesToMul.first.second]
+ != inputTensorYInfo.GetShape()[axesToMul.second.first])
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": The final axis of input tensor X must be the same size as "
+ "the second last axis of input tensor Y.");
+ }
+
+ auto axesNotMul = BatchMatMulDescriptor::GetAxesNotMul(m_Parameters,
+ inputTensorXInfo.GetShape(),
+ inputTensorYInfo.GetShape());
+
+ { // Separate scope so we don't pollute the rest of the scope with our temp variables
+ // e.g. NHWC isnt compatible with NCHW as of now
+ DataLayout xLayout;
+ DataLayout yLayout;
+
+ if(m_Parameters.m_DataLayoutX == EmptyOptional())
+ {
+ xLayout = DataLayout::NCHW; // Not equivalent - I'm just concerned with the last 2 axes
+ }
+ else
+ {
+ xLayout = m_Parameters.m_DataLayoutX.value();
+ }
+
+ if(m_Parameters.m_DataLayoutY == EmptyOptional())
+ {
+ yLayout = DataLayout::NCHW;
+ }
+ else
+ {
+ yLayout = m_Parameters.m_DataLayoutY.value();
+ }
+
+ if(xLayout == DataLayout::NCHW || xLayout == DataLayout::NCDHW)
+ {
+ if(yLayout == DataLayout::NHWC || yLayout == DataLayout::NDHWC)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid input tensor data layout combination.");
+ }
+ }
+ if(yLayout == DataLayout::NCHW || yLayout == DataLayout::NCDHW)
+ {
+ if(xLayout == DataLayout::NHWC || xLayout == DataLayout::NDHWC)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid input tensor data layout combination.");
+ }
+ }
+ }
+
+ // Simulate aligning the ends of the matrix dims and prepending 1's to the beginning of the shorter one
+ unsigned int outputTensorDimSize = std::max(inputTensorXInfo.GetNumDimensions(),
+ inputTensorYInfo.GetNumDimensions());
+ if(outputTensorDimSize-2 > 0)
+ {
+ TensorInfo tiXNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+ DataType::Float32);
+ TensorInfo tiYNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+ DataType::Float32);
+ TensorInfo tiOutNotMul = TensorInfo(TensorShape(outputTensorDimSize-2),
+ DataType::Float32);
+
+ auto doAxisExtension = [&](std::vector<unsigned int> axisIndices, TensorInfo& ti)
+ {
+ auto sizeDiff = (outputTensorDimSize-2) - axisIndices.size();
+
+ for(unsigned int i = 0; i < sizeDiff; i++)
+ {
+ axisIndices.insert(axisIndices.begin(), 1);
+ }
+
+ for(unsigned int i = 0; i < ti.GetNumDimensions(); i++)
+ {
+ ti.GetShape()[i] = inputTensorXInfo.GetShape()[i];
+ }
+ };
+
+ doAxisExtension(axesNotMul.first, tiXNotMul);
+ doAxisExtension(axesNotMul.second, tiYNotMul);
+
+ for(unsigned int i = 0; i < tiOutNotMul.GetNumDimensions(); i++)
+ {
+ tiOutNotMul.GetShape()[i] = std::max(tiXNotMul.GetShape()[i],
+ tiYNotMul.GetShape()[i]);
+ }
+
+ ValidateBroadcastTensorShapesMatch(tiXNotMul,
+ tiYNotMul,
+ tiOutNotMul,
+ descriptorName,
+ "input_X",
+ "input_Y");
+ }
+
+ // Also check descriptor parameter validity
+ // This will eventually be moved to the start of the function as explained below
+ if ((!m_Parameters.m_TransposeX.empty() && !m_Parameters.m_AdjointX.empty()) ||
+ (!m_Parameters.m_TransposeY.empty() && !m_Parameters.m_AdjointY.empty()))
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameters - Transpose and Adjoint "
+ "vectors cannot both be true for a given input tensor.");
+ }
+
+ if(m_Parameters.m_TransposeX.size() != 0 && m_Parameters.m_TransposeX.size() != inputTensorXInfo.GetNumDimensions())
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameter - Transpose X vector must be "
+ "the same size as tensor input X's dimensionality.");
+ }
+ if(m_Parameters.m_AdjointX.size() != 0 && m_Parameters.m_AdjointX.size() != inputTensorXInfo.GetNumDimensions())
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameter - Adjoint X vector must be "
+ "the same size as tensor input X's dimensionality.");
+ }
+ if(m_Parameters.m_TransposeY.size() != 0 && m_Parameters.m_TransposeY.size() != inputTensorYInfo.GetNumDimensions())
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameter - Transpose Y vector must be "
+ "the same size as tensor input Y's dimensionality.");
+ }
+ if(m_Parameters.m_AdjointY.size() != 0 && m_Parameters.m_AdjointY.size() != inputTensorXInfo.GetNumDimensions())
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Invalid descriptor parameter - Adjoint Y vector must be "
+ "the same size as tensor input Y's dimensionality.");
+ }
+ // Note: for adjoint/transpose, you'll need to do the validation atop the resultant permutation.
+}
+
} // namespace armnn \ No newline at end of file
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 3660e6e721..70006e4f79 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -133,6 +133,22 @@ bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::BatchMatMul:
+ {
+ auto cLayer = PolymorphicDowncast<const BatchMatMulLayer*>(&layer);
+ const BatchMatMulDescriptor& descriptor = cLayer->GetParameters();
+
+ const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+ result = layerSupportObject.IsBatchMatMulSupported(
+ OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, dataType),
+ descriptor,
+ reason);
+ break;
+ }
case LayerType::BatchNormalization:
{
auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
diff --git a/src/backends/backendsCommon/common.mk b/src/backends/backendsCommon/common.mk
index 86de7e331b..007cca57fa 100644
--- a/src/backends/backendsCommon/common.mk
+++ b/src/backends/backendsCommon/common.mk
@@ -46,6 +46,7 @@ COMMON_TEST_SOURCES := \
test/layerTests/ActivationTestImpl.cpp \
test/layerTests/AdditionTestImpl.cpp \
test/layerTests/ArgMinMaxTestImpl.cpp \
+ test/layerTests/BatchMatMulTestImpl.cpp \
test/layerTests/BatchNormalizationTestImpl.cpp \
test/layerTests/CastTestImpl.cpp \
test/layerTests/ChannelShuffleTestImpl.cpp \
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 8beb7c4169..c5b97ebf4c 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -68,6 +68,8 @@ list(APPEND armnnBackendsCommonUnitTests_sources
layerTests/AdditionTestImpl.hpp
layerTests/ArgMinMaxTestImpl.cpp
layerTests/ArgMinMaxTestImpl.hpp
+ layerTests/BatchMatMulTestImpl.cpp
+ layerTests/BatchMatMulTestImpl.hpp
layerTests/BatchNormalizationTestImpl.cpp
layerTests/BatchNormalizationTestImpl.hpp
layerTests/BatchToSpaceNdTestImpl.hpp
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index ba8cfd5f68..5fdcd9c57a 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -614,6 +614,8 @@ DECLARE_LAYER_POLICY_1_PARAM(Addition)
DECLARE_LAYER_POLICY_2_PARAM(ArgMinMax)
+DECLARE_LAYER_POLICY_2_PARAM(BatchMatMul)
+
DECLARE_LAYER_POLICY_2_PARAM(BatchNormalization)
DECLARE_LAYER_POLICY_2_PARAM(BatchToSpaceNd)
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 8d73027783..25435b24ec 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -9,6 +9,7 @@
#include <backendsCommon/test/layerTests/ActivationTestImpl.hpp>
#include <backendsCommon/test/layerTests/AdditionTestImpl.hpp>
#include <backendsCommon/test/layerTests/ArgMinMaxTestImpl.hpp>
+#include <backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp>
#include <backendsCommon/test/layerTests/BatchNormalizationTestImpl.hpp>
#include <backendsCommon/test/layerTests/BatchToSpaceNdTestImpl.hpp>
#include <backendsCommon/test/layerTests/CastTestImpl.hpp>
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
new file mode 100644
index 0000000000..41add6e6da
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.cpp
@@ -0,0 +1,1010 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "BatchMatMulTestImpl.hpp"
+
+#include <armnn/backends/IBackendInternal.hpp>
+#include <armnn/backends/Workload.hpp>
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/backends/WorkloadFactory.hpp>
+
+#include <armnnTestUtils/WorkloadTestUtils.hpp>
+#include <armnnUtils/QuantizeHelper.hpp>
+#include <armnnTestUtils/TensorCopyUtils.hpp>
+#include <armnn/Optional.hpp>
+
+
+template<armnn::DataType ArmnnType, typename T, std::size_t NumDims>
+LayerTestResult<T, NumDims> BatchMatMulTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory,
+ armnn::BatchMatMulDescriptor descriptor,
+ const std::vector<T>& inputX,
+ const std::vector<T>& inputY,
+ const std::vector<T>& outputExpected,
+ const armnn::TensorInfo& inputXInfo,
+ const armnn::TensorInfo& inputYInfo,
+ const armnn::TensorInfo& outputInfo)
+{
+ std::vector<T> outputActual(outputInfo.GetNumElements());
+
+ std::unique_ptr<armnn::ITensorHandle> inputXHandle = tensorHandleFactory.CreateTensorHandle(inputXInfo);
+ std::unique_ptr<armnn::ITensorHandle> inputYHandle = tensorHandleFactory.CreateTensorHandle(inputYInfo);
+ std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
+
+ armnn::BatchMatMulQueueDescriptor queueDescriptor;
+ queueDescriptor.m_Parameters = descriptor;
+ armnn::WorkloadInfo workloadInfo;
+
+ AddInputToWorkload(queueDescriptor, workloadInfo, inputXInfo, inputXHandle.get());
+ AddInputToWorkload(queueDescriptor, workloadInfo, inputYInfo, inputYHandle.get());
+ AddOutputToWorkload(queueDescriptor, workloadInfo, outputInfo, outputHandle.get());
+
+ auto workload = workloadFactory.CreateWorkload(armnn::LayerType::BatchMatMul, queueDescriptor, workloadInfo);
+
+ inputXHandle->Allocate();
+ inputYHandle->Allocate();
+ outputHandle->Allocate();
+
+ CopyDataToITensorHandle(inputXHandle.get(), inputX.data());
+ CopyDataToITensorHandle(inputYHandle.get(), inputY.data());
+
+ workload->PostAllocationConfigure();
+ ExecuteWorkload(*workload, memoryManager);
+
+ CopyDataFromITensorHandle(outputActual.data(), outputHandle.get());
+
+ return LayerTestResult<T, NumDims>(outputActual,
+ outputExpected,
+ outputHandle->GetShape(),
+ outputInfo.GetShape());
+}
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5, 6,
+ 7, 8
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 19, 22,
+ 43, 50
+ }, qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({1,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({1,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5, 6,
+ 7, 8
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 19, 22,
+ 43, 50
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(
+ armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW),
+ armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NCHW));
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({1,1,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5, 6,
+ 7, 8
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 19, 22,
+ 43, 50
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNCHWSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(
+ armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC),
+ armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({1,2,2,1}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5, 6,
+ 7, 8
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 19, 22,
+ 43, 50
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 4>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 4>
+BatchMatMulNHWCSimpleTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DBatchTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4,
+
+ 9, 10,
+ 11, 12
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5, 6,
+ 7, 8,
+
+ 13, 14,
+ 15, 16
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 19, 22,
+ 43, 50,
+
+ 267, 286,
+ 323, 346
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DBatchTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({1,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4,
+
+ 9, 10,
+ 11, 12
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 13, 14,
+ 15, 16
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 43, 46,
+ 99, 106,
+
+ 267, 286,
+ 323, 346
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DBroadcastTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 2,
+ 3, 4,
+
+ 9, 10,
+ 11, 12
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 13, 14,
+ 15, 16
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 43, 46,
+ 99, 106,
+
+ 267, 286,
+ 323, 346
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3D2DBroadcastTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(
+ armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NDHWC),
+ armnn::Optional<armnn::DataLayout>(armnn::DataLayout::NHWC));
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,1,2,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({1,2,2,2}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({1,1,2,2,2}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 1, 20,
+ 3, 22,
+
+ 2, 21,
+ 4, 23
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5, 24,
+ 7, 26,
+
+ 6, 25,
+ 8, 27
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 23, 1030,
+ 31, 1114,
+
+ 34, 1079,
+ 46, 1167
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 5>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 5>
+BatchMatMulNDHWCNHWCTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 2> BatchMatMul2DTinyTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({1,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({1,1}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({1,1}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 3
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 5
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 15
+ }, qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 2>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 2>
+BatchMatMul2DTinyTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T>
+LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory)
+{
+ auto descriptor = armnn::BatchMatMulDescriptor(); // Arbitrary layout with no transpose/adjointing
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ switch(ArmnnType)
+ {
+ case armnn::DataType::QAsymmS8:
+ case armnn::DataType::QAsymmU8:
+ case armnn::DataType::QSymmS16:
+ qScale = 1.0f;
+ break;
+ default:
+ break;
+ }
+
+ armnn::TensorInfo inputXInfo({2,5,3}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo inputYInfo({2,3,4}, ArmnnType, qScale, qOffset);
+ armnn::TensorInfo outputInfo({2,5,4}, ArmnnType, qScale, qOffset);
+
+ std::vector<T> inputX = armnnUtils::QuantizedVector<T>({
+ 8, 8, 4,
+ 6, 1, 3,
+ 8, 8, 3,
+ 8, 9, 8,
+ 5, 4, 4,
+
+ 1, 8, 5,
+ 7, 1, 1,
+ 8, 7, 9,
+ 3, 2, 7,
+ 8, 5, 3
+ }, qScale, qOffset);
+
+ std::vector<T> inputY = armnnUtils::QuantizedVector<T>({
+ 6, 2, 3, 2,
+ 6, 2, 2, 8,
+ 3, 7, 8, 1,
+
+ 7, 2, 9, 5,
+ 2, 3, 1, 3,
+ 2, 7, 7, 5
+ }, qScale, qOffset);
+
+ std::vector<T> outputExpected = armnnUtils::QuantizedVector<T>({
+ 108, 60, 72, 84,
+ 51, 35, 44, 23,
+ 105, 53, 64, 83,
+ 126, 90, 106, 96,
+ 66, 46, 55, 46,
+
+ 33, 61, 52, 54,
+ 53, 24, 71, 43,
+ 88, 100, 142, 106,
+ 39, 61, 78, 56,
+ 72, 52, 98, 70
+ },qScale, qOffset);
+
+ return BatchMatMulTestImpl<ArmnnType, T, 3>(workloadFactory,
+ memoryManager,
+ tensorHandleFactory,
+ descriptor,
+ inputX,
+ inputY,
+ outputExpected,
+ inputXInfo,
+ inputYInfo,
+ outputInfo);
+}
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::BFloat16>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::BFloat16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::Float32>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::Float16>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::Float16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::QAsymmS8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmU8>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::QAsymmU8>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template LayerTestResult<armnn::ResolveType<armnn::DataType::QSymmS16>, 3>
+BatchMatMul3DNonSquareTest<armnn::DataType::QSymmS16>(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
new file mode 100644
index 0000000000..9e2139667b
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/BatchMatMulTestImpl.hpp
@@ -0,0 +1,85 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnnTestUtils/LayerTestResult.hpp>
+
+#include <ResolveType.hpp>
+
+#include <armnn/backends/IBackendInternal.hpp>
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>, std::size_t NumDims>
+LayerTestResult<T, NumDims> BatchMatMulTestImpl(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory,
+ armnn::BatchMatMulDescriptor descriptor,
+ const std::vector<T>& inputX,
+ const std::vector<T>& inputY,
+ const std::vector<T>& outputExpected,
+ const armnn::TensorInfo& inputXInfo,
+ const armnn::TensorInfo& inputYInfo,
+ const armnn::TensorInfo& outputInfo);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNCHWSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 4> BatchMatMulNHWCSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DBatchTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DBroadcastTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3D2DBroadcastTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 5> BatchMatMulNDHWCNHWCTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 2> BatchMatMul2DTinyTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory);
+
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
+LayerTestResult<T, 3> BatchMatMul3DNonSquareTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
+ const armnn::ITensorHandleFactory& tensorHandleFactory); \ No newline at end of file
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 8051dcffa0..40909019ba 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -79,6 +79,12 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
infos[1],
*(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
reasonIfUnsupported);
+ case LayerType::BatchMatMul:
+ return IsBatchMatMulSupported(infos[0],
+ infos[1],
+ infos[2],
+ *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::BatchNormalization:
return IsBatchNormalizationSupported(infos[0],
infos[1],
@@ -642,6 +648,52 @@ bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const
return supported;
}
+bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string &> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+
+ std::array<DataType, 6> supportedTypes =
+ {
+ DataType::BFloat16,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ bool supported = true;
+
+ supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
+ "Reference batch matrix multiplication: input X is not a supported type");
+
+ supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
+ "Reference batch matrix multiplication: input Y is not a supported type");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference batch matrix multiplication: output is not a supported type");
+
+ supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
+ "Reference batch matrix multiplication: input X and input Y types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
+ "Reference batch matrix multiplication: inputs and output types are mismatched");
+
+ supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
+ reasonIfUnsupported,
+ "Reference batch matrix multiplication: input X is not of rank 2 or greater");
+
+ supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
+ reasonIfUnsupported,
+ "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
+
+ return supported;
+}
+
bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index aa8bd8dda4..b64244db24 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -34,6 +34,12 @@ public:
const ArgMinMaxDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsBatchMatMulSupported(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string &> reasonIfUnsupported = EmptyOptional()) const;
+
bool IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 2d956582db..093d0d5e20 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -170,6 +170,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type,
auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
return std::make_unique<RefArgMinMaxWorkload>(*argMinMaxQueueDescriptor, info);
}
+ case LayerType::BatchMatMul:
+ {
+ auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor);
+ return std::make_unique<RefBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info);
+ }
case LayerType::BatchNormalization :
{
auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor);
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index d9a5a1d32c..ed942e67cd 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -23,6 +23,7 @@ BACKEND_SOURCES := \
RefTensorHandleFactory.cpp \
workloads/Activation.cpp \
workloads/ArgMinMax.cpp \
+ workloads/BatchMatMulImpl.cpp \
workloads/BatchNormImpl.cpp \
workloads/BatchToSpaceNd.cpp \
workloads/Broadcast.cpp \
@@ -49,6 +50,7 @@ BACKEND_SOURCES := \
workloads/Reduce.cpp \
workloads/RefActivationWorkload.cpp \
workloads/RefArgMinMaxWorkload.cpp \
+ workloads/RefBatchMatMulWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdWorkload.cpp \
workloads/RefCastWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 419ae2b0e9..593dc7851e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1062,6 +1062,77 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1ElementInt32, Multiplicati
ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1DVectorInt32, MultiplicationBroadcast1DVectorInt32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest)
+// Batch Mat Mul
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleBFloat16, BatchMatMul2DSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat16, BatchMatMul2DSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmS8, BatchMatMul2DSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmU8, BatchMatMul2DSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQASymmS16, BatchMatMul2DSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleBFloat16, BatchMatMul3DSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat16, BatchMatMul3DSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmS8, BatchMatMul3DSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmU8, BatchMatMul3DSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQASymmS16, BatchMatMul3DSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleBFloat16, BatchMatMulNCHWSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat16, BatchMatMulNCHWSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmS8, BatchMatMulNCHWSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmU8, BatchMatMulNCHWSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQASymmS16, BatchMatMulNCHWSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleBFloat16, BatchMatMulNHWCSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat32, BatchMatMulNHWCSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat16, BatchMatMulNHWCSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmS8, BatchMatMulNHWCSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmU8, BatchMatMulNHWCSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQASymmS16, BatchMatMulNHWCSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchBFloat16, BatchMatMul3DBatchTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat16, BatchMatMul3DBatchTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmS8, BatchMatMul3DBatchTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmU8, BatchMatMul3DBatchTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQASymmS16, BatchMatMul3DBatchTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastBFloat16, BatchMatMul3DBroadcastTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat16, BatchMatMul3DBroadcastTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmS8, BatchMatMul3DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmU8, BatchMatMul3DBroadcastTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQASymmS16, BatchMatMul3DBroadcastTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastBFloat16, BatchMatMul3D2DBroadcastTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat16, BatchMatMul3D2DBroadcastTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmS8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmU8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQASymmSS16, BatchMatMul3D2DBroadcastTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCBFloat16, BatchMatMulNDHWCNHWCTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat32, BatchMatMulNDHWCNHWCTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat16, BatchMatMulNDHWCNHWCTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmS8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmU8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQASymmSS16, BatchMatMulNDHWCNHWCTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyBFloat16, BatchMatMul2DTinyTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat16, BatchMatMul2DTinyTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmS8, BatchMatMul2DTinyTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmU8, BatchMatMul2DTinyTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQASymmS16, BatchMatMul2DTinyTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareBFloat16, BatchMatMul3DNonSquareTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat32, BatchMatMul3DNonSquareTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat16, BatchMatMul3DNonSquareTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSquareTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>);
+
// Batch Norm
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
new file mode 100644
index 0000000000..74a358cc5c
--- /dev/null
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -0,0 +1,230 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "BatchMatMulImpl.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/Logging.hpp>
+
+namespace armnn
+{
+
+void BatchMatMul::BatchMatMulImpl()
+{
+ inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+ inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+ // At this point, we don't touch the input decoders - just the resultant vectors
+
+ // Pre-transpose and pre-adjoint if their vectors aren't empty
+ // and also DataLayouts which may change with permutations/adjoints
+
+ // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
+
+ auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+ RecurseBMM(idx, 0);
+}
+
+void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+{
+ // We're working off of the indexes of the output tensor (the max possible shape)
+
+ if(!(curDim < outputInfo.GetNumDimensions()))
+ {
+ // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
+ inputXInfo.GetShape(),
+ inputYInfo.GetShape());
+ AdjustAxesToMulForUnequalRanks(axesToMul);
+
+ unsigned int inputXColDim = axesToMul.first.second;
+ unsigned int inputYRowDim = axesToMul.second.first;
+
+ unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+
+ float sum = 0.0f;
+
+ // You could also use inputXColSize
+ for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
+ auto xIdx = curIdx;
+ xIdx[inputXColDim] = inputYRowIdx;
+
+ auto yIdx = curIdx;
+ yIdx[inputYRowDim] = inputYRowIdx;
+
+ sum += (GetValueAt(DataSlot::InputX, xIdx)
+ * GetValueAt(DataSlot::InputY, yIdx));
+ }
+
+ SetValueAt(sum, DataSlot::Output, curIdx);
+
+ return;
+ }
+
+ for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+ {
+ curIdx[curDim] = i;
+ RecurseBMM(curIdx, curDim+1);
+ }
+}
+
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(
+ std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+{
+ long rankDiff = static_cast<long>(inputXInfo.GetNumDimensions()) - inputYInfo.GetNumDimensions();
+ if(rankDiff == 0)
+ {
+ return;
+ }
+ else if(rankDiff < 0)
+ {
+ // Y is the larger one
+ axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ }
+ else if(rankDiff > 0)
+ {
+ // X is the larger one
+ axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ }
+}
+
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+{
+ // This gets the data from the input vector that we have, Not the decoder
+ // But for the output, it is operating on the encoder itself
+
+ AdjustToSafeIdx(type, idx);
+ unsigned int flatIdx = CalcFlatIdx(type, idx);
+ float value = 0.0f;
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ value = inputXData[flatIdx];
+ break;
+ case DataSlot::InputY:
+ value = inputYData[flatIdx];
+ break;
+ case DataSlot::Output:
+ outputEncoder[flatIdx];
+ value = outputEncoder.Get();
+ break;
+ default:
+ break;
+ }
+
+ return value;
+}
+
+void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
+{
+ AdjustToSafeIdx(type, idx);
+
+ unsigned int flatIdx = CalcFlatIdx(type, idx);
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ inputXData[flatIdx] = value;
+ break;
+ case DataSlot::InputY:
+ inputYData[flatIdx] = value;
+ break;
+ case DataSlot::Output:
+ outputEncoder[flatIdx];
+ outputEncoder.Set(value);
+ break;
+ default:
+ break;
+ }
+}
+
+void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
+{
+ for(unsigned int dim = 0; dim < idx.size(); dim++)
+ {
+ switch(type)
+ {
+ case DataSlot::InputX:
+ {
+ auto xRank = inputXInfo.GetNumDimensions();
+ auto xDiff = outputInfo.GetNumDimensions() - xRank;
+ if (dim < xDiff ||
+ idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
+ {
+ idx[dim] = 0; // Broadcasting
+ }
+ break;
+ }
+ case DataSlot::InputY:
+ {
+ auto yRank = inputYInfo.GetNumDimensions();
+ auto yDiff = outputInfo.GetNumDimensions() - yRank;
+ if (dim < yDiff ||
+ idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
+ {
+ idx[dim] = 0;
+ }
+ break;
+ }
+ case DataSlot::Output:
+ {
+ // Our indices are based off the output
+ break;
+ }
+ default:
+ break;
+ }
+ }
+}
+
+unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
+{
+ unsigned int result = idx[idx.size()-1];
+
+ unsigned int dimMultiplier = 1;
+
+ unsigned int offset;
+
+ // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
+ for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
+ {
+ switch(type)
+ {
+ case DataSlot::InputX:
+ offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
+ dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
+ break;
+ case DataSlot::InputY:
+ offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
+ dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
+ break;
+ case DataSlot::Output:
+ dimMultiplier *= outputInfo.GetShape()[i+1];
+ break;
+ default:
+ break;
+ }
+ result += (idx[i] * dimMultiplier);
+ }
+ return result;
+}
+
+template <typename T>
+std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
+{
+ std::string res = "{ ";
+ for(auto x : vec)
+ {
+ res += std::to_string(x);
+ res += " ";
+ }
+ res += "}";
+ return res;
+}
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
new file mode 100644
index 0000000000..25b6c85d77
--- /dev/null
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -0,0 +1,75 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class BatchMatMul {
+public:
+ enum DataSlot
+ {
+ InputX = 0,
+ InputY = 1,
+ Output = 2
+ };
+
+ BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder)
+ : params(params),
+ inputXInfo(inputXInfo),
+ inputYInfo(inputYInfo),
+ outputInfo(outputInfo),
+ inputXDecoder(inputXDecoder),
+ inputYDecoder(inputYDecoder),
+ outputEncoder(outputEncoder)
+ {}
+
+ void BatchMatMulImpl();
+
+ void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
+
+ // Adjusts it for when input tensors are of unequal rank
+ void AdjustAxesToMulForUnequalRanks(
+ std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
+
+ float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
+
+ void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
+
+ // Takes into account broadcasting
+ void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
+
+ unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
+
+ template <typename T>
+ std::string StringifyVec(const std::vector<T>& vec);
+
+private:
+ const BatchMatMulDescriptor& params;
+ const TensorInfo& inputXInfo;
+ const TensorInfo& inputYInfo;
+ const TensorInfo& outputInfo;
+ Decoder<float>& inputXDecoder;
+ Decoder<float>& inputYDecoder;
+ Encoder<float>& outputEncoder;
+
+ std::vector<float> inputXData;
+ std::vector<float> inputYData;
+
+};
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index b1f6d8b250..b8835e3cdb 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -10,6 +10,8 @@ list(APPEND armnnRefBackendWorkloads_sources
ArgMinMax.cpp
ArgMinMax.hpp
BaseIterator.hpp
+ BatchMatMulImpl.cpp
+ BatchMatMulImpl.hpp
BatchNormImpl.cpp
BatchNormImpl.hpp
BatchToSpaceNd.cpp
@@ -69,6 +71,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefArgMinMaxWorkload.cpp
RefArgMinMaxWorkload.hpp
RefBaseWorkload.hpp
+ RefBatchMatMulWorkload.cpp
+ RefBatchMatMulWorkload.hpp
RefBatchNormalizationWorkload.cpp
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
new file mode 100644
index 0000000000..388190c4ef
--- /dev/null
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefBatchMatMulWorkload.hpp"
+
+#include "BatchMatMulImpl.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+RefBatchMatMulWorkload::RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : RefBaseWorkload(descriptor, info)
+{}
+
+void RefBatchMatMulWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefBatchMatMulWorkload::ExecuteAsync(ExecutionData& executionData)
+{
+ WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+ Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+}
+
+void RefBatchMatMulWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchMatMulWorkload_Execute");
+
+ const TensorInfo& inputXInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& inputYInfo = GetTensorInfo(inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+ std::unique_ptr<Decoder<float>> inputXDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
+ inputs[0]->Map());
+
+ std::unique_ptr<Decoder<float>> inputYDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]),
+ inputs[1]->Map());
+
+ std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
+ outputs[0]->Map());
+
+ auto bmm = BatchMatMul(m_Data.m_Parameters,
+ inputXInfo,
+ inputYInfo,
+ outputInfo,
+ *inputXDecoder,
+ *inputYDecoder,
+ *outputEncoder);
+
+ bmm.BatchMatMulImpl();
+
+}
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp
new file mode 100644
index 0000000000..e9dfcaef73
--- /dev/null
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp
@@ -0,0 +1,30 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "RefBaseWorkload.hpp"
+#include <armnn/backends/WorkloadData.hpp>
+
+#include "BatchMatMulImpl.hpp"
+
+namespace armnn
+{
+
+class RefBatchMatMulWorkload : public RefBaseWorkload<BatchMatMulQueueDescriptor>
+{
+public:
+ explicit RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
+ void Execute() const override;
+ void ExecuteAsync(ExecutionData& executionData) override;
+
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
+
+};
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index b9c7a2a1fb..e049d8db2c 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -7,6 +7,7 @@
#include "RefActivationWorkload.hpp"
#include "RefArgMinMaxWorkload.hpp"
+#include "RefBatchMatMulWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
#include "RefCastWorkload.hpp"