aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp52
-rw-r--r--src/backends/reference/RefLayerSupport.hpp6
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp5
-rw-r--r--src/backends/reference/backend.mk2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp71
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.cpp230
-rw-r--r--src/backends/reference/workloads/BatchMatMulImpl.hpp75
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt4
-rw-r--r--src/backends/reference/workloads/RefBatchMatMulWorkload.cpp59
-rw-r--r--src/backends/reference/workloads/RefBatchMatMulWorkload.hpp30
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
11 files changed, 535 insertions, 0 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 8051dcffa0..40909019ba 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -79,6 +79,12 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type,
infos[1],
*(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)),
reasonIfUnsupported);
+ case LayerType::BatchMatMul:
+ return IsBatchMatMulSupported(infos[0],
+ infos[1],
+ infos[2],
+ *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)),
+ reasonIfUnsupported);
case LayerType::BatchNormalization:
return IsBatchNormalizationSupported(infos[0],
infos[1],
@@ -642,6 +648,52 @@ bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const
return supported;
}
+bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string &> reasonIfUnsupported) const
+{
+ IgnoreUnused(descriptor);
+
+ std::array<DataType, 6> supportedTypes =
+ {
+ DataType::BFloat16,
+ DataType::Float16,
+ DataType::Float32,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS16
+ };
+
+ bool supported = true;
+
+ supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported,
+ "Reference batch matrix multiplication: input X is not a supported type");
+
+ supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported,
+ "Reference batch matrix multiplication: input Y is not a supported type");
+
+ supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+ "Reference batch matrix multiplication: output is not a supported type");
+
+ supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported,
+ "Reference batch matrix multiplication: input X and input Y types are mismatched");
+
+ supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported,
+ "Reference batch matrix multiplication: inputs and output types are mismatched");
+
+ supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2),
+ reasonIfUnsupported,
+ "Reference batch matrix multiplication: input X is not of rank 2 or greater");
+
+ supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2),
+ reasonIfUnsupported,
+ "Reference batch matrix multiplication: input Y is not of rank 2 or greater");
+
+ return supported;
+}
+
bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index aa8bd8dda4..b64244db24 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -34,6 +34,12 @@ public:
const ArgMinMaxDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsBatchMatMulSupported(const TensorInfo& inputX,
+ const TensorInfo& inputY,
+ const TensorInfo& output,
+ const BatchMatMulDescriptor& descriptor,
+ Optional<std::string &> reasonIfUnsupported = EmptyOptional()) const;
+
bool IsBatchNormalizationSupported(const TensorInfo& input,
const TensorInfo& output,
const TensorInfo& mean,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 2d956582db..093d0d5e20 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -170,6 +170,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type,
auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor);
return std::make_unique<RefArgMinMaxWorkload>(*argMinMaxQueueDescriptor, info);
}
+ case LayerType::BatchMatMul:
+ {
+ auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor);
+ return std::make_unique<RefBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info);
+ }
case LayerType::BatchNormalization :
{
auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor);
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index d9a5a1d32c..ed942e67cd 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -23,6 +23,7 @@ BACKEND_SOURCES := \
RefTensorHandleFactory.cpp \
workloads/Activation.cpp \
workloads/ArgMinMax.cpp \
+ workloads/BatchMatMulImpl.cpp \
workloads/BatchNormImpl.cpp \
workloads/BatchToSpaceNd.cpp \
workloads/Broadcast.cpp \
@@ -49,6 +50,7 @@ BACKEND_SOURCES := \
workloads/Reduce.cpp \
workloads/RefActivationWorkload.cpp \
workloads/RefArgMinMaxWorkload.cpp \
+ workloads/RefBatchMatMulWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdWorkload.cpp \
workloads/RefCastWorkload.cpp \
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index 419ae2b0e9..593dc7851e 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -1062,6 +1062,77 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1ElementInt32, Multiplicati
ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1DVectorInt32, MultiplicationBroadcast1DVectorInt32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest)
+// Batch Mat Mul
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleBFloat16, BatchMatMul2DSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat16, BatchMatMul2DSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmS8, BatchMatMul2DSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmU8, BatchMatMul2DSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQASymmS16, BatchMatMul2DSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleBFloat16, BatchMatMul3DSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat16, BatchMatMul3DSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmS8, BatchMatMul3DSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmU8, BatchMatMul3DSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQASymmS16, BatchMatMul3DSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleBFloat16, BatchMatMulNCHWSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat16, BatchMatMulNCHWSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmS8, BatchMatMulNCHWSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmU8, BatchMatMulNCHWSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQASymmS16, BatchMatMulNCHWSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleBFloat16, BatchMatMulNHWCSimpleTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat32, BatchMatMulNHWCSimpleTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat16, BatchMatMulNHWCSimpleTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmS8, BatchMatMulNHWCSimpleTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmU8, BatchMatMulNHWCSimpleTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQASymmS16, BatchMatMulNHWCSimpleTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchBFloat16, BatchMatMul3DBatchTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat16, BatchMatMul3DBatchTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmS8, BatchMatMul3DBatchTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmU8, BatchMatMul3DBatchTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQASymmS16, BatchMatMul3DBatchTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastBFloat16, BatchMatMul3DBroadcastTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat16, BatchMatMul3DBroadcastTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmS8, BatchMatMul3DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmU8, BatchMatMul3DBroadcastTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQASymmS16, BatchMatMul3DBroadcastTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastBFloat16, BatchMatMul3D2DBroadcastTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat16, BatchMatMul3D2DBroadcastTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmS8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmU8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQASymmSS16, BatchMatMul3D2DBroadcastTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCBFloat16, BatchMatMulNDHWCNHWCTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat32, BatchMatMulNDHWCNHWCTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat16, BatchMatMulNDHWCNHWCTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmS8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmU8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQASymmSS16, BatchMatMulNDHWCNHWCTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyBFloat16, BatchMatMul2DTinyTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat16, BatchMatMul2DTinyTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmS8, BatchMatMul2DTinyTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmU8, BatchMatMul2DTinyTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQASymmS16, BatchMatMul2DTinyTest<DataType::QSymmS16>);
+
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareBFloat16, BatchMatMul3DNonSquareTest<DataType::BFloat16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat32, BatchMatMul3DNonSquareTest<DataType::Float32>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat16, BatchMatMul3DNonSquareTest<DataType::Float16>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSquareTest<DataType::QAsymmS8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>);
+ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>);
+
// Batch Norm
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test)
ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest)
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp
new file mode 100644
index 0000000000..74a358cc5c
--- /dev/null
+++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp
@@ -0,0 +1,230 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "BatchMatMulImpl.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/Logging.hpp>
+
+namespace armnn
+{
+
+void BatchMatMul::BatchMatMulImpl()
+{
+ inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape());
+ inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape());
+ // At this point, we don't touch the input decoders - just the resultant vectors
+
+ // Pre-transpose and pre-adjoint if their vectors aren't empty
+ // and also DataLayouts which may change with permutations/adjoints
+
+ // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes?
+
+ auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0);
+ RecurseBMM(idx, 0);
+}
+
+void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim)
+{
+ // We're working off of the indexes of the output tensor (the max possible shape)
+
+ if(!(curDim < outputInfo.GetNumDimensions()))
+ {
+ // We're at the leaf level of this call tree, so we operate here (each leaf is a data point)
+
+ auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params,
+ inputXInfo.GetShape(),
+ inputYInfo.GetShape());
+ AdjustAxesToMulForUnequalRanks(axesToMul);
+
+ unsigned int inputXColDim = axesToMul.first.second;
+ unsigned int inputYRowDim = axesToMul.second.first;
+
+ unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim];
+
+ float sum = 0.0f;
+
+ // You could also use inputXColSize
+ for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) {
+ auto xIdx = curIdx;
+ xIdx[inputXColDim] = inputYRowIdx;
+
+ auto yIdx = curIdx;
+ yIdx[inputYRowDim] = inputYRowIdx;
+
+ sum += (GetValueAt(DataSlot::InputX, xIdx)
+ * GetValueAt(DataSlot::InputY, yIdx));
+ }
+
+ SetValueAt(sum, DataSlot::Output, curIdx);
+
+ return;
+ }
+
+ for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++)
+ {
+ curIdx[curDim] = i;
+ RecurseBMM(curIdx, curDim+1);
+ }
+}
+
+void BatchMatMul::AdjustAxesToMulForUnequalRanks(
+ std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul)
+{
+ long rankDiff = static_cast<long>(inputXInfo.GetNumDimensions()) - inputYInfo.GetNumDimensions();
+ if(rankDiff == 0)
+ {
+ return;
+ }
+ else if(rankDiff < 0)
+ {
+ // Y is the larger one
+ axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ }
+ else if(rankDiff > 0)
+ {
+ // X is the larger one
+ axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff));
+ }
+}
+
+float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx)
+{
+ // This gets the data from the input vector that we have, Not the decoder
+ // But for the output, it is operating on the encoder itself
+
+ AdjustToSafeIdx(type, idx);
+ unsigned int flatIdx = CalcFlatIdx(type, idx);
+ float value = 0.0f;
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ value = inputXData[flatIdx];
+ break;
+ case DataSlot::InputY:
+ value = inputYData[flatIdx];
+ break;
+ case DataSlot::Output:
+ outputEncoder[flatIdx];
+ value = outputEncoder.Get();
+ break;
+ default:
+ break;
+ }
+
+ return value;
+}
+
+void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx)
+{
+ AdjustToSafeIdx(type, idx);
+
+ unsigned int flatIdx = CalcFlatIdx(type, idx);
+
+ switch(type)
+ {
+ case DataSlot::InputX:
+ inputXData[flatIdx] = value;
+ break;
+ case DataSlot::InputY:
+ inputYData[flatIdx] = value;
+ break;
+ case DataSlot::Output:
+ outputEncoder[flatIdx];
+ outputEncoder.Set(value);
+ break;
+ default:
+ break;
+ }
+}
+
+void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx)
+{
+ for(unsigned int dim = 0; dim < idx.size(); dim++)
+ {
+ switch(type)
+ {
+ case DataSlot::InputX:
+ {
+ auto xRank = inputXInfo.GetNumDimensions();
+ auto xDiff = outputInfo.GetNumDimensions() - xRank;
+ if (dim < xDiff ||
+ idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1)
+ {
+ idx[dim] = 0; // Broadcasting
+ }
+ break;
+ }
+ case DataSlot::InputY:
+ {
+ auto yRank = inputYInfo.GetNumDimensions();
+ auto yDiff = outputInfo.GetNumDimensions() - yRank;
+ if (dim < yDiff ||
+ idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1)
+ {
+ idx[dim] = 0;
+ }
+ break;
+ }
+ case DataSlot::Output:
+ {
+ // Our indices are based off the output
+ break;
+ }
+ default:
+ break;
+ }
+ }
+}
+
+unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx)
+{
+ unsigned int result = idx[idx.size()-1];
+
+ unsigned int dimMultiplier = 1;
+
+ unsigned int offset;
+
+ // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x)
+ for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--)
+ {
+ switch(type)
+ {
+ case DataSlot::InputX:
+ offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions();
+ dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset];
+ break;
+ case DataSlot::InputY:
+ offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions();
+ dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset];
+ break;
+ case DataSlot::Output:
+ dimMultiplier *= outputInfo.GetShape()[i+1];
+ break;
+ default:
+ break;
+ }
+ result += (idx[i] * dimMultiplier);
+ }
+ return result;
+}
+
+template <typename T>
+std::string BatchMatMul::StringifyVec(const std::vector<T>& vec)
+{
+ std::string res = "{ ";
+ for(auto x : vec)
+ {
+ res += std::to_string(x);
+ res += " ";
+ }
+ res += "}";
+ return res;
+}
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp
new file mode 100644
index 0000000000..25b6c85d77
--- /dev/null
+++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp
@@ -0,0 +1,75 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+
+#include <armnn/backends/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class BatchMatMul {
+public:
+ enum DataSlot
+ {
+ InputX = 0,
+ InputY = 1,
+ Output = 2
+ };
+
+ BatchMatMul(const BatchMatMulDescriptor& params,
+ const TensorInfo& inputXInfo,
+ const TensorInfo& inputYInfo,
+ const TensorInfo& outputInfo,
+ Decoder<float>& inputXDecoder,
+ Decoder<float>& inputYDecoder,
+ Encoder<float>& outputEncoder)
+ : params(params),
+ inputXInfo(inputXInfo),
+ inputYInfo(inputYInfo),
+ outputInfo(outputInfo),
+ inputXDecoder(inputXDecoder),
+ inputYDecoder(inputYDecoder),
+ outputEncoder(outputEncoder)
+ {}
+
+ void BatchMatMulImpl();
+
+ void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim);
+
+ // Adjusts it for when input tensors are of unequal rank
+ void AdjustAxesToMulForUnequalRanks(
+ std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul);
+
+ float GetValueAt(DataSlot type, std::vector<unsigned int> idx);
+
+ void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx);
+
+ // Takes into account broadcasting
+ void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx);
+
+ unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx);
+
+ template <typename T>
+ std::string StringifyVec(const std::vector<T>& vec);
+
+private:
+ const BatchMatMulDescriptor& params;
+ const TensorInfo& inputXInfo;
+ const TensorInfo& inputYInfo;
+ const TensorInfo& outputInfo;
+ Decoder<float>& inputXDecoder;
+ Decoder<float>& inputYDecoder;
+ Encoder<float>& outputEncoder;
+
+ std::vector<float> inputXData;
+ std::vector<float> inputYData;
+
+};
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index b1f6d8b250..b8835e3cdb 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -10,6 +10,8 @@ list(APPEND armnnRefBackendWorkloads_sources
ArgMinMax.cpp
ArgMinMax.hpp
BaseIterator.hpp
+ BatchMatMulImpl.cpp
+ BatchMatMulImpl.hpp
BatchNormImpl.cpp
BatchNormImpl.hpp
BatchToSpaceNd.cpp
@@ -69,6 +71,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefArgMinMaxWorkload.cpp
RefArgMinMaxWorkload.hpp
RefBaseWorkload.hpp
+ RefBatchMatMulWorkload.cpp
+ RefBatchMatMulWorkload.hpp
RefBatchNormalizationWorkload.cpp
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
new file mode 100644
index 0000000000..388190c4ef
--- /dev/null
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp
@@ -0,0 +1,59 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefBatchMatMulWorkload.hpp"
+
+#include "BatchMatMulImpl.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+RefBatchMatMulWorkload::RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : RefBaseWorkload(descriptor, info)
+{}
+
+void RefBatchMatMulWorkload::Execute() const
+{
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+}
+
+void RefBatchMatMulWorkload::ExecuteAsync(ExecutionData& executionData)
+{
+ WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+ Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+}
+
+void RefBatchMatMulWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchMatMulWorkload_Execute");
+
+ const TensorInfo& inputXInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& inputYInfo = GetTensorInfo(inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
+
+ std::unique_ptr<Decoder<float>> inputXDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
+ inputs[0]->Map());
+
+ std::unique_ptr<Decoder<float>> inputYDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]),
+ inputs[1]->Map());
+
+ std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
+ outputs[0]->Map());
+
+ auto bmm = BatchMatMul(m_Data.m_Parameters,
+ inputXInfo,
+ inputYInfo,
+ outputInfo,
+ *inputXDecoder,
+ *inputYDecoder,
+ *outputEncoder);
+
+ bmm.BatchMatMulImpl();
+
+}
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp
new file mode 100644
index 0000000000..e9dfcaef73
--- /dev/null
+++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp
@@ -0,0 +1,30 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "RefBaseWorkload.hpp"
+#include <armnn/backends/WorkloadData.hpp>
+
+#include "BatchMatMulImpl.hpp"
+
+namespace armnn
+{
+
+class RefBatchMatMulWorkload : public RefBaseWorkload<BatchMatMulQueueDescriptor>
+{
+public:
+ explicit RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
+ void Execute() const override;
+ void ExecuteAsync(ExecutionData& executionData) override;
+
+private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
+
+};
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index b9c7a2a1fb..e049d8db2c 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -7,6 +7,7 @@
#include "RefActivationWorkload.hpp"
#include "RefArgMinMaxWorkload.hpp"
+#include "RefBatchMatMulWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
#include "RefCastWorkload.hpp"