aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-12-19 13:53:44 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-01-06 12:50:25 +0000
commit4536193e78570989c1b54f1c5d57627f29f9d400 (patch)
treeb571913e061477a19345ee803bf91b35ff61e188
parent807ce59755c4aecc5be6d9ef7d0305f895acdfa3 (diff)
downloadComputeLibrary-4536193e78570989c1b54f1c5d57627f29f9d400.tar.gz
COMPMID-2801: Add support for QASYMM8_SIGNED in NEDirectConvolutionLayerOutputStageKernel
Change-Id: Ib047dd1024b8ecac60e2d368cb161ca418c933ff Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/2503 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
-rw-r--r--arm_compute/core/KernelDescriptors.h11
-rw-r--r--arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h39
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp469
-rw-r--r--src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp14
4 files changed, 204 insertions, 329 deletions
diff --git a/arm_compute/core/KernelDescriptors.h b/arm_compute/core/KernelDescriptors.h
index f358153b0d..d009ccc73d 100644
--- a/arm_compute/core/KernelDescriptors.h
+++ b/arm_compute/core/KernelDescriptors.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -83,5 +83,14 @@ struct SoftmaxKernelInfo
bool is_log{ false }; /**< Flag used to perform Log Softmax operation */
DataType input_data_type{ DataType::UNKNOWN }; /**< Input tensor data type */
};
+
+/** Descriptor used by the direct convolution layer output stage kernels */
+struct DirectConvolutionLayerOutputStageKernelInfo
+{
+ int32_t result_fixedpoint_multiplier{ 0 }; /**< Result output stage multiplier used for quantizing */
+ int32_t result_shift{ 0 }; /**< Result output stage shift used for quantizing */
+ int32_t result_offset_after_shift{ 0 }; /**< Result offset used for quantizing */
+ DataType output_data_type{ DataType::UNKNOWN }; /**< Output tensor data type to use if the output is not initialized */
+};
} // namespace arm_compute
#endif /* ARM_COMPUTE_CORE_KERNEL_DESCRIPTORS_H */
diff --git a/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h b/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h
index 3f41edc5aa..b7632d70c4 100644
--- a/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h
+++ b/arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -24,6 +24,7 @@
#ifndef ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H
#define ARM_COMPUTE_NEDIRECTCONVOLUTIONLAYEROUTPUTSTAGEKERNEL_H
+#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/NEON/INEKernel.h"
namespace arm_compute
@@ -32,6 +33,8 @@ class ITensor;
/** NEON kernel to accumulate the biases, if provided, or downscale in case of quantized input.
*
* @note We assume bias to be shared
+ * @note For quantized computations (i.e. @p input of S32 type) the output data type for auto-initialization must be passed as part
+ * of the @ref DirectConvolutionLayerOutputStageKernelInfo.
*/
class NEDirectConvolutionLayerOutputStageKernel : public INEKernel
{
@@ -54,32 +57,30 @@ public:
~NEDirectConvolutionLayerOutputStageKernel() = default;
/** Set the accumulate buffer and the biases of the kernel.
*
- * @param[in, out] input Input to add the bias to. If @p output is not specified then accumulation is done in-place.
- * Data type supported: F16/F32
- * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
- * @param[out] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
- * Data type supported: F16/F32
- * @param[in] result_fixedpoint_multiplier (Optional) Fixed point value to be multiplied to each element of the input matrix once the result_offset has been added
- * @param[in] result_shift (Optional) Integer value used to round the result of the fixed point multiplication to nearest division by a power-of-two
- * @param[in] result_offset_after_shift (Optional) Offset to be applied to result before converting it back to QASYMM8
+ * @param[in, out] input Input to add the bias to. If @p output is not specified then accumulation is done in-place.
+ * Data type supported: F16/F32/S32
+ * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
+ * @param[out] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
+ * Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr.
+ * Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p input is S32
+ * @param[in] info (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata
*/
void configure(ITensor *input, const ITensor *bias = nullptr, ITensor *output = nullptr,
- int result_fixedpoint_multiplier = 0, int result_shift = 0, int result_offset_after_shift = 0);
+ const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo());
/** Static function to check if given info will lead to a valid configuration of @ref NEDirectConvolutionLayerOutputStageKernel
*
- * @param[in] input Input to add the bias to. If @p output is not specified then accumulation is done in-place.
- * Data type supported: F16/F32
- * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
- * @param[in] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
- * Data type supported: F16/F32
- * @param[in] result_fixedpoint_multiplier (Optional) Fixed point value to be multiplied to each element of the input matrix once the result_offset has been added
- * @param[in] result_shift (Optional) Integer value used to round the result of the fixed point multiplication to nearest division by a power-of-two
- * @param[in] result_offset_after_shift (Optional) Offset to be applied to result before converting it back to QASYMM8
+ * @param[in] input Input to add the bias to. If @p output is not specified then accumulation is done in-place.
+ * Data type supported: F16/F32/S32
+ * @param[in] bias (Optional) The shared bias tensor to add. It must be 1D Tensor. Data type supported: Same as @p input
+ * @param[in] output (Optional) If the output tensor is specified the accumulation is done out-of-place. (Defaults to nullptr)
+ * Note that in-place computation is only supported for F16/F32. For S32 this must not be nullptr.
+ * Data type supported: F16/F32 or QASYMM8/QASYMM8_SIGNED if @p input is S32
+ * @param[in] info (Optional) DirectConvolutionLayerOutputStageKernel descriptor metadata
*
* @return a status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *bias = nullptr, const ITensorInfo *output = nullptr,
- int result_fixedpoint_multiplier = 0, int result_shift = 0, int result_offset_after_shift = 0);
+ const DirectConvolutionLayerOutputStageKernelInfo &info = DirectConvolutionLayerOutputStageKernelInfo());
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
index 8834d9747a..2f106a3f79 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -30,9 +30,11 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/NEAsymm.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
+#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/Traits.h"
#include <arm_neon.h>
#include <cstddef>
@@ -43,62 +45,68 @@ namespace arm_compute
namespace
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ const DirectConvolutionLayerOutputStageKernelInfo &info)
{
- ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
- ARM_COMPUTE_UNUSED(result_shift);
- ARM_COMPUTE_UNUSED(result_offset_after_shift);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8,
- DataType::F16,
- DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::S32, DataType::F32);
if(bias != nullptr)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::F16, DataType::S32, DataType::F32);
-
- if(is_data_type_quantized_asymmetric(input->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
- }
- else
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
- }
-
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
}
+ if(input->data_type() == DataType::S32)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output == nullptr, "In-place computation not allowed for quantized output");
+ }
+
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
-
- if(is_data_type_quantized_asymmetric(output->data_type()))
+ if(is_data_type_float(input->data_type()))
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && output->data_type() != DataType::QASYMM8, "Wrong data type for bias");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
}
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
+ }
+ else if(input->data_type() == DataType::S32)
+ {
+ // In case of quantized computation and unconfigured output, the output data type must be provided through DirectConvolutionLayerOutputStageKernelInfo
+ ARM_COMPUTE_RETURN_ERROR_ON((info.output_data_type != DataType::QASYMM8) && (info.output_data_type != DataType::QASYMM8_SIGNED));
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output,
+ const DirectConvolutionLayerOutputStageKernelInfo &info)
{
ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
+ const DataType data_type = input->data_type();
+
+ // Auto-initialize output output if required
+ if(output != nullptr)
+ {
+ // Work out expected output data type
+ const DataType output_dt = (data_type == DataType::S32) ? info.output_data_type : data_type;
+ // Output tensor auto initialization if not yet initialized
+ auto_init_if_empty(*output, input->clone()->set_data_type(output_dt));
+ }
+
bool window_changed = false;
- unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
+ unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(data_type);
// Update processed elements when input is S32 (comes from quantization input)
- if(input->data_type() == DataType::S32)
+ if(data_type == DataType::S32)
{
num_elems_processed_per_iteration = 16;
}
@@ -150,107 +158,44 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
return std::make_pair(err, win);
}
-// Internal load
-inline float32x4_t internal_vld1q(const float *in)
-{
- return vld1q_f32(in);
-}
-
-// Internal store
-inline void internal_vst1q(float *p, const float32x4_t &v)
-{
- vst1q_f32(p, v);
-}
-
-// Internal vdup
-inline float32x4_t internal_vdupq_n(float v)
-{
- return vdupq_n_f32(v);
-}
-
-// Internal vadd
-inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y)
+template <typename T, bool has_bias>
+typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
+output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
{
- return vaddq_f32(x, y);
-}
+ /** NEON vector tag type. */
+ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-inline float16x8_t internal_vld1q(const float16_t *in)
-{
- return vld1q_f16(in);
-}
-inline void internal_vst1q(float16_t *p, const float16x8_t &v)
-{
- vst1q_f16(p, v);
-}
-inline float16x8_t internal_vdupq_n(float16_t v)
-{
- return vdupq_n_f16(v);
-}
-inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
-{
- return vaddq_f16(x, y);
-}
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-
-template <typename T1, typename T2, bool in_place, bool has_bias>
-void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
-{
ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
ARM_COMPUTE_UNUSED(result_shift);
ARM_COMPUTE_UNUSED(result_offset_after_shift);
Iterator in(input, window);
-
- if(in_place) // In place accumulate
+ Iterator out(output, window);
+ execute_window_loop(window, [&](const Coordinates & id)
{
- execute_window_loop(window, [&](const Coordinates & id)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
+ auto v_in = wrapper::vloadq(in_ptr);
- // Accumulate bias
- if(has_bias)
- {
- const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
- internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
- }
- else
- {
- internal_vst1q(in_ptr, internal_vld1q(in_ptr));
- }
- },
- in);
- }
- else // Out of place accumulate
- {
- Iterator out(output, window);
- execute_window_loop(window, [&](const Coordinates & id)
+ // Accumulate bias
+ if(has_bias)
{
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<const T1 *>(in.ptr());
- const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
+ const auto vb = wrapper::vdup_n(*reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
+ v_in = wrapper::vadd(v_in, vb);
+ }
- // Accumulate bias
- if(has_bias)
- {
- const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
- internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
- }
- else
- {
- internal_vst1q(out_ptr, internal_vld1q(in_ptr));
- }
- },
- in, out);
- }
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+ wrapper::vstore(out_ptr, v_in);
+ },
+ in, out);
}
-template <typename T1, typename T2, bool in_place, bool has_bias>
-void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+template <typename T, bool has_bias>
+typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
+output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
{
ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
ARM_COMPUTE_UNUSED(result_shift);
@@ -263,59 +208,39 @@ void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window
Iterator in(input, window);
Iterator bi(bias, window_bias);
-
- if(in_place) // In place accumulate
+ Iterator out(output, window);
+ execute_window_loop(window, [&](const Coordinates &)
{
- execute_window_loop(window, [&](const Coordinates &)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
- const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+ // Get bias and pointer to input
+ const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
+ auto v_in = wrapper::vloadq(in_ptr);
- // Accumulate bias
- if(has_bias)
- {
- internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
- }
- else
- {
- internal_vst1q(in_ptr, internal_vld1q(in_ptr));
- }
- },
- in, bi);
- }
- else // Out of place accumulate
- {
- Iterator out(output, window);
- execute_window_loop(window, [&](const Coordinates &)
+ // Accumulate bias
+ if(has_bias)
{
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
- const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
- const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
+ const auto bias_ptr = reinterpret_cast<T *>(bi.ptr());
+ v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
+ }
- // Accumulate bias
- if(has_bias)
- {
- internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
- }
- else
- {
- internal_vst1q(out_ptr, internal_vld1q(in_ptr));
- }
- },
- in, bi, out);
- }
+ const auto out_ptr = reinterpret_cast<T *>(out.ptr());
+ wrapper::vstore(out_ptr, v_in);
+
+ },
+ in, bi, out);
}
-// QASYMM8 specializations
-template <>
-void output_stage_nchw<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+// Quantized case
+template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
+void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
{
+ using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+ using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
+
const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
- uint8x16_t min = vdupq_n_u8(0);
- uint8x16_t max = vdupq_n_u8(255);
+
+ const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+ const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
Iterator in(input, window);
Iterator out(output, window);
@@ -327,68 +252,44 @@ void output_stage_nchw<int32_t, uint8_t, false, true>(ITensor *input, const ITen
int32x4x4_t v_in =
{
{
- vld1q_s32(in_ptr),
- vld1q_s32(in_ptr + 4),
- vld1q_s32(in_ptr + 8),
- vld1q_s32(in_ptr + 12)
+ wrapper::vloadq(in_ptr),
+ wrapper::vloadq(in_ptr + 4),
+ wrapper::vloadq(in_ptr + 8),
+ wrapper::vloadq(in_ptr + 12)
}
};
// Accumulate bias
- const auto vb = vdupq_n_s32(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))));
- v_in =
+ if(has_bias)
{
+ const auto vb = wrapper::vdup_n(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
+ v_in =
{
- vaddq_s32(v_in.val[0], vb),
- vaddq_s32(v_in.val[1], vb),
- vaddq_s32(v_in.val[2], vb),
- vaddq_s32(v_in.val[3], vb)
- }
- };
+ {
+ wrapper::vadd(v_in.val[0], vb),
+ wrapper::vadd(v_in.val[1], vb),
+ wrapper::vadd(v_in.val[2], vb),
+ wrapper::vadd(v_in.val[3], vb)
+ }
+ };
+ }
- const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
- vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
+ wrapper::vstore(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
},
in, out);
}
-template <>
-void output_stage_nchw<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
+void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
+ int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
{
- ARM_COMPUTE_UNUSED(bias);
+ using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
+ using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
- uint8x16_t min = vdupq_n_u8(0);
- uint8x16_t max = vdupq_n_u8(255);
-
- Iterator in(input, window);
- Iterator out(output, window);
- execute_window_loop(window, [&](const Coordinates &)
- {
- // Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
- int32x4x4_t v_in =
- {
- {
- vld1q_s32(in_ptr),
- vld1q_s32(in_ptr + 4),
- vld1q_s32(in_ptr + 8),
- vld1q_s32(in_ptr + 12)
- }
- };
- const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
- vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
- },
- in, out);
-}
-template <>
-void output_stage_nhwc<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
-{
- const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
- uint8x16_t min = vdupq_n_u8(0);
- uint8x16_t max = vdupq_n_u8(255);
+ const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
+ const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
Window window_bias = window;
window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
@@ -402,56 +303,32 @@ void output_stage_nhwc<int32_t, uint8_t, false, true>(ITensor *input, const ITen
execute_window_loop(window, [&](const Coordinates &)
{
// Get bias and pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
- const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr());
-
- // Accumulate bias
+ const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
int32x4x4_t v_in =
{
{
- vaddq_s32(vld1q_s32(in_ptr), vld1q_s32(bias_ptr)),
- vaddq_s32(vld1q_s32(in_ptr + 4), vld1q_s32(bias_ptr + 4)),
- vaddq_s32(vld1q_s32(in_ptr + 8), vld1q_s32(bias_ptr + 8)),
- vaddq_s32(vld1q_s32(in_ptr + 12), vld1q_s32(bias_ptr + 12))
+ wrapper::vloadq(in_ptr),
+ wrapper::vloadq(in_ptr + 4),
+ wrapper::vloadq(in_ptr + 8),
+ wrapper::vloadq(in_ptr + 12),
}
};
- const auto out_ptr = out.ptr();
- vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
- },
- in, bi, out);
-}
-template <>
-void output_stage_nhwc<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
-{
- ARM_COMPUTE_UNUSED(bias);
-
- const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
- uint8x16_t min = vdupq_n_u8(0);
- uint8x16_t max = vdupq_n_u8(255);
-
- Iterator in(input, window);
- Iterator out(output, window);
- execute_window_loop(window, [&](const Coordinates &)
- {
- // Get pointer to input
- const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
-
- int32x4x4_t v_in =
+ // Accumulate bias
+ if(has_bias)
{
- {
- vld1q_s32(in_ptr),
- vld1q_s32(in_ptr + 4),
- vld1q_s32(in_ptr + 8),
- vld1q_s32(in_ptr + 12)
- }
- };
+ const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr());
+
+ wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
+ wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
+ wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
+ wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
+ }
- const auto out_ptr = out.ptr();
- vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
+ const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
+ wrapper::vstore(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
},
- in, out);
+ in, bi, out);
}
} // namespace
@@ -461,37 +338,27 @@ NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKe
}
void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ const DirectConvolutionLayerOutputStageKernelInfo &info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input);
-
- // Auto-initialize output output if required
- if(output != nullptr)
- {
- // Work out expected output data type
- const DataType output_dt = (input->info()->data_type() == DataType::S32) ? DataType::QASYMM8 : input->info()->data_type();
- // Output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_dt));
- }
-
// Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(),
- result_fixedpoint_multiplier, result_shift, result_offset_after_shift));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info));
_func = nullptr;
_bias = bias;
_input = input;
- _output = output;
- _result_fixedpoint_multiplier = result_fixedpoint_multiplier;
- _result_shift = result_shift;
- _result_offset_after_shift = result_offset_after_shift;
+ _output = (output != nullptr) ? output : input;
+ _result_fixedpoint_multiplier = info.result_fixedpoint_multiplier;
+ _result_shift = info.result_shift;
+ _result_offset_after_shift = info.result_offset_after_shift;
// Configure kernel window
- auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info());
+ auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
- const bool has_bias = bias != nullptr;
+ const bool has_bias = bias != nullptr;
+ const bool is_qasymm8_signed = (output != nullptr) ? is_data_type_quantized_asymmetric_signed(output->info()->data_type()) : false;
// Set appropriate function
if(input->info()->data_layout() == DataLayout::NCHW)
@@ -500,33 +367,26 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
{
case DataType::S32:
{
- _func = (bias == nullptr) ? &output_stage_nchw<int32_t, uint8_t, false, false> : &output_stage_nchw<int32_t, uint8_t, false, true>;
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- if(has_bias)
+ if(is_qasymm8_signed)
{
- _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, true> : &output_stage_nchw<float16_t, float16_t, false, true>;
+ _func = (has_bias) ? &output_stage_nchw<int8_t, true> : &output_stage_nchw<int8_t, false>;
}
else
{
- _func = (output == nullptr) ? &output_stage_nchw<float16_t, float16_t, true, false> : &output_stage_nchw<float16_t, float16_t, false, false>;
+ _func = (has_bias) ? &output_stage_nchw<uint8_t, true> : &output_stage_nchw<uint8_t, false>;
}
break;
}
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ {
+ _func = (has_bias) ? &output_stage_nchw<float16_t, true> : &output_stage_nchw<float16_t, false>;
+ break;
+ }
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- if(has_bias)
- {
- _func = (output == nullptr) ? &output_stage_nchw<float, float, true, true> : &output_stage_nchw<float, float, false, true>;
- }
- else
- {
- _func = (output == nullptr) ? &output_stage_nchw<float, float, true, false> : &output_stage_nchw<float, float, false, false>;
- }
+ _func = (has_bias) ? &output_stage_nchw<float, true> : &output_stage_nchw<float, false>;
break;
}
default:
@@ -541,33 +401,26 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
{
case DataType::S32:
{
- _func = (bias == nullptr) ? &output_stage_nhwc<int32_t, uint8_t, false, false> : &output_stage_nhwc<int32_t, uint8_t, false, true>;
- break;
- }
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- {
- if(has_bias)
+ if(is_qasymm8_signed)
{
- _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, true> : &output_stage_nhwc<float16_t, float16_t, false, true>;
+ _func = (has_bias) ? &output_stage_nhwc<int8_t, true> : &output_stage_nhwc<int8_t, false>;
}
else
{
- _func = (output == nullptr) ? &output_stage_nhwc<float16_t, float16_t, true, false> : &output_stage_nhwc<float16_t, float16_t, false, false>;
+ _func = (has_bias) ? &output_stage_nhwc<uint8_t, true> : &output_stage_nhwc<uint8_t, false>;
}
break;
}
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ case DataType::F16:
+ {
+ _func = (has_bias) ? &output_stage_nhwc<float16_t, true> : &output_stage_nhwc<float16_t, false>;
+ break;
+ }
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
case DataType::F32:
{
- if(has_bias)
- {
- _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
- }
- else
- {
- _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, false> : &output_stage_nhwc<float, float, false, false>;
- }
+ _func = (has_bias) ? &output_stage_nhwc<float, true> : &output_stage_nhwc<float, false>;
break;
}
default:
@@ -579,10 +432,14 @@ void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const
}
Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
- int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
+ const DirectConvolutionLayerOutputStageKernelInfo &info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, result_fixedpoint_multiplier, result_shift, result_offset_after_shift));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
+ bias == nullptr ? nullptr : bias->clone().get(),
+ output == nullptr ? nullptr : output->clone().get(),
+ info)
+ .first);
return Status{};
}
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index ddcc71f466..0320002fba 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -67,7 +67,9 @@ Status validate_arguments_optimized(const ITensorInfo *input, const ITensorInfo
if(is_quantized)
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&accumulator, biases, output));
+ DirectConvolutionLayerOutputStageKernelInfo direct_conv_info;
+ direct_conv_info.output_data_type = input->data_type();
+ ARM_COMPUTE_RETURN_ON_ERROR(NEDirectConvolutionLayerOutputStageKernel::validate(&accumulator, biases, output, direct_conv_info));
}
}
else
@@ -196,7 +198,13 @@ void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::
int32_t output_multiplier;
int32_t output_shift;
quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
- _output_stage_kernel.configure(&_accumulator, biases, _is_nchw ? output : &_permuted_output, output_multiplier, output_shift, oq_info.offset);
+
+ DirectConvolutionLayerOutputStageKernelInfo direct_conv_info;
+ direct_conv_info.result_fixedpoint_multiplier = output_multiplier;
+ direct_conv_info.result_shift = output_shift;
+ direct_conv_info.result_offset_after_shift = oq_info.offset;
+ direct_conv_info.output_data_type = input->info()->data_type();
+ _output_stage_kernel.configure(&_accumulator, biases, _is_nchw ? output : &_permuted_output, direct_conv_info);
_accumulator.allocator()->allocate();
}
else if(_has_bias)