From 7891a73ef36f4ad7b71069b3c57694f85bb79454 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 20 Aug 2021 21:39:25 +0100 Subject: Move CPU/GPU files from Core/Runtime to the respective backend folders Legacy structure contained two libraries core/runtime with two backends in each. We reduce the core/runtime libraries to a single library thus merging the backend files Signed-off-by: Georgios Pinitas Change-Id: I69545765fe7a730368105cdbd067d3135ec7a174 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6155 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins --- .../assembly/CpuGemmAssemblyWrapperKernel.h | 126 +++++++++++ src/cpu/kernels/assembly/arm_gemm.hpp | 190 +++++++++++++++++ .../kernels/assembly/arm_gemm_compute_iface.hpp | 130 ++++++++++++ src/cpu/kernels/assembly/arm_gemm_local.hpp | 31 +++ .../kernels/assembly/convolution_parameters.hpp | 65 ++++++ src/cpu/kernels/assembly/gemm_common.hpp | 236 +++++++++++++++++++++ src/cpu/kernels/assembly/ndrange.hpp | 199 +++++++++++++++++ 7 files changed, 977 insertions(+) create mode 100644 src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h create mode 100644 src/cpu/kernels/assembly/arm_gemm.hpp create mode 100644 src/cpu/kernels/assembly/arm_gemm_compute_iface.hpp create mode 100644 src/cpu/kernels/assembly/arm_gemm_local.hpp create mode 100644 src/cpu/kernels/assembly/convolution_parameters.hpp create mode 100644 src/cpu/kernels/assembly/gemm_common.hpp create mode 100644 src/cpu/kernels/assembly/ndrange.hpp (limited to 'src/cpu/kernels/assembly') diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h new file mode 100644 index 0000000000..3b9a6b4760 --- /dev/null +++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2018-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H +#define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H + +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "src/core/NEON/INEKernel.h" +#include "src/cpu/kernels/assembly/arm_gemm_compute_iface.hpp" + +#include "gemm_common.hpp" + +namespace arm_compute +{ +class ITensor; + +namespace cpu +{ +namespace kernel +{ +/** This class is a wrapper for the assembly kernels. + * + * Some kernels were written in assembly and highly optimised for specific CPUs like A53 or A55. + * This class works as a wrapper for these assembly kernels. The arm compute library creates an instance + * of CpuGemmAssemblyWrapperKernel and other auxiliary data structures to execute a single assembly kernel + * in the context of an NEFunctions. + * + * The type T is the type of the actual kernel implemented in assembly which is of type + * template class GemmCommon + * + * + */ +template +class CpuGemmAssemblyWrapperKernel final : public INEKernel +{ +public: + /** Constructor + */ + CpuGemmAssemblyWrapperKernel() + : _kernel(nullptr), _name("CpuGemmAssemblyWrapperKernel") + { + } + + CpuGemmAssemblyWrapperKernel(CpuGemmAssemblyWrapperKernel &) = delete; + CpuGemmAssemblyWrapperKernel(CpuGemmAssemblyWrapperKernel &&) = default; + CpuGemmAssemblyWrapperKernel &operator=(CpuGemmAssemblyWrapperKernel &) = delete; + + const char *name() const override + { + return _name.c_str(); + } + + void run(const Window &window, const ThreadInfo &info) override + { + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(_kernel))); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + + auto win = arm_gemm::to_ndcoord(window); + + arm_gemm::ndcoord_t thread_locator{}; + + _kernel->execute(win, thread_locator, info.thread_id); + } + + // Inherited methods overridden: + void run_nd(const Window &window, const ThreadInfo &info, const Window &thread_locator) override + { + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(_kernel))); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + + //convert between arm_compute and arm_gemm types + auto ndc_win = arm_gemm::to_ndcoord(window); + auto ndc_tlc = arm_gemm::to_ndcoord(thread_locator); + + _kernel->execute(ndc_win, ndc_tlc, info.thread_id); + } + + /** Initialise the kernel's input and output. + * + * @param[in] kernel Pointer to an assembly kernel implementation. + * @param[in] kernel_name_tag Tag to be attacehd to the kernel's name. + */ + void configure(arm_gemm::GemmCommon *kernel, std::string kernel_name_tag) + { + ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast(kernel))); + _kernel = kernel; + + Window win = to_window(kernel->get_window_size()); + + INEKernel::configure(win); + + if(!kernel_name_tag.empty()) + { + _name += "/" + kernel_name_tag; + } + } + +private: + arm_gemm::GemmCommon *_kernel; + std::string _name; +}; +} // namespace kernel +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H */ diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp new file mode 100644 index 0000000000..e38cc09202 --- /dev/null +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2018-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include +#include +#include + +#include "arm_gemm_local.hpp" +#include "gemm_common.hpp" + +namespace arm_gemm +{ +enum class GemmMethod +{ + DEFAULT, + GEMV_BATCHED, + GEMV_PRETRANSPOSED, + GEMV_NATIVE_TRANSPOSED, + GEMM_NATIVE, + GEMM_HYBRID, + GEMM_INTERLEAVED, + GEMM_INTERLEAVED_2D, + QUANTIZE_WRAPPER, + QUANTIZE_WRAPPER_2D, + GEMM_HYBRID_QUANTIZED +}; + +struct KernelDescription +{ + GemmMethod method = GemmMethod::DEFAULT; + std::string name = ""; + bool is_default = false; + uint64_t cycle_estimate = 0; + + KernelDescription(GemmMethod m, std::string n, bool d = false, uint64_t c = 0) + : method(m), name(n), is_default(d), cycle_estimate(c) + { + } + KernelDescription() noexcept + { + } +}; + +struct GemmConfig +{ + GemmMethod method = GemmMethod::DEFAULT; + std::string filter = ""; + unsigned int inner_block_size = 0; + unsigned int outer_block_size = 0; + + GemmConfig(GemmMethod method) + : method(method) + { + } + GemmConfig() + { + } +}; + +struct Activation +{ + enum class Type + { + None, + ReLU, + BoundedReLU + }; + + Type type; + float param1; + float param2; + + Activation(Type type = Type::None, float p1 = 0.0f, float p2 = 0.0f) + : type(type), param1(p1), param2(p2) + { + } +}; + +struct GemmArgs +{ +public: + const CPUInfo *_ci; + unsigned int _Msize; + unsigned int _Nsize; + unsigned int _Ksize; + unsigned int _Ksections; + unsigned int _nbatches; + unsigned int _nmulti; + bool _indirect_input; + Activation _act; + int _maxthreads; + bool _fast_mode; + const GemmConfig *_cfg; + + GemmArgs(const CPUInfo *ci, unsigned int M, unsigned int N, + unsigned int K, unsigned int Ksections, unsigned int nbatches, + unsigned int nmulti, bool indirect_input, Activation act, const int maxthreads, + bool fast_mode = false, const GemmConfig *cfg = nullptr) + : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _Ksections(Ksections), _nbatches(nbatches), _nmulti(nmulti), _indirect_input(indirect_input), _act(act), _maxthreads(maxthreads), _fast_mode(fast_mode), + _cfg(cfg) + { + } +}; + +struct Requantize32 +{ +public: + const int32_t *bias = nullptr; + size_t bias_multi_stride = 0; + int32_t a_offset = 0; + int32_t b_offset = 0; + int32_t c_offset = 0; + bool per_channel_requant = false; + int32_t per_layer_left_shift = 0; + int32_t per_layer_right_shift = 0; + int32_t per_layer_mul = 0; + const int32_t *per_channel_left_shifts = nullptr; + const int32_t *per_channel_right_shifts = nullptr; + const int32_t *per_channel_muls = nullptr; + int32_t minval = 0; + int32_t maxval = 0; + + Requantize32() = default; + + // Constructor for per-tensor quantization + Requantize32(const int32_t *bias, size_t bias_multi_stride, + int32_t a_offset, int32_t b_offset, int32_t c_offset, + int32_t requant_shift, int32_t requant_mul, int32_t minv, int32_t maxv) + : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(false), per_layer_left_shift(std::max(requant_shift, 0)), + per_layer_right_shift(std::min(requant_shift, 0)), per_layer_mul(requant_mul), minval(minv), maxval(maxv) + { + } + + // Constructor for per-channel quantization + Requantize32(const int32_t *bias, size_t bias_multi_stride, + int32_t a_offset, int32_t b_offset, int32_t c_offset, + const int32_t *requant_left_shifts, + const int32_t *requant_right_shifts, + const int32_t *requant_muls, + int32_t minv, int32_t maxv) + : bias(bias), bias_multi_stride(bias_multi_stride), a_offset(a_offset), b_offset(b_offset), c_offset(c_offset), per_channel_requant(true), per_channel_left_shifts(requant_left_shifts), + per_channel_right_shifts(requant_right_shifts), per_channel_muls(requant_muls), minval(minv), maxval(maxv) + { + } +}; + +struct Nothing +{ +}; + +template +using UniqueGemmCommon = std::unique_ptr>; + +/* Low level API calls. + * These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */ + +/* get_gemm_method(): Given the templated types and provided parameters, + * which is the preferred method to implement this GEMM? */ +template +KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {}); + +template +UniqueGemmCommon gemm(const GemmArgs &args, const OutputStage & = {}); + +template +std::vector get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); + +} // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/arm_gemm_compute_iface.hpp b/src/cpu/kernels/assembly/arm_gemm_compute_iface.hpp new file mode 100644 index 0000000000..718fcd1fb4 --- /dev/null +++ b/src/cpu/kernels/assembly/arm_gemm_compute_iface.hpp @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2020-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include "arm_compute/core/Dimensions.h" +#include "arm_compute/core/Window.h" + +#include "ndrange.hpp" + +#include + +/* This file contains mapping between integral types used in arm_compute and arm_gemm + * These two codebases both require a degree of separation for the sake of modularity + * so maintain their own types which represent similar information. + */ + +namespace arm_gemm +{ +//we want to unify the maximum number of dimensions used beween arm_gemm and arm compute library +constexpr std::size_t ndrange_max = + arm_compute::Dimensions::num_max_dimensions; + +using ndrange_t = NDRange; +using ndcoord_t = NDCoordinate; + +/* Converts an `arm_gemm::ndrange_t` to a `arm_compute::Window` + * + * As `NDRange` does not not encode start positions, we specify + * the start to be zero in the produced `arm_compute::Window` + * + * @param [ndr] the `arm_gemm::ndrange_t` we wish to convert into a `arm_compute::Window` + * @returns an `arm_compute::Window` representing the same dimensional ranges as `ndr` + */ +inline arm_compute::Window to_window(const ndrange_t &ndr) +{ + arm_compute::Window win; + + for(unsigned int i = 0; i != ndrange_max; ++i) + { + //populate the window with the dimensions of the NDRange + win.set(i, arm_compute::Window::Dimension(0, ndr.get_size(i))); + } + + return win; +} + +/* + * Converts an `arm_gemm::ndcoord_t` to a `arm_compute::Window` + * + * @param [ndc] the `arm_gemm::ndcoord_t` we wish to convert into a `arm_compute::Window` + * @returns an `arm_compute::Window` representing the same dimensional ranges as `ndc` + */ +inline arm_compute::Window to_window(const ndcoord_t &ndc) +{ + arm_compute::Window win; + + for(unsigned int i = 0; i != ndrange_max; ++i) + { + const auto start = ndc.get_position(i); + const auto size = ndc.get_size(i); + const auto stop = start + size; + + //populate the window with the dimensions of the NDRange + win.set(i, arm_compute::Window::Dimension(start, stop)); + } + + return win; +} + +/** Convert an `arm_compute::Window` to an `arm_gemm::NDRange` of the same max dimensions + * + * It should be noted that `arm_compute::Window` specifies a `start()` and an `end()` + * where as `arm_gemm::ndrange_t` only has a size, as a result we store the delta between the range + * + * @param [win] the `arm_compute::Window` we want to convert to `arm_gemm::ndrange_t` + * @return the resultant ndrange_t + */ +inline ndrange_t to_ndrange(const arm_compute::Window &win) +{ + return + { + static_cast(win[0].end() - win[0].start()), + static_cast(win[1].end() - win[1].start()), + static_cast(win[2].end() - win[2].start()), + static_cast(win[3].end() - win[3].start()), + static_cast(win[4].end() - win[4].start()), + static_cast(win[5].end() - win[5].start()) + }; +} + +/** Convert an `arm_compute::Window` to an `arm_gemm::NDCoord` of the same max dimensions + * + * @param [win] the `arm_compute::Window` we want to convert to `arm_gemm::ndcoord_t` + * @return the resultant ndcoord_t + */ +inline ndcoord_t to_ndcoord(const arm_compute::Window &win) +{ + return + { + { static_cast(win[0].start()), static_cast(win[0].end() - win[0].start()) }, + { static_cast(win[1].start()), static_cast(win[1].end() - win[1].start()) }, + { static_cast(win[2].start()), static_cast(win[2].end() - win[2].start()) }, + { static_cast(win[3].start()), static_cast(win[3].end() - win[3].start()) }, + { static_cast(win[4].start()), static_cast(win[4].end() - win[4].start()) }, + { static_cast(win[5].start()), static_cast(win[5].end() - win[5].start()) } + }; +} + +} //namespace arm_gemm diff --git a/src/cpu/kernels/assembly/arm_gemm_local.hpp b/src/cpu/kernels/assembly/arm_gemm_local.hpp new file mode 100644 index 0000000000..78e0adf31f --- /dev/null +++ b/src/cpu/kernels/assembly/arm_gemm_local.hpp @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2018-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +/* This file is used to configure integration-specific aspects of arm_gemm into ACL */ + +#include "arm_compute/core/CPP/CPPTypes.h" + +using CPUModel = arm_compute::CPUModel; +using CPUInfo = arm_compute::CPUInfo; diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp new file mode 100644 index 0000000000..0c1ae58902 --- /dev/null +++ b/src/cpu/kernels/assembly/convolution_parameters.hpp @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2018-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include + +namespace arm_gemm +{ +/* + * Parameter set for "convolution" type GEMM. + * + * For a "convolution" GEMM, the GEMM parameters (M, K) are specified as if + * an im2row had been performed on the input tensor to generate the operand + * matrix, but instead this structure describes the convolution parameters + * such that this can be done on the fly. + * + * The parameters describe the convolution details - the notional shape of + * the input and output tensors, whether padding is to be applied, the size + * of the kernel and a constant value to be used for padding (needed for + * quantized tensors). + * + * The second part describes the layout of the input tensor in memory, which + * is assumed to be in NHWC format. This consists of a base pointer and + * strides for columns, rows and batches. 'multis' are not supported for + * convolution type GEMMs. + */ +struct ConvolutionParameters +{ + int64_t input_width; + int64_t input_height; + int64_t input_channels; + int64_t kernel_width; + int64_t kernel_height; + int64_t output_width; + int64_t output_height; + int64_t output_stride_w; + int64_t output_stride_h; + // output_channels not included as they do not affect the input. + int64_t padding_top; + int64_t padding_left; + float padding_value; +}; + +} // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp new file mode 100644 index 0000000000..378f1041be --- /dev/null +++ b/src/cpu/kernels/assembly/gemm_common.hpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2017-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include "convolution_parameters.hpp" +#include "ndrange.hpp" + +#include + +namespace arm_gemm +{ +// Avoid circular dependency with arm_gemm.hpp +struct GemmConfig; + +// Abstract class for the GEMM/GEMV functions. +// +// GEMM implementations may be "native" (never require any input +// permutation), "pretransposed" (require permutation up-front) or require +// working space (permute as they go along). This interface should support +// all of them. + +// The real GemmCommon class is templated based on the operand and return +// type. This is an interface class which is independent of those types. +class IGemmCommon +{ +public: + /* Pass in the pointers to the arrays to be operated on and their + * strides. This "generic" version uses void *s, the preferred version + * is the one provided by templated GemmCommon (below) which takes + * appropriately typed pointers. If B is pretransposed (see below) then + * the settings for B here are ignored. + */ + virtual void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const void *B, const int ldb, /* batches share B */ const int B_multi_stride, + void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0; + + /** @returns an ndrange containing ranges of the compute space which can be + * broken up and parallelised over + */ + virtual ndrange_t get_window_size() const = 0; + + /* The maximum thread count is specified when the GEMM is created. Some + * implementations need to know how many threads will actually run in + * order to work properly. + * + * In some cases, after creating the GEMM the number of threads needs to + * be reduced (e.g. not enough work to split across threads). This + * method allows the number of actual threads to be run to be set (must + * be equal or lower). + * + * This has an empty default implementation, as GEMMs which don't care + * about thread count can safely ignore this. + */ + virtual void set_nthreads(int) {}; + + /* Whether this GEMM can be dynamically scheduled or not. */ + virtual bool supports_dynamic_scheduling() const + { + return false; + } + + /** Main execute member fucntion + * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size() + * @param [in] thread_locator where are we inside of the thread space + * @param [in] threadid a unique threadid + */ + virtual void execute(const ndcoord_t &work_range, const ndcoord_t &thread_locator, int threadid) = 0; + + /*** Working space interface (optional) ***/ + /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */ + virtual size_t get_working_size() const + { + return 0; + } + /* Provide working space buffer - the void * passed in must remain allocated for the duration of any execute calls. */ + virtual void set_working_space(void *) {}; + + /*** "Pretransposed" interface (optional) ***/ + /* Is this object set up for pretranspose? If so, pretranspose_array() needs to be called before execute(); */ + virtual bool B_is_pretransposed() const + { + return false; + } + /* Does pretranspose still need to be done? */ + virtual bool B_pretranspose_required() const + { + return false; + } + /* Total number of bytes of space needed for pretransposed arrays. */ + virtual size_t get_B_pretransposed_array_size() const + { + return 0; + } + /* Perform pretranspose - arguments are output, input, input row stride and input multi stride. */ + /* The "real" version of this depends on the templated operand type (see below). */ + virtual void pretranspose_B_array_generic(void *, const void *, const int, const int) = 0; + /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */ + virtual void set_pretransposed_B_data(void *) + { + } + + /*** "Quantized bias" interface (optional) ***/ + /* Set the bias vector for quantized GEMMs */ + virtual void set_quantized_bias(const int32_t *, size_t) + { + } + + /*** Indirect interface (optional) ***/ + /* Set the indirect table. This comprises a number of values per kernel point, and a densely packed array of pointers, + * multis * batches * kernel_points */ + virtual void set_indirect_parameters_generic(size_t, const void *const *const *) + { + } + + /*** Convolution interface (optional) ***/ + /* Set the convolution parameters. */ + virtual void set_convolution_parameters(ConvolutionParameters) + { + } + + /*** Introspection interface ***/ + /* Get the configuration of this GEMM */ + virtual GemmConfig get_config() = 0; + + // Destructor + virtual ~IGemmCommon() + { + } +}; + +/* "Real" GemmCommon class which is templated on the operand and return types. + * + * In addition to correctly typed versions of the functions that operate on + * operand and return data, this class provides a default implementation of + * 'set_arrays' to capture the provided arguments in protected class + * members, as essentially any implementation will need these. + */ +template +class GemmCommon : public IGemmCommon +{ +protected: + const To *_Aptr = nullptr; + int _lda = 0; + int _A_batch_stride = 0; + int _A_multi_stride = 0; + const To *_Bptr = nullptr; + int _ldb = 0; + int _B_multi_stride = 0; + Tr *_Cptr = nullptr; + int _ldc = 0; + int _C_batch_stride = 0; + int _C_multi_stride = 0; + const Tr *_bias = nullptr; + int _bias_multi_stride = 0; + +public: + /* Pass in the pointers to the arrays to be operated on and their + * strides (templated version with appropriate types). */ + virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const To *B, const int ldb, /* batches share B */ const int B_multi_stride, + Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const Tr *bias, /* no row or batch stride needed */ const int bias_multi_stride) + { + _Aptr = A; + _lda = lda; + _A_batch_stride = A_batch_stride; + _A_multi_stride = A_multi_stride; + _Bptr = B; + _ldb = ldb; + _B_multi_stride = B_multi_stride; + _Cptr = C; + _ldc = ldc; + _C_batch_stride = C_batch_stride; + _C_multi_stride = C_multi_stride; + _bias = bias; + _bias_multi_stride = bias_multi_stride; + } + + /* Implementation of the void * overload which casts its arguments to the appropriate type. */ + void set_arrays_generic(const void *A, const int lda, const int A_batch_stride, const int A_multi_stride, + const void *B, const int ldb, /* batches share B */ const int B_multi_stride, + void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, + const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) override + { + set_arrays(static_cast(A), lda, A_batch_stride, A_multi_stride, + static_cast(B), ldb, B_multi_stride, + static_cast(C), ldc, C_batch_stride, C_multi_stride, + static_cast(bias), bias_multi_stride); + } + + /*** "Pretransposed" interface ***/ + + /* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */ + /* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */ + virtual void pretranspose_B_array(void *, const To *, const int, const int) {}; + + /* Implementation of the void * overload which casts its arguments to the appropriate type. */ + void pretranspose_B_array_generic(void *out, const void *in, const int row_stride, const int multi_stride) override + { + pretranspose_B_array(out, static_cast(in), row_stride, multi_stride); + } + + /*** Indirect interface ***/ + virtual void set_indirect_parameters(size_t, const To *const *const *) + { + } + + void set_indirect_parameters_generic(size_t sz, const void *const *const *ptr) override + { + set_indirect_parameters(sz, reinterpret_cast(ptr)); + } +}; + +} // namespace arm_gemm diff --git a/src/cpu/kernels/assembly/ndrange.hpp b/src/cpu/kernels/assembly/ndrange.hpp new file mode 100644 index 0000000000..1c8261aef7 --- /dev/null +++ b/src/cpu/kernels/assembly/ndrange.hpp @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#include +#include +#include +#include + +namespace arm_gemm +{ +template +class NDRange +{ +private: + std::array m_sizes{}; + std::array m_totalsizes{}; + + class NDRangeIterator + { + private: + const NDRange &m_parent; + unsigned int m_pos = 0; + unsigned int m_end = 0; + + public: + NDRangeIterator(const NDRange &p, unsigned int s, unsigned int e) + : m_parent(p), m_pos(s), m_end(e) + { + } + + bool done() const + { + return (m_pos >= m_end); + } + + unsigned int dim(unsigned int d) const + { + unsigned int r = m_pos; + + if(d < (D - 1)) + { + r %= m_parent.m_totalsizes[d]; + } + + if(d > 0) + { + r /= m_parent.m_totalsizes[d - 1]; + } + + return r; + } + + bool next_dim0() + { + m_pos++; + + return !done(); + } + + bool next_dim1() + { + m_pos += m_parent.m_sizes[0] - dim(0); + + return !done(); + } + + unsigned int dim0_max() const + { + unsigned int offset = std::min(m_end - m_pos, m_parent.m_sizes[0] - dim(0)); + + return dim(0) + offset; + } + }; + + void set_totalsizes() + { + unsigned int t = 1; + + for(unsigned int i = 0; i < D; i++) + { + if(m_sizes[i] == 0) + { + m_sizes[i] = 1; + } + + t *= m_sizes[i]; + + m_totalsizes[i] = t; + } + } + +public: + NDRange &operator=(const NDRange &rhs) = default; + NDRange(const NDRange &rhs) = default; + + template + NDRange(T... ts) + : m_sizes{ ts... } + { + set_totalsizes(); + } + + NDRange(const std::array &n) + : m_sizes(n) + { + set_totalsizes(); + } + + NDRangeIterator iterator(unsigned int start, unsigned int end) const + { + return NDRangeIterator(*this, start, end); + } + + unsigned int total_size() const + { + return m_totalsizes[D - 1]; + } + + unsigned int get_size(unsigned int v) const + { + return m_sizes[v]; + } +}; + +/** NDCoordinate builds upon a range, but specifies a starting position + * in addition to a size which it inherits from NDRange + */ +template +class NDCoordinate : public NDRange +{ + using int_t = unsigned int; + using ndrange_t = NDRange; + + std::array m_positions{}; + +public: + NDCoordinate &operator=(const NDCoordinate &rhs) = default; + NDCoordinate(const NDCoordinate &rhs) = default; + NDCoordinate(const std::initializer_list> &list) + { + std::array sizes{}; + + std::size_t i = 0; + for(auto &p : list) + { + m_positions[i] = p.first; + sizes[i++] = p.second; + } + + //update the parents sizes + static_cast(*this) = ndrange_t(sizes); + } + + int_t get_position(int_t d) const + { + assert(d < N); + + return m_positions[d]; + } + + void set_position(int_t d, int_t v) + { + assert(d < N); + + m_positions[d] = v; + } + + int_t get_position_end(int_t d) const + { + return get_position(d) + ndrange_t::get_size(d); + } +}; //class NDCoordinate + +using ndrange_t = NDRange<6>; +using ndcoord_t = NDCoordinate<6>; + +} // namespace arm_gemm -- cgit v1.2.1