aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
index 92c911c370..da5ac22fdc 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp
@@ -133,7 +133,10 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
if(_a_offset != 0)
{
TensorShape shape_vector_sum_col = b->info()->tensor_shape();
- shape_vector_sum_col.remove_dimension(1);
+ if(b->info()->num_dimensions() > 1)
+ {
+ shape_vector_sum_col.remove_dimension(1);
+ }
TensorInfo info_vector_sum_col(shape_vector_sum_col, 1, DataType::S32);
_vector_sum_col.allocator()->init(info_vector_sum_col);
_memory_group.manage(&_vector_sum_col);
@@ -147,7 +150,10 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b,
{
TensorShape shape_vector_sum_row = a->info()->tensor_shape();
shape_vector_sum_row.set(Window::DimX, a->info()->dimension(1));
- shape_vector_sum_row.remove_dimension(1);
+ if(a->info()->num_dimensions() > 1)
+ {
+ shape_vector_sum_row.remove_dimension(1);
+ }
TensorInfo info_vector_sum_row(shape_vector_sum_row, 1, DataType::S32);
_vector_sum_row.allocator()->init(info_vector_sum_row);
_memory_group.manage(&_vector_sum_row);