aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/AssemblyHelper.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r--arm_compute/runtime/NEON/AssemblyHelper.h52
1 files changed, 37 insertions, 15 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 3aa43ec96e..c4ba1a584e 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -51,7 +51,7 @@ public:
using TypeResult = TypeOutput;
/** Default constructor. */
AssemblyKernelGlue()
- : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr)
+ : _gemm_kernel_asm(nullptr), _optimised_kernel(nullptr), _a(nullptr), _b(nullptr), _d(nullptr), _workspace(nullptr), _pretranspose(nullptr), _is_prepared(false)
{
}
/** Assembly Gemm */
@@ -76,6 +76,31 @@ public:
ITensor *_workspace;
/** Pre-transpose tensor */
ITensor *_pretranspose;
+ /** Prepared flag */
+ bool _is_prepared;
+
+ /** Runs a preparation step, usually for pre-transposing matrix b */
+ void prepare()
+ {
+ // Pretranspose B if required
+ if(_gemm_kernel_asm->B_pretranspose_required())
+ {
+ const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
+ const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+
+ // Forcing 128-byte alignment (required by 32-bit kernels)
+ const unsigned int alignment = 128;
+ void *raw_ptr = reinterpret_cast<void *>(_pretranspose->buffer());
+ size_t space = _pretranspose->info()->total_size();
+ void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space);
+ ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr);
+ _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b);
+ _b->mark_as_unused();
+ }
+
+ _is_prepared = true;
+ }
/** Configures the arrays pointers and strides in the assembly kernel and executes the assembly kernel.
* The call to set_arrays is needed to deal with the input sizes containing batches (dims > 2)
@@ -102,28 +127,25 @@ public:
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
- // Set workspace if needed
+ // Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
if(_workspace != nullptr)
{
_gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(_workspace->buffer()));
+ const unsigned int window_size = _gemm_kernel_asm->get_window_size();
+ unsigned int num_threads = NEScheduler::get().num_threads();
+ if(window_size < num_threads)
+ {
+ num_threads = window_size;
+ _gemm_kernel_asm->set_nthreads(num_threads);
+ }
}
+ // Prepare assembly kernel
+ prepare();
+
// Set gemm parameters
_gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
- // Pretranspose B if required
- if(_gemm_kernel_asm->B_pretranspose_required())
- {
- // Forcing 128-byte alignment (required by 32-bit kernels)
- const unsigned int alignment = 128;
- void *raw_ptr = reinterpret_cast<void *>(_pretranspose->buffer());
- size_t space = _pretranspose->info()->total_size();
- void *aligned_ptr = support::cpp11::align(alignment, _gemm_kernel_asm->get_B_pretransposed_array_size(), raw_ptr, space);
- ARM_COMPUTE_ERROR_ON(_pretranspose == nullptr || _pretranspose->buffer() == nullptr);
- _gemm_kernel_asm->pretranspose_B_array(aligned_ptr, in1_ptr, ldb, multi_stride_b);
- _b->mark_as_unused();
- }
-
// Schedule assembly kernel
NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
}