aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h9
-rw-r--r--src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp3
2 files changed, 11 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
index da6ef2dea9..26d9e9999d 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
@@ -52,6 +52,11 @@ class IInterleavedStrategy
public:
/** Virtual Destructor */
virtual ~IInterleavedStrategy() = default;
+ /** Return output height of the interleaved strategy
+ *
+ * @return Output height of strategy
+ */
+ virtual unsigned int out_height() const = 0;
/** Instantiate and configure a prepareB Kernel
*
* @param[in] b Input tensor B.
@@ -117,6 +122,10 @@ public:
public:
// Inherited methods overridden
+ unsigned int out_height() const override
+ {
+ return strategy::out_height();
+ }
std::unique_ptr<NEGEMMInterleavedPrepareBWrapperKernel> instantiate_prepareB(const ITensor *b,
ITensor *transformed_b,
const INEGEMMWrapperKernel::Params &params,
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
index 695fc859de..34aaea0ef1 100644
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -207,7 +207,7 @@ void NEGEMMInterleavedWrapper::prepare()
//Maximum number of workloads to create:
const unsigned int num_threads = NEScheduler::get().num_threads();
- const unsigned int max_iterations = num_threads == 1 ? 1 : num_threads;
+ const unsigned int max_iterations = std::max(num_threads, _num_windows);
//Maximum number of iterations the parameters allow:
const unsigned int num_iterations = _batch_window.num_iterations_total();
// Keep the smallest of the two:
@@ -357,6 +357,7 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe
// Get strategy
std::unique_ptr<detail::IInterleavedStrategy> strategy = detail::create_strategy(gemm_kernel_info.name);
+ _num_windows = iceildiv(_params.M, strategy->out_height()) * _params.batches;
ARM_COMPUTE_ERROR_ON(strategy == nullptr);
if(!_pretranspose_b)