From d56e770e7c394d13706a21ee350e7dafe4278987 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Wed, 28 Feb 2018 14:29:36 +0000 Subject: COMPMID-979: Add NHWC data layout to the tensor's metadata (Part 2) Change-Id: I24aa35a85834abf0c9954aba714aeae654615b44 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122646 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- arm_compute/core/Helpers.h | 2 +- arm_compute/core/Helpers.inl | 10 +++++----- arm_compute/core/ITensor.h | 4 ++-- arm_compute/core/ITensorInfo.h | 7 +++++++ arm_compute/core/SubTensorInfo.h | 5 +++++ arm_compute/core/TensorInfo.h | 5 +++++ src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp | 1 + src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp | 1 + src/core/CL/kernels/CLTransposeKernel.cpp | 3 ++- src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp | 1 + src/core/NEON/kernels/NETransposeKernel.cpp | 3 ++- 11 files changed, 32 insertions(+), 10 deletions(-) diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index c91299f218..24ba521d60 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -655,7 +655,7 @@ inline int coords2index(const TensorShape &shape, const Coordinates &coord); * * @return The int conversion of the requested data layout index. */ -inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLayoutDimension data_layout_dimension); +inline size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension); } // namespace arm_compute #include "arm_compute/core/Helpers.inl" diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl index ff85773abb..3db8369f08 100644 --- a/arm_compute/core/Helpers.inl +++ b/arm_compute/core/Helpers.inl @@ -370,9 +370,9 @@ inline int coords2index(const TensorShape &shape, const Coordinates &coord) return index; } -inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLayoutDimension data_layout_dimension) +inline size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension) { - ARM_COMPUTE_ERROR_ON_MSG(info.data_layout() == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); + ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); /* Return the index based on the data layout * [N C H W] @@ -382,13 +382,13 @@ inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLa switch(data_layout_dimension) { case DataLayoutDimension::CHANNEL: - return (info.data_layout() == DataLayout::NCHW) ? 2 : 0; + return (data_layout == DataLayout::NCHW) ? 2 : 0; break; case DataLayoutDimension::HEIGHT: - return (info.data_layout() == DataLayout::NCHW) ? 1 : 2; + return (data_layout == DataLayout::NCHW) ? 1 : 2; break; case DataLayoutDimension::WIDTH: - return (info.data_layout() == DataLayout::NCHW) ? 0 : 1; + return (data_layout == DataLayout::NCHW) ? 0 : 1; break; case DataLayoutDimension::BATCHES: return 3; diff --git a/arm_compute/core/ITensor.h b/arm_compute/core/ITensor.h index 202b50a0d8..1ef9c6d3f6 100644 --- a/arm_compute/core/ITensor.h +++ b/arm_compute/core/ITensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,7 @@ #ifndef __ARM_COMPUTE_ITENSOR_H__ #define __ARM_COMPUTE_ITENSOR_H__ -#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/ITensorInfo.h" #include diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h index 50a1eb2ff1..167fb41bb3 100644 --- a/arm_compute/core/ITensorInfo.h +++ b/arm_compute/core/ITensorInfo.h @@ -132,6 +132,13 @@ public: * @return Dimension of the requested dimension */ virtual size_t dimension(size_t index) const = 0; + /** Return the size of the requested data layout dimension + * + * @param[in] dimension DataLayoutDimension of the dimension + * + * @return Dimension of the requested dimension + */ + virtual size_t dimension(DataLayoutDimension dimension) const = 0; /** The strides in bytes for accessing each dimension of the tensor * * @return Strides in bytes for each tensor dimension diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h index f9ed99b308..882e4ec1d0 100644 --- a/arm_compute/core/SubTensorInfo.h +++ b/arm_compute/core/SubTensorInfo.h @@ -127,6 +127,11 @@ public: { return _tensor_shape[index]; } + size_t dimension(DataLayoutDimension dimension) const override + { + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + return get_data_layout_dimension_index(_parent->data_layout(), dimension); + } const Strides &strides_in_bytes() const override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index 27cf5bae82..97f9d03dc7 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -28,6 +28,7 @@ #include "ITensorInfo.h" #include "arm_compute/core/Coordinates.h" +#include "arm_compute/core/Helpers.h" #include "arm_compute/core/Strides.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" @@ -228,6 +229,10 @@ public: { return _tensor_shape[index]; } + size_t dimension(DataLayoutDimension dimension) const override + { + return get_data_layout_dimension_index(_data_layout, dimension); + } const Strides &strides_in_bytes() const override { return _strides_in_bytes; diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp index ae498ec8a7..3f705ac0a7 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp @@ -31,6 +31,7 @@ #include "arm_compute/core/CL/OpenCL.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp index 6c31e371da..3143075a9d 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.cpp @@ -32,6 +32,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/FixedPoint.h" #include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" diff --git a/src/core/CL/kernels/CLTransposeKernel.cpp b/src/core/CL/kernels/CLTransposeKernel.cpp index deb22e3044..c295cad599 100644 --- a/src/core/CL/kernels/CLTransposeKernel.cpp +++ b/src/core/CL/kernels/CLTransposeKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -30,6 +30,7 @@ #include "arm_compute/core/CL/OpenCL.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" diff --git a/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp index 4ab6f3e89d..47bfebcc09 100644 --- a/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp +++ b/src/core/GLES_COMPUTE/kernels/GCIm2ColKernel.cpp @@ -32,6 +32,7 @@ #include "arm_compute/core/GLES_COMPUTE/OpenGLES.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Size2D.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" #include "support/ToolchainSupport.h" diff --git a/src/core/NEON/kernels/NETransposeKernel.cpp b/src/core/NEON/kernels/NETransposeKernel.cpp index fc22b05823..92271378ff 100644 --- a/src/core/NEON/kernels/NETransposeKernel.cpp +++ b/src/core/NEON/kernels/NETransposeKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" -- cgit v1.2.1