From 39412954c72ca1cca20153e88fb8bfbac8b1dd15 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Tue, 14 Aug 2018 17:06:16 +0100 Subject: COMPMID-1506 NPY Loader doesn't work for NHWC pipelines Change-Id: I696fcded606e82a91526a9471f16fa2d1226ff4f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144144 Tested-by: Jenkins Reviewed-by: Georgios Pinitas --- utils/Utils.cpp | 5 +---- utils/Utils.h | 70 ++++++++++++++++++++++----------------------------------- 2 files changed, 28 insertions(+), 47 deletions(-) (limited to 'utils') diff --git a/utils/Utils.cpp b/utils/Utils.cpp index 133248e30c..057d309b2e 100644 --- a/utils/Utils.cpp +++ b/utils/Utils.cpp @@ -223,10 +223,7 @@ std::tuple, bool, std::string> parse_npy_header(std:: std::string typestr; npy::parse_header(header, typestr, fortran_order, shape); - if(!fortran_order) - { - std::reverse(shape.begin(), shape.end()); - } + std::reverse(shape.begin(), shape.end()); return std::make_tuple(shape, fortran_order, typestr); } diff --git a/utils/Utils.h b/utils/Utils.h index adb0e54f54..0bbdcc25d1 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -393,21 +393,22 @@ public: } } } - // Validate tensor shape - ARM_COMPUTE_ERROR_ON_MSG(_shape.size() != tensor.info()->tensor_shape().num_dimensions(), "Tensor ranks mismatch"); - if(!_fortran_order) + + TensorShape permuted_shape = tensor.info()->tensor_shape(); + arm_compute::PermutationVector perm; + if(are_layouts_different && tensor.info()->tensor_shape().num_dimensions() > 2) { - for(size_t i = 0; i < _shape.size(); ++i) - { - ARM_COMPUTE_ERROR_ON_MSG(tensor.info()->tensor_shape()[i] != _shape[i], "Tensor dimensions mismatch"); - } + perm = (tensor.info()->data_layout() == arm_compute::DataLayout::NHWC) ? arm_compute::PermutationVector(2U, 0U, 1U) : arm_compute::PermutationVector(1U, 2U, 0U); + arm_compute::PermutationVector perm_vec = (tensor.info()->data_layout() == arm_compute::DataLayout::NCHW) ? arm_compute::PermutationVector(2U, 0U, 1U) : arm_compute::PermutationVector(1U, 2U, 0U); + + arm_compute::permute(permuted_shape, perm_vec); } - else + + // Validate tensor shape + ARM_COMPUTE_ERROR_ON_MSG(_shape.size() != tensor.info()->tensor_shape().num_dimensions(), "Tensor ranks mismatch"); + for(size_t i = 0; i < _shape.size(); ++i) { - for(size_t i = 0; i < _shape.size(); ++i) - { - ARM_COMPUTE_ERROR_ON_MSG(tensor.info()->tensor_shape()[i] != _shape[_shape.size() - i - 1], "Tensor dimensions mismatch"); - } + ARM_COMPUTE_ERROR_ON_MSG(permuted_shape[i] != _shape[i], "Tensor dimensions mismatch"); } switch(tensor.info()->data_type()) @@ -423,48 +424,31 @@ public: else { // If tensor has padding or is in fortran order accessing tensor elements through execution window. - Window window; - TensorShape permuted_shape = tensor.info()->tensor_shape(); - const unsigned int num_dims = _shape.size(); - arm_compute::PermutationVector perm; + Window window; + const unsigned int num_dims = _shape.size(); if(_fortran_order) { for(unsigned int dim = 0; dim < num_dims; dim++) { - permuted_shape.set(dim, _shape[dim]); + permuted_shape.set(dim, _shape[num_dims - dim - 1]); perm.set(dim, num_dims - dim - 1); } - } - else - { - for(unsigned int dim = 0; dim < num_dims; dim++) - { - perm.set(dim, dim); - } - } - if(are_layouts_different) - { - // Permute only if num_dimensions greater than 2 - if(num_dims > 2) + if(are_layouts_different) { - if(_file_layout == DataLayout::NHWC) // i.e destination is NCHW --> permute(1,2,0) + // Permute only if num_dimensions greater than 2 + if(num_dims > 2) { - size_t perm_0 = perm[0]; - perm[0] = perm[1]; - perm[1] = perm[2]; - perm[2] = perm_0; - } - else - { - // destination layout is NHWC --> permute (2,0,1) - size_t perm_0 = perm[0]; - perm[0] = perm[2]; - perm[2] = perm[1]; - perm[1] = perm_0; + if(_file_layout == DataLayout::NHWC) // i.e destination is NCHW --> permute(1,2,0) + { + arm_compute::permute(perm, arm_compute::PermutationVector(1U, 2U, 0U)); + } + else + { + arm_compute::permute(perm, arm_compute::PermutationVector(2U, 0U, 1U)); + } } } } - window.use_tensor_dimensions(permuted_shape); execute_window_loop(window, [&](const Coordinates & id) -- cgit v1.2.1