diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2018-04-27 19:07:19 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:51:17 +0000 |
commit | cac13b1cfd593889271f8e2191be2039b8d88f36 (patch) | |
tree | d1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /utils/Utils.h | |
parent | ad0c7388f6261989a268ffb2d042f2bd80736e3f (diff) | |
download | ComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz |
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'utils/Utils.h')
-rw-r--r-- | utils/Utils.h | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/utils/Utils.h b/utils/Utils.h index 6241562a28..cadba3a088 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -406,7 +406,14 @@ public: { ARM_COMPUTE_ERROR_ON(!is_open()); ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::U8, DataType::F32); - ARM_COMPUTE_ERROR_ON(tensor.info()->dimension(0) != _width || tensor.info()->dimension(1) != _height || tensor.info()->dimension(2) != 3); + + const DataLayout data_layout = tensor.info()->data_layout(); + const TensorShape tensor_shape = tensor.info()->tensor_shape(); + + ARM_COMPUTE_UNUSED(tensor_shape); + ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)] != _width); + ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)] != _height); + ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)] != 3); try { @@ -423,11 +430,25 @@ public: "Not enough data in file"); ARM_COMPUTE_UNUSED(end_position); + // Stride across channels + size_t stride_z = 0; + // Iterate through every pixel of the image arm_compute::Window window; - window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1)); - window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1)); - window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1)); + if(data_layout == DataLayout::NCHW) + { + window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1)); + window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1)); + window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1)); + stride_z = tensor.info()->strides_in_bytes()[2]; + } + else + { + window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1)); + window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _width, 1)); + window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, _height, 1)); + stride_z = tensor.info()->strides_in_bytes()[0]; + } arm_compute::Iterator out(&tensor, window); @@ -435,8 +456,6 @@ public: unsigned char green = 0; unsigned char blue = 0; - size_t stride_z = tensor.info()->strides_in_bytes()[2]; - arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates & id) { red = _fs.get(); |