diff options
author | Samuel Yap <samuel.yap@arm.com> | 2022-07-06 15:36:03 +0100 |
---|---|---|
committer | Nikhil Raj <nikhil.raj@arm.com> | 2022-07-27 15:58:31 +0100 |
commit | 6b47809e7d6c55d20a05d863ce2f09159f381f85 (patch) | |
tree | c33e5820f89e359c80d8773288e8adb075735039 /src/backends/reference | |
parent | 919ec71ea7f44bb2d284eb88cda511c2424358b2 (diff) | |
download | armnn-6b47809e7d6c55d20a05d863ce2f09159f381f85.tar.gz |
IVGCVSW-7109: Add Batch MatMul front end support - Reference
* Descriptors added for BatchMatMul
* Layer definition added
* Input validation added (will likely change when opt. param support comes in)
* Ref workload implementation for BatchMatMul added (will also change with opt. param support)
* Ref layer tests made for BatchMatMul
* CMake and other build files updated
Signed-off-by: Samuel Yap <samuel.yap@arm.com>
Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617
Diffstat (limited to 'src/backends/reference')
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 52 | ||||
-rw-r--r-- | src/backends/reference/RefLayerSupport.hpp | 6 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/backend.mk | 2 | ||||
-rw-r--r-- | src/backends/reference/test/RefLayerTests.cpp | 71 | ||||
-rw-r--r-- | src/backends/reference/workloads/BatchMatMulImpl.cpp | 230 | ||||
-rw-r--r-- | src/backends/reference/workloads/BatchMatMulImpl.hpp | 75 | ||||
-rw-r--r-- | src/backends/reference/workloads/CMakeLists.txt | 4 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefBatchMatMulWorkload.cpp | 59 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefBatchMatMulWorkload.hpp | 30 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefWorkloads.hpp | 1 |
11 files changed, 535 insertions, 0 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 8051dcffa0..40909019ba 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -79,6 +79,12 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast<const ArgMinMaxDescriptor*>(&descriptor)), reasonIfUnsupported); + case LayerType::BatchMatMul: + return IsBatchMatMulSupported(infos[0], + infos[1], + infos[2], + *(PolymorphicDowncast<const BatchMatMulDescriptor*>(&descriptor)), + reasonIfUnsupported); case LayerType::BatchNormalization: return IsBatchNormalizationSupported(infos[0], infos[1], @@ -642,6 +648,52 @@ bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const return supported; } +bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX, + const TensorInfo& inputY, + const TensorInfo& output, + const BatchMatMulDescriptor& descriptor, + Optional<std::string &> reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + + std::array<DataType, 6> supportedTypes = + { + DataType::BFloat16, + DataType::Float16, + DataType::Float32, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS16 + }; + + bool supported = true; + + supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported, + "Reference batch matrix multiplication: input X is not a supported type"); + + supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported, + "Reference batch matrix multiplication: input Y is not a supported type"); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference batch matrix multiplication: output is not a supported type"); + + supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported, + "Reference batch matrix multiplication: input X and input Y types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported, + "Reference batch matrix multiplication: inputs and output types are mismatched"); + + supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2), + reasonIfUnsupported, + "Reference batch matrix multiplication: input X is not of rank 2 or greater"); + + supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2), + reasonIfUnsupported, + "Reference batch matrix multiplication: input Y is not of rank 2 or greater"); + + return supported; +} + bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index aa8bd8dda4..b64244db24 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -34,6 +34,12 @@ public: const ArgMinMaxDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsBatchMatMulSupported(const TensorInfo& inputX, + const TensorInfo& inputY, + const TensorInfo& output, + const BatchMatMulDescriptor& descriptor, + Optional<std::string &> reasonIfUnsupported = EmptyOptional()) const; + bool IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 2d956582db..093d0d5e20 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -170,6 +170,11 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type, auto argMinMaxQueueDescriptor = PolymorphicDowncast<const ArgMinMaxQueueDescriptor*>(&descriptor); return std::make_unique<RefArgMinMaxWorkload>(*argMinMaxQueueDescriptor, info); } + case LayerType::BatchMatMul: + { + auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor); + return std::make_unique<RefBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info); + } case LayerType::BatchNormalization : { auto batchNormQueueDescriptor = PolymorphicDowncast<const BatchNormalizationQueueDescriptor*>(&descriptor); diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index d9a5a1d32c..ed942e67cd 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -23,6 +23,7 @@ BACKEND_SOURCES := \ RefTensorHandleFactory.cpp \ workloads/Activation.cpp \ workloads/ArgMinMax.cpp \ + workloads/BatchMatMulImpl.cpp \ workloads/BatchNormImpl.cpp \ workloads/BatchToSpaceNd.cpp \ workloads/Broadcast.cpp \ @@ -49,6 +50,7 @@ BACKEND_SOURCES := \ workloads/Reduce.cpp \ workloads/RefActivationWorkload.cpp \ workloads/RefArgMinMaxWorkload.cpp \ + workloads/RefBatchMatMulWorkload.cpp \ workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdWorkload.cpp \ workloads/RefCastWorkload.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 419ae2b0e9..593dc7851e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1062,6 +1062,77 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1ElementInt32, Multiplicati ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1DVectorInt32, MultiplicationBroadcast1DVectorInt32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest) +// Batch Mat Mul +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleBFloat16, BatchMatMul2DSimpleTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat16, BatchMatMul2DSimpleTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmS8, BatchMatMul2DSimpleTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmU8, BatchMatMul2DSimpleTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQASymmS16, BatchMatMul2DSimpleTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleBFloat16, BatchMatMul3DSimpleTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat16, BatchMatMul3DSimpleTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmS8, BatchMatMul3DSimpleTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmU8, BatchMatMul3DSimpleTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQASymmS16, BatchMatMul3DSimpleTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleBFloat16, BatchMatMulNCHWSimpleTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat16, BatchMatMulNCHWSimpleTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmS8, BatchMatMulNCHWSimpleTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmU8, BatchMatMulNCHWSimpleTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQASymmS16, BatchMatMulNCHWSimpleTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleBFloat16, BatchMatMulNHWCSimpleTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat32, BatchMatMulNHWCSimpleTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat16, BatchMatMulNHWCSimpleTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmS8, BatchMatMulNHWCSimpleTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmU8, BatchMatMulNHWCSimpleTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQASymmS16, BatchMatMulNHWCSimpleTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchBFloat16, BatchMatMul3DBatchTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat16, BatchMatMul3DBatchTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmS8, BatchMatMul3DBatchTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmU8, BatchMatMul3DBatchTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQASymmS16, BatchMatMul3DBatchTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastBFloat16, BatchMatMul3DBroadcastTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat16, BatchMatMul3DBroadcastTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmS8, BatchMatMul3DBroadcastTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmU8, BatchMatMul3DBroadcastTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQASymmS16, BatchMatMul3DBroadcastTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastBFloat16, BatchMatMul3D2DBroadcastTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat16, BatchMatMul3D2DBroadcastTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmS8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmU8, BatchMatMul3D2DBroadcastTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQASymmSS16, BatchMatMul3D2DBroadcastTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCBFloat16, BatchMatMulNDHWCNHWCTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat32, BatchMatMulNDHWCNHWCTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat16, BatchMatMulNDHWCNHWCTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmS8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmU8, BatchMatMulNDHWCNHWCTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQASymmSS16, BatchMatMulNDHWCNHWCTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyBFloat16, BatchMatMul2DTinyTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat16, BatchMatMul2DTinyTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmS8, BatchMatMul2DTinyTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmU8, BatchMatMul2DTinyTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQASymmS16, BatchMatMul2DTinyTest<DataType::QSymmS16>); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareBFloat16, BatchMatMul3DNonSquareTest<DataType::BFloat16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat32, BatchMatMul3DNonSquareTest<DataType::Float32>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat16, BatchMatMul3DNonSquareTest<DataType::Float16>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSquareTest<DataType::QAsymmS8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest<DataType::QAsymmU8>); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest<DataType::QSymmS16>); + // Batch Norm ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest) diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp new file mode 100644 index 0000000000..74a358cc5c --- /dev/null +++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp @@ -0,0 +1,230 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BatchMatMulImpl.hpp" + +#include <armnn/backends/WorkloadData.hpp> +#include <armnn/Logging.hpp> + +namespace armnn +{ + +void BatchMatMul::BatchMatMulImpl() +{ + inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape()); + inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape()); + // At this point, we don't touch the input decoders - just the resultant vectors + + // Pre-transpose and pre-adjoint if their vectors aren't empty + // and also DataLayouts which may change with permutations/adjoints + + // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes? + + auto idx = std::vector<unsigned int>(outputInfo.GetNumDimensions(), 0); + RecurseBMM(idx, 0); +} + +void BatchMatMul::RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim) +{ + // We're working off of the indexes of the output tensor (the max possible shape) + + if(!(curDim < outputInfo.GetNumDimensions())) + { + // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) + + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params, + inputXInfo.GetShape(), + inputYInfo.GetShape()); + AdjustAxesToMulForUnequalRanks(axesToMul); + + unsigned int inputXColDim = axesToMul.first.second; + unsigned int inputYRowDim = axesToMul.second.first; + + unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; + + float sum = 0.0f; + + // You could also use inputXColSize + for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) { + auto xIdx = curIdx; + xIdx[inputXColDim] = inputYRowIdx; + + auto yIdx = curIdx; + yIdx[inputYRowDim] = inputYRowIdx; + + sum += (GetValueAt(DataSlot::InputX, xIdx) + * GetValueAt(DataSlot::InputY, yIdx)); + } + + SetValueAt(sum, DataSlot::Output, curIdx); + + return; + } + + for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++) + { + curIdx[curDim] = i; + RecurseBMM(curIdx, curDim+1); + } +} + +void BatchMatMul::AdjustAxesToMulForUnequalRanks( + std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul) +{ + long rankDiff = static_cast<long>(inputXInfo.GetNumDimensions()) - inputYInfo.GetNumDimensions(); + if(rankDiff == 0) + { + return; + } + else if(rankDiff < 0) + { + // Y is the larger one + axesToMul.first.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); + axesToMul.first.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); + } + else if(rankDiff > 0) + { + // X is the larger one + axesToMul.second.first += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); + axesToMul.second.second += static_cast<std::make_unsigned<unsigned int>::type>(std::abs(rankDiff)); + } +} + +float BatchMatMul::GetValueAt(DataSlot type, std::vector<unsigned int> idx) +{ + // This gets the data from the input vector that we have, Not the decoder + // But for the output, it is operating on the encoder itself + + AdjustToSafeIdx(type, idx); + unsigned int flatIdx = CalcFlatIdx(type, idx); + float value = 0.0f; + + switch(type) + { + case DataSlot::InputX: + value = inputXData[flatIdx]; + break; + case DataSlot::InputY: + value = inputYData[flatIdx]; + break; + case DataSlot::Output: + outputEncoder[flatIdx]; + value = outputEncoder.Get(); + break; + default: + break; + } + + return value; +} + +void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx) +{ + AdjustToSafeIdx(type, idx); + + unsigned int flatIdx = CalcFlatIdx(type, idx); + + switch(type) + { + case DataSlot::InputX: + inputXData[flatIdx] = value; + break; + case DataSlot::InputY: + inputYData[flatIdx] = value; + break; + case DataSlot::Output: + outputEncoder[flatIdx]; + outputEncoder.Set(value); + break; + default: + break; + } +} + +void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx) +{ + for(unsigned int dim = 0; dim < idx.size(); dim++) + { + switch(type) + { + case DataSlot::InputX: + { + auto xRank = inputXInfo.GetNumDimensions(); + auto xDiff = outputInfo.GetNumDimensions() - xRank; + if (dim < xDiff || + idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1) + { + idx[dim] = 0; // Broadcasting + } + break; + } + case DataSlot::InputY: + { + auto yRank = inputYInfo.GetNumDimensions(); + auto yDiff = outputInfo.GetNumDimensions() - yRank; + if (dim < yDiff || + idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1) + { + idx[dim] = 0; + } + break; + } + case DataSlot::Output: + { + // Our indices are based off the output + break; + } + default: + break; + } + } +} + +unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx) +{ + unsigned int result = idx[idx.size()-1]; + + unsigned int dimMultiplier = 1; + + unsigned int offset; + + // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x) + for(unsigned int i = static_cast<unsigned int>(idx.size()-2); static_cast<int>(i) >= 0; i--) + { + switch(type) + { + case DataSlot::InputX: + offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions(); + dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset]; + break; + case DataSlot::InputY: + offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions(); + dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset]; + break; + case DataSlot::Output: + dimMultiplier *= outputInfo.GetShape()[i+1]; + break; + default: + break; + } + result += (idx[i] * dimMultiplier); + } + return result; +} + +template <typename T> +std::string BatchMatMul::StringifyVec(const std::vector<T>& vec) +{ + std::string res = "{ "; + for(auto x : vec) + { + res += std::to_string(x); + res += " "; + } + res += "}"; + return res; +} + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp new file mode 100644 index 0000000000..25b6c85d77 --- /dev/null +++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp @@ -0,0 +1,75 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Encoders.hpp" +#include "Decoders.hpp" + +#include <armnn/backends/WorkloadData.hpp> + +namespace armnn +{ + +class BatchMatMul { +public: + enum DataSlot + { + InputX = 0, + InputY = 1, + Output = 2 + }; + + BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder<float>& inputXDecoder, + Decoder<float>& inputYDecoder, + Encoder<float>& outputEncoder) + : params(params), + inputXInfo(inputXInfo), + inputYInfo(inputYInfo), + outputInfo(outputInfo), + inputXDecoder(inputXDecoder), + inputYDecoder(inputYDecoder), + outputEncoder(outputEncoder) + {} + + void BatchMatMulImpl(); + + void RecurseBMM(std::vector<unsigned int>& curIdx, unsigned int curDim); + + // Adjusts it for when input tensors are of unequal rank + void AdjustAxesToMulForUnequalRanks( + std::pair<std::pair<unsigned int, unsigned int>, std::pair<unsigned int, unsigned int>>& axesToMul); + + float GetValueAt(DataSlot type, std::vector<unsigned int> idx); + + void SetValueAt(float value, DataSlot type, std::vector<unsigned int> idx); + + // Takes into account broadcasting + void AdjustToSafeIdx(DataSlot type, std::vector<unsigned int>& idx); + + unsigned int CalcFlatIdx(DataSlot type, const std::vector<unsigned int>& idx); + + template <typename T> + std::string StringifyVec(const std::vector<T>& vec); + +private: + const BatchMatMulDescriptor& params; + const TensorInfo& inputXInfo; + const TensorInfo& inputYInfo; + const TensorInfo& outputInfo; + Decoder<float>& inputXDecoder; + Decoder<float>& inputYDecoder; + Encoder<float>& outputEncoder; + + std::vector<float> inputXData; + std::vector<float> inputYData; + +}; + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index b1f6d8b250..b8835e3cdb 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -10,6 +10,8 @@ list(APPEND armnnRefBackendWorkloads_sources ArgMinMax.cpp ArgMinMax.hpp BaseIterator.hpp + BatchMatMulImpl.cpp + BatchMatMulImpl.hpp BatchNormImpl.cpp BatchNormImpl.hpp BatchToSpaceNd.cpp @@ -69,6 +71,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefArgMinMaxWorkload.cpp RefArgMinMaxWorkload.hpp RefBaseWorkload.hpp + RefBatchMatMulWorkload.cpp + RefBatchMatMulWorkload.hpp RefBatchNormalizationWorkload.cpp RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp new file mode 100644 index 0000000000..388190c4ef --- /dev/null +++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp @@ -0,0 +1,59 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefBatchMatMulWorkload.hpp" + +#include "BatchMatMulImpl.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" + +namespace armnn +{ + +RefBatchMatMulWorkload::RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info) + : RefBaseWorkload(descriptor, info) +{} + +void RefBatchMatMulWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefBatchMatMulWorkload::ExecuteAsync(ExecutionData& executionData) +{ + WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); + Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); +} + +void RefBatchMatMulWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchMatMulWorkload_Execute"); + + const TensorInfo& inputXInfo = GetTensorInfo(inputs[0]); + const TensorInfo& inputYInfo = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); + + std::unique_ptr<Decoder<float>> inputXDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]), + inputs[0]->Map()); + + std::unique_ptr<Decoder<float>> inputYDecoder = MakeDecoder<float>(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + + std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]), + outputs[0]->Map()); + + auto bmm = BatchMatMul(m_Data.m_Parameters, + inputXInfo, + inputYInfo, + outputInfo, + *inputXDecoder, + *inputYDecoder, + *outputEncoder); + + bmm.BatchMatMulImpl(); + +} + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp new file mode 100644 index 0000000000..e9dfcaef73 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "RefBaseWorkload.hpp" +#include <armnn/backends/WorkloadData.hpp> + +#include "BatchMatMulImpl.hpp" + +namespace armnn +{ + +class RefBatchMatMulWorkload : public RefBaseWorkload<BatchMatMulQueueDescriptor> +{ +public: + explicit RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(ExecutionData& executionData) override; + +private: + void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; + +}; + +} // namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index b9c7a2a1fb..e049d8db2c 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -7,6 +7,7 @@ #include "RefActivationWorkload.hpp" #include "RefArgMinMaxWorkload.hpp" +#include "RefBatchMatMulWorkload.hpp" #include "RefBatchNormalizationWorkload.hpp" #include "RefBatchToSpaceNdWorkload.hpp" #include "RefCastWorkload.hpp" |