diff options
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r-- | arm_compute/runtime/NEON/AssemblyHelper.h | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h index 3db419e148..ecaf35ac3e 100644 --- a/arm_compute/runtime/NEON/AssemblyHelper.h +++ b/arm_compute/runtime/NEON/AssemblyHelper.h @@ -84,7 +84,12 @@ public: const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput); const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); - const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput); + // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is + // the relevant multiple of the row stride. + const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC; + const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z(); + + const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput); const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); @@ -158,7 +163,7 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d const int M = d->info()->tensor_shape().y(); const int N = d->info()->tensor_shape().x(); const int K = a->info()->tensor_shape().x(); - const int batches = a->info()->tensor_shape().total_size_upper(2); + const int batches = d->info()->tensor_shape().total_size_upper(2); const int multis = b->info()->tensor_shape().z(); unsigned int num_threads = NEScheduler::get().num_threads(); |