aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core
diff options
context:
space:
mode:
authorDiego Lopez Recas <Diego.LopezRecas@arm.com>2017-12-04 18:56:10 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:00 +0000
commit35ceeb2199c569810a1524a0a21c2df2a3f5f29e (patch)
tree4a55f8626cb2960843547fabdb2431a70ec1029a /arm_compute/core
parent97cf2497d2b617de3209330893ad51bd0cc126ce (diff)
downloadComputeLibrary-35ceeb2199c569810a1524a0a21c2df2a3f5f29e.tar.gz
IVGCVSW-798 Add Softmax NEON support for QASYMM8
Change-Id: I4f2cca52caf210fdb7d6bb7e9436ac51cb5088b4 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/112398 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core')
-rw-r--r--arm_compute/core/AccessWindowAutoPadding.h4
-rw-r--r--arm_compute/core/AccessWindowStatic.h4
-rw-r--r--arm_compute/core/AccessWindowTranspose.h4
-rw-r--r--arm_compute/core/Error.h6
-rw-r--r--arm_compute/core/FixedPoint.inl46
-rw-r--r--arm_compute/core/Helpers.h4
-rw-r--r--arm_compute/core/IAccessWindow.h6
-rw-r--r--arm_compute/core/ITensorInfo.h8
-rw-r--r--arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h116
-rw-r--r--arm_compute/core/SubTensorInfo.h8
-rw-r--r--arm_compute/core/TensorInfo.h10
-rw-r--r--arm_compute/core/TensorShape.h26
-rw-r--r--arm_compute/core/Types.h15
-rw-r--r--arm_compute/core/Utils.h16
-rw-r--r--arm_compute/core/utils/misc/utility.h19
15 files changed, 133 insertions, 159 deletions
diff --git a/arm_compute/core/AccessWindowAutoPadding.h b/arm_compute/core/AccessWindowAutoPadding.h
index 0a3344b115..0003bb26cd 100644
--- a/arm_compute/core/AccessWindowAutoPadding.h
+++ b/arm_compute/core/AccessWindowAutoPadding.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,7 +66,7 @@ public:
// Inherited methods overridden:
bool update_window_if_needed(Window &window) const override;
- bool update_padding_if_needed(const Window &window) const override;
+ bool update_padding_if_needed(const Window &window) override;
ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override;
private:
diff --git a/arm_compute/core/AccessWindowStatic.h b/arm_compute/core/AccessWindowStatic.h
index 6dcba072c4..a0ceeda273 100644
--- a/arm_compute/core/AccessWindowStatic.h
+++ b/arm_compute/core/AccessWindowStatic.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -79,7 +79,7 @@ public:
// Inherited methods overriden:
bool update_window_if_needed(Window &window) const override;
- bool update_padding_if_needed(const Window &window) const override;
+ bool update_padding_if_needed(const Window &window) override;
ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override;
ITensorInfo *_info;
diff --git a/arm_compute/core/AccessWindowTranspose.h b/arm_compute/core/AccessWindowTranspose.h
index 102860f9d8..4e59e58dce 100644
--- a/arm_compute/core/AccessWindowTranspose.h
+++ b/arm_compute/core/AccessWindowTranspose.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,7 +40,7 @@ class AccessWindowTranspose : public AccessWindowRectangle
public:
using AccessWindowRectangle::AccessWindowRectangle;
bool update_window_if_needed(Window &window) const override;
- bool update_padding_if_needed(const Window &window) const override;
+ bool update_padding_if_needed(const Window &window) override;
using AccessWindowRectangle::compute_valid_region;
ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override;
};
diff --git a/arm_compute/core/Error.h b/arm_compute/core/Error.h
index 97dbba3fab..56c7ccdd93 100644
--- a/arm_compute/core/Error.h
+++ b/arm_compute/core/Error.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -86,7 +86,7 @@ public:
return _error_description;
}
/** Throws a runtime exception in case it contains a valid error status */
- void throw_if_error()
+ void throw_if_error() const
{
if(!bool(*this))
{
@@ -96,7 +96,7 @@ public:
private:
/** Internal throwing function */
- [[noreturn]] void internal_throw_on_error();
+ [[noreturn]] void internal_throw_on_error() const;
private:
ErrorCode _code;
diff --git a/arm_compute/core/FixedPoint.inl b/arm_compute/core/FixedPoint.inl
index 5ea0f6c825..9c7e35ab16 100644
--- a/arm_compute/core/FixedPoint.inl
+++ b/arm_compute/core/FixedPoint.inl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,27 +22,11 @@
* SOFTWARE.
*/
#include "arm_compute/core/Error.h"
+#include "arm_compute/core/utils/misc/utility.h"
#include <cmath>
#include <limits>
-namespace
-{
-template <typename TpIn, typename TpSat>
-inline TpSat saturate_convert(TpIn a)
-{
- if(a > std::numeric_limits<TpSat>::max())
- {
- a = std::numeric_limits<TpSat>::max();
- }
- if(a < std::numeric_limits<TpSat>::min())
- {
- a = std::numeric_limits<TpSat>::min();
- }
- return static_cast<TpSat>(a);
-}
-} // namespace
-
namespace arm_compute
{
inline qint8_t sqshl_qs8(qint8_t a, int shift)
@@ -50,7 +34,7 @@ inline qint8_t sqshl_qs8(qint8_t a, int shift)
qint16_t tmp = static_cast<qint16_t>(a) << shift;
// Saturate the result in case of overflow and cast to qint8_t
- return saturate_convert<qint16_t, qint8_t>(tmp);
+ return utility::saturate_cast<qint8_t>(tmp);
}
inline qint16_t sqshl_qs16(qint16_t a, int shift)
@@ -58,7 +42,7 @@ inline qint16_t sqshl_qs16(qint16_t a, int shift)
qint32_t tmp = static_cast<qint32_t>(a) << shift;
// Saturate the result in case of overflow and cast to qint16_t
- return saturate_convert<qint32_t, qint16_t>(tmp);
+ return utility::saturate_cast<qint16_t>(tmp);
}
inline qint8_t sshr_qs8(qint8_t a, int shift)
@@ -101,7 +85,7 @@ inline qint8_t sqadd_qs8(qint8_t a, qint8_t b)
qint16_t tmp = (static_cast<qint16_t>(a) + static_cast<qint16_t>(b));
// Saturate the result in case of overflow and cast to qint8_t
- return saturate_convert<qint16_t, qint8_t>(tmp);
+ return utility::saturate_cast<qint8_t>(tmp);
}
inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
@@ -110,7 +94,7 @@ inline qint16_t sqadd_qs16(qint16_t a, qint16_t b)
qint32_t tmp = (static_cast<qint32_t>(a) + static_cast<qint32_t>(b));
// Saturate the result in case of overflow and cast to qint16_t
- return saturate_convert<qint32_t, qint16_t>(tmp);
+ return utility::saturate_cast<qint16_t>(tmp);
}
inline qint32_t sqadd_qs32(qint32_t a, qint32_t b)
@@ -119,7 +103,7 @@ inline qint32_t sqadd_qs32(qint32_t a, qint32_t b)
qint64_t tmp = (static_cast<qint64_t>(a) + static_cast<qint64_t>(b));
// Saturate the result in case of overflow and cast to qint32_t
- return saturate_convert<qint64_t, qint32_t>(tmp);
+ return utility::saturate_cast<qint32_t>(tmp);
}
inline qint8_t ssub_qs8(qint8_t a, qint8_t b)
@@ -138,7 +122,7 @@ inline qint8_t sqsub_qs8(qint8_t a, qint8_t b)
qint16_t tmp = static_cast<qint16_t>(a) - static_cast<qint16_t>(b);
// Saturate the result in case of overflow and cast to qint8_t
- return saturate_convert<qint16_t, qint8_t>(tmp);
+ return utility::saturate_cast<qint8_t>(tmp);
}
inline qint16_t sqsub_qs16(qint16_t a, qint16_t b)
@@ -147,7 +131,7 @@ inline qint16_t sqsub_qs16(qint16_t a, qint16_t b)
qint32_t tmp = static_cast<qint32_t>(a) - static_cast<qint32_t>(b);
// Saturate the result in case of overflow and cast to qint16_t
- return saturate_convert<qint32_t, qint16_t>(tmp);
+ return utility::saturate_cast<qint16_t>(tmp);
}
inline qint8_t smul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
@@ -183,7 +167,7 @@ inline qint8_t sqmul_qs8(qint8_t a, qint8_t b, int fixed_point_position)
// Rounding up
tmp += round_up_const;
- return saturate_convert<qint16_t, qint8_t>(tmp >> fixed_point_position);
+ return utility::saturate_cast<qint8_t>(tmp >> fixed_point_position);
}
inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
@@ -195,7 +179,7 @@ inline qint16_t sqmul_qs16(qint16_t a, qint16_t b, int fixed_point_position)
// Rounding up
tmp += round_up_const;
- return saturate_convert<qint32_t, qint16_t>(tmp >> fixed_point_position);
+ return utility::saturate_cast<qint16_t>(tmp >> fixed_point_position);
}
inline qint16_t sqmull_qs8(qint8_t a, qint8_t b, int fixed_point_position)
@@ -394,7 +378,7 @@ inline float scvt_f32_qs8(qint8_t a, int fixed_point_position)
inline qint8_t sqcvt_qs8_f32(float a, int fixed_point_position)
{
// round_nearest_integer(a * 2^(fixed_point_position))
- return saturate_convert<float, qint8_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
+ return utility::saturate_cast<qint8_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
}
inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
@@ -405,18 +389,18 @@ inline float scvt_f32_qs16(qint16_t a, int fixed_point_position)
inline qint16_t sqcvt_qs16_f32(float a, int fixed_point_position)
{
// round_nearest_integer(a * 2^(fixed_point_position))
- return saturate_convert<float, qint16_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
+ return utility::saturate_cast<qint16_t>(a * (1 << fixed_point_position) + ((a >= 0) ? 0.5 : -0.5));
}
inline qint8_t sqmovn_qs16(qint16_t a)
{
// Saturate the result in case of overflow and cast to qint8_t
- return saturate_convert<qint16_t, qint8_t>(a);
+ return utility::saturate_cast<qint8_t>(a);
}
inline qint16_t sqmovn_qs32(qint32_t a)
{
// Saturate the result in case of overflow and cast to qint16_t
- return saturate_convert<qint32_t, qint16_t>(a);
+ return utility::saturate_cast<qint16_t>(a);
}
}
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h
index e01e4baa6b..c6a7db4f96 100644
--- a/arm_compute/core/Helpers.h
+++ b/arm_compute/core/Helpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2018 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -350,7 +350,7 @@ bool update_window_and_padding(Window &win, Ts &&... patterns)
bool padding_changed = false;
- utility::for_each([&](const IAccessWindow & w)
+ utility::for_each([&](IAccessWindow & w)
{
padding_changed |= w.update_padding_if_needed(win);
},
diff --git a/arm_compute/core/IAccessWindow.h b/arm_compute/core/IAccessWindow.h
index cf7490d53e..583041a48b 100644
--- a/arm_compute/core/IAccessWindow.h
+++ b/arm_compute/core/IAccessWindow.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -85,7 +85,7 @@ public:
*
* @return True if the padding has been changed.
*/
- virtual bool update_padding_if_needed(const Window &window) const = 0;
+ virtual bool update_padding_if_needed(const Window &window) = 0;
/** Compute the valid region based on access pattern and valid region of the inputs.
*
* @note This method assumes that there is no border.
@@ -168,7 +168,7 @@ public:
ValidRegion compute_valid_region(const Window &window, ValidRegion input_valid_region, bool border_undefined, BorderSize border_size) const override;
bool update_window_if_needed(Window &window) const override;
- bool update_padding_if_needed(const Window &window) const override;
+ bool update_padding_if_needed(const Window &window) override;
protected:
ITensorInfo *_info;
diff --git a/arm_compute/core/ITensorInfo.h b/arm_compute/core/ITensorInfo.h
index 9a67712f3d..9112f3ea18 100644
--- a/arm_compute/core/ITensorInfo.h
+++ b/arm_compute/core/ITensorInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -79,7 +79,7 @@ public:
*
* @return Reference to this ITensorInfo object
*/
- virtual ITensorInfo &set_tensor_shape(TensorShape shape) = 0;
+ virtual ITensorInfo &set_tensor_shape(const TensorShape &shape) = 0;
/** Set the fixed point position to the specified value
*
* @warning The fixed point position must be set once the data type has been configured
@@ -95,7 +95,7 @@ public:
*
* @return Reference to this ITensorInfo object
*/
- virtual ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) = 0;
+ virtual ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) = 0;
/** Resets the padding settings of the tensor.
*
* @return Reference to this ITensorInfo object
@@ -214,7 +214,7 @@ public:
*
* @param[in] valid_region Valid region to set.
*/
- virtual void set_valid_region(ValidRegion valid_region) = 0;
+ virtual void set_valid_region(const ValidRegion &valid_region) = 0;
/** Get the quantization settings (scale and offset) of the tensor.
*
diff --git a/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h b/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h
index bd0e642d76..c30a4cd23d 100644
--- a/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h
+++ b/arm_compute/core/NEON/kernels/NESoftmaxLayerKernel.h
@@ -43,13 +43,13 @@ public:
NELogits1DMaxKernel();
/** Set the input and output tensors.
*
- * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32.
+ * @param[in] input Source tensor. Data types supported: QASYMM8/QS8/QS16/F16/F32.
* @param[out] output Destination tensor. Data types supported: same as @p input
*/
void configure(const ITensor *input, ITensor *output);
/** Static function to check if given info will lead to a valid configuration of @ref NELogits1DMaxKernel
*
- * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32
+ * @param[in] input Source tensor. Data types supported: QASYMM8/QS8/QS16/F16/F32.
* @param[in] output Destination tensor. Data types supported: same as @p input
*
* @return a status
@@ -61,117 +61,71 @@ public:
BorderSize border_size() const override;
private:
- using Logits1DMaxFunction = void(const ITensor *in, ITensor *out, const Window &window);
+ using Logits1DMaxFunction = void(const ITensor &in, ITensor &out, const Window &window);
private:
Logits1DMaxFunction *_func;
BorderSize _border_size;
};
-/** Interface for shifting the logits values around the max value and exponentiating the result */
-class NELogits1DShiftExpSumKernel : public INEKernel
+/** Interface for softmax computation for QASYMM8 with pre-computed max. */
+class NELogits1DSoftmaxKernel : public INEKernel
{
public:
const char *name() const override
{
- return "NELogits1DShiftExpSumKernel";
+ return "NELogits1DSoftmaxKernel";
}
/** Default constructor */
- NELogits1DShiftExpSumKernel();
+ NELogits1DSoftmaxKernel();
/** Prevent instances of this class from being copied (As this class contains pointers) */
- NELogits1DShiftExpSumKernel(const NELogits1DShiftExpSumKernel &) = delete;
+ NELogits1DSoftmaxKernel(const NELogits1DSoftmaxKernel &) = delete;
/** Prevent instances of this class from being copied (As this class contains pointers) */
- NELogits1DShiftExpSumKernel &operator=(const NELogits1DShiftExpSumKernel &) = delete;
+ NELogits1DSoftmaxKernel &operator=(const NELogits1DSoftmaxKernel &) = delete;
/** Allow instances of this class to be moved */
- NELogits1DShiftExpSumKernel(NELogits1DShiftExpSumKernel &&) = default;
+ NELogits1DSoftmaxKernel(NELogits1DSoftmaxKernel &&) = default;
/** Allow instances of this class to be moved */
- NELogits1DShiftExpSumKernel &operator=(NELogits1DShiftExpSumKernel &&) = default;
+ NELogits1DSoftmaxKernel &operator=(NELogits1DSoftmaxKernel &&) = default;
/** Default destructor */
- ~NELogits1DShiftExpSumKernel() = default;
+ ~NELogits1DSoftmaxKernel() = default;
/** Set the input and output tensors.
*
- * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32.
- * @param[in] max Max values tensor. Data types supported: same as @p input.
+ * @param[in] input Source tensor. Data types supported: QASYMM8/QS8/QS16/F16/F32.
+ * @param[in] max Max values tensor. Same shape as input with dimension 0 set to 1.
+ * Data types supported: same as @p input.
* @param[out] output Destination tensor. Data types supported: same as @p input.
- * @param[out] sum Sum of 1D logits tensor. Data types supported: same as @p input.
- * @param[in] beta (Optional) A scaling factor for the exponent. QS8/QS16 only support a beta value of 1.
- */
- void configure(const ITensor *input, const ITensor *max, ITensor *output, ITensor *sum, float beta = 1.0f);
- /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DShiftExpSumKernel
- *
- * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32
- * @param[in] max Max values tensor. Data types supported: same as @p input
- * @param[in] output Destination tensor. Data types supported: same as @p input.
- * @param[in] sum Sum of 1D logits tensor. Data types supported: same as @p input.
- * @param[in] beta (Optional) A scaling factor for the exponent. QS8/QS16 only support a beta value of 1.
- *
- * @return a status
- */
- static Status validate(const ITensorInfo *input, const ITensorInfo *max, const ITensorInfo *output, const ITensorInfo *sum, float beta = 1.0f);
-
- // Inherited methods overridden:
- void run(const Window &window, const ThreadInfo &info) override;
-
-private:
- using Logits1DShiftExpSumFunction = void(const ITensor *in, const ITensor *max, ITensor *out, ITensor *sum, const Window &window, float beta);
-
-private:
- Logits1DShiftExpSumFunction *_func;
- const ITensor *_input;
- const ITensor *_max;
- ITensor *_output;
- ITensor *_sum;
- float _beta;
-};
-
-/** Interface for calculating the final step of the Softmax Layer where each logit value is multiplied by the inverse of the sum of the logits. */
-class NELogits1DNormKernel : public INEKernel
-{
-public:
- const char *name() const override
- {
- return "NELogits1DNormKernel";
- }
- /** Default constructor */
- NELogits1DNormKernel();
- /** Prevent instances of this class from being copied (As this class contains pointers) */
- NELogits1DNormKernel(const NELogits1DNormKernel &) = delete;
- /** Prevent instances of this class from being copied (As this class contains pointers) */
- NELogits1DNormKernel &operator=(const NELogits1DNormKernel &) = delete;
- /** Allow instances of this class to be moved */
- NELogits1DNormKernel(NELogits1DNormKernel &&) = default;
- /** Allow instances of this class to be moved */
- NELogits1DNormKernel &operator=(NELogits1DNormKernel &&) = default;
- /** Default destructor */
- ~NELogits1DNormKernel() = default;
- /** Set the input and output tensors.
+ * @param[in] beta A scaling factor for the exponent.
*
- * @param[in] input Source tensor. Data types supported: QS8/QS16/F16/F32.
- * @param[in] sum Sum tensor. The number of dimensions should be dim(input)-1. Data types supported: same as @p input.
- * @param[out] output Destination tensor. Data types supported: same as @p input.
+ * @param tmp Auxiliary tensor. Must be type F32 and same shape as the input.
*/
- void configure(const ITensor *input, const ITensor *sum, ITensor *output);
- /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DNormKernel
+ void configure(const ITensor *input, const ITensor *max, ITensor *output, const float beta, ITensor *tmp);
+ /** Static function to check if given info will lead to a valid configuration of @ref NELogits1DSoftmaxKernel
*
- * @param[in] input Source tensor. Data types supported: QS8/QS16/S32/F16/F32
- * @param[in] sum Sum tensor. The number of dimensions should be dim(input)-1. Data types supported: same as @p input.
- * @param[in] output Destination tensor. Data types supported: same as @p input.
+ * @param[in] input Source tensor info. Data types supported: QASYMM8/QS8/QS16/F16/F32.
+ * @param[in] max Max values tensor info. Same shape as input with dimension 0 set to 1.
+ * Data types supported: same as @p input.
+ * @param[in] output Destination tensor info. Data types supported: same as @p input.
+ * @param[in] beta A scaling factor for the exponent.
+ * @param[in] tmp Tensor info of auxiliary. Must be type F32 and same shape as the input.
*
* @return a status
*/
- static Status validate(const ITensorInfo *input, const ITensorInfo *sum, const ITensorInfo *output);
+ static Status validate(const ITensorInfo *input, const ITensorInfo *max,
+ const ITensorInfo *output, const float beta, const ITensorInfo *tmp);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
private:
- using Logits1DNormFunction = void(const ITensor *in, const ITensor *sum, ITensor *out, const Window &window);
+ using LogitsSoftmaxFunction = void(const ITensor &in, const ITensor &max, void *const tmp, ITensor &out, const float beta,
+ const Window &window);
-private:
- Logits1DNormFunction *_func;
- const ITensor *_input;
- const ITensor *_sum;
- ITensor *_output;
+ LogitsSoftmaxFunction *_func;
+ const ITensor *_input;
+ const ITensor *_max;
+ ITensor *_output;
+ float _beta;
+ ITensor *_tmp; //Temporary. Used internally
};
} // namespace arm_compute
#endif /*__ARM_COMPUTE_NESOFTMAXLAYERKERNEL_H__ */
diff --git a/arm_compute/core/SubTensorInfo.h b/arm_compute/core/SubTensorInfo.h
index 67574f1326..7f4239d49b 100644
--- a/arm_compute/core/SubTensorInfo.h
+++ b/arm_compute/core/SubTensorInfo.h
@@ -98,8 +98,8 @@ public:
_parent->set_fixed_point_position(fixed_point_position);
return *this;
};
- ITensorInfo &set_tensor_shape(TensorShape shape) override;
- ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) override
+ ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
+ ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
_parent->set_quantization_info(quantization_info);
@@ -196,7 +196,7 @@ public:
{
return _valid_region;
}
- void set_valid_region(ValidRegion valid_region) override
+ void set_valid_region(const ValidRegion &valid_region) override
{
ARM_COMPUTE_ERROR_ON(_parent == nullptr);
// Check if subtensor is valid if parent is configured
@@ -204,7 +204,7 @@ public:
{
ARM_COMPUTE_ERROR_ON_INVALID_SUBTENSOR_VALID_REGION(_parent->valid_region(), valid_region);
}
- _valid_region = std::move(valid_region);
+ _valid_region = valid_region;
}
QuantizationInfo quantization_info() const override
{
diff --git a/arm_compute/core/TensorInfo.h b/arm_compute/core/TensorInfo.h
index 80ef7f8d5a..0b8989f942 100644
--- a/arm_compute/core/TensorInfo.h
+++ b/arm_compute/core/TensorInfo.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -217,9 +217,9 @@ public:
ITensorInfo &set_data_type(DataType data_type) override;
ITensorInfo &set_num_channels(int num_channels) override;
ITensorInfo &set_format(Format format) override;
- ITensorInfo &set_tensor_shape(TensorShape shape) override;
+ ITensorInfo &set_tensor_shape(const TensorShape &shape) override;
ITensorInfo &set_fixed_point_position(int fixed_point_position) override;
- ITensorInfo &set_quantization_info(QuantizationInfo quantization_info) override;
+ ITensorInfo &set_quantization_info(const QuantizationInfo &quantization_info) override;
ITensorInfo &reset_padding() override;
bool auto_padding() override;
bool extend_padding(const PaddingSize &padding) override;
@@ -289,9 +289,9 @@ public:
{
return _valid_region;
}
- void set_valid_region(ValidRegion valid_region) override
+ void set_valid_region(const ValidRegion &valid_region) override
{
- _valid_region = std::move(valid_region);
+ _valid_region = valid_region;
}
QuantizationInfo quantization_info() const override
{
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h
index ad102607e8..50f1211c18 100644
--- a/arm_compute/core/TensorShape.h
+++ b/arm_compute/core/TensorShape.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -70,26 +70,30 @@ public:
*
* @param[in] dimension Dimension for which the value is set.
* @param[in] value Value to be set for the dimension.
+ *
+ * @return *this.
*/
- void set(size_t dimension, size_t value)
+ TensorShape &set(size_t dimension, size_t value)
{
// Clear entire shape if one dimension is zero
if(value == 0)
{
_num_dimensions = 0;
std::fill(_id.begin(), _id.end(), 0);
- return;
}
+ else
+ {
+ // Make sure all empty dimensions are filled with 1
+ std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
- // Make sure all empty dimensions are filled with 1
- std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
-
- // Set the specified dimension and increase the number of dimensions if
- // necessary
- Dimensions::set(dimension, value);
+ // Set the specified dimension and increase the number of dimensions if
+ // necessary
+ Dimensions::set(dimension, value);
- // Correct number dimensions to ignore trailing dimensions of size 1
- apply_dimension_correction();
+ // Correct number dimensions to ignore trailing dimensions of size 1
+ apply_dimension_correction();
+ }
+ return *this;
}
/** Accessor to remove the dimension n from the tensor shape.
diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h
index 5197000bf9..aa415acebe 100644
--- a/arm_compute/core/Types.h
+++ b/arm_compute/core/Types.h
@@ -190,6 +190,21 @@ struct ValidRegion
return anchor[d] + shape[d];
}
+ /** Accessor to set the value of anchor and shape for one of the dimensions.
+ *
+ * @param[in] dimension Dimension for which the value is set.
+ * @param[in] start Value to be set in anchor for the dimension.
+ * @param[in] size Value to be set in shape for the dimension.
+ *
+ * @return *this.
+ */
+ ValidRegion &set(size_t dimension, int start, size_t size)
+ {
+ anchor.set(dimension, start);
+ shape.set(dimension, size);
+ return *this;
+ }
+
Coordinates anchor;
TensorShape shape;
};
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 51967b1762..fc89d97073 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -40,12 +40,19 @@
namespace arm_compute
{
+/** Calculate the rounded up quotient of val / m. */
+template <typename S, typename T>
+constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
+{
+ return (val + m - 1) / m;
+}
+
/** Computes the smallest number larger or equal to value that is a multiple of divisor. */
template <typename S, typename T>
inline auto ceil_to_multiple(S value, T divisor) -> decltype(((value + divisor - 1) / divisor) * divisor)
{
ARM_COMPUTE_ERROR_ON(value < 0 || divisor <= 0);
- return ((value + divisor - 1) / divisor) * divisor;
+ return DIV_CEIL(value, divisor) * divisor;
}
/** Computes the largest number smaller or equal to value that is a multiple of divisor. */
@@ -56,13 +63,6 @@ inline auto floor_to_multiple(S value, T divisor) -> decltype((value / divisor)
return (value / divisor) * divisor;
}
-/** Calculate the rounded up quotient of val / m. */
-template <typename S, typename T>
-constexpr auto DIV_CEIL(S val, T m) -> decltype((val + m - 1) / m)
-{
- return (val + m - 1) / m;
-}
-
/** Returns the arm_compute library build information
*
* Contains the version number and the build options used to build the library
diff --git a/arm_compute/core/utils/misc/utility.h b/arm_compute/core/utils/misc/utility.h
index 45b3b5268e..e8d823b5bc 100644
--- a/arm_compute/core/utils/misc/utility.h
+++ b/arm_compute/core/utils/misc/utility.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,6 +25,7 @@
#define __ARM_COMPUTE_MISC_UTILITY_H__
#include <array>
+#include <limits>
namespace arm_compute
{
@@ -123,6 +124,22 @@ inline auto foldl(F &&func, T &&initial, U &&value, Us &&... values) -> decltype
{
return foldl(std::forward<F>(func), func(std::forward<T>(initial), std::forward<U>(value)), std::forward<Us>(values)...);
}
+
+/** Type cast with saturation.
+ *
+ * @param[in] val Value of type U to cast.
+ *
+ * @return Original value clamped to numeric limits of T and converted to type T.
+ *
+ * @warning Numeric limits of T must be representable without loss in type U.
+ */
+template <typename T, typename U>
+T saturate_cast(U val)
+{
+ const auto low = static_cast<U>(std::numeric_limits<T>::lowest());
+ const auto high = static_cast<U>(std::numeric_limits<T>::max());
+ return static_cast<T>(clamp(val, low, high));
+}
} // namespace utility
} // namespace arm_compute
#endif /* __ARM_COMPUTE_MISC_UTILITY_H__ */