aboutsummaryrefslogtreecommitdiff
path: root/utils/Utils.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-04-27 19:07:19 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:17 +0000
commitcac13b1cfd593889271f8e2191be2039b8d88f36 (patch)
treed1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /utils/Utils.h
parentad0c7388f6261989a268ffb2d042f2bd80736e3f (diff)
downloadComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'utils/Utils.h')
-rw-r--r--utils/Utils.h31
1 files changed, 25 insertions, 6 deletions
diff --git a/utils/Utils.h b/utils/Utils.h
index 6241562a28..cadba3a088 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -406,7 +406,14 @@ public:
{
ARM_COMPUTE_ERROR_ON(!is_open());
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::U8, DataType::F32);
- ARM_COMPUTE_ERROR_ON(tensor.info()->dimension(0) != _width || tensor.info()->dimension(1) != _height || tensor.info()->dimension(2) != 3);
+
+ const DataLayout data_layout = tensor.info()->data_layout();
+ const TensorShape tensor_shape = tensor.info()->tensor_shape();
+
+ ARM_COMPUTE_UNUSED(tensor_shape);
+ ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)] != _width);
+ ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)] != _height);
+ ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)] != 3);
try
{
@@ -423,11 +430,25 @@ public:
"Not enough data in file");
ARM_COMPUTE_UNUSED(end_position);
+ // Stride across channels
+ size_t stride_z = 0;
+
// Iterate through every pixel of the image
arm_compute::Window window;
- window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1));
- window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1));
- window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1));
+ if(data_layout == DataLayout::NCHW)
+ {
+ window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1));
+ window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1));
+ window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1));
+ stride_z = tensor.info()->strides_in_bytes()[2];
+ }
+ else
+ {
+ window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1));
+ window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _width, 1));
+ window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, _height, 1));
+ stride_z = tensor.info()->strides_in_bytes()[0];
+ }
arm_compute::Iterator out(&tensor, window);
@@ -435,8 +456,6 @@ public:
unsigned char green = 0;
unsigned char blue = 0;
- size_t stride_z = tensor.info()->strides_in_bytes()[2];
-
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates & id)
{
red = _fs.get();