From 35ceeb2199c569810a1524a0a21c2df2a3f5f29e Mon Sep 17 00:00:00 2001 From: Diego Lopez Recas Date: Mon, 4 Dec 2017 18:56:10 +0000 Subject: IVGCVSW-798 Add Softmax NEON support for QASYMM8 Change-Id: I4f2cca52caf210fdb7d6bb7e9436ac51cb5088b4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112398 Reviewed-by: Anthony Barbier Tested-by: Jenkins --- arm_compute/core/AccessWindowAutoPadding.h | 4 +- arm_compute/core/AccessWindowStatic.h | 4 +- arm_compute/core/AccessWindowTranspose.h | 4 +- arm_compute/core/Error.h | 6 +- arm_compute/core/FixedPoint.inl | 46 +++----- arm_compute/core/Helpers.h | 4 +- arm_compute/core/IAccessWindow.h | 6 +- arm_compute/core/ITensorInfo.h | 8 +- .../core/NEON/kernels/NESoftmaxLayerKernel.h | 116 +++++++-------------- arm_compute/core/SubTensorInfo.h | 8 +- arm_compute/core/TensorInfo.h | 10 +- arm_compute/core/TensorShape.h | 26 +++-- arm_compute/core/Types.h | 15 +++ arm_compute/core/Utils.h | 16 +-- arm_compute/core/utils/misc/utility.h | 19 +++- 15 files changed, 133 insertions(+), 159 deletions(-) (limited to 'arm_compute/core') diff --git a/arm_compute/core/AccessWindowAutoPadding.h b/arm_compute/core/AccessWindowAutoPadding.h index 0a3344b115..0003bb26cd 100644 --- a/arm_compute/core/AccessWindowAutoPadding.h +++ b/arm_compute/core/AccessWindowAutoPadding.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -66,7 +66,7 @@ public: // Inherited methods overridden: bool update_window_if_needed(Window &window) const override; - bool update_padding_if_needed(const Window &window) const override; + bool update_padding_if_needed(const Window &window) override; ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override; private: diff --git a/arm_compute/core/AccessWindowStatic.h b/arm_compute/core/AccessWindowStatic.h index 6dcba072c4..a0ceeda273 100644 --- a/arm_compute/core/AccessWindowStatic.h +++ b/arm_compute/core/AccessWindowStatic.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -79,7 +79,7 @@ public: // Inherited methods overriden: bool update_window_if_needed(Window &window) const override; - bool update_padding_if_needed(const Window &window) const override; + bool update_padding_if_needed(const Window &window) override; ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override; ITensorInfo *_info; diff --git a/arm_compute/core/AccessWindowTranspose.h b/arm_compute/core/AccessWindowTranspose.h index 102860f9d8..4e59e58dce 100644 --- a/arm_compute/core/AccessWindowTranspose.h +++ b/arm_compute/core/AccessWindowTranspose.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,7 +40,7 @@ class AccessWindowTranspose : public AccessWindowRectangle public: using AccessWindowRectangle::AccessWindowRectangle; bool update_window_if_needed(Window &window) const override; - bool update_padding_if_needed(const Window &window) const override; + bool update_padding_if_needed(const Window &window) override; using AccessWindowRectangle::compute_valid_region; ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override; }; diff --git a/arm_compute/core/Error.h b/arm_compute/core/Error.h index 97dbba3fab..56c7ccdd93 100644 --- a/arm_compute/core/Error.h +++ b/arm_compute/core/Error.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -86,7 +86,7 @@ public: return _error_description; } /** Throws a runtime exception in case it contains a valid error status */ - void throw_if_error() + void throw_if_error() const { if(!bool(*this)) { @@ -96,7 +96,7 @@ public: private: /** Internal throwing function */ - [[noreturn]] void internal_throw_on_error(); + [[noreturn]] void internal_throw_on_error() const; private: ErrorCode _code; diff --git a/arm_compute/core/FixedPoint.inl b/arm_compute/core/FixedPoint.inl index 5ea0f6c825..9c7e35ab16 100644 --- a/arm_compute/core/FixedPoint.inl +++ b/arm_compute/core/FixedPoint.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,27 +22,11 @@ * SOFTWARE. */ #include "arm_compute/core/Error.h" +#include "arm_compute/core/utils/misc/utility.h" #include #include -namespace -{ -template -inline TpSat saturate_convert(TpIn a) -{ - if(a > std::numeric_limits::max()) - { - a = std::numeric_limits::max(); - } - if(a < std::numeric_limits::min()) - { - a = std::numeric_limits::min(); - } - return static_cast(a); -} -} // namespace - namespace arm_compute { inline qint8_t sqshl_qs8(qint8_t a, int shift) @@ -50,7 +34,7 @@ inline qint8_t sqshl_qs8(qint8_t a, int shift) qint16_t tmp = static_cast(a) << shift; // Saturate the result in case of overflow and cast to qint8_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint16_t sqshl_qs16(qint16_t a, int shift) @@ -58,7 +42,7 @@ inline qint16_t sqshl_qs16(qint16_t a, int shift) qint32_t tmp = static_cast(a) << shift; // Saturate the result in case of overflow and cast to qint16_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint8_t sshr_qs8(qint8_t a, int shift) @@ -101,7 +85,7 @@ inline qint8_t sqadd_qs8(qint8_t a, qint8_t b) qint16_t tmp = (static_cast(a) + static_cast(b)); // Saturate the result in case of overflow and cast to qint8_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint16_t sqadd_qs16(qint16_t a, qint16_t b) @@ -110,7 +94,7 @@ inline qint16_t sqadd_qs16(qint16_t a, qint16_t b) qint32_t tmp = (static_cast(a) + static_cast(b)); // Saturate the result in case of overflow and cast to qint16_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint32_t sqadd_qs32(qint32_t a, qint32_t b) @@ -119,7 +103,7 @@ inline qint32_t sqadd_qs32(qint32_t a, qint32_t b) qint64_t tmp = (static_cast(a) + static_cast(b)); // Saturate the result in case of overflow and cast to qint32_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint8_t ssub_qs8(qint8_t a, qint8_t b) @@ -138,7 +122,7 @@ inline qint8_t sqsub_qs8(qint8_t a, qint8_t b) qint16_t tmp = static_cast(a) - static_cast(b); // Saturate the result in case of overflow and cast to qint8_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint16_t sqsub_qs16(qint16_t a, qint16_t b) @@ -147,7 +131,7 @@ inline qint16_t sqsub_qs16(qint16_t a, qint16_t b) qint32_t tmp = static_cast(a) - static_cast(b); // Saturate the result in case of overflow and cast to qint16_t - return saturate_convert(tmp); + return utility::saturate_cast(tmp); } inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position) @@ -183,7 +167,7 @@ inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position) // Rounding up tmp += round_up_const; - return saturate_convert(tmp >> fixed_point_position); + return utility::saturate_cast(tmp >> fixed_point_position); } inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position) @@ -195,7 +179,7 @@ inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position) // Rounding up tmp += round_up_const; - return saturate_convert(tmp >> fixed_point_position); + return utility::saturate_cast(tmp >> fixed_point_position); } inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position) @@ -394,7 +378,7 @@ inline float scvt_f32_qs8(qint8_t a, int fixed_point_position) inline qint8_t sqcvt_qs8_f32(float a, int fixed_point_position) { // round_nearest_integer(a * 2^(fixed_point_position)) - return saturate_convert(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5)); + return utility::saturate_cast(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5)); } inline float scvt_f32_qs16(qint16_t a, int fixed_point_position) @@ -405,18 +389,18 @@ inline float scvt_f32_qs16(qint16_t a, int fixed_point_position) inline qint16_t sqcvt_qs16_f32(float a, int fixed_point_position) { // round_nearest_integer(a * 2^(fixed_point_position)) - return saturate_convert(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5)); + return utility::saturate_cast(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5)); } inline qint8_t sqmovn_qs16(qint16_t a) { // Saturate the result in case of overflow and cast to qint8_t - return saturate_convert(a); + return utility::saturate_cast(a); } inline qint16_t sqmovn_qs32(qint32_t a) { // Saturate the result in case of overflow and cast to qint16_t - return saturate_convert(a); + return utility::saturate_cast(a); } } diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index e01e4baa6b..c6a7db4f96 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2018 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -350,7 +350,7 @@ bool update_window_and_padding(Window &win, Ts &&... patterns) bool padding_changed = false; - utility::for_each([&](const IAccessWindow & w) + utility::for_each([&](IAccessWindow & w) { padding_changed |= w.update_padding_if_needed(win); }, diff --git a/arm_compute/core/IAccessWindow.h b/arm_compute/core/IAccessWindow.h index cf7490d53e..583041a48b 100644 --- a/arm_compute/core/IAccessWindow.h +++ b/arm_compute/core/IAccessWindow.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -85,7 +85,7 @@ public: * * @return True if the padding has been changed. */ - virtual bool update_padding_if_needed(const Window &window) const = 0; + virtual bool update_padding_if_needed(const Window &window) = 0; /** Compute the valid region based on access pattern and valid region of the inputs. * * @note This method assumes that there is no border. @@ -168,7 +168,7 @@ public: ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override; bool update_window_if_needed(Window &window) const override; - bool update_padding_if_needed(const Window &window) const override; + bool update_padding_if_needed(const Window &window) override; protected: ITensorInfo *_info; diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h index 9a67712f3d..9112f3ea18 100644 --- a/arm_compute/core/ITensorInfo.h +++ b/arm_compute/core/ITensorInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -79,7 +79,7 @@ public: * * @return Reference to this ITensorInfo object */ - virtual ITensorInfo &set_tensor_shape(TensorShape shape) = 0; + virtual ITensorInfo &set_tensor_shape(const TensorShape &shape) = 0; /** Set the fixed point position to the specified value * * @warning The fixed point position must be set once the data type has been configured @@ -95,7 +95,7 @@ public: * * @return Reference to this ITensorInfo object */ - virtual ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) = 0; + virtual ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) = 0; /** Resets the padding settings of the tensor. * * @return Reference to this ITensorInfo object @@ -214,7 +214,7 @@ public: * * @param[in] valid_region Valid region to set. */ - virtual void set_valid_region(ValidRegion valid_region) = 0; + virtual void set_valid_region(const ValidRegion &valid_region) = 0; /** Get the quantization settings (scale and offset) of the tensor. * diff --git a/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h b/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h index bd0e642d76..c30a4cd23d 100644 --- a/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h +++ b/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h @@ -43,13 +43,13 @@ public: NELogits1DMaxKernel(); /** Set the input and output tensors. * - * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32. + * @param[in] input Source tensor. Data types supported: QASYMM8/QS8/QS16/F16/F32. * @param[out] output Destination tensor. Data types supported: same as @p input */ void configure(const ITensor *input, ITensor *output); /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DMaxKernel * - * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32 + * @param[in] input Source tensor. Data types supported: QASYMM8/QS8/QS16/F16/F32. * @param[in] output Destination tensor. Data types supported: same as @p input * * @return a status @@ -61,117 +61,71 @@ public: BorderSize border_size() const override; private: - using Logits1DMaxFunction = void(const ITensor *in, ITensor *out, const Window &window); + using Logits1DMaxFunction = void(const ITensor &in, ITensor &out, const Window &window); private: Logits1DMaxFunction *_func; BorderSize _border_size; }; -/** Interface for shifting the logits values around the max value and exponentiating the result */ -class NELogits1DShiftExpSumKernel : public INEKernel +/** Interface for softmax computation for QASYMM8 with pre-computed max. */ +class NELogits1DSoftmaxKernel : public INEKernel { public: const char *name() const override { - return "NELogits1DShiftExpSumKernel"; + return "NELogits1DSoftmaxKernel"; } /** Default constructor */ - NELogits1DShiftExpSumKernel(); + NELogits1DSoftmaxKernel(); /** Prevent instances of this class from being copied (As this class contains pointers) */ - NELogits1DShiftExpSumKernel(const NELogits1DShiftExpSumKernel &) = delete; + NELogits1DSoftmaxKernel(const NELogits1DSoftmaxKernel &) = delete; /** Prevent instances of this class from being copied (As this class contains pointers) */ - NELogits1DShiftExpSumKernel &operator=(const NELogits1DShiftExpSumKernel &) = delete; + NELogits1DSoftmaxKernel &operator=(const NELogits1DSoftmaxKernel &) = delete; /** Allow instances of this class to be moved */ - NELogits1DShiftExpSumKernel(NELogits1DShiftExpSumKernel &&) = default; + NELogits1DSoftmaxKernel(NELogits1DSoftmaxKernel &&) = default; /** Allow instances of this class to be moved */ - NELogits1DShiftExpSumKernel &operator=(NELogits1DShiftExpSumKernel &&) = default; + NELogits1DSoftmaxKernel &operator=(NELogits1DSoftmaxKernel &&) = default; /** Default destructor */ - ~NELogits1DShiftExpSumKernel() = default; + ~NELogits1DSoftmaxKernel() = default; /** Set the input and output tensors. * - * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32. - * @param[in] max Max values tensor. Data types supported: same as @p input. + * @param[in] input Source tensor. Data types supported: QASYMM8/QS8/QS16/F16/F32. + * @param[in] max Max values tensor. Same shape as input with dimension 0 set to 1. + * Data types supported: same as @p input. * @param[out] output Destination tensor. Data types supported: same as @p input. - * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p input. - * @param[in] beta (Optional) A scaling factor for the exponent. QS8/QS16 only support a beta value of 1. - */ - void configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta = 1.0f); - /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DShiftExpSumKernel - * - * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32 - * @param[in] max Max values tensor. Data types supported: same as @p input - * @param[in] output Destination tensor. Data types supported: same as @p input. - * @param[in] sum Sum of 1D logits tensor. Data types supported: same as @p input. - * @param[in] beta (Optional) A scaling factor for the exponent. QS8/QS16 only support a beta value of 1. - * - * @return a status - */ - static Status validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta = 1.0f); - - // Inherited methods overridden: - void run(const Window &window, const ThreadInfo &info) override; - -private: - using Logits1DShiftExpSumFunction = void(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta); - -private: - Logits1DShiftExpSumFunction *_func; - const ITensor *_input; - const ITensor *_max; - ITensor *_output; - ITensor *_sum; - float _beta; -}; - -/** Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by the inverse of the sum of the logits. */ -class NELogits1DNormKernel : public INEKernel -{ -public: - const char *name() const override - { - return "NELogits1DNormKernel"; - } - /** Default constructor */ - NELogits1DNormKernel(); - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NELogits1DNormKernel(const NELogits1DNormKernel &) = delete; - /** Prevent instances of this class from being copied (As this class contains pointers) */ - NELogits1DNormKernel &operator=(const NELogits1DNormKernel &) = delete; - /** Allow instances of this class to be moved */ - NELogits1DNormKernel(NELogits1DNormKernel &&) = default; - /** Allow instances of this class to be moved */ - NELogits1DNormKernel &operator=(NELogits1DNormKernel &&) = default; - /** Default destructor */ - ~NELogits1DNormKernel() = default; - /** Set the input and output tensors. + * @param[in] beta A scaling factor for the exponent. * - * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32. - * @param[in] sum Sum tensor. The number of dimensions should be dim(input)-1. Data types supported: same as @p input. - * @param[out] output Destination tensor. Data types supported: same as @p input. + * @param tmp Auxiliary tensor. Must be type F32 and same shape as the input. */ - void configure(const ITensor *input, const ITensor *sum, ITensor *output); - /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DNormKernel + void configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp); + /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DSoftmaxKernel * - * @param[in] input Source tensor. Data types supported: QS8/QS16/S32/F16/F32 - * @param[in] sum Sum tensor. The number of dimensions should be dim(input)-1. Data types supported: same as @p input. - * @param[in] output Destination tensor. Data types supported: same as @p input. + * @param[in] input Source tensor info. Data types supported: QASYMM8/QS8/QS16/F16/F32. + * @param[in] max Max values tensor info. Same shape as input with dimension 0 set to 1. + * Data types supported: same as @p input. + * @param[in] output Destination tensor info. Data types supported: same as @p input. + * @param[in] beta A scaling factor for the exponent. + * @param[in] tmp Tensor info of auxiliary. Must be type F32 and same shape as the input. * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output); + static Status validate(const ITensorInfo *input, const ITensorInfo *max, + const ITensorInfo *output, const float beta, const ITensorInfo *tmp); // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; private: - using Logits1DNormFunction = void(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window); + using LogitsSoftmaxFunction = void(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta, + const Window &window); -private: - Logits1DNormFunction *_func; - const ITensor *_input; - const ITensor *_sum; - ITensor *_output; + LogitsSoftmaxFunction *_func; + const ITensor *_input; + const ITensor *_max; + ITensor *_output; + float _beta; + ITensor *_tmp; //Temporary. Used internally }; } // namespace arm_compute #endif /*__ARM_COMPUTE_NESOFTMAXLAYERKERNEL_H__ */ diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h index 67574f1326..7f4239d49b 100644 --- a/arm_compute/core/SubTensorInfo.h +++ b/arm_compute/core/SubTensorInfo.h @@ -98,8 +98,8 @@ public: _parent->set_fixed_point_position(fixed_point_position); return *this; }; - ITensorInfo &set_tensor_shape(TensorShape shape) override; - ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) override + ITensorInfo &set_tensor_shape(const TensorShape &shape) override; + ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); _parent->set_quantization_info(quantization_info); @@ -196,7 +196,7 @@ public: { return _valid_region; } - void set_valid_region(ValidRegion valid_region) override + void set_valid_region(const ValidRegion &valid_region) override { ARM_COMPUTE_ERROR_ON(_parent == nullptr); // Check if subtensor is valid if parent is configured @@ -204,7 +204,7 @@ public: { ARM_COMPUTE_ERROR_ON_INVALID_SUBTENSOR_VALID_REGION(_parent->valid_region(), valid_region); } - _valid_region = std::move(valid_region); + _valid_region = valid_region; } QuantizationInfo quantization_info() const override { diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h index 80ef7f8d5a..0b8989f942 100644 --- a/arm_compute/core/TensorInfo.h +++ b/arm_compute/core/TensorInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -217,9 +217,9 @@ public: ITensorInfo &set_data_type(DataType data_type) override; ITensorInfo &set_num_channels(int num_channels) override; ITensorInfo &set_format(Format format) override; - ITensorInfo &set_tensor_shape(TensorShape shape) override; + ITensorInfo &set_tensor_shape(const TensorShape &shape) override; ITensorInfo &set_fixed_point_position(int fixed_point_position) override; - ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) override; + ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override; ITensorInfo &reset_padding() override; bool auto_padding() override; bool extend_padding(const PaddingSize &padding) override; @@ -289,9 +289,9 @@ public: { return _valid_region; } - void set_valid_region(ValidRegion valid_region) override + void set_valid_region(const ValidRegion &valid_region) override { - _valid_region = std::move(valid_region); + _valid_region = valid_region; } QuantizationInfo quantization_info() const override { diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h index ad102607e8..50f1211c18 100644 --- a/arm_compute/core/TensorShape.h +++ b/arm_compute/core/TensorShape.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -70,26 +70,30 @@ public: * * @param[in] dimension Dimension for which the value is set. * @param[in] value Value to be set for the dimension. + * + * @return *this. */ - void set(size_t dimension, size_t value) + TensorShape &set(size_t dimension, size_t value) { // Clear entire shape if one dimension is zero if(value == 0) { _num_dimensions = 0; std::fill(_id.begin(), _id.end(), 0); - return; } + else + { + // Make sure all empty dimensions are filled with 1 + std::fill(_id.begin() + _num_dimensions, _id.end(), 1); - // Make sure all empty dimensions are filled with 1 - std::fill(_id.begin() + _num_dimensions, _id.end(), 1); - - // Set the specified dimension and increase the number of dimensions if - // necessary - Dimensions::set(dimension, value); + // Set the specified dimension and increase the number of dimensions if + // necessary + Dimensions::set(dimension, value); - // Correct number dimensions to ignore trailing dimensions of size 1 - apply_dimension_correction(); + // Correct number dimensions to ignore trailing dimensions of size 1 + apply_dimension_correction(); + } + return *this; } /** Accessor to remove the dimension n from the tensor shape. diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index 5197000bf9..aa415acebe 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -190,6 +190,21 @@ struct ValidRegion return anchor[d] + shape[d]; } + /** Accessor to set the value of anchor and shape for one of the dimensions. + * + * @param[in] dimension Dimension for which the value is set. + * @param[in] start Value to be set in anchor for the dimension. + * @param[in] size Value to be set in shape for the dimension. + * + * @return *this. + */ + ValidRegion &set(size_t dimension, int start, size_t size) + { + anchor.set(dimension, start); + shape.set(dimension, size); + return *this; + } + Coordinates anchor; TensorShape shape; }; diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h index 51967b1762..fc89d97073 100644 --- a/arm_compute/core/Utils.h +++ b/arm_compute/core/Utils.h @@ -40,12 +40,19 @@ namespace arm_compute { +/** Calculate the rounded up quotient of val / m. */ +template +constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m) +{ + return (val + m - 1) / m; +} + /** Computes the smallest number larger or equal to value that is a multiple of divisor. */ template inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor) { ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0); - return ((value + divisor - 1) / divisor) * divisor; + return DIV_CEIL(value, divisor) * divisor; } /** Computes the largest number smaller or equal to value that is a multiple of divisor. */ @@ -56,13 +63,6 @@ inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor) return (value / divisor) * divisor; } -/** Calculate the rounded up quotient of val / m. */ -template -constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m) -{ - return (val + m - 1) / m; -} - /** Returns the arm_compute library build information * * Contains the version number and the build options used to build the library diff --git a/arm_compute/core/utils/misc/utility.h b/arm_compute/core/utils/misc/utility.h index 45b3b5268e..e8d823b5bc 100644 --- a/arm_compute/core/utils/misc/utility.h +++ b/arm_compute/core/utils/misc/utility.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,7 @@ #define __ARM_COMPUTE_MISC_UTILITY_H__ #include +#include namespace arm_compute { @@ -123,6 +124,22 @@ inline auto foldl(F &&func, T &&initial, U &&value, Us &&... values) -> decltype { return foldl(std::forward(func), func(std::forward(initial), std::forward(value)), std::forward(values)...); } + +/** Type cast with saturation. + * + * @param[in] val Value of type U to cast. + * + * @return Original value clamped to numeric limits of T and converted to type T. + * + * @warning Numeric limits of T must be representable without loss in type U. + */ +template +T saturate_cast(U val) +{ + const auto low = static_cast(std::numeric_limits::lowest()); + const auto high = static_cast(std::numeric_limits::max()); + return static_cast(clamp(val, low, high)); +} } // namespace utility } // namespace arm_compute #endif /* __ARM_COMPUTE_MISC_UTILITY_H__ */ -- cgit v1.2.1