aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/NEON/AssemblyHelper.h
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-23 15:17:31 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:37 +0000
commite250389ed6d78153a55382fa5b3519c151bfd79f (patch)
tree80c63793769ad18fd0406e7f8b40840aed7ac3ce /arm_compute/runtime/NEON/AssemblyHelper.h
parent79ffadebd8dff7eaecbcfa3a28106736f240f1c5 (diff)
downloadComputeLibrary-e250389ed6d78153a55382fa5b3519c151bfd79f.tar.gz
COMPMID-810 Add NHWC data format support for NEON convolution
Change-Id: I2a7b49a12da7f3bc3f04749243b1dc111160de6e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129348 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/AssemblyHelper.h')
-rw-r--r--arm_compute/runtime/NEON/AssemblyHelper.h9
1 files changed, 7 insertions, 2 deletions
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 3db419e148..ecaf35ac3e 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -84,7 +84,12 @@ public:
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
- const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
+ // the relevant multiple of the row stride.
+ const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC;
+ const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z();
+
+ const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput);
const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
@@ -158,7 +163,7 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d
const int M = d->info()->tensor_shape().y();
const int N = d->info()->tensor_shape().x();
const int K = a->info()->tensor_shape().x();
- const int batches = a->info()->tensor_shape().total_size_upper(2);
+ const int batches = d->info()->tensor_shape().total_size_upper(2);
const int multis = b->info()->tensor_shape().z();
unsigned int num_threads = NEScheduler::get().num_threads();