diff options
author | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-01-08 03:14:31 +0000 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2021-01-12 03:50:44 +0000 |
commit | b14a0f0c1c72a2365c42f7bd1ff698f8fb94c070 (patch) | |
tree | f35a184fc9775ff4a74cd0a0354c31bda459f6fb /arm_compute/core/TensorInfo.h | |
parent | f8f0442e9a6105be0e32f4defec5fbc10248ea6e (diff) | |
download | ComputeLibrary-b14a0f0c1c72a2365c42f7bd1ff698f8fb94c070.tar.gz |
Add meta-data to express dynamic shapes in ITensorInfo
Add `get_tensor_shape_state` and `set_tensor_shape_state` to inject
shape dynamism.
The state is represented by an array of integers which index maps to the
respective shape dimension index.
If -1 is passed as a dimension state then the corresponding dimension
is dynamic.
Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com>
Change-Id: I3a8a5ad109b90d4df8545b460a9f8dfcc13dfa0f
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4784
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/TensorInfo.h')
-rw-r--r-- | arm_compute/core/TensorInfo.h | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index 31f27328dd..42a969e01b 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 Arm Limited. + * Copyright (c) 2016-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -224,6 +224,7 @@ public: ITensorInfo &set_num_channels(int num_channels) override; ITensorInfo &set_format(Format format) override; ITensorInfo &set_tensor_shape(const TensorShape &shape) override; + ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override; ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override; ITensorInfo &set_data_layout(const DataLayout &data_layout) override; ITensorInfo &reset_padding() override; @@ -262,6 +263,10 @@ public: { return _tensor_shape; } + const TensorDimsState &tensor_dims_state() const override + { + return _dims_state; + } DataType data_type() const override { return _data_type; @@ -288,18 +293,13 @@ public: } bool is_dynamic() const override { - return _is_dynamic; + return std::find(std::cbegin(_dims_state), std::cend(_dims_state), -1) != std::cend(_dims_state); } ITensorInfo &set_is_resizable(bool is_resizable) override { _is_resizable = is_resizable; return *this; } - ITensorInfo &set_is_dynamic(bool is_dynamic) override - { - _is_dynamic = is_dynamic; - return *this; - } ValidRegion valid_region() const override { return _valid_region; @@ -329,10 +329,10 @@ private: Strides _strides_in_bytes; size_t _num_channels; TensorShape _tensor_shape; + TensorDimsState _dims_state; DataType _data_type; Format _format; bool _is_resizable; - bool _is_dynamic; ValidRegion _valid_region; PaddingSize _padding; QuantizationInfo _quantization_info; |