diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index 92c911c370..da5ac22fdc 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -133,7 +133,10 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, if(_a_offset != 0) { TensorShape shape_vector_sum_col = b->info()->tensor_shape(); - shape_vector_sum_col.remove_dimension(1); + if(b->info()->num_dimensions() > 1) + { + shape_vector_sum_col.remove_dimension(1); + } TensorInfo info_vector_sum_col(shape_vector_sum_col, 1, DataType::S32); _vector_sum_col.allocator()->init(info_vector_sum_col); _memory_group.manage(&_vector_sum_col); @@ -147,7 +150,10 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, { TensorShape shape_vector_sum_row = a->info()->tensor_shape(); shape_vector_sum_row.set(Window::DimX, a->info()->dimension(1)); - shape_vector_sum_row.remove_dimension(1); + if(a->info()->num_dimensions() > 1) + { + shape_vector_sum_row.remove_dimension(1); + } TensorInfo info_vector_sum_row(shape_vector_sum_row, 1, DataType::S32); _vector_sum_row.allocator()->init(info_vector_sum_row); _memory_group.manage(&_vector_sum_row); |