aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/TensorDescriptor.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/TensorDescriptor.h')
-rw-r--r--arm_compute/graph/TensorDescriptor.h74
1 files changed, 67 insertions, 7 deletions
diff --git a/arm_compute/graph/TensorDescriptor.h b/arm_compute/graph/TensorDescriptor.h
index 785c493cbc..704f015672 100644
--- a/arm_compute/graph/TensorDescriptor.h
+++ b/arm_compute/graph/TensorDescriptor.h
@@ -26,29 +26,89 @@
#include "arm_compute/graph/Types.h"
+#include "arm_compute/core/utils/misc/ICloneable.h"
+
+#include <memory>
+
namespace arm_compute
{
namespace graph
{
/** Tensor metadata class */
-struct TensorDescriptor final
+struct TensorDescriptor final : public misc::ICloneable<TensorDescriptor>
{
/** Default Constructor **/
TensorDescriptor() = default;
/** Constructor
*
- * @param[in] tensor_shape Tensor shape
- * @param[in] tensor_data_type Tensor data type
- * @param[in] tensor_quant_info Tensor quantization info
- * @param[in] tensor_target Target to allocate the tensor for
+ * @param[in] tensor_shape Tensor shape
+ * @param[in] tensor_data_type Tensor data type
+ * @param[in] tensor_quant_info Tensor quantization info
+ * @param[in] tensor_data_layout Tensor data layout
+ * @param[in] tensor_target Target to allocate the tensor for
+ */
+ TensorDescriptor(TensorShape tensor_shape,
+ DataType tensor_data_type,
+ QuantizationInfo tensor_quant_info = QuantizationInfo(),
+ DataLayout tensor_data_layout = DataLayout::NCHW,
+ Target tensor_target = Target::UNSPECIFIED)
+ : shape(tensor_shape), data_type(tensor_data_type), layout(tensor_data_layout), quant_info(tensor_quant_info), target(tensor_target)
+ {
+ }
+ /** Sets tensor descriptor shape
+ *
+ * @param[in] tensor_shape Tensor shape to set
+ *
+ * @return This tensor descriptor
*/
- TensorDescriptor(TensorShape tensor_shape, DataType tensor_data_type, QuantizationInfo tensor_quant_info = QuantizationInfo(), Target tensor_target = Target::UNSPECIFIED)
- : shape(tensor_shape), data_type(tensor_data_type), quant_info(tensor_quant_info), target(tensor_target)
+ TensorDescriptor &set_shape(TensorShape &tensor_shape)
+ {
+ shape = tensor_shape;
+ return *this;
+ }
+ /** Sets tensor descriptor data type
+ *
+ * @param[in] tensor_data_type Data type
+ *
+ * @return This tensor descriptor
+ */
+ TensorDescriptor &set_data_type(DataType tensor_data_type)
+ {
+ data_type = tensor_data_type;
+ return *this;
+ }
+ /** Sets tensor descriptor data layout
+ *
+ * @param[in] data_layout Data layout
+ *
+ * @return This tensor descriptor
+ */
+ TensorDescriptor &set_layout(DataLayout data_layout)
+ {
+ layout = data_layout;
+ return *this;
+ }
+ /** Sets tensor descriptor quantization info
+ *
+ * @param[in] tensor_quant_info Quantization information
+ *
+ * @return This tensor descriptor
+ */
+ TensorDescriptor &set_quantization_info(QuantizationInfo tensor_quant_info)
+ {
+ quant_info = tensor_quant_info;
+ return *this;
+ }
+
+ // Inherited methods overridden:
+ std::unique_ptr<TensorDescriptor> clone() const override
{
+ return support::cpp14::make_unique<TensorDescriptor>(*this);
}
TensorShape shape{}; /**< Tensor shape */
DataType data_type{ DataType::UNKNOWN }; /**< Data type */
+ DataLayout layout{ DataLayout::NCHW }; /**< Data layout */
QuantizationInfo quant_info{}; /**< Quantization info */
Target target{ Target::UNSPECIFIED }; /**< Target */
};