From cac13b1cfd593889271f8e2191be2039b8d88f36 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 27 Apr 2018 19:07:19 +0100 Subject: COMPMID-1097: Port mobilenet to NHWC Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- utils/Utils.h | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) (limited to 'utils/Utils.h') diff --git a/utils/Utils.h b/utils/Utils.h index 6241562a28..cadba3a088 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -406,7 +406,14 @@ public: { ARM_COMPUTE_ERROR_ON(!is_open()); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::U8, DataType::F32); - ARM_COMPUTE_ERROR_ON(tensor.info()->dimension(0) != _width || tensor.info()->dimension(1) != _height || tensor.info()->dimension(2) != 3); + + const DataLayout data_layout = tensor.info()->data_layout(); + const TensorShape tensor_shape = tensor.info()->tensor_shape(); + + ARM_COMPUTE_UNUSED(tensor_shape); + ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)] != _width); + ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)] != _height); + ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)] != 3); try { @@ -423,11 +430,25 @@ public: "Not enough data in file"); ARM_COMPUTE_UNUSED(end_position); + // Stride across channels + size_t stride_z = 0; + // Iterate through every pixel of the image arm_compute::Window window; - window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1)); - window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1)); - window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1)); + if(data_layout == DataLayout::NCHW) + { + window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1)); + window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1)); + window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1)); + stride_z = tensor.info()->strides_in_bytes()[2]; + } + else + { + window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1)); + window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _width, 1)); + window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, _height, 1)); + stride_z = tensor.info()->strides_in_bytes()[0]; + } arm_compute::Iterator out(&tensor, window); @@ -435,8 +456,6 @@ public: unsigned char green = 0; unsigned char blue = 0; - size_t stride_z = tensor.info()->strides_in_bytes()[2]; - arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates & id) { red = _fs.get(); -- cgit v1.2.1