diff options
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r-- | arm_compute/core/NEON/kernels/assembly/gemm_common.hpp | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp index d17fd5fe97..ea9b524e15 100644 --- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp +++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 Arm Limited. + * Copyright (c) 2017-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,7 +23,10 @@ */ #pragma once +#include "arm_compute/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp" + #include <cstddef> +#include <cassert> #define UNUSED(x) (void)(x) @@ -51,10 +54,10 @@ public: void *C, const int ldc, const int C_batch_stride, const int C_multi_stride, const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0; - /* For threading, we divide the work into some number of units and work - * out internally what unit corresponds to what work. This returns the - * total number of units. */ - virtual unsigned int get_window_size() const = 0; + /** @returns an ndrange containing ranges of the compute space which can be + * broken up and parallelised over + */ + virtual ndrange_t get_window_size() const = 0; /* The maximum thread count is specified when the GEMM is created. Some * implementations need to know how many threads will actually run in @@ -73,9 +76,12 @@ public: /* Whether this GEMM can be dynamically scheduled or not. */ virtual bool supports_dynamic_scheduling() const { return false; } - /* Actually do the work. Provide a threadid to index any per-thread - * buffers, and a start/end range to indicate which work to do. */ - virtual void execute(unsigned int, unsigned int, int) = 0; + /** Main execute member fucntion + * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size() + * @param [in] thread_locator where are we inside of the thread space + * @naram [in] threadid a unique threadid + */ + virtual void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) = 0; /*** Working space interface (optional) ***/ /* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */ @@ -108,8 +114,7 @@ public: virtual ~IGemmCommon() { } }; -/* - * "Real" GemmCommon class which is templated on the operand and return types. +/* "Real" GemmCommon class which is templated on the operand and return types. * * In addition to correctly typed versions of the functions that operate on * operand and return data, this class provides a default implementation of @@ -178,4 +183,19 @@ public: } }; +template<typename GemmKernel> +inline +int unsigned get_total_window_size(const GemmKernel& kernel) +{ + auto window=kernel.get_window_size(); + + unsigned int total = 1; + for(unsigned i = 0; i != arm_gemm::ndrange_max; ++i) + { + total *= window.get_size(i); + } + + return total; +} + } // namespace arm_gemm |