aboutsummaryrefslogtreecommitdiff
path: root/utils/Utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/Utils.h')
-rw-r--r--utils/Utils.h31
1 files changed, 25 insertions, 6 deletions
diff --git a/utils/Utils.h b/utils/Utils.h
index 6241562a28..cadba3a088 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -406,7 +406,14 @@ public:
{
ARM_COMPUTE_ERROR_ON(!is_open());
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&tensor, 1, DataType::U8, DataType::F32);
- ARM_COMPUTE_ERROR_ON(tensor.info()->dimension(0) != _width || tensor.info()->dimension(1) != _height || tensor.info()->dimension(2) != 3);
+
+ const DataLayout data_layout = tensor.info()->data_layout();
+ const TensorShape tensor_shape = tensor.info()->tensor_shape();
+
+ ARM_COMPUTE_UNUSED(tensor_shape);
+ ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)] != _width);
+ ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)] != _height);
+ ARM_COMPUTE_ERROR_ON(tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)] != 3);
try
{
@@ -423,11 +430,25 @@ public:
"Not enough data in file");
ARM_COMPUTE_UNUSED(end_position);
+ // Stride across channels
+ size_t stride_z = 0;
+
// Iterate through every pixel of the image
arm_compute::Window window;
- window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1));
- window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1));
- window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1));
+ if(data_layout == DataLayout::NCHW)
+ {
+ window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, _width, 1));
+ window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _height, 1));
+ window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, 1, 1));
+ stride_z = tensor.info()->strides_in_bytes()[2];
+ }
+ else
+ {
+ window.set(arm_compute::Window::DimX, arm_compute::Window::Dimension(0, 1, 1));
+ window.set(arm_compute::Window::DimY, arm_compute::Window::Dimension(0, _width, 1));
+ window.set(arm_compute::Window::DimZ, arm_compute::Window::Dimension(0, _height, 1));
+ stride_z = tensor.info()->strides_in_bytes()[0];
+ }
arm_compute::Iterator out(&tensor, window);
@@ -435,8 +456,6 @@ public:
unsigned char green = 0;
unsigned char blue = 0;
- size_t stride_z = tensor.info()->strides_in_bytes()[2];
-
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates & id)
{
red = _fs.get();