From a1b1e41bb261f5613f443fed7071936a360686ed Mon Sep 17 00:00:00 2001 From: Mohammed Suhail Munshi Date: Thu, 23 Mar 2023 22:21:31 +0000 Subject: Implement MatMul Function and Operator with Floating Point support for CPU - Implements MatMul function and operator for floating point datatype FP16/FP32 - Includes support for transposing dynamic tensors prior to matrix multiplication. - Adds tests for 2D/3D/4D+ tensors in MatMul with F32/F16 datatype (with all combinations of transposed/not-transposed tensors) - Updates fixture to allow for testing fused activation in MatMul - Adds tests for matmul with and without fused activation Resolved: [COMPMID-5898] Signed-off-by: Mohammed Suhail Munshi Change-Id: Iefa84b26dd723c9a51e6c3f91023152c6c31ace2 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9411 Reviewed-by: SiCong Li Tested-by: Arm Jenkins Benchmark: Arm Jenkins --- src/BUILD.bazel | 4 +- src/CMakeLists.txt | 6 +- src/core/helpers/AutoConfiguration.h | 19 +- src/cpu/operators/CpuMatMul.cpp | 226 +++++++++++++++++++++ src/cpu/operators/CpuMatMul.h | 115 +++++++++++ .../operators/internal/CpuGemmAssemblyDispatch.h | 34 +++- src/runtime/CL/functions/CLMatMul.cpp | 6 +- src/runtime/NEON/functions/NEMatMul.cpp | 75 +++++++ 8 files changed, 472 insertions(+), 13 deletions(-) create mode 100644 src/cpu/operators/CpuMatMul.cpp create mode 100644 src/cpu/operators/CpuMatMul.h create mode 100644 src/runtime/NEON/functions/NEMatMul.cpp (limited to 'src') diff --git a/src/BUILD.bazel b/src/BUILD.bazel index 279c52e151..26acc14a68 100644 --- a/src/BUILD.bazel +++ b/src/BUILD.bazel @@ -765,10 +765,10 @@ filegroup( "cpu/kernels/instancenorm/generic/neon/impl.cpp", "cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp", "cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp", - "cpu/kernels/lut/generic/neon/u8.cpp", "cpu/kernels/l2normlayer/generic/neon/fp16.cpp", "cpu/kernels/l2normlayer/generic/neon/fp32.cpp", "cpu/kernels/l2normlayer/generic/neon/impl.cpp", + "cpu/kernels/lut/generic/neon/u8.cpp", "cpu/kernels/maxunpool/generic/neon/fp16.cpp", "cpu/kernels/maxunpool/generic/neon/fp32.cpp", "cpu/kernels/maxunpool/generic/neon/impl.cpp", @@ -837,6 +837,7 @@ filegroup( "cpu/operators/CpuGemmDirectConv2d.cpp", "cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp", "cpu/operators/CpuGemmLowpOutputStage.cpp", + "cpu/operators/CpuMatMul.cpp", "cpu/operators/CpuMaxUnpooling.cpp", "cpu/operators/CpuMul.cpp", "cpu/operators/CpuPermute.cpp", @@ -921,6 +922,7 @@ filegroup( "runtime/NEON/functions/NELSTMLayer.cpp", "runtime/NEON/functions/NELSTMLayerQuantized.cpp", "runtime/NEON/functions/NELogical.cpp", + "runtime/NEON/functions/NEMatMul.cpp", "runtime/NEON/functions/NEMaxUnpoolingLayer.cpp", "runtime/NEON/functions/NEMeanStdDevNormalizationLayer.cpp", "runtime/NEON/functions/NENormalizationLayer.cpp", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 92c888056e..336d2cd5cc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -757,10 +757,10 @@ target_sources( cpu/kernels/instancenorm/generic/neon/impl.cpp cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp - cpu/kernels/lut/generic/neon/u8.cpp cpu/kernels/l2normlayer/generic/neon/fp16.cpp cpu/kernels/l2normlayer/generic/neon/fp32.cpp cpu/kernels/l2normlayer/generic/neon/impl.cpp + cpu/kernels/lut/generic/neon/u8.cpp cpu/kernels/maxunpool/generic/neon/fp16.cpp cpu/kernels/maxunpool/generic/neon/fp32.cpp cpu/kernels/maxunpool/generic/neon/impl.cpp @@ -829,6 +829,7 @@ target_sources( cpu/operators/CpuGemmDirectConv2d.cpp cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp cpu/operators/CpuGemmLowpOutputStage.cpp + cpu/operators/CpuMatMul.cpp cpu/operators/CpuMaxUnpooling.cpp cpu/operators/CpuMul.cpp cpu/operators/CpuPermute.cpp @@ -913,6 +914,7 @@ target_sources( runtime/NEON/functions/NELSTMLayer.cpp runtime/NEON/functions/NELSTMLayerQuantized.cpp runtime/NEON/functions/NELogical.cpp + runtime/NEON/functions/NEMatMul.cpp runtime/NEON/functions/NEMaxUnpoolingLayer.cpp runtime/NEON/functions/NEMeanStdDevNormalizationLayer.cpp runtime/NEON/functions/NENormalizationLayer.cpp @@ -960,4 +962,4 @@ target_sources( runtime/Tensor.cpp runtime/TensorAllocator.cpp runtime/Utils.cpp -) +) \ No newline at end of file diff --git a/src/core/helpers/AutoConfiguration.h b/src/core/helpers/AutoConfiguration.h index 6880a6cb66..18ffbd6295 100644 --- a/src/core/helpers/AutoConfiguration.h +++ b/src/core/helpers/AutoConfiguration.h @@ -1,5 +1,5 @@ /* -* Copyright (c) 2020 Arm Limited. +* Copyright (c) 2020, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -57,12 +57,16 @@ inline bool auto_init_if_empty(ITensorInfo &info, } /** Auto initialize the tensor info using another tensor info. -* -* @param info_sink Tensor info used to check and assign -* @param info_source Tensor info used to assign -* -* @return True if the tensor info has been initialized -*/ + * + * (COMPMID-6012) This method should remain in sync with the fields of ITensorInfo that have setters. + * + * + * @param info_sink Tensor info used to check and assign + * @param info_source Tensor info used to assign + * + * + * @return True if the tensor info has been initialized + */ inline bool auto_init_if_empty(ITensorInfo &info_sink, const ITensorInfo &info_source) { if(info_sink.tensor_shape().total_size() == 0) @@ -72,6 +76,7 @@ inline bool auto_init_if_empty(ITensorInfo &info_sink, const ITensorInfo &info_s info_sink.set_tensor_shape(info_source.tensor_shape()); info_sink.set_quantization_info(info_source.quantization_info()); info_sink.set_data_layout(info_source.data_layout()); + info_sink.set_are_values_constant(info_source.are_values_constant()); return true; } diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp new file mode 100644 index 0000000000..b5359e51af --- /dev/null +++ b/src/cpu/operators/CpuMatMul.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/cpu/operators/CpuMatMul.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/experimental/Types.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include "arm_compute/runtime/NEON/functions/NEMatMul.h" +#include "src/common/utils/Log.h" +#include "src/core/CPP/Validate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/MemoryHelpers.h" +#include "src/cpu/utils/CpuAuxTensorHandler.h" + +using namespace arm_compute::experimental; + +namespace arm_compute +{ +namespace cpu +{ +CpuMatMul::CpuMatMul() + : _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape() +{ +} + +Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings) +{ + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic."); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic."); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(lhs); + + const auto adj_lhs = info.adj_lhs(); + const auto adj_rhs = info.adj_rhs(); + + const ITensorInfo *lhs_to_use = lhs; + const ITensorInfo *rhs_to_use = rhs; + TensorInfo lhs_transposed{}; + TensorInfo rhs_transposed{}; + + auto gemm_info = AsmGemmInfo(); + gemm_info.activation_info = info.fused_activation(); + gemm_info.fast_mode = settings.fast_math(); + + // Validate and then permute a/b + if(adj_lhs) + { + auto_init_if_empty(lhs_transposed, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*lhs))); + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuTransposeKernel::validate(lhs_to_use, &lhs_transposed)); + // Assign lhs_to_use pointer to use transposed TensorInfo + lhs_to_use = &lhs_transposed; + } + if(adj_rhs) + { + auto_init_if_empty(rhs_transposed, rhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transposed_shape(*rhs))); + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuTransposeKernel::validate(rhs_to_use, &rhs_transposed)); + // Assign rhs_to_use pointer to use transposed TensorInfo + rhs_to_use = &rhs_transposed; + } + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(0) != rhs_to_use->dimension(1), + "The product AB is defined only if the number of columns in A is equal to the number of rows in B (after transpose)"); + + if(lhs_to_use->num_dimensions() > 2) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->num_dimensions() != rhs_to_use->num_dimensions(), "Broadcasting in Batch dimension is unsupported by this operator."); + } + + // Iterate over dimensions to be collapsed in operator - check dimensions are equivelent between tensors + for(unsigned int i = 2; i < lhs_to_use->num_dimensions(); i++) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(i) != rhs_to_use->dimension(i), "Broadcasting in Batch dimension is unsupported by this operator."); + } + + cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info); + + return Status{}; +} + +void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, info, settings); + ARM_COMPUTE_ERROR_THROW_ON(CpuMatMul::validate(lhs, rhs, dst, info, settings)); + + _adj_lhs = info.adj_lhs(); + _adj_rhs = info.adj_rhs(); + _fast_math = settings.fast_math(); + + // 1. Create and reshape tensors + // ------------------------------------------------------ + // a. Clone TensorInfo to prevent changing original tensor values during setup + // b. Change shape of lhs/dst to [x, y, 1, collapsed(z)] to match assembly kernel configuration + // c. For rhs collapse all dimensions larger than 3 to z dimension + TensorInfo lhs_to_use = *lhs->clone(); + TensorInfo dst_to_use = *dst->clone(); + TensorInfo rhs_to_use = *rhs->clone(); + + // Save starting shape of tensors + _original_lhs_shape = lhs_to_use.tensor_shape(); + _original_dst_shape = dst_to_use.tensor_shape(); + _original_rhs_shape = rhs_to_use.tensor_shape(); + + // Reshape lhs for use with assembly kernels. + lhs_to_use.set_tensor_shape(TensorShape(_original_lhs_shape.x(), _original_lhs_shape.y(), 1, _original_lhs_shape.collapsed_from(2).z())); + dst_to_use.set_tensor_shape(TensorShape(_original_dst_shape.x(), _original_dst_shape.y(), 1, _original_dst_shape.collapsed_from(2).z())); + rhs_to_use.set_tensor_shape(_original_rhs_shape.collapsed_from(2)); + + // 2. Configuration for transpose of lhs/rhs + // ------------------------------------------------------ + // Initialise transposed TensorInfo class for aux tensors (intermediary tensors) + if(_adj_lhs) + { + // Setup transpose LHS + _transpose_kernel_lhs = std::make_unique(); + _transpose_kernel_lhs->configure(&lhs_to_use, &_lhs_transposed); + } + + if(_adj_rhs) + { + // Setup transpose RHS + _transpose_kernel_rhs = std::make_unique(); + _transpose_kernel_rhs->configure(&rhs_to_use, &_rhs_transposed); + } + + // 3. Configure assembly kernel using transposed tensors. + // ----------------------------------------------------- + // Use transposed tensors if the corresponding transpose flags are set + // Fill AsmGemmInfo class object before configuration + _gemm_info.activation_info = info.fused_activation(); + _gemm_info.fast_mode = settings.fast_math(); + + lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use; + rhs_to_use = (_adj_rhs) ? _rhs_transposed : rhs_to_use; + + // Configure Asm Kernel + _asm_glue = std::make_unique(); + _asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use, _gemm_info); // c is nullptr as bias not supported in MatMul + + // Specify memory requirements for intermediate tensors + auto asm_mem_req = _asm_glue->workspace(); + // Specify memory required by gemm kernel + int idx = 0; + for(const auto &aux : asm_mem_req) + { + _aux_mem[idx] = aux; + idx++; + } + // Memory requirements for transposed tensors + _aux_mem[TransposeLHS] = MemoryInfo(offset_int_vec(TransposeLHS), MemoryLifetime::Temporary, lhs->total_size()); + _aux_mem[TransposeRHS] = MemoryInfo(offset_int_vec(TransposeRHS), MemoryLifetime::Temporary, rhs->total_size()); +} + +void CpuMatMul::run(ITensorPack &tensors) +{ + // Retrieve tensors from tensor pack + auto lhs = tensors.get_tensor(ACL_SRC_0); + auto rhs = tensors.get_const_tensor(ACL_SRC_1); + auto dst = tensors.get_tensor(ACL_DST); + + // Reshape LHS and DST to ensure compatibility with GEMM asm kernel (Batch dimensions is 4th for lhs and dst within asm) + // Collapse RHS (necessary to support dimensions larger than 3 in gemm assembly) + lhs->info()->set_tensor_shape(TensorShape(_original_lhs_shape.x(), _original_lhs_shape.y(), 1, _original_lhs_shape.collapsed_from(2).z())); // Collapsed 3+ dimensions into z + dst->info()->set_tensor_shape(TensorShape(_original_dst_shape.x(), _original_dst_shape.y(), 1, _original_dst_shape.collapsed_from(2).z())); // Collapsed 3+ dimensions into z + rhs->info()->set_tensor_shape(_original_rhs_shape.collapsed_from(2)); + + // Initialise object to handle stored transposed tensors in auxillary memory + CpuAuxTensorHandler lhs_transposed(offset_int_vec(TransposeLHS), _lhs_transposed, tensors, true); + CpuAuxTensorHandler rhs_transposed(offset_int_vec(TransposeRHS), _rhs_transposed, tensors, true); + + // Create tensor pack for asm kernel + ITensorPack asm_tensors(tensors); + + // Run transpose lhs if necessary + if(_adj_lhs) + { + ITensorPack lhs_transpose_pack = { { TensorType::ACL_SRC, lhs }, { TensorType::ACL_DST, lhs_transposed.get() } }; + NEScheduler::get().schedule_op(_transpose_kernel_lhs.get(), Window::DimY, _transpose_kernel_lhs->window(), lhs_transpose_pack); + asm_tensors.add_const_tensor(TensorType::ACL_SRC_0, lhs_transposed.get()); + } + // Run transpose rhs if necessary + if(_adj_rhs) + { + ITensorPack rhs_transpose_pack = { { TensorType::ACL_SRC, rhs }, { TensorType::ACL_DST, rhs_transposed.get() } }; + NEScheduler::get().schedule_op(_transpose_kernel_rhs.get(), Window::DimY, _transpose_kernel_rhs->window(), rhs_transpose_pack); + asm_tensors.add_const_tensor(TensorType::ACL_SRC_1, rhs_transposed.get()); + } + // Run asm kernel + _asm_glue->run(asm_tensors); + + // Undo reshape of tensors + dst->info()->set_tensor_shape(_original_dst_shape); + lhs->info()->set_tensor_shape(_original_lhs_shape); + rhs->info()->set_tensor_shape(_original_rhs_shape); +} + +experimental::MemoryRequirements CpuMatMul::workspace() const +{ + return _aux_mem; +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/cpu/operators/CpuMatMul.h b/src/cpu/operators/CpuMatMul.h new file mode 100644 index 0000000000..ae6345141e --- /dev/null +++ b/src/cpu/operators/CpuMatMul.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef SRC_CPU_OPERATORS_CPUMATMUL +#define SRC_CPU_OPERATORS_CPUMATMUL + +#include "arm_compute/core/TensorInfo.h" +#include "src/core/common/Macros.h" +#include "src/cpu/ICpuOperator.h" +#include "src/cpu/kernels/CpuTransposeKernel.h" +#include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h" + +namespace arm_compute +{ +// Forward Declarations +class MatMulInfo; +class CpuMatMulSettings; + +namespace cpu +{ +/** Function to execute MatMul Operation. This function calls the following functions/kernels: + * + * If adjoint/adj flag is enabled for either input lhs or rhs (or both) : + * -# @ref cpu::kernels::CpuTransposeKernel + * Then : + * -# @ref cpu::CpuGemmAssemblyDispatch + */ +class CpuMatMul : public ICpuOperator +{ +public: + /* Constructor */ + CpuMatMul(); + /* Destructor */ + ~CpuMatMul() = default; + + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuMatMul); + /** Configure operator for a given list of arguments + * + * Note: Check documentation of @ref NEMatMul for a list of supported datatypes and layouts + * + * + * @param[in] lhs Source tensor info. + * @param[in] rhs Source tensor info. + * @param[out] dst Destination tensor info. Data types supported: same as @p lhs / @p rhs. + * @param[in] info Contains MatMul operation information described in @ref MatMulInfo. + * @param[in] settings The settings for matmul operation (i.e fast math) + */ + void configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings); + /** Static function to check if given info will lead to a valid configuration + * + * Similar to CpuMatMul::configure() + * + * @return a status + */ + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; + +private: + enum InternalTensorIdx + { + AsmGemmWorkspace = 0, // Pre-allocate workspace tensors for CpuGemmAssemblyDispatch + PretransposeRHS, // Pre-allocate workspace tensors for CpuGemmAssemblyDispatch + TransposeLHS, + TransposeRHS, + Count + }; + + // Define unique pointers to kernels/operators used by matmul + std::unique_ptr _transpose_kernel_lhs{ nullptr }; + std::unique_ptr _transpose_kernel_rhs{ nullptr }; + std::unique_ptr _asm_glue{ nullptr }; + + // TensorInfo for tensors stored in auxillary memory + TensorInfo _lhs_transposed{}; + TensorInfo _rhs_transposed{}; + + // Original tensor shapes prior to reshaping tensors and collapsing dimensions + TensorShape _original_lhs_shape{}; + TensorShape _original_rhs_shape{}; + TensorShape _original_dst_shape{}; + + // Note : adj_lhs means the same as transposing lhs + bool _adj_lhs{ false }; + bool _adj_rhs{ false }; + bool _fast_math{ false }; + AsmGemmInfo _gemm_info{}; + experimental::MemoryRequirements _aux_mem{ Count }; +}; +} +} + +#endif /* SRC_CPU_OPERATORS_CPUMATMUL */ diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h index 0c51c92359..588c45294a 100644 --- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h +++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022 Arm Limited. + * Copyright (c) 2018-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -81,6 +81,38 @@ public: public: /** If supported create a Compute Library function else fallback to the arm_gemm function. + * + * @note Configuring "batches" + * The shapes of @p a @p b and @p d are arranged as follows: + * Lowest dimension <-> Highest dimension + * a: [K, M, Batch, Multi] + * b: [N, K, Multi] + * d: [N, M, Batch, Multi] + * + * The "Batch" refers to where "Batch" number of MxK slices of tensor a multiplies with a single KxN slice of b + * The "Multi" refers to where "Multi" number of individual multiplication of a with b + * + * E.g. the following are some example input shape configurations + * + * (1) Normal 2D gemm + * a: [K=3, M=4] + * b: [N=5, K=3] + * d: [N=5, M=4] + * + * (2) Batches of a sharing b (e.g. gemm-based batched convolution where b is the shared ) + * a: [K=3, M=4, Batch=9] + * b: [N=5, K=3] + * d: [N=5, M=4, Batch=9] + * + * (3) "Batches" of independent gemm (e.g. batched matmul) + * a: [K=3, M=4, Batch=1, Multi=7] + * b: [N=5, K=3, Multi=7] + * d: [N=5, M=4, Batch=1, Multi=7] + * + * (4) "Batches" of independent gemm where b is also shared + * a: [K=3, M=4, Batch=4, Multi=7] + * b: [N=5, K=3, Multi=7] + * d: [N=5, M=4, Batch=4, Multi=7] * * @param[in] a Input tensor (Matrix A) * @param[in] b Input tensor (Matrix B) diff --git a/src/runtime/CL/functions/CLMatMul.cpp b/src/runtime/CL/functions/CLMatMul.cpp index f42e4ff309..ae5a01f679 100644 --- a/src/runtime/CL/functions/CLMatMul.cpp +++ b/src/runtime/CL/functions/CLMatMul.cpp @@ -42,14 +42,16 @@ CLMatMul::CLMatMul() CLMatMul::~CLMatMul() = default; -void CLMatMul::configure(ICLTensor *lhs, ICLTensor *rhs, ICLTensor *output, const MatMulInfo &matmul_info) +void CLMatMul::configure(ICLTensor *lhs, ICLTensor *rhs, ICLTensor *output, const MatMulInfo &matmul_info, const GpuMatMulSettings &settings) { + ARM_COMPUTE_UNUSED(settings); configure(CLKernelLibrary::get().get_compile_context(), lhs, rhs, output, matmul_info); } -void CLMatMul::configure(const CLCompileContext &compile_context, ICLTensor *lhs, ICLTensor *rhs, ICLTensor *output, const MatMulInfo &matmul_info) +void CLMatMul::configure(const CLCompileContext &compile_context, ICLTensor *lhs, ICLTensor *rhs, ICLTensor *output, const MatMulInfo &matmul_info, const GpuMatMulSettings &settings) { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, output); + ARM_COMPUTE_UNUSED(settings); _impl->op = std::make_unique(); _impl->op->configure(compile_context, lhs->info(), rhs->info(), output->info(), matmul_info); diff --git a/src/runtime/NEON/functions/NEMatMul.cpp b/src/runtime/NEON/functions/NEMatMul.cpp new file mode 100644 index 0000000000..0c46516f1e --- /dev/null +++ b/src/runtime/NEON/functions/NEMatMul.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2023 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/runtime/NEON/functions/NEMatMul.h" + +#include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/MemoryGroup.h" +#include "src/core/helpers/MemoryHelpers.h" +#include "src/cpu/operators/CpuMatMul.h" +#include "arm_compute/runtime/Tensor.h" + +namespace arm_compute +{ +struct NEMatMul::Impl +{ + const ITensor *lhs{ nullptr }; + const ITensor *rhs{ nullptr }; + ITensor *output{ nullptr }; + std::unique_ptr op{ nullptr }; + MemoryGroup memory_group{}; + WorkspaceData workspace_tensors{}; + ITensorPack run_pack{}; +}; + +NEMatMul::NEMatMul() + : _impl(std::make_unique()) +{ +} + +NEMatMul::~NEMatMul() = default; + +void NEMatMul::configure(ITensor *lhs, ITensor *rhs, ITensor *output, const MatMulInfo &info, const CpuMatMulSettings &settings) +{ + _impl->lhs = lhs; + _impl->rhs = rhs; + _impl->output = output; + + ARM_COMPUTE_ERROR_ON_NULLPTR(_impl->lhs, _impl->rhs, _impl->output); + _impl->op = std::make_unique(); + _impl->op->configure(lhs->info(), rhs->info(), output->info(), info, settings); + _impl->run_pack = { { ACL_SRC_0, lhs }, { ACL_SRC_1, rhs }, { ACL_DST, output } }; + _impl->workspace_tensors = manage_workspace(_impl->op->workspace(), _impl->memory_group, _impl->run_pack); +} + +Status NEMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *output, const MatMulInfo &info, const CpuMatMulSettings &settings) +{ + return cpu::CpuMatMul::validate(lhs, rhs, output, info, settings); +} + +void NEMatMul::run() +{ + MemoryGroupResourceScope scope_mg(_impl->memory_group); + _impl->op->run(_impl->run_pack); +} +} // namespace arm_compute -- cgit v1.2.1