From 6b47809e7d6c55d20a05d863ce2f09159f381f85 Mon Sep 17 00:00:00 2001 From: Samuel Yap Date: Wed, 6 Jul 2022 15:36:03 +0100 Subject: IVGCVSW-7109: Add Batch MatMul front end support - Reference * Descriptors added for BatchMatMul * Layer definition added * Input validation added (will likely change when opt. param support comes in) * Ref workload implementation for BatchMatMul added (will also change with opt. param support) * Ref layer tests made for BatchMatMul * CMake and other build files updated Signed-off-by: Samuel Yap Change-Id: Ic885301da543ee0fbe7922b85e7f9658c4efc617 --- src/backends/reference/RefLayerSupport.cpp | 52 +++++ src/backends/reference/RefLayerSupport.hpp | 6 + src/backends/reference/RefWorkloadFactory.cpp | 5 + src/backends/reference/backend.mk | 2 + src/backends/reference/test/RefLayerTests.cpp | 71 +++++++ .../reference/workloads/BatchMatMulImpl.cpp | 230 +++++++++++++++++++++ .../reference/workloads/BatchMatMulImpl.hpp | 75 +++++++ src/backends/reference/workloads/CMakeLists.txt | 4 + .../reference/workloads/RefBatchMatMulWorkload.cpp | 59 ++++++ .../reference/workloads/RefBatchMatMulWorkload.hpp | 30 +++ src/backends/reference/workloads/RefWorkloads.hpp | 1 + 11 files changed, 535 insertions(+) create mode 100644 src/backends/reference/workloads/BatchMatMulImpl.cpp create mode 100644 src/backends/reference/workloads/BatchMatMulImpl.hpp create mode 100644 src/backends/reference/workloads/RefBatchMatMulWorkload.cpp create mode 100644 src/backends/reference/workloads/RefBatchMatMulWorkload.hpp (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 8051dcffa0..40909019ba 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -79,6 +79,12 @@ bool RefLayerSupport::IsLayerSupported(const LayerType& type, infos[1], *(PolymorphicDowncast(&descriptor)), reasonIfUnsupported); + case LayerType::BatchMatMul: + return IsBatchMatMulSupported(infos[0], + infos[1], + infos[2], + *(PolymorphicDowncast(&descriptor)), + reasonIfUnsupported); case LayerType::BatchNormalization: return IsBatchNormalizationSupported(infos[0], infos[1], @@ -642,6 +648,52 @@ bool RefLayerSupport::IsArgMinMaxSupported(const armnn::TensorInfo &input, const return supported; } +bool RefLayerSupport::IsBatchMatMulSupported(const TensorInfo& inputX, + const TensorInfo& inputY, + const TensorInfo& output, + const BatchMatMulDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + IgnoreUnused(descriptor); + + std::array supportedTypes = + { + DataType::BFloat16, + DataType::Float16, + DataType::Float32, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS16 + }; + + bool supported = true; + + supported &= CheckSupportRule(TypeAnyOf(inputX, supportedTypes), reasonIfUnsupported, + "Reference batch matrix multiplication: input X is not a supported type"); + + supported &= CheckSupportRule(TypeAnyOf(inputY, supportedTypes), reasonIfUnsupported, + "Reference batch matrix multiplication: input Y is not a supported type"); + + supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported, + "Reference batch matrix multiplication: output is not a supported type"); + + supported &= CheckSupportRule(TypesAreEqual(inputX, inputY), reasonIfUnsupported, + "Reference batch matrix multiplication: input X and input Y types are mismatched"); + + supported &= CheckSupportRule(TypesAreEqual(inputX, output), reasonIfUnsupported, + "Reference batch matrix multiplication: inputs and output types are mismatched"); + + supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputX, 2), + reasonIfUnsupported, + "Reference batch matrix multiplication: input X is not of rank 2 or greater"); + + supported &= CheckSupportRule(TensorNumDimensionsAreGreaterOrEqualTo(inputY, 2), + reasonIfUnsupported, + "Reference batch matrix multiplication: input Y is not of rank 2 or greater"); + + return supported; +} + bool RefLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index aa8bd8dda4..b64244db24 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -34,6 +34,12 @@ public: const ArgMinMaxDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsBatchMatMulSupported(const TensorInfo& inputX, + const TensorInfo& inputY, + const TensorInfo& output, + const BatchMatMulDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const; + bool IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 2d956582db..093d0d5e20 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -170,6 +170,11 @@ std::unique_ptr RefWorkloadFactory::CreateWorkload(LayerType type, auto argMinMaxQueueDescriptor = PolymorphicDowncast(&descriptor); return std::make_unique(*argMinMaxQueueDescriptor, info); } + case LayerType::BatchMatMul: + { + auto batchMatMulQueueDescriptor = PolymorphicDowncast(&descriptor); + return std::make_unique(*batchMatMulQueueDescriptor, info); + } case LayerType::BatchNormalization : { auto batchNormQueueDescriptor = PolymorphicDowncast(&descriptor); diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index d9a5a1d32c..ed942e67cd 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -23,6 +23,7 @@ BACKEND_SOURCES := \ RefTensorHandleFactory.cpp \ workloads/Activation.cpp \ workloads/ArgMinMax.cpp \ + workloads/BatchMatMulImpl.cpp \ workloads/BatchNormImpl.cpp \ workloads/BatchToSpaceNd.cpp \ workloads/Broadcast.cpp \ @@ -49,6 +50,7 @@ BACKEND_SOURCES := \ workloads/Reduce.cpp \ workloads/RefActivationWorkload.cpp \ workloads/RefArgMinMaxWorkload.cpp \ + workloads/RefBatchMatMulWorkload.cpp \ workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdWorkload.cpp \ workloads/RefCastWorkload.cpp \ diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index 419ae2b0e9..593dc7851e 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -1062,6 +1062,77 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1ElementInt32, Multiplicati ARMNN_AUTO_TEST_CASE_WITH_THF(MultiplicationBroadcast1DVectorInt32, MultiplicationBroadcast1DVectorInt32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest) +// Batch Mat Mul +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleBFloat16, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat32, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleFloat16, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmS8, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQAsymmU8, BatchMatMul2DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DSimpleQASymmS16, BatchMatMul2DSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleBFloat16, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat32, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleFloat16, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmS8, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQAsymmU8, BatchMatMul3DSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DSimpleQASymmS16, BatchMatMul3DSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleBFloat16, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat32, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleFloat16, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmS8, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQAsymmU8, BatchMatMulNCHWSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNCHWSimpleQASymmS16, BatchMatMulNCHWSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleBFloat16, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat32, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleFloat16, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmS8, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQAsymmU8, BatchMatMulNHWCSimpleTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNHWCSimpleQASymmS16, BatchMatMulNHWCSimpleTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchBFloat16, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat32, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchFloat16, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmS8, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQAsymmU8, BatchMatMul3DBatchTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBatchQASymmS16, BatchMatMul3DBatchTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastBFloat16, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat32, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastFloat16, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmS8, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQAsymmU8, BatchMatMul3DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DBroadcastQASymmS16, BatchMatMul3DBroadcastTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastBFloat16, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat32, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastFloat16, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmS8, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQAsymmU8, BatchMatMul3D2DBroadcastTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3D2DBroadcastQASymmSS16, BatchMatMul3D2DBroadcastTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCBFloat16, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat32, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCFloat16, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmS8, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQAsymmU8, BatchMatMulNDHWCNHWCTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMulNDHWCNHWCQASymmSS16, BatchMatMulNDHWCNHWCTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyBFloat16, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat32, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyFloat16, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmS8, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQAsymmU8, BatchMatMul2DTinyTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul2DTinyQASymmS16, BatchMatMul2DTinyTest); + +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareBFloat16, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat32, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareFloat16, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmS8, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQAsymmU8, BatchMatMul3DNonSquareTest); +ARMNN_AUTO_TEST_CASE_WITH_THF(BatchMatMul3DNonSquareQASymmS16, BatchMatMul3DNonSquareTest); + // Batch Norm ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest) diff --git a/src/backends/reference/workloads/BatchMatMulImpl.cpp b/src/backends/reference/workloads/BatchMatMulImpl.cpp new file mode 100644 index 0000000000..74a358cc5c --- /dev/null +++ b/src/backends/reference/workloads/BatchMatMulImpl.cpp @@ -0,0 +1,230 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "BatchMatMulImpl.hpp" + +#include +#include + +namespace armnn +{ + +void BatchMatMul::BatchMatMulImpl() +{ + inputXData = inputXDecoder.DecodeTensor(inputXInfo.GetShape()); + inputYData = inputYDecoder.DecodeTensor(inputYInfo.GetShape()); + // At this point, we don't touch the input decoders - just the resultant vectors + + // Pre-transpose and pre-adjoint if their vectors aren't empty + // and also DataLayouts which may change with permutations/adjoints + + // Todo: Have you updated input validation and inferred output shapes to accommodate for these pre-permutes? + + auto idx = std::vector(outputInfo.GetNumDimensions(), 0); + RecurseBMM(idx, 0); +} + +void BatchMatMul::RecurseBMM(std::vector& curIdx, unsigned int curDim) +{ + // We're working off of the indexes of the output tensor (the max possible shape) + + if(!(curDim < outputInfo.GetNumDimensions())) + { + // We're at the leaf level of this call tree, so we operate here (each leaf is a data point) + + auto axesToMul = BatchMatMulDescriptor::GetAxesToMul(params, + inputXInfo.GetShape(), + inputYInfo.GetShape()); + AdjustAxesToMulForUnequalRanks(axesToMul); + + unsigned int inputXColDim = axesToMul.first.second; + unsigned int inputYRowDim = axesToMul.second.first; + + unsigned int inputYRowSize = inputYInfo.GetShape()[inputYRowDim]; + + float sum = 0.0f; + + // You could also use inputXColSize + for (unsigned int inputYRowIdx = 0; inputYRowIdx < inputYRowSize; inputYRowIdx++) { + auto xIdx = curIdx; + xIdx[inputXColDim] = inputYRowIdx; + + auto yIdx = curIdx; + yIdx[inputYRowDim] = inputYRowIdx; + + sum += (GetValueAt(DataSlot::InputX, xIdx) + * GetValueAt(DataSlot::InputY, yIdx)); + } + + SetValueAt(sum, DataSlot::Output, curIdx); + + return; + } + + for (unsigned int i = 0; i < outputInfo.GetShape()[curDim]; i++) + { + curIdx[curDim] = i; + RecurseBMM(curIdx, curDim+1); + } +} + +void BatchMatMul::AdjustAxesToMulForUnequalRanks( + std::pair, std::pair>& axesToMul) +{ + long rankDiff = static_cast(inputXInfo.GetNumDimensions()) - inputYInfo.GetNumDimensions(); + if(rankDiff == 0) + { + return; + } + else if(rankDiff < 0) + { + // Y is the larger one + axesToMul.first.first += static_cast::type>(std::abs(rankDiff)); + axesToMul.first.second += static_cast::type>(std::abs(rankDiff)); + } + else if(rankDiff > 0) + { + // X is the larger one + axesToMul.second.first += static_cast::type>(std::abs(rankDiff)); + axesToMul.second.second += static_cast::type>(std::abs(rankDiff)); + } +} + +float BatchMatMul::GetValueAt(DataSlot type, std::vector idx) +{ + // This gets the data from the input vector that we have, Not the decoder + // But for the output, it is operating on the encoder itself + + AdjustToSafeIdx(type, idx); + unsigned int flatIdx = CalcFlatIdx(type, idx); + float value = 0.0f; + + switch(type) + { + case DataSlot::InputX: + value = inputXData[flatIdx]; + break; + case DataSlot::InputY: + value = inputYData[flatIdx]; + break; + case DataSlot::Output: + outputEncoder[flatIdx]; + value = outputEncoder.Get(); + break; + default: + break; + } + + return value; +} + +void BatchMatMul::SetValueAt(float value, DataSlot type, std::vector idx) +{ + AdjustToSafeIdx(type, idx); + + unsigned int flatIdx = CalcFlatIdx(type, idx); + + switch(type) + { + case DataSlot::InputX: + inputXData[flatIdx] = value; + break; + case DataSlot::InputY: + inputYData[flatIdx] = value; + break; + case DataSlot::Output: + outputEncoder[flatIdx]; + outputEncoder.Set(value); + break; + default: + break; + } +} + +void BatchMatMul::AdjustToSafeIdx(DataSlot type, std::vector& idx) +{ + for(unsigned int dim = 0; dim < idx.size(); dim++) + { + switch(type) + { + case DataSlot::InputX: + { + auto xRank = inputXInfo.GetNumDimensions(); + auto xDiff = outputInfo.GetNumDimensions() - xRank; + if (dim < xDiff || + idx[dim] > inputXInfo.GetShape()[dim-xDiff]-1) + { + idx[dim] = 0; // Broadcasting + } + break; + } + case DataSlot::InputY: + { + auto yRank = inputYInfo.GetNumDimensions(); + auto yDiff = outputInfo.GetNumDimensions() - yRank; + if (dim < yDiff || + idx[dim] > inputYInfo.GetShape()[dim-yDiff]-1) + { + idx[dim] = 0; + } + break; + } + case DataSlot::Output: + { + // Our indices are based off the output + break; + } + default: + break; + } + } +} + +unsigned int BatchMatMul::CalcFlatIdx(DataSlot type, const std::vector& idx) +{ + unsigned int result = idx[idx.size()-1]; + + unsigned int dimMultiplier = 1; + + unsigned int offset; + + // -2 because final dim is already accounted for in the multiplier (last dim is just a multiplier of 1x) + for(unsigned int i = static_cast(idx.size()-2); static_cast(i) >= 0; i--) + { + switch(type) + { + case DataSlot::InputX: + offset = outputInfo.GetNumDimensions() - inputXInfo.GetNumDimensions(); + dimMultiplier *= inputXInfo.GetShape()[i + 1 - offset]; + break; + case DataSlot::InputY: + offset = outputInfo.GetNumDimensions() - inputYInfo.GetNumDimensions(); + dimMultiplier *= inputYInfo.GetShape()[i + 1 - offset]; + break; + case DataSlot::Output: + dimMultiplier *= outputInfo.GetShape()[i+1]; + break; + default: + break; + } + result += (idx[i] * dimMultiplier); + } + return result; +} + +template +std::string BatchMatMul::StringifyVec(const std::vector& vec) +{ + std::string res = "{ "; + for(auto x : vec) + { + res += std::to_string(x); + res += " "; + } + res += "}"; + return res; +} + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/BatchMatMulImpl.hpp b/src/backends/reference/workloads/BatchMatMulImpl.hpp new file mode 100644 index 0000000000..25b6c85d77 --- /dev/null +++ b/src/backends/reference/workloads/BatchMatMulImpl.hpp @@ -0,0 +1,75 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "Encoders.hpp" +#include "Decoders.hpp" + +#include + +namespace armnn +{ + +class BatchMatMul { +public: + enum DataSlot + { + InputX = 0, + InputY = 1, + Output = 2 + }; + + BatchMatMul(const BatchMatMulDescriptor& params, + const TensorInfo& inputXInfo, + const TensorInfo& inputYInfo, + const TensorInfo& outputInfo, + Decoder& inputXDecoder, + Decoder& inputYDecoder, + Encoder& outputEncoder) + : params(params), + inputXInfo(inputXInfo), + inputYInfo(inputYInfo), + outputInfo(outputInfo), + inputXDecoder(inputXDecoder), + inputYDecoder(inputYDecoder), + outputEncoder(outputEncoder) + {} + + void BatchMatMulImpl(); + + void RecurseBMM(std::vector& curIdx, unsigned int curDim); + + // Adjusts it for when input tensors are of unequal rank + void AdjustAxesToMulForUnequalRanks( + std::pair, std::pair>& axesToMul); + + float GetValueAt(DataSlot type, std::vector idx); + + void SetValueAt(float value, DataSlot type, std::vector idx); + + // Takes into account broadcasting + void AdjustToSafeIdx(DataSlot type, std::vector& idx); + + unsigned int CalcFlatIdx(DataSlot type, const std::vector& idx); + + template + std::string StringifyVec(const std::vector& vec); + +private: + const BatchMatMulDescriptor& params; + const TensorInfo& inputXInfo; + const TensorInfo& inputYInfo; + const TensorInfo& outputInfo; + Decoder& inputXDecoder; + Decoder& inputYDecoder; + Encoder& outputEncoder; + + std::vector inputXData; + std::vector inputYData; + +}; + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index b1f6d8b250..b8835e3cdb 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -10,6 +10,8 @@ list(APPEND armnnRefBackendWorkloads_sources ArgMinMax.cpp ArgMinMax.hpp BaseIterator.hpp + BatchMatMulImpl.cpp + BatchMatMulImpl.hpp BatchNormImpl.cpp BatchNormImpl.hpp BatchToSpaceNd.cpp @@ -69,6 +71,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefArgMinMaxWorkload.cpp RefArgMinMaxWorkload.hpp RefBaseWorkload.hpp + RefBatchMatMulWorkload.cpp + RefBatchMatMulWorkload.hpp RefBatchNormalizationWorkload.cpp RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp new file mode 100644 index 0000000000..388190c4ef --- /dev/null +++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.cpp @@ -0,0 +1,59 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefBatchMatMulWorkload.hpp" + +#include "BatchMatMulImpl.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" + +namespace armnn +{ + +RefBatchMatMulWorkload::RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, const WorkloadInfo& info) + : RefBaseWorkload(descriptor, info) +{} + +void RefBatchMatMulWorkload::Execute() const +{ + Execute(m_Data.m_Inputs, m_Data.m_Outputs); +} + +void RefBatchMatMulWorkload::ExecuteAsync(ExecutionData& executionData) +{ + WorkingMemDescriptor* workingMemDescriptor = static_cast(executionData.m_Data); + Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs); +} + +void RefBatchMatMulWorkload::Execute(std::vector inputs, std::vector outputs) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefBatchMatMulWorkload_Execute"); + + const TensorInfo& inputXInfo = GetTensorInfo(inputs[0]); + const TensorInfo& inputYInfo = GetTensorInfo(inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(outputs[0]); + + std::unique_ptr> inputXDecoder = MakeDecoder(GetTensorInfo(inputs[0]), + inputs[0]->Map()); + + std::unique_ptr> inputYDecoder = MakeDecoder(GetTensorInfo(inputs[1]), + inputs[1]->Map()); + + std::unique_ptr> outputEncoder = MakeEncoder(GetTensorInfo(outputs[0]), + outputs[0]->Map()); + + auto bmm = BatchMatMul(m_Data.m_Parameters, + inputXInfo, + inputYInfo, + outputInfo, + *inputXDecoder, + *inputYDecoder, + *outputEncoder); + + bmm.BatchMatMulImpl(); + +} + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp new file mode 100644 index 0000000000..e9dfcaef73 --- /dev/null +++ b/src/backends/reference/workloads/RefBatchMatMulWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "RefBaseWorkload.hpp" +#include + +#include "BatchMatMulImpl.hpp" + +namespace armnn +{ + +class RefBatchMatMulWorkload : public RefBaseWorkload +{ +public: + explicit RefBatchMatMulWorkload(const BatchMatMulQueueDescriptor& descriptor, + const WorkloadInfo& info); + + void Execute() const override; + void ExecuteAsync(ExecutionData& executionData) override; + +private: + void Execute(std::vector inputs, std::vector outputs) const; + +}; + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index b9c7a2a1fb..e049d8db2c 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -7,6 +7,7 @@ #include "RefActivationWorkload.hpp" #include "RefArgMinMaxWorkload.hpp" +#include "RefBatchMatMulWorkload.hpp" #include "RefBatchNormalizationWorkload.hpp" #include "RefBatchToSpaceNdWorkload.hpp" #include "RefCastWorkload.hpp" -- cgit v1.2.1