aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/Types.h4
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp7
2 files changed, 7 insertions, 4 deletions
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 7d632fec28..d46c93247c 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -1646,8 +1646,8 @@ class GEMMInfo
public:
/** Default constructor */
GEMMInfo()
- : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(false), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false),
- _gemmlowp_output_stage(), _fp_mixed_precision(false)
+ : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(),
+ _fp_mixed_precision(false)
{
}
/** Constructor
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
index 6c201cedb3..41a031c1c7 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedPrepareBWrapperKernel.cpp
@@ -68,7 +68,7 @@ void for_each_element_in_window(const Window &window, const ITensor *b, ITensor
// Calculate the size of transformed_b:
template <typename To, bool use_dot>
-unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, const BlockSizes &bs)
+unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, const BlockSizes &bs, unsigned int multis)
{
using strategy = typename Kernel<To, use_dot>::strategy;
@@ -89,6 +89,9 @@ unsigned int get_B_pretransposed_array_size(unsigned int N, unsigned int K, cons
// Calculate the total size of the buffer:
size_t total = num_full_k * normal_k_size * (num_full_x * normal_x_size + left_over_x_size);
total += left_over_k_size * (left_over_x_size + num_full_x * normal_x_size);
+
+ total *= multis;
+
return total;
}
@@ -114,7 +117,7 @@ void NEGEMMInterleavedPrepareBWrapperKernelTemplate<To, use_dot>::configure(cons
_block_sizes = calculate_block_sizes<strategy>(ci, params.M, params.N, params.K);
- auto_init_if_empty(*transformed_b->info(), b->info()->clone()->set_tensor_shape(TensorShape{ get_B_pretransposed_array_size<To, use_dot>(_Nsize, _Ksize, _block_sizes) }));
+ auto_init_if_empty(*transformed_b->info(), b->info()->clone()->set_tensor_shape(TensorShape{ get_B_pretransposed_array_size<To, use_dot>(_Nsize, _Ksize, _block_sizes, multis) }));
Window window;
window.set(Window::DimX, Window::Dimension(0, ceil_to_multiple(_Nsize, _block_sizes.x_block), _block_sizes.x_block));