aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/gemm_common.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/gemm_common.hpp40
1 files changed, 30 insertions, 10 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
index d17fd5fe97..ea9b524e15 100644
--- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 Arm Limited.
+ * Copyright (c) 2017-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,7 +23,10 @@
*/
#pragma once
+#include "arm_compute/core/NEON/kernels/assembly/arm_gemm_compute_iface.hpp"
+
#include <cstddef>
+#include <cassert>
#define UNUSED(x) (void)(x)
@@ -51,10 +54,10 @@ public:
void *C, const int ldc, const int C_batch_stride, const int C_multi_stride,
const void *bias, /* no row or batch stride needed */ const int bias_multi_stride) = 0;
- /* For threading, we divide the work into some number of units and work
- * out internally what unit corresponds to what work. This returns the
- * total number of units. */
- virtual unsigned int get_window_size() const = 0;
+ /** @returns an ndrange containing ranges of the compute space which can be
+ * broken up and parallelised over
+ */
+ virtual ndrange_t get_window_size() const = 0;
/* The maximum thread count is specified when the GEMM is created. Some
* implementations need to know how many threads will actually run in
@@ -73,9 +76,12 @@ public:
/* Whether this GEMM can be dynamically scheduled or not. */
virtual bool supports_dynamic_scheduling() const { return false; }
- /* Actually do the work. Provide a threadid to index any per-thread
- * buffers, and a start/end range to indicate which work to do. */
- virtual void execute(unsigned int, unsigned int, int) = 0;
+ /** Main execute member fucntion
+ * @param [in] work_range specifies the range of work we want to be computed, total range defined by get_window_size()
+ * @param [in] thread_locator where are we inside of the thread space
+ * @naram [in] threadid a unique threadid
+ */
+ virtual void execute(const ndcoord_t& work_range, const ndcoord_t& thread_locator, int threadid) = 0;
/*** Working space interface (optional) ***/
/* Total number of bytes of temporary working space needed. If zero, it's not necessary to call set_working_space(). */
@@ -108,8 +114,7 @@ public:
virtual ~IGemmCommon() { }
};
-/*
- * "Real" GemmCommon class which is templated on the operand and return types.
+/* "Real" GemmCommon class which is templated on the operand and return types.
*
* In addition to correctly typed versions of the functions that operate on
* operand and return data, this class provides a default implementation of
@@ -178,4 +183,19 @@ public:
}
};
+template<typename GemmKernel>
+inline
+int unsigned get_total_window_size(const GemmKernel& kernel)
+{
+ auto window=kernel.get_window_size();
+
+ unsigned int total = 1;
+ for(unsigned i = 0; i != arm_gemm::ndrange_max; ++i)
+ {
+ total *= window.get_size(i);
+ }
+
+ return total;
+}
+
} // namespace arm_gemm