From b14a0f0c1c72a2365c42f7bd1ff698f8fb94c070 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 8 Jan 2021 03:14:31 +0000 Subject: Add meta-data to express dynamic shapes in ITensorInfo Add `get_tensor_shape_state` and `set_tensor_shape_state` to inject shape dynamism. The state is represented by an array of integers which index maps to the respective shape dimension index. If -1 is passed as a dimension state then the corresponding dimension is dynamic. Signed-off-by: Georgios Pinitas Change-Id: I3a8a5ad109b90d4df8545b460a9f8dfcc13dfa0f Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4784 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- arm_compute/core/ITensorInfo.h | 28 ++++++++++++++++++++-------- arm_compute/core/SubTensorInfo.h | 25 +++++++++++++------------ arm_compute/core/TensorInfo.h | 16 ++++++++-------- src/core/SubTensorInfo.cpp | 13 ++++++++++--- src/core/TensorInfo.cpp | 12 +++++++++--- 5 files changed, 60 insertions(+), 34 deletions(-) diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h index 3eb7239460..9ddafce7c0 100644 --- a/arm_compute/core/ITensorInfo.h +++ b/arm_compute/core/ITensorInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Arm Limited. + * Copyright (c) 2016-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -39,6 +39,9 @@ namespace arm_compute /** Store the tensor's metadata */ class ITensorInfo : public misc::ICloneable { +public: + using TensorDimsState = Coordinates; + public: /** Default virtual destructor */ virtual ~ITensorInfo() = default; @@ -81,6 +84,17 @@ public: * @return Reference to this ITensorInfo object */ virtual ITensorInfo &set_tensor_shape(const TensorShape &shape) = 0; + /** Set the state for each dimension of the tensor + * + * This sets the state of each dimension of the shape in terms of dynamic behavior using -1 where appropriate. + * The index in the state is a 1 to 1 mapping with the shape dimension index. + * For example if you want to express [?, 3, 3] as a dynamic input then [-1, 3, 3] has to be set as a state + * + * @param[in] state Tensor dimensions state + * + * @return Reference to this ITensorInfo object + */ + virtual ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) = 0; /** Set the quantization settings (scale and offset) of the tensor. * * @param[in] quantization_info QuantizationInfo containing the scale and offset @@ -170,6 +184,11 @@ public: * @return A vector with the size for each dimension of the tensor */ virtual const TensorShape &tensor_shape() const = 0; + /** State of each dimension of the tensor shape + * + * @return A vector with the state for each dimension of the tensor, where -1 specifies dynamic dimension + */ + virtual const TensorDimsState &tensor_dims_state() const = 0; /** Data type used for each element of the tensor * * @return Tensor data type @@ -212,13 +231,6 @@ public: * @return Reference to this ITensorInfo object */ virtual ITensorInfo &set_is_resizable(bool is_resizable) = 0; - /** Set the flag whether the tensor size is dynamic. - * - * @param[in] is_dynamic Flag that marks the tensor if it's dynamic. - * - * @return Reference to this ITensorInfo object - */ - virtual ITensorInfo &set_is_dynamic(bool is_dynamic) = 0; /** Valid region of the tensor. All elements in the valid region have defined values, i.e. are not undefined. * * @return The valid region. diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h index 6654ccf00a..1b2278d99b 100644 --- a/arm_compute/core/SubTensorInfo.h +++ b/arm_compute/core/SubTensorInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -98,6 +98,7 @@ public: return *this; }; ITensorInfo &set_tensor_shape(const TensorShape &shape) override; + ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override; ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); @@ -155,6 +156,11 @@ public: ARM_COMPUTE_ERROR_ON(_parent == nullptr); return _tensor_shape; } + const TensorDimsState &tensor_dims_state() const override + { + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + return _dims_state; + } DataType data_type() const override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); @@ -196,12 +202,6 @@ public: _parent->set_is_resizable(is_resizable); return *this; } - ITensorInfo &set_is_dynamic(bool is_dynamic) override - { - ARM_COMPUTE_ERROR_ON(_parent == nullptr); - _parent->set_is_dynamic(is_dynamic); - return *this; - } ValidRegion valid_region() const override { return _valid_region; @@ -228,11 +228,12 @@ public: } private: - ITensorInfo *_parent; - TensorShape _tensor_shape; - Coordinates _coords; - ValidRegion _valid_region; - bool _extend_parent; + ITensorInfo *_parent; + TensorShape _tensor_shape; + TensorDimsState _dims_state; + Coordinates _coords; + ValidRegion _valid_region; + bool _extend_parent; }; } // namespace arm_compute #endif /*ARM_COMPUTE_SUBTENSORINFO_H */ diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index 31f27328dd..42a969e01b 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 Arm Limited. + * Copyright (c) 2016-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -224,6 +224,7 @@ public: ITensorInfo &set_num_channels(int num_channels) override; ITensorInfo &set_format(Format format) override; ITensorInfo &set_tensor_shape(const TensorShape &shape) override; + ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override; ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override; ITensorInfo &set_data_layout(const DataLayout &data_layout) override; ITensorInfo &reset_padding() override; @@ -262,6 +263,10 @@ public: { return _tensor_shape; } + const TensorDimsState &tensor_dims_state() const override + { + return _dims_state; + } DataType data_type() const override { return _data_type; @@ -288,18 +293,13 @@ public: } bool is_dynamic() const override { - return _is_dynamic; + return std::find(std::cbegin(_dims_state), std::cend(_dims_state), -1) != std::cend(_dims_state); } ITensorInfo &set_is_resizable(bool is_resizable) override { _is_resizable = is_resizable; return *this; } - ITensorInfo &set_is_dynamic(bool is_dynamic) override - { - _is_dynamic = is_dynamic; - return *this; - } ValidRegion valid_region() const override { return _valid_region; @@ -329,10 +329,10 @@ private: Strides _strides_in_bytes; size_t _num_channels; TensorShape _tensor_shape; + TensorDimsState _dims_state; DataType _data_type; Format _format; bool _is_resizable; - bool _is_dynamic; ValidRegion _valid_region; PaddingSize _padding; QuantizationInfo _quantization_info; diff --git a/src/core/SubTensorInfo.cpp b/src/core/SubTensorInfo.cpp index bb8ecf60ea..6279992e89 100644 --- a/src/core/SubTensorInfo.cpp +++ b/src/core/SubTensorInfo.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -56,12 +56,12 @@ TensorShape extend_parent_shape(TensorShape parent_shape, TensorShape shape, Coo } // namespace SubTensorInfo::SubTensorInfo() - : _parent(nullptr), _tensor_shape(), _coords(), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(false) + : _parent(nullptr), _tensor_shape(), _dims_state(), _coords(), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(false) { } SubTensorInfo::SubTensorInfo(ITensorInfo *parent, TensorShape tensor_shape, Coordinates coords, bool extend_parent) - : _parent(parent), _tensor_shape(tensor_shape), _coords(coords), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(extend_parent) + : _parent(parent), _tensor_shape(tensor_shape), _dims_state(), _coords(coords), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(extend_parent) { ARM_COMPUTE_ERROR_ON(parent == nullptr); @@ -107,6 +107,13 @@ ITensorInfo &SubTensorInfo::set_tensor_shape(const TensorShape &shape) return *this; } +ITensorInfo &SubTensorInfo::set_tensor_dims_state(const TensorDimsState &state) +{ + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + _dims_state = state; + return *this; +} + bool SubTensorInfo::extend_padding(const PaddingSize &padding) { ARM_COMPUTE_ERROR_ON(_parent == nullptr); diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp index 7b1f9c542a..bedfe147b0 100644 --- a/src/core/TensorInfo.cpp +++ b/src/core/TensorInfo.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Arm Limited. + * Copyright (c) 2016-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,7 @@ using namespace arm_compute; TensorInfo::TensorInfo() - : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true }, _is_dynamic{ false }, + : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _dims_state(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true }, _valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW) { } @@ -48,10 +48,10 @@ TensorInfo::TensorInfo(const ITensorInfo &info) _strides_in_bytes = info.strides_in_bytes(); _num_channels = info.num_channels(); _tensor_shape = info.tensor_shape(); + _dims_state = info.tensor_dims_state(); _data_type = info.data_type(); _format = info.format(); _is_resizable = info.is_resizable(); - _is_dynamic = info.is_dynamic(); _valid_region = info.valid_region(); _padding = info.padding(); _quantization_info = info.quantization_info(); @@ -371,6 +371,12 @@ ITensorInfo &TensorInfo::set_tensor_shape(const TensorShape &shape) return *this; } +ITensorInfo &TensorInfo::set_tensor_dims_state(const TensorDimsState &state) +{ + _dims_state = state; + return *this; +} + ITensorInfo &TensorInfo::set_quantization_info(const QuantizationInfo &quantization_info) { _quantization_info = quantization_info; -- cgit v1.2.1