From 283c1790da45ab562ecfb2aa7741297191886d85 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 10 Nov 2017 18:14:06 +0000 Subject: COMPMID-676: Rework TensorInfo building Change-Id: Ic98f64ffe30739437a1fe31ef98d83ee900741e3 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95512 Reviewed-by: Michalis Spyrou Tested-by: Kaizen Reviewed-by: Anthony Barbier --- arm_compute/core/Dimensions.h | 11 ++++++++ arm_compute/core/Helpers.h | 9 ++++++ arm_compute/core/Helpers.inl | 15 ++++++++++ arm_compute/core/ITensorInfo.h | 36 ++++++++++++++++-------- arm_compute/core/SubTensorInfo.h | 29 +++++++++++-------- arm_compute/core/TensorInfo.h | 17 ++++++----- arm_compute/core/utils/misc/ICloneable.h | 48 ++++++++++++++++++++++++++++++++ 7 files changed, 133 insertions(+), 32 deletions(-) create mode 100644 arm_compute/core/utils/misc/ICloneable.h (limited to 'arm_compute/core') diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h index 70b6e1a301..3d9a3fa7ff 100644 --- a/arm_compute/core/Dimensions.h +++ b/arm_compute/core/Dimensions.h @@ -179,5 +179,16 @@ protected: std::array _id; size_t _num_dimensions{ 0 }; }; + +template +inline bool operator==(const Dimensions &lhs, const Dimensions &rhs) +{ + return ((lhs.num_dimensions() == rhs.num_dimensions()) && std::equal(lhs.cbegin(), lhs.cend(), rhs.cbegin())); +} +template +inline bool operator!=(const Dimensions &lhs, const Dimensions &rhs) +{ + return !(lhs == rhs); +} } #endif /*__ARM_COMPUTE_DIMENSIONS_H__*/ diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index edb05e99a1..13d1f6c99f 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -476,6 +476,15 @@ bool auto_init_if_empty(ITensorInfo &info, int fixed_point_position, QuantizationInfo quantization_info = QuantizationInfo()); +/** Auto initialize the tensor info using another tensor info. + * + * @param info_sink Tensor info used to check and assign + * @param info_source Tensor info used to assign + * + * @return True if the tensor info has been initialized + */ +bool auto_init_if_empty(ITensorInfo &info_sink, ITensorInfo &info_source); + /* Set the shape to the specified value if the current assignment is empty. * * @param[in,out] info Tensor info used to check and assign. diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl index 656956d00a..1e565344b7 100644 --- a/arm_compute/core/Helpers.inl +++ b/arm_compute/core/Helpers.inl @@ -217,6 +217,21 @@ inline bool auto_init_if_empty(ITensorInfo &info, return false; } +inline bool auto_init_if_empty(ITensorInfo &info_sink, ITensorInfo &info_source) +{ + if(info_sink.tensor_shape().total_size() == 0) + { + info_sink.set_data_type(info_source.data_type()); + info_sink.set_num_channels(info_source.num_channels()); + info_sink.set_tensor_shape(info_source.tensor_shape()); + info_sink.set_fixed_point_position(info_source.fixed_point_position()); + info_sink.set_quantization_info(info_source.quantization_info()); + return true; + } + + return false; +} + inline bool set_shape_if_empty(ITensorInfo &info, const TensorShape &shape) { if(info.tensor_shape().total_size() == 0) diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h index 09351522dd..5e8d4e8136 100644 --- a/arm_compute/core/ITensorInfo.h +++ b/arm_compute/core/ITensorInfo.h @@ -29,13 +29,14 @@ #include "arm_compute/core/TensorShape.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/misc/ICloneable.h" #include namespace arm_compute { /** Store the tensor's metadata */ -class ITensorInfo +class ITensorInfo : public misc::ICloneable { public: /** Default virtual destructor */ @@ -45,15 +46,19 @@ public: * @warning This resets the format to UNKNOWN. * * @param[in] data_type The new data type. + * + * @return Reference to this ITensorInfo object */ - virtual void set_data_type(DataType data_type) = 0; + virtual ITensorInfo &set_data_type(DataType data_type) = 0; /** Set the number of channels to the specified value. * * @warning This resets the format to UNKNOWN. * * @param[in] num_channels New number of channels. + * + * @return Reference to this ITensorInfo object */ - virtual void set_num_channels(int num_channels) = 0; + virtual ITensorInfo &set_num_channels(int num_channels) = 0; /** Set the format of an already initialized tensor. * * @note If the data type has already been configured (i.e. not UNKNOWN) it @@ -61,23 +66,36 @@ public: * be based on the format. * * @param[in] format Single-plane format of the tensor. + * + * @return Reference to this ITensorInfo object */ - virtual void set_format(Format format) = 0; + virtual ITensorInfo &set_format(Format format) = 0; /** Set the shape of an already initialized tensor. * * @warning Changing the shape requires to recompute the strides and is * therefore only possible if the tensor hasn't been allocated yet. * * @param[in] shape New tensor shape. + * + * @return Reference to this ITensorInfo object */ - virtual void set_tensor_shape(TensorShape shape) = 0; + virtual ITensorInfo &set_tensor_shape(TensorShape shape) = 0; /** Set the fixed point position to the specified value * * @warning The fixed point position must be set once the data type has been configured * * @param[in] fixed_point_position The new fixed point position + * + * @return Reference to this ITensorInfo object */ - virtual void set_fixed_point_position(int fixed_point_position) = 0; + virtual ITensorInfo &set_fixed_point_position(int fixed_point_position) = 0; + /** Set the quantization settings (scale and offset) of the tensor. + * + * @param[in] quantization_info QuantizationInfo containing the scale and offset + * + * @return Reference to this ITensorInfo object + */ + virtual ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) = 0; /** Update the offset to the first element and the strides to automatically computed values. * * @note The padding used by this method is really conservative so that the tensor can be used for most functions. @@ -196,12 +214,6 @@ public: * @return A QuantizationInfo containing the scale and offset. */ virtual QuantizationInfo quantization_info() const = 0; - - /** Set the quantization settings (scale and offset) of the tensor. - * - * @param[in] quantization_info QuantizationInfo containing the scale and offset. - */ - virtual void set_quantization_info(QuantizationInfo quantization_info) = 0; }; } #endif /*__ARM_COMPUTE_TENSORINFO_H__ */ diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h index 3a88ebae5a..5fec11a2e8 100644 --- a/arm_compute/core/SubTensorInfo.h +++ b/arm_compute/core/SubTensorInfo.h @@ -34,6 +34,7 @@ #include "arm_compute/core/Validate.h" #include +#include namespace arm_compute { @@ -50,7 +51,7 @@ public: * X and Y dimensions must match the parent's ones. * @param[in] coords Coordinates of starting element inside parent tensor. */ - SubTensorInfo(ITensorInfo *parent, const TensorShape &tensor_shape, const Coordinates &coords); + SubTensorInfo(ITensorInfo *parent, TensorShape tensor_shape, Coordinates coords); /** Default destructor */ ~SubTensorInfo() = default; /** Allow instances of this class to be copy constructed */ @@ -71,27 +72,38 @@ public: } // Inherited methods overridden: - void set_data_type(DataType data_type) override + std::unique_ptr clone() const override; + ITensorInfo &set_data_type(DataType data_type) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); _parent->set_data_type(data_type); + return *this; }; - void set_num_channels(int num_channels) override + ITensorInfo &set_num_channels(int num_channels) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); _parent->set_num_channels(num_channels); + return *this; }; - void set_format(Format format) override + ITensorInfo &set_format(Format format) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); _parent->set_format(format); + return *this; }; - void set_fixed_point_position(int fixed_point_position) override + ITensorInfo &set_fixed_point_position(int fixed_point_position) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); _parent->set_fixed_point_position(fixed_point_position); + return *this; }; - void set_tensor_shape(TensorShape shape) override; + ITensorInfo &set_tensor_shape(TensorShape shape) override; + ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) override + { + ARM_COMPUTE_ERROR_ON(_parent == nullptr); + _parent->set_quantization_info(quantization_info); + return *this; + } bool auto_padding() override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); @@ -191,11 +203,6 @@ public: ARM_COMPUTE_ERROR_ON(_parent == nullptr); return _parent->quantization_info(); } - void set_quantization_info(QuantizationInfo quantization_info) override - { - ARM_COMPUTE_ERROR_ON(_parent == nullptr); - _parent->set_quantization_info(quantization_info); - } private: ITensorInfo *_parent; diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index 5d1ee7c578..2383f2db21 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -34,6 +34,7 @@ #include "arm_compute/core/Utils.h" #include +#include namespace arm_compute { @@ -212,11 +213,13 @@ public: size_t init_auto_padding(const HOGInfo &hog_info, unsigned int width, unsigned int height); // Inherited methods overridden: - void set_data_type(DataType data_type) override; - void set_num_channels(int num_channels) override; - void set_format(Format format) override; - void set_tensor_shape(TensorShape shape) override; - void set_fixed_point_position(int fixed_point_position) override; + std::unique_ptr clone() const override; + ITensorInfo &set_data_type(DataType data_type) override; + ITensorInfo &set_num_channels(int num_channels) override; + ITensorInfo &set_format(Format format) override; + ITensorInfo &set_tensor_shape(TensorShape shape) override; + ITensorInfo &set_fixed_point_position(int fixed_point_position) override; + ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) override; bool auto_padding() override; bool extend_padding(const PaddingSize &padding) override; size_t dimension(size_t index) const override @@ -292,10 +295,6 @@ public: { return _quantization_info; } - void set_quantization_info(QuantizationInfo quantization_info) override - { - _quantization_info = quantization_info; - } private: /** Calculates strides, offset and total size resulting from the specified padding around the XY plane. diff --git a/arm_compute/core/utils/misc/ICloneable.h b/arm_compute/core/utils/misc/ICloneable.h new file mode 100644 index 0000000000..5852f14f7a --- /dev/null +++ b/arm_compute/core/utils/misc/ICloneable.h @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_MISC_ICLONEABLE_H__ +#define __ARM_COMPUTE_MISC_ICLONEABLE_H__ + +#include + +namespace arm_compute +{ +namespace misc +{ +/** Clonable Interface */ +template +class ICloneable +{ +public: + /** Default virtual desctructor */ + virtual ~ICloneable() = default; + /** Provide a clone of the current object of class T + * + * @return Clone object of class T + */ + virtual std::unique_ptr clone() const = 0; +}; +} // namespace misc +} // namespace arm_compute +#endif /* __ARM_COMPUTE_MISC_ICLONEABLE_H__ */ -- cgit v1.2.1