aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2021-01-08 03:14:31 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2021-01-12 03:50:44 +0000
commitb14a0f0c1c72a2365c42f7bd1ff698f8fb94c070 (patch)
treef35a184fc9775ff4a74cd0a0354c31bda459f6fb
parentf8f0442e9a6105be0e32f4defec5fbc10248ea6e (diff)
downloadComputeLibrary-b14a0f0c1c72a2365c42f7bd1ff698f8fb94c070.tar.gz
Add meta-data to express dynamic shapes in ITensorInfo
Add `get_tensor_shape_state` and `set_tensor_shape_state` to inject shape dynamism. The state is represented by an array of integers which index maps to the respective shape dimension index. If -1 is passed as a dimension state then the corresponding dimension is dynamic. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: I3a8a5ad109b90d4df8545b460a9f8dfcc13dfa0f Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4784 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/ITensorInfo.h28
-rw-r--r--arm_compute/core/SubTensorInfo.h25
-rw-r--r--arm_compute/core/TensorInfo.h16
-rw-r--r--src/core/SubTensorInfo.cpp13
-rw-r--r--src/core/TensorInfo.cpp12
5 files changed, 60 insertions, 34 deletions
diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h
index 3eb7239460..9ddafce7c0 100644
--- a/arm_compute/core/ITensorInfo.h
+++ b/arm_compute/core/ITensorInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,6 +40,9 @@ namespace arm_compute
class ITensorInfo : public misc::ICloneable<ITensorInfo>
{
public:
+ using TensorDimsState = Coordinates;
+
+public:
/** Default virtual destructor */
virtual ~ITensorInfo() = default;
/** Set the data type to the specified value.
@@ -81,6 +84,17 @@ public:
* @return Reference to this ITensorInfo object
*/
virtual ITensorInfo &set_tensor_shape(const TensorShape &shape) = 0;
+ /** Set the state for each dimension of the tensor
+ *
+ * This sets the state of each dimension of the shape in terms of dynamic behavior using -1 where appropriate.
+ * The index in the state is a 1 to 1 mapping with the shape dimension index.
+ * For example if you want to express [?, 3, 3] as a dynamic input then [-1, 3, 3] has to be set as a state
+ *
+ * @param[in] state Tensor dimensions state
+ *
+ * @return Reference to this ITensorInfo object
+ */
+ virtual ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) = 0;
/** Set the quantization settings (scale and offset) of the tensor.
*
* @param[in] quantization_info QuantizationInfo containing the scale and offset
@@ -170,6 +184,11 @@ public:
* @return A vector with the size for each dimension of the tensor
*/
virtual const TensorShape &tensor_shape() const = 0;
+ /** State of each dimension of the tensor shape
+ *
+ * @return A vector with the state for each dimension of the tensor, where -1 specifies dynamic dimension
+ */
+ virtual const TensorDimsState &tensor_dims_state() const = 0;
/** Data type used for each element of the tensor
*
* @return Tensor data type
@@ -212,13 +231,6 @@ public:
* @return Reference to this ITensorInfo object
*/
virtual ITensorInfo &set_is_resizable(bool is_resizable) = 0;
- /** Set the flag whether the tensor size is dynamic.
- *
- * @param[in] is_dynamic Flag that marks the tensor if it's dynamic.
- *
- * @return Reference to this ITensorInfo object
- */
- virtual ITensorInfo &set_is_dynamic(bool is_dynamic) = 0;
/** Valid region of the tensor. All elements in the valid region have defined values, i.e. are not undefined.
*
* @return The valid region.
diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h
index 6654ccf00a..1b2278d99b 100644
--- a/arm_compute/core/SubTensorInfo.h
+++ b/arm_compute/core/SubTensorInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -98,6 +98,7 @@ public:
return *this;
};
ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
+ ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override;
ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
@@ -155,6 +156,11 @@ public:
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
return _tensor_shape;
}
+ const TensorDimsState &tensor_dims_state() const override
+ {
+ ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+ return _dims_state;
+ }
DataType data_type() const override
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
@@ -196,12 +202,6 @@ public:
_parent->set_is_resizable(is_resizable);
return *this;
}
- ITensorInfo &set_is_dynamic(bool is_dynamic) override
- {
- ARM_COMPUTE_ERROR_ON(_parent == nullptr);
- _parent->set_is_dynamic(is_dynamic);
- return *this;
- }
ValidRegion valid_region() const override
{
return _valid_region;
@@ -228,11 +228,12 @@ public:
}
private:
- ITensorInfo *_parent;
- TensorShape _tensor_shape;
- Coordinates _coords;
- ValidRegion _valid_region;
- bool _extend_parent;
+ ITensorInfo *_parent;
+ TensorShape _tensor_shape;
+ TensorDimsState _dims_state;
+ Coordinates _coords;
+ ValidRegion _valid_region;
+ bool _extend_parent;
};
} // namespace arm_compute
#endif /*ARM_COMPUTE_SUBTENSORINFO_H */
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h
index 31f27328dd..42a969e01b 100644
--- a/arm_compute/core/TensorInfo.h
+++ b/arm_compute/core/TensorInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -224,6 +224,7 @@ public:
ITensorInfo &set_num_channels(int num_channels) override;
ITensorInfo &set_format(Format format) override;
ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
+ ITensorInfo &set_tensor_dims_state(const TensorDimsState &state) override;
ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override;
ITensorInfo &set_data_layout(const DataLayout &data_layout) override;
ITensorInfo &reset_padding() override;
@@ -262,6 +263,10 @@ public:
{
return _tensor_shape;
}
+ const TensorDimsState &tensor_dims_state() const override
+ {
+ return _dims_state;
+ }
DataType data_type() const override
{
return _data_type;
@@ -288,18 +293,13 @@ public:
}
bool is_dynamic() const override
{
- return _is_dynamic;
+ return std::find(std::cbegin(_dims_state), std::cend(_dims_state), -1) != std::cend(_dims_state);
}
ITensorInfo &set_is_resizable(bool is_resizable) override
{
_is_resizable = is_resizable;
return *this;
}
- ITensorInfo &set_is_dynamic(bool is_dynamic) override
- {
- _is_dynamic = is_dynamic;
- return *this;
- }
ValidRegion valid_region() const override
{
return _valid_region;
@@ -329,10 +329,10 @@ private:
Strides _strides_in_bytes;
size_t _num_channels;
TensorShape _tensor_shape;
+ TensorDimsState _dims_state;
DataType _data_type;
Format _format;
bool _is_resizable;
- bool _is_dynamic;
ValidRegion _valid_region;
PaddingSize _padding;
QuantizationInfo _quantization_info;
diff --git a/src/core/SubTensorInfo.cpp b/src/core/SubTensorInfo.cpp
index bb8ecf60ea..6279992e89 100644
--- a/src/core/SubTensorInfo.cpp
+++ b/src/core/SubTensorInfo.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -56,12 +56,12 @@ TensorShape extend_parent_shape(TensorShape parent_shape, TensorShape shape, Coo
} // namespace
SubTensorInfo::SubTensorInfo()
- : _parent(nullptr), _tensor_shape(), _coords(), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(false)
+ : _parent(nullptr), _tensor_shape(), _dims_state(), _coords(), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(false)
{
}
SubTensorInfo::SubTensorInfo(ITensorInfo *parent, TensorShape tensor_shape, Coordinates coords, bool extend_parent)
- : _parent(parent), _tensor_shape(tensor_shape), _coords(coords), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(extend_parent)
+ : _parent(parent), _tensor_shape(tensor_shape), _dims_state(), _coords(coords), _valid_region{ Coordinates(), _tensor_shape }, _extend_parent(extend_parent)
{
ARM_COMPUTE_ERROR_ON(parent == nullptr);
@@ -107,6 +107,13 @@ ITensorInfo &SubTensorInfo::set_tensor_shape(const TensorShape &shape)
return *this;
}
+ITensorInfo &SubTensorInfo::set_tensor_dims_state(const TensorDimsState &state)
+{
+ ARM_COMPUTE_ERROR_ON(_parent == nullptr);
+ _dims_state = state;
+ return *this;
+}
+
bool SubTensorInfo::extend_padding(const PaddingSize &padding)
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
diff --git a/src/core/TensorInfo.cpp b/src/core/TensorInfo.cpp
index 7b1f9c542a..bedfe147b0 100644
--- a/src/core/TensorInfo.cpp
+++ b/src/core/TensorInfo.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2020 Arm Limited.
+ * Copyright (c) 2016-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,7 @@
using namespace arm_compute;
TensorInfo::TensorInfo()
- : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true }, _is_dynamic{ false },
+ : _total_size(0), _offset_first_element_in_bytes(0), _strides_in_bytes(), _num_channels(0), _tensor_shape(), _dims_state(), _data_type(DataType::UNKNOWN), _format(Format::UNKNOWN), _is_resizable{ true },
_valid_region{ Coordinates(), _tensor_shape }, _padding{ 0 }, _quantization_info(), _data_layout(DataLayout::NCHW)
{
}
@@ -48,10 +48,10 @@ TensorInfo::TensorInfo(const ITensorInfo &info)
_strides_in_bytes = info.strides_in_bytes();
_num_channels = info.num_channels();
_tensor_shape = info.tensor_shape();
+ _dims_state = info.tensor_dims_state();
_data_type = info.data_type();
_format = info.format();
_is_resizable = info.is_resizable();
- _is_dynamic = info.is_dynamic();
_valid_region = info.valid_region();
_padding = info.padding();
_quantization_info = info.quantization_info();
@@ -371,6 +371,12 @@ ITensorInfo &TensorInfo::set_tensor_shape(const TensorShape &shape)
return *this;
}
+ITensorInfo &TensorInfo::set_tensor_dims_state(const TensorDimsState &state)
+{
+ _dims_state = state;
+ return *this;
+}
+
ITensorInfo &TensorInfo::set_quantization_info(const QuantizationInfo &quantization_info)
{
_quantization_info = quantization_info;