From d56e770e7c394d13706a21ee350e7dafe4278987 Mon Sep 17 00:00:00 2001 From: Isabella Gottardi Date: Wed, 28 Feb 2018 14:29:36 +0000 Subject: COMPMID-979: Add NHWC data layout to the tensor's metadata (Part 2) Change-Id: I24aa35a85834abf0c9954aba714aeae654615b44 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122646 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- arm_compute/core/Helpers.h | 2 +- arm_compute/core/Helpers.inl | 10 +++++----- arm_compute/core/ITensor.h | 4 ++-- arm_compute/core/ITensorInfo.h | 7 +++++++ arm_compute/core/SubTensorInfo.h | 5 +++++ arm_compute/core/TensorInfo.h | 5 +++++ 6 files changed, 25 insertions(+), 8 deletions(-) (limited to 'arm_compute/core') diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index c91299f218..24ba521d60 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -655,7 +655,7 @@ inline int coords2index(const TensorShape &shape, const Coordinates &coord); * * @return The int conversion of the requested data layout index. */ -inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLayoutDimension data_layout_dimension); +inline size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension); } // namespace arm_compute #include "arm_compute/core/Helpers.inl" diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl index ff85773abb..3db8369f08 100644 --- a/arm_compute/core/Helpers.inl +++ b/arm_compute/core/Helpers.inl @@ -370,9 +370,9 @@ inline int coords2index(const TensorShape &shape, const Coordinates &coord) return index; } -inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLayoutDimension data_layout_dimension) +inline size_t get_data_layout_dimension_index(const DataLayout data_layout, const DataLayoutDimension data_layout_dimension) { - ARM_COMPUTE_ERROR_ON_MSG(info.data_layout() == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); + ARM_COMPUTE_ERROR_ON_MSG(data_layout == DataLayout::UNKNOWN, "Cannot retrieve the dimension index for an unknown layout!"); /* Return the index based on the data layout * [N C H W] @@ -382,13 +382,13 @@ inline int get_data_layout_dimension_index(const ITensorInfo &info, const DataLa switch(data_layout_dimension) { case DataLayoutDimension::CHANNEL: - return (info.data_layout() == DataLayout::NCHW) ? 2 : 0; + return (data_layout == DataLayout::NCHW) ? 2 : 0; break; case DataLayoutDimension::HEIGHT: - return (info.data_layout() == DataLayout::NCHW) ? 1 : 2; + return (data_layout == DataLayout::NCHW) ? 1 : 2; break; case DataLayoutDimension::WIDTH: - return (info.data_layout() == DataLayout::NCHW) ? 0 : 1; + return (data_layout == DataLayout::NCHW) ? 0 : 1; break; case DataLayoutDimension::BATCHES: return 3; diff --git a/arm_compute/core/ITensor.h b/arm_compute/core/ITensor.h index 202b50a0d8..1ef9c6d3f6 100644 --- a/arm_compute/core/ITensor.h +++ b/arm_compute/core/ITensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -24,7 +24,7 @@ #ifndef __ARM_COMPUTE_ITENSOR_H__ #define __ARM_COMPUTE_ITENSOR_H__ -#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/ITensorInfo.h" #include diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h index 50a1eb2ff1..167fb41bb3 100644 --- a/arm_compute/core/ITensorInfo.h +++ b/arm_compute/core/ITensorInfo.h @@ -132,6 +132,13 @@ public: * @return Dimension of the requested dimension */ virtual size_t dimension(size_t index) const = 0; + /** Return the size of the requested data layout dimension + * + * @param[in] dimension DataLayoutDimension of the dimension + * + * @return Dimension of the requested dimension + */ + virtual size_t dimension(DataLayoutDimension dimension) const = 0; /** The strides in bytes for accessing each dimension of the tensor * * @return Strides in bytes for each tensor dimension diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h index f9ed99b308..882e4ec1d0 100644 --- a/arm_compute/core/SubTensorInfo.h +++ b/arm_compute/core/SubTensorInfo.h @@ -127,6 +127,11 @@ public: { return _tensor_shape[index]; } + size_t dimension(DataLayoutDimension dimension) const override + { + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + return get_data_layout_dimension_index(_parent->data_layout(), dimension); + } const Strides &strides_in_bytes() const override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index 27cf5bae82..97f9d03dc7 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -28,6 +28,7 @@ #include "ITensorInfo.h" #include "arm_compute/core/Coordinates.h" +#include "arm_compute/core/Helpers.h" #include "arm_compute/core/Strides.h" #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" @@ -228,6 +229,10 @@ public: { return _tensor_shape[index]; } + size_t dimension(DataLayoutDimension dimension) const override + { + return get_data_layout_dimension_index(_data_layout, dimension); + } const Strides &strides_in_bytes() const override { return _strides_in_bytes; -- cgit v1.2.1