From cac13b1cfd593889271f8e2191be2039b8d88f36 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 27 Apr 2018 19:07:19 +0100 Subject: COMPMID-1097: Port mobilenet to NHWC Change-Id: I789065bfa0d4ef133388e1904c5caf31e450f80f Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/129495 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- arm_compute/graph/TensorDescriptor.h | 74 ++++++++++++++++++++++++++++++++---- 1 file changed, 67 insertions(+), 7 deletions(-) (limited to 'arm_compute/graph/TensorDescriptor.h') diff --git a/arm_compute/graph/TensorDescriptor.h b/arm_compute/graph/TensorDescriptor.h index 785c493cbc..704f015672 100644 --- a/arm_compute/graph/TensorDescriptor.h +++ b/arm_compute/graph/TensorDescriptor.h @@ -26,29 +26,89 @@ #include "arm_compute/graph/Types.h" +#include "arm_compute/core/utils/misc/ICloneable.h" + +#include + namespace arm_compute { namespace graph { /** Tensor metadata class */ -struct TensorDescriptor final +struct TensorDescriptor final : public misc::ICloneable { /** Default Constructor **/ TensorDescriptor() = default; /** Constructor * - * @param[in] tensor_shape Tensor shape - * @param[in] tensor_data_type Tensor data type - * @param[in] tensor_quant_info Tensor quantization info - * @param[in] tensor_target Target to allocate the tensor for + * @param[in] tensor_shape Tensor shape + * @param[in] tensor_data_type Tensor data type + * @param[in] tensor_quant_info Tensor quantization info + * @param[in] tensor_data_layout Tensor data layout + * @param[in] tensor_target Target to allocate the tensor for + */ + TensorDescriptor(TensorShape tensor_shape, + DataType tensor_data_type, + QuantizationInfo tensor_quant_info = QuantizationInfo(), + DataLayout tensor_data_layout = DataLayout::NCHW, + Target tensor_target = Target::UNSPECIFIED) + : shape(tensor_shape), data_type(tensor_data_type), layout(tensor_data_layout), quant_info(tensor_quant_info), target(tensor_target) + { + } + /** Sets tensor descriptor shape + * + * @param[in] tensor_shape Tensor shape to set + * + * @return This tensor descriptor */ - TensorDescriptor(TensorShape tensor_shape, DataType tensor_data_type, QuantizationInfo tensor_quant_info = QuantizationInfo(), Target tensor_target = Target::UNSPECIFIED) - : shape(tensor_shape), data_type(tensor_data_type), quant_info(tensor_quant_info), target(tensor_target) + TensorDescriptor &set_shape(TensorShape &tensor_shape) + { + shape = tensor_shape; + return *this; + } + /** Sets tensor descriptor data type + * + * @param[in] tensor_data_type Data type + * + * @return This tensor descriptor + */ + TensorDescriptor &set_data_type(DataType tensor_data_type) + { + data_type = tensor_data_type; + return *this; + } + /** Sets tensor descriptor data layout + * + * @param[in] data_layout Data layout + * + * @return This tensor descriptor + */ + TensorDescriptor &set_layout(DataLayout data_layout) + { + layout = data_layout; + return *this; + } + /** Sets tensor descriptor quantization info + * + * @param[in] tensor_quant_info Quantization information + * + * @return This tensor descriptor + */ + TensorDescriptor &set_quantization_info(QuantizationInfo tensor_quant_info) + { + quant_info = tensor_quant_info; + return *this; + } + + // Inherited methods overridden: + std::unique_ptr clone() const override { + return support::cpp14::make_unique(*this); } TensorShape shape{}; /**< Tensor shape */ DataType data_type{ DataType::UNKNOWN }; /**< Data type */ + DataLayout layout{ DataLayout::NCHW }; /**< Data layout */ QuantizationInfo quant_info{}; /**< Quantization info */ Target target{ Target::UNSPECIFIED }; /**< Target */ }; -- cgit v1.2.1