aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/TensorDescriptor.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-04-27 19:07:19 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:51:17 +0000
commitcac13b1cfd593889271f8e2191be2039b8d88f36 (patch)
treed1c5196877d7fbd5dcfbb9f9003faf6035f82a33 /arm_compute/graph/TensorDescriptor.h
parentad0c7388f6261989a268ffb2d042f2bd80736e3f (diff)
downloadComputeLibrary-cac13b1cfd593889271f8e2191be2039b8d88f36.tar.gz
COMPMID-1097: Port mobilenet to NHWC
Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/graph/TensorDescriptor.h')
-rw-r--r--arm_compute/graph/TensorDescriptor.h74
1 files changed, 67 insertions, 7 deletions
diff --git a/arm_compute/graph/TensorDescriptor.h b/arm_compute/graph/TensorDescriptor.h
index 785c493cbc..704f015672 100644
--- a/arm_compute/graph/TensorDescriptor.h
+++ b/arm_compute/graph/TensorDescriptor.h
@@ -26,29 +26,89 @@
#include "arm_compute/graph/Types.h"
+#include "arm_compute/core/utils/misc/ICloneable.h"
+
+#include <memory>
+
namespace arm_compute
{
namespace graph
{
/** Tensor metadata class */
-struct TensorDescriptor final
+struct TensorDescriptor final : public misc::ICloneable<TensorDescriptor>
{
/** Default Constructor **/
TensorDescriptor() = default;
/** Constructor
*
- * @param[in] tensor_shape Tensor shape
- * @param[in] tensor_data_type Tensor data type
- * @param[in] tensor_quant_info Tensor quantization info
- * @param[in] tensor_target Target to allocate the tensor for
+ * @param[in] tensor_shape Tensor shape
+ * @param[in] tensor_data_type Tensor data type
+ * @param[in] tensor_quant_info Tensor quantization info
+ * @param[in] tensor_data_layout Tensor data layout
+ * @param[in] tensor_target Target to allocate the tensor for
+ */
+ TensorDescriptor(TensorShape tensor_shape,
+ DataType tensor_data_type,
+ QuantizationInfo tensor_quant_info = QuantizationInfo(),
+ DataLayout tensor_data_layout = DataLayout::NCHW,
+ Target tensor_target = Target::UNSPECIFIED)
+ : shape(tensor_shape), data_type(tensor_data_type), layout(tensor_data_layout), quant_info(tensor_quant_info), target(tensor_target)
+ {
+ }
+ /** Sets tensor descriptor shape
+ *
+ * @param[in] tensor_shape Tensor shape to set
+ *
+ * @return This tensor descriptor
*/
- TensorDescriptor(TensorShape tensor_shape, DataType tensor_data_type, QuantizationInfo tensor_quant_info = QuantizationInfo(), Target tensor_target = Target::UNSPECIFIED)
- : shape(tensor_shape), data_type(tensor_data_type), quant_info(tensor_quant_info), target(tensor_target)
+ TensorDescriptor &set_shape(TensorShape &tensor_shape)
+ {
+ shape = tensor_shape;
+ return *this;
+ }
+ /** Sets tensor descriptor data type
+ *
+ * @param[in] tensor_data_type Data type
+ *
+ * @return This tensor descriptor
+ */
+ TensorDescriptor &set_data_type(DataType tensor_data_type)
+ {
+ data_type = tensor_data_type;
+ return *this;
+ }
+ /** Sets tensor descriptor data layout
+ *
+ * @param[in] data_layout Data layout
+ *
+ * @return This tensor descriptor
+ */
+ TensorDescriptor &set_layout(DataLayout data_layout)
+ {
+ layout = data_layout;
+ return *this;
+ }
+ /** Sets tensor descriptor quantization info
+ *
+ * @param[in] tensor_quant_info Quantization information
+ *
+ * @return This tensor descriptor
+ */
+ TensorDescriptor &set_quantization_info(QuantizationInfo tensor_quant_info)
+ {
+ quant_info = tensor_quant_info;
+ return *this;
+ }
+
+ // Inherited methods overridden:
+ std::unique_ptr<TensorDescriptor> clone() const override
{
+ return support::cpp14::make_unique<TensorDescriptor>(*this);
}
TensorShape shape{}; /**< Tensor shape */
DataType data_type{ DataType::UNKNOWN }; /**< Data type */
+ DataLayout layout{ DataLayout::NCHW }; /**< Data layout */
QuantizationInfo quant_info{}; /**< Quantization info */
Target target{ Target::UNSPECIFIED }; /**< Target */
};