aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-29 14:14:13 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-31 16:06:44 +0000
commitc4d5136707280d98f660a67219114f5ee5b10fb8 (patch)
tree368b0b044127915ba50de53f228a8fb2ee06b13d
parent4d600c728a75792c5479b54114ec11c6d8fea61a (diff)
downloadComputeLibrary-c4d5136707280d98f660a67219114f5ee5b10fb8.tar.gz
COMPMID-2493: Update qs8 in Depthwise assembly
Introduces minor optimisation for qasymm8 for depthwise convolution. Change-Id: I1b88b1475f8f1ef34c3a7c5580cdeef8b032a100 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1647 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/convolution/depthwise/depthwise_quantized.hpp5
-rw-r--r--src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp805
2 files changed, 167 insertions, 643 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise_quantized.hpp b/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise_quantized.hpp
index b65ced6f35..f8db4db6cc 100644
--- a/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise_quantized.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/depthwise/depthwise_quantized.hpp
@@ -109,11 +109,6 @@ class QAsymm8DepthwiseConvolution : public DepthwiseConvolutionBase<
);
protected:
- static nck::ActivationFunction get_activation_fn(
- nck::ActivationFunction activation,
- const qasymm8::QAsymm8Params& output_quantisation
- );
-
uint8_t _input_padding_value(void) const;
void _pack_params(
diff --git a/src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp b/src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp
index bda875dfe1..f638f0bb38 100644
--- a/src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp
+++ b/src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp
@@ -38,33 +38,10 @@
#pragma once
-// Comment the following to use floating-point based quantisation, leave
-// uncommented to use fixed-point.
-#define FIXED_POINT_REQUANTISATION 1
-
using namespace neon_convolution_kernels;
using namespace qasymm8;
template <typename T>
-struct clamp_to_limits
-{
- template <typename U>
- static inline U clamp(const U& v)
- {
- const std::numeric_limits<T> limits;
- const U min = static_cast<U>(limits.min());
- const U max = static_cast<U>(limits.max());
- return std::min(std::max(v, min), max);
- }
-
- template <typename U>
- static inline T clamp_and_cast(const U& v)
- {
- return static_cast<U>(clamp(v));
- }
-};
-
-template <typename T>
inline T saturating_doubling_high_mul(const T&, const int32_t&);
template <>
@@ -182,8 +159,7 @@ QAsymm8DepthwiseConvolution<
unsigned int padding_bottom,
unsigned int padding_right
) : Base(
- n_batches, n_input_rows, n_input_cols, n_channels,
- get_activation_fn(activation, output_quantisation),
+ n_batches, n_input_rows, n_input_cols, n_channels, activation,
padding_top, padding_left, padding_bottom, padding_right
),
_weights_quant(weight_quantisation),
@@ -214,8 +190,7 @@ QAsymm8DepthwiseConvolution<
unsigned int padding_right
) : Base(
n_batches, n_input_rows, n_input_cols, n_channels,
- n_output_rows, n_output_cols,
- get_activation_fn(activation, output_quantisation),
+ n_output_rows, n_output_cols, activation,
padding_top, padding_left, padding_bottom, padding_right
),
_weights_quant(weight_quantisation),
@@ -230,45 +205,6 @@ template <
unsigned int KernelRows, unsigned int KernelCols,
unsigned int StrideRows, unsigned int StrideCols
>
-ActivationFunction QAsymm8DepthwiseConvolution<
- OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols
->::get_activation_fn(
- const ActivationFunction activation,
- const QAsymm8Params& output_quant
-)
-{
- if (
- (activation == ActivationFunction::ReLU &&
- output_quant.quantize(0) == 0) ||
- (activation == ActivationFunction::ReLU6 &&
- output_quant.quantize(0) == 0 &&
- output_quant.dequantize(255) <= 6.0f)
- )
- {
- // If the range of values which can be represented by a quantized value are
- // within the range that would be produced by the activation function, then
- // the activation function is redundant and can be skipped.
- return ActivationFunction::None;
- }
- else if(
- activation == ActivationFunction::ReLU6 &&
- output_quant.dequantize(255) <= 6.0f
- )
- {
- // If the largest value that can be represented by a quantized value is
- // lower than the upper boundary, then the activation function can be
- // relaxed to a ReLU.
- return ActivationFunction::ReLU;
- }
-
- return activation;
-}
-
-template <
- unsigned int OutputTileRows, unsigned int OutputTileCols,
- unsigned int KernelRows, unsigned int KernelCols,
- unsigned int StrideRows, unsigned int StrideCols
->
uint8_t QAsymm8DepthwiseConvolution<
OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols
>::_input_padding_value(void) const
@@ -295,20 +231,9 @@ void QAsymm8DepthwiseConvolution<
const int32_t *bptr = static_cast<const int32_t *>(biases);
uint8_t *outptr = static_cast<uint8_t *>(buffer);
- // We set the vector length to use quad registers on Aarch64 and only doubles
- // on Aarch32. NOTE For SVE set this to the actual vector length.
-#if defined(__aarch64__)
- unsigned int veclen = 16;
-#else
-#if defined(__arm__)
+ // We set the vector length to use doubles on both Aarch64 and Aarch32. NOTE
+ // For SVE set this to half the vector length.
unsigned int veclen = 8;
-#endif
-#endif
-
- // Compute the rank 0 offset arising from the quantisation parameters.
- const int32_t rank0_offset = (KernelRows * KernelCols *
- static_cast<int32_t>(_weights_quant.offset) *
- static_cast<int32_t>(_inputs_quant.offset));
// While there are channels left to process, pack a vector length of them at
// a time and reduce the size of vector used as the size of the tensor
@@ -335,8 +260,8 @@ void QAsymm8DepthwiseConvolution<
// Copy a vector length of elements
for (unsigned int n = 0; n < veclen && n < n_channels; n++)
{
- int32_t bias = (bptr != nullptr) ? *(bptr++) : 0;
- uint32_t weight_sum = 0;
+ const int32_t bias = (bptr != nullptr) ? *(bptr++) : 0;
+ out_bptr[n] = bias;
for (unsigned int i = 0; i < KernelRows; i++)
{
@@ -345,16 +270,9 @@ void QAsymm8DepthwiseConvolution<
{
uint8_t w = *(wptr + i*weight_row_stride + j*weight_col_stride);
row_outptr[j*veclen + n] = w;
- weight_sum += static_cast<uint32_t>(w);
}
}
wptr++;
-
- // Include in the bias contributions from the quantisation offset
- int32_t rank1_offset = static_cast<int32_t>(
- static_cast<uint32_t>(_inputs_quant.offset) * weight_sum
- );
- out_bptr[n] = bias + rank0_offset - rank1_offset;
}
}
}
@@ -362,156 +280,33 @@ void QAsymm8DepthwiseConvolution<
template <
unsigned int OutputTileRows, unsigned int OutputTileCols,
unsigned int KernelRows, unsigned int KernelCols,
- unsigned int StrideRows, unsigned int StrideCols
+ unsigned int StrideRows, unsigned int StrideCols,
+ typename FInput, typename FOutput
>
-template<ActivationFunction Activation>
-void QAsymm8DepthwiseConvolution<
- OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols
->::execute_tile(
+static inline void tilefn(
int n_channels,
const void* packed_params,
- const uint8_t* inptr,
- const unsigned int in_row_stride,
- const unsigned int in_col_stride,
- uint8_t* outptr,
- const unsigned int out_row_stride,
- const unsigned int out_col_stride
+ FInput &get_input_ptr,
+ FOutput &get_output_ptr,
+ const int32_t clamp_max,
+ const int32_t clamp_min,
+ const uint8_t input_offset,
+ const uint8_t weight_offset,
+ const uint8_t output_offset,
+ const int32_t requant_multiplier,
+ const int32_t requant_shift
)
{
- // Activation parameters (unused if Activation is None)
- const uint8_t aqmin = _output_quant.offset;
- const uint8_t aqmax = (Activation == ActivationFunction::ReLU6) ?
- std::min<uint8_t>(255u, _output_quant.quantize(6.0f)) : 255u;
+ constexpr int InnerTileRows = StrideRows * (OutputTileRows - 1) + KernelRows;
+ constexpr int InnerTileCols = StrideCols * (OutputTileCols - 1) + KernelCols;
+
+ // Offset into channels
+ int channel = 0;
// Byte type pointer to weights and biases
const uint8_t *wbptr = static_cast<const uint8_t *>(packed_params);
-#if defined(__aarch64__) // Under Aarch64 only use quad registers
- for (; n_channels >= 16; n_channels -= 16)
- {
- // Load biases
- const int32x4_t biases[4] = {
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr)),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 4),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 8),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 12)
- };
- wbptr += 16*sizeof(int32_t);
-
- // Load weights
- uint8x16_t weights[KernelRows][KernelCols];
- for (unsigned int i = 0; i < KernelRows; i++)
- {
- for (unsigned int j = 0; j < KernelCols; j++)
- {
- weights[i][j] = vld1q_u8(wbptr);
- wbptr += 16;
- }
- }
-
- // Load the input activations
- uint8x16_t inputs[Base::inner_tile_rows][Base::inner_tile_cols];
- for (unsigned int i = 0; i < Base::inner_tile_rows; i++)
- {
- for (unsigned int j = 0; j < Base::inner_tile_cols; j++)
- {
- inputs[i][j] = vld1q_u8(inptr + i*in_row_stride + j*in_col_stride);
- }
- }
- inptr += 16;
-
- // Perform the convolution
- for (unsigned int oi = 0; oi < OutputTileRows; oi++)
- {
- for (unsigned int oj = 0; oj < OutputTileCols; oj++)
- {
- // Two sets of operations are required, we perform the
- // multiply-accumulates for the convolution proper but must also sum
- // the tile elements to account for the _weight_ offset.
- uint32x4_t accs[4];
- for (unsigned int i = 0; i < 4; i++)
- {
- accs[i] = reinterpret_cast<uint32x4_t>(biases[i]);
- }
-
- for (unsigned int wi = 0; wi < KernelRows; wi++)
- {
- for (unsigned int wj = 0; wj < KernelCols; wj++)
- {
- // Get relevant weight and activation pixel
- const uint8x16_t w = weights[wi][wj];
- const uint8x16_t x = inputs[oi*StrideRows + wi][oj*StrideCols + wj];
-
- // Perform multiplication and accumulation
- const uint16x8_t muls[2] = {
- vmull_u8(vget_low_u8(w), vget_low_u8(x)),
- vmull_u8(vget_high_u8(w), vget_high_u8(x))
- };
-
- const uint8x8_t woffset = vdup_n_u8(_weights_quant.offset);
- const uint16x8_t sum_elems[2] = {
- vmull_u8(vget_low_u8(x), woffset),
- vmull_u8(vget_high_u8(x), woffset)
- };
-
- const uint32x4_t tmps[4] = {
- vsubl_u16(vget_low_u16(muls[0]), vget_low_u16(sum_elems[0])),
- vsubl_u16(vget_high_u16(muls[0]), vget_high_u16(sum_elems[0])),
- vsubl_u16(vget_low_u16(muls[1]), vget_low_u16(sum_elems[1])),
- vsubl_u16(vget_high_u16(muls[1]), vget_high_u16(sum_elems[1])),
- };
- for (unsigned int i = 0; i < 4; i++)
- {
- accs[i] = vaddq_u32(accs[i], tmps[i]);
- }
- }
- }
-
- // Rescale the accumulator and add in the new offset.
- uint32x4_t final_accs[4];
- for (unsigned int i = 0; i < 4; i++)
- {
-#ifdef FIXED_POINT_REQUANTISATION
- const int32x4_t y = rounding_divide_by_exp2(
- saturating_doubling_high_mul(
- reinterpret_cast<int32x4_t>(accs[i]), rescale_parameters.multiplier
- ),
- rescale_parameters.shift
- );
- const int32x4_t offset = reinterpret_cast<int32x4_t>(vdupq_n_u32(_output_quant.offset));
- final_accs[i] = reinterpret_cast<uint32x4_t>(vmaxq_s32(vaddq_s32(y, offset), vdupq_n_s32(0)));
-#else // floating point requantisation
- float32x4_t fp_acc = vcvtq_f32_s32(reinterpret_cast<int32x4_t>(accs[i]));
- fp_acc = vmulq_f32(fp_acc, vdupq_n_f32(rescale_parameters.rescale));
- fp_acc = vaddq_f32(fp_acc, vdupq_n_f32(static_cast<float>(_output_quant.offset)));
- fp_acc = vmaxq_f32(fp_acc, vdupq_n_f32(0.0f));
- final_accs[i] = vcvtq_u32_f32(fp_acc);
-#endif
- }
-
- uint8x16_t output = vcombine_u8(
- vqmovn_u16(vcombine_u16(vqmovn_u32(final_accs[0]), vqmovn_u32(final_accs[1]))),
- vqmovn_u16(vcombine_u16(vqmovn_u32(final_accs[2]), vqmovn_u32(final_accs[3])))
- );
-
- // Apply the activation function
- if (Activation == ActivationFunction::ReLU ||
- Activation == ActivationFunction::ReLU6)
- {
- output = vmaxq_u8(output, vdupq_n_u8(aqmin));
- }
- if (Activation == ActivationFunction::ReLU6)
- {
- output = vminq_u8(output, vdupq_n_u8(aqmax));
- }
-
- vst1q_u8(outptr + oi*out_row_stride + oj*out_col_stride, output);
- }
- }
- outptr += 16;
- }
-#endif // defined(__aarch64__)
- for (; n_channels >= 8; n_channels -= 8)
+ for (; n_channels >= 8; n_channels -= 8, channel += 8)
{
const int32x4_t biases[2] = {
vld1q_s32(reinterpret_cast<const int32_t *>(wbptr)),
@@ -519,123 +314,99 @@ void QAsymm8DepthwiseConvolution<
};
wbptr += 8*sizeof(int32_t);
- uint8x8_t weights[KernelRows][KernelCols];
+ int16x8_t weights[KernelRows][KernelCols];
+ const uint8x8_t woffset = vdup_n_u8(weight_offset);
for (unsigned int i = 0; i < KernelRows; i++)
{
for (unsigned int j = 0; j < KernelCols; j++)
{
- weights[i][j] = vld1_u8(wbptr);
+ const uint8x8_t w = vld1_u8(wbptr);
+ weights[i][j] = reinterpret_cast<int16x8_t>(vsubl_u8(w, woffset));
wbptr += 8;
}
}
- uint8x8_t inputs[Base::inner_tile_rows][Base::inner_tile_cols];
- for (unsigned int i = 0; i < Base::inner_tile_rows; i++)
+ int16x8_t inputs[InnerTileRows][InnerTileCols];
+ const uint8x8_t ioffset = vdup_n_u8(input_offset);
+ for (unsigned int i = 0; i < InnerTileRows; i++)
{
- for (unsigned int j = 0; j < Base::inner_tile_cols; j++)
+ for (unsigned int j = 0; j < InnerTileCols; j++)
{
- inputs[i][j] = vld1_u8(inptr + i*in_row_stride + j*in_col_stride);
+ const auto x = vld1_u8(get_input_ptr(i, j, channel));
+ inputs[i][j] = reinterpret_cast<int16x8_t>(vsubl_u8(x, ioffset));
}
}
- inptr += 8;
for (unsigned int oi = 0; oi < OutputTileRows; oi++)
{
for (unsigned int oj = 0; oj < OutputTileCols; oj++)
{
- uint32x4_t accs[2];
- for (unsigned int i = 0; i < 2; i++)
- {
- accs[i] = reinterpret_cast<uint32x4_t>(biases[i]);
- }
+ int32x4_t acc_a = biases[0], acc_b = biases[1];
for (unsigned int wi = 0; wi < KernelRows; wi++)
{
for (unsigned int wj = 0; wj < KernelCols; wj++)
{
- const uint8x8_t w = weights[wi][wj];
- const uint8x8_t x = inputs[oi*StrideRows + wi][oj*StrideCols + wj];
-
- const uint16x8_t muls = vmull_u8(w, x);
- const uint8x8_t woffset = vdup_n_u8(_weights_quant.offset);
- const uint16x8_t sum_elems = vmull_u8(x, woffset);
-
- const uint32x4_t tmps[2] = {
- vsubl_u16(vget_low_u16(muls), vget_low_u16(sum_elems)),
- vsubl_u16(vget_high_u16(muls), vget_high_u16(sum_elems)),
- };
- for (unsigned int i = 0; i < 2; i++)
- {
- accs[i] = vaddq_u32(accs[i], tmps[i]);
- }
+ const auto w = weights[wi][wj];
+ const auto x = inputs[oi * StrideRows + wi][oj * StrideCols + wj];
+#ifndef __aarch64__
+ acc_a = vmlal_s16(acc_a, vget_low_s16(w), vget_low_s16(x));
+ acc_b = vmlal_s16(acc_b, vget_high_s16(w), vget_high_s16(x));
+#else
+ asm("smlal %[acc_a].4s, %[w].4h, %[x].4h\n"
+ "smlal2 %[acc_b].4s, %[w].8h, %[x].8h\n"
+ : [acc_a] "+w"(acc_a), [acc_b] "+w"(acc_b)
+ : [w] "w"(w), [x] "w"(x));
+#endif // __aarch64__
}
}
- uint32x4_t final_accs[2];
+ int32x4_t final_accs[2];
for (unsigned int i = 0; i < 2; i++)
{
-#ifdef FIXED_POINT_REQUANTISATION
const int32x4_t y = rounding_divide_by_exp2(
- saturating_doubling_high_mul(
- reinterpret_cast<int32x4_t>(accs[i]), rescale_parameters.multiplier
- ),
- rescale_parameters.shift
- );
- const int32x4_t offset = reinterpret_cast<int32x4_t>(vdupq_n_u32(_output_quant.offset));
- final_accs[i] = reinterpret_cast<uint32x4_t>(vmaxq_s32(vaddq_s32(y, offset), vdupq_n_s32(0)));
-#else // floating point requantisation
- float32x4_t fp_acc = vcvtq_f32_s32(reinterpret_cast<int32x4_t>(accs[i]));
- fp_acc = vmulq_f32(fp_acc, vdupq_n_f32(rescale_parameters.rescale));
- fp_acc = vaddq_f32(fp_acc, vdupq_n_f32(static_cast<float>(_output_quant.offset)));
- fp_acc = vmaxq_f32(fp_acc, vdupq_n_f32(0.0f));
- final_accs[i] = vcvtq_u32_f32(fp_acc);
-#endif
- }
-
- uint8x8_t output = vqmovn_u16(vcombine_u16(vqmovn_u32(final_accs[0]), vqmovn_u32(final_accs[1])));
-
- // Apply the activation function
- if (Activation == ActivationFunction::ReLU ||
- Activation == ActivationFunction::ReLU6)
- {
- output = vmax_u8(output, vdup_n_u8(aqmin));
- }
- if (Activation == ActivationFunction::ReLU6)
- {
- output = vmin_u8(output, vdup_n_u8(aqmax));
+ saturating_doubling_high_mul((i == 0 ? acc_a : acc_b), requant_multiplier),
+ requant_shift);
+ const int32x4_t offset = reinterpret_cast<int32x4_t>(vdupq_n_u32(output_offset));
+ final_accs[i] = vaddq_s32(y, offset);
+ final_accs[i] = vmaxq_s32(final_accs[i], vdupq_n_s32(clamp_min));
+ final_accs[i] = vminq_s32(final_accs[i], vdupq_n_s32(clamp_max));
}
- vst1_u8(outptr + oi*out_row_stride + oj*out_col_stride, output);
+ const int8x16_t elems = vreinterpretq_s8_s16(
+ vuzp1q_s16(vreinterpretq_s16_s32(final_accs[0]),
+ vreinterpretq_s16_s32(final_accs[1])));
+ const uint8x8_t output =
+ vget_low_u8(vreinterpretq_u8_s8(vuzp1q_s8(elems, elems)));
+ vst1_u8(get_output_ptr(oi, oj, channel), output);
}
}
- outptr += 8;
}
- for (; n_channels; n_channels--)
+ for (; n_channels; n_channels--, channel++)
{
// Load bias
const int32_t bias = *reinterpret_cast<const int32_t *>(wbptr);
wbptr += sizeof(int32_t);
// Load weights
- uint8_t weights[KernelRows][KernelCols];
+ int16_t weights[KernelRows][KernelCols];
for (unsigned int i = 0; i < KernelRows; i++)
{
for (unsigned int j = 0; j < KernelCols; j++)
{
- weights[i][j] = *(wbptr++);
+ weights[i][j] = *(wbptr++) - weight_offset;
}
}
// Load the input activations
- uint8_t inputs[Base::inner_tile_rows][Base::inner_tile_cols];
- for (unsigned int i = 0; i < Base::inner_tile_rows; i++)
+ int16_t inputs[InnerTileRows][InnerTileCols];
+ for (unsigned int i = 0; i < InnerTileRows; i++)
{
- for (unsigned int j = 0; j < Base::inner_tile_cols; j++)
+ for (unsigned int j = 0; j < InnerTileCols; j++)
{
- inputs[i][j] = *(inptr + i*in_row_stride + j*in_col_stride);
+ inputs[i][j] = *(get_input_ptr(i, j, channel)) - input_offset;
}
}
- inptr++;
// Perform the convolution
for (unsigned int oi = 0; oi < OutputTileRows; oi++)
@@ -643,377 +414,135 @@ void QAsymm8DepthwiseConvolution<
for (unsigned int oj = 0; oj < OutputTileCols; oj++)
{
int32_t acc = bias;
- uint32_t element_sum = 0;
for (unsigned int wi = 0; wi < KernelRows; wi++)
{
for (unsigned int wj = 0; wj < KernelCols; wj++)
{
const auto w = weights[wi][wj], x = inputs[oi*StrideRows + wi][oj*StrideCols + wj];
- acc += static_cast<int32_t>(static_cast<uint32_t>(w) * static_cast<uint32_t>(x));
- element_sum += static_cast<uint32_t>(x);
+ acc += w * x;
}
}
- acc -= static_cast<int32_t>(element_sum) * static_cast<int32_t>(_weights_quant.offset);
-
// Requantize
-#ifdef FIXED_POINT_REQUANTISATION
acc = rounding_divide_by_exp2(
- saturating_doubling_high_mul(acc, rescale_parameters.multiplier),
- rescale_parameters.shift
- );
- acc += _output_quant.offset;
- uint8_t output = clamp_to_limits<uint8_t>::clamp_and_cast<int32_t>(acc);
-#else // floating point requantization
- float fp_acc = static_cast<float>(acc);
- fp_acc *= rescale_parameters.rescale;
- fp_acc += static_cast<float>(_output_quant.offset);
- fp_acc = std::max<float>(fp_acc, 0.0f);
- uint8_t output = static_cast<uint8_t>(std::min<int32_t>(static_cast<int32_t>(fp_acc), 255));
-#endif
-
- // Apply the activation function
- if (Activation == ActivationFunction::ReLU ||
- Activation == ActivationFunction::ReLU6)
- {
- output = std::max(output, aqmin);
- }
- if (Activation == ActivationFunction::ReLU6)
- {
- output = std::min(output, aqmax);
- }
-
- *(outptr + oi*out_row_stride + oj*out_col_stride) = output;
+ saturating_doubling_high_mul(acc, requant_multiplier),
+ requant_shift);
+ acc += output_offset;
+ acc = std::max(acc, clamp_min);
+ acc = std::min(acc, clamp_max);
+ uint8_t output = static_cast<uint8_t>(acc);
+ *(get_output_ptr(oi, oj, channel)) = output;
}
}
- outptr++;
}
}
template <
unsigned int OutputTileRows, unsigned int OutputTileCols,
unsigned int KernelRows, unsigned int KernelCols,
- unsigned int StrideRows, unsigned int StrideCols
+ unsigned int StrideRows, unsigned int StrideCols,
+ typename FInput, typename FOutput
>
-template<ActivationFunction Activation>
-void QAsymm8DepthwiseConvolution<
- OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols
->::execute_tile(
+static inline void execute_tilefn(
int n_channels,
const void* packed_params,
- const uint8_t* inptrs[Base::inner_tile_rows][Base::inner_tile_cols],
- uint8_t* outptrs[Base::output_tile_rows][Base::output_tile_cols]
-)
-{
- // Activation parameters (unused if Activation is None)
- const uint8_t aqmin = _output_quant.offset;
- const uint8_t aqmax = (Activation == ActivationFunction::ReLU6) ?
- std::min<uint8_t>(255u, _output_quant.quantize(6.0f)) : 255u;
-
- // Byte type pointer to weights and biases
- const uint8_t *wbptr = static_cast<const uint8_t *>(packed_params);
-
- // Offset into input/output tensors
- int n = 0;
-
-#if defined(__aarch64__) // Under Aarch64 only use quad registers
- for (; n_channels >= 16; n_channels -= 16, n += 16)
- {
- // Load biases
- const int32x4_t biases[4] = {
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr)),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 4),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 8),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 12)
- };
- wbptr += 16*sizeof(int32_t);
-
- // Load weights
- uint8x16_t weights[KernelRows][KernelCols];
- for (unsigned int i = 0; i < KernelRows; i++)
- {
- for (unsigned int j = 0; j < KernelCols; j++)
- {
- weights[i][j] = vld1q_u8(wbptr);
- wbptr += 16;
- }
- }
-
- // Load the input activations
- uint8x16_t inputs[Base::inner_tile_rows][Base::inner_tile_cols];
- for (unsigned int i = 0; i < Base::inner_tile_rows; i++)
- {
- for (unsigned int j = 0; j < Base::inner_tile_cols; j++)
- {
- inputs[i][j] = vld1q_u8(inptrs[i][j] + n);
- }
- }
-
- // Perform the convolution
- for (unsigned int oi = 0; oi < OutputTileRows; oi++)
- {
- for (unsigned int oj = 0; oj < OutputTileCols; oj++)
- {
- // Two sets of operations are required, we perform the
- // multiply-accumulates for the convolution proper but must also sum
- // the tile elements to account for the _weight_ offset.
- uint32x4_t accs[4];
- for (unsigned int i = 0; i < 4; i++)
- {
- accs[i] = reinterpret_cast<uint32x4_t>(biases[i]);
- }
-
- for (unsigned int wi = 0; wi < KernelRows; wi++)
- {
- for (unsigned int wj = 0; wj < KernelCols; wj++)
- {
- // Get relevant weight and activation pixel
- const uint8x16_t w = weights[wi][wj];
- const uint8x16_t x = inputs[oi*StrideRows + wi][oj*StrideCols + wj];
-
- // Perform multiplication and accumulation
- const uint16x8_t muls[2] = {
- vmull_u8(vget_low_u8(w), vget_low_u8(x)),
- vmull_u8(vget_high_u8(w), vget_high_u8(x))
- };
-
- const uint8x8_t woffset = vdup_n_u8(_weights_quant.offset);
- const uint16x8_t sum_elems[2] = {
- vmull_u8(vget_low_u8(x), woffset),
- vmull_u8(vget_high_u8(x), woffset)
- };
-
- const uint32x4_t tmps[4] = {
- vsubl_u16(vget_low_u16(muls[0]), vget_low_u16(sum_elems[0])),
- vsubl_u16(vget_high_u16(muls[0]), vget_high_u16(sum_elems[0])),
- vsubl_u16(vget_low_u16(muls[1]), vget_low_u16(sum_elems[1])),
- vsubl_u16(vget_high_u16(muls[1]), vget_high_u16(sum_elems[1])),
- };
- for (unsigned int i = 0; i < 4; i++)
- {
- accs[i] = vaddq_u32(accs[i], tmps[i]);
- }
- }
- }
-
- // Rescale the accumulator and add in the new offset.
- uint32x4_t final_accs[4];
- for (unsigned int i = 0; i < 4; i++)
- {
-#ifdef FIXED_POINT_REQUANTISATION
- const int32x4_t y = rounding_divide_by_exp2(
- saturating_doubling_high_mul(
- reinterpret_cast<int32x4_t>(accs[i]), rescale_parameters.multiplier
- ),
- rescale_parameters.shift
- );
- const int32x4_t offset = reinterpret_cast<int32x4_t>(vdupq_n_u32(_output_quant.offset));
- final_accs[i] = reinterpret_cast<uint32x4_t>(vmaxq_s32(vaddq_s32(y, offset), vdupq_n_s32(0)));
-#else // floating point requantisation
- float32x4_t fp_acc = vcvtq_f32_s32(reinterpret_cast<int32x4_t>(accs[i]));
- fp_acc = vmulq_f32(fp_acc, vdupq_n_f32(rescale_parameters.rescale));
- fp_acc = vaddq_f32(fp_acc, vdupq_n_f32(static_cast<float>(_output_quant.offset)));
- fp_acc = vmaxq_f32(fp_acc, vdupq_n_f32(0.0f));
- final_accs[i] = vcvtq_u32_f32(fp_acc);
-#endif
- }
-
- uint8x16_t output = vcombine_u8(
- vqmovn_u16(vcombine_u16(vqmovn_u32(final_accs[0]), vqmovn_u32(final_accs[1]))),
- vqmovn_u16(vcombine_u16(vqmovn_u32(final_accs[2]), vqmovn_u32(final_accs[3])))
- );
-
- // Apply the activation function
- if (Activation == ActivationFunction::ReLU ||
- Activation == ActivationFunction::ReLU6)
- {
- output = vmaxq_u8(output, vdupq_n_u8(aqmin));
- }
- if (Activation == ActivationFunction::ReLU6)
- {
- output = vminq_u8(output, vdupq_n_u8(aqmax));
- }
-
- vst1q_u8(outptrs[oi][oj] + n, output);
- }
- }
+ const nck::ActivationFunction actfn,
+ FInput &get_input_ptr,
+ FOutput &get_output_ptr,
+ const QAsymm8Params &input_quant,
+ const QAsymm8Params &weight_quant,
+ const QAsymm8Params &output_quant,
+ const QAsymm8RescaleParams &requant
+) {
+ // Compute min/max clamp values
+ int32_t clamp_min = std::numeric_limits<uint8_t>::min();
+ int32_t clamp_max = std::numeric_limits<uint8_t>::max();
+
+ if (actfn == nck::ActivationFunction::ReLU ||
+ actfn == nck::ActivationFunction::ReLU6) {
+ const int32_t bottom_rail = output_quant.offset;
+ clamp_min = std::max(clamp_min, bottom_rail);
}
-#endif // defined(__aarch64__)
- for (; n_channels >= 8; n_channels -= 8, n += 8)
- {
- const int32x4_t biases[2] = {
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr)),
- vld1q_s32(reinterpret_cast<const int32_t *>(wbptr) + 4),
- };
- wbptr += 8*sizeof(int32_t);
- uint8x8_t weights[KernelRows][KernelCols];
- for (unsigned int i = 0; i < KernelRows; i++)
- {
- for (unsigned int j = 0; j < KernelCols; j++)
- {
- weights[i][j] = vld1_u8(wbptr);
- wbptr += 8;
- }
- }
-
- uint8x8_t inputs[Base::inner_tile_rows][Base::inner_tile_cols];
- for (unsigned int i = 0; i < Base::inner_tile_rows; i++)
- {
- for (unsigned int j = 0; j < Base::inner_tile_cols; j++)
- {
- inputs[i][j] = vld1_u8(inptrs[i][j] + n);
- }
- }
-
- for (unsigned int oi = 0; oi < OutputTileRows; oi++)
- {
- for (unsigned int oj = 0; oj < OutputTileCols; oj++)
- {
- uint32x4_t accs[2];
- for (unsigned int i = 0; i < 2; i++)
- {
- accs[i] = reinterpret_cast<uint32x4_t>(biases[i]);
- }
-
- for (unsigned int wi = 0; wi < KernelRows; wi++)
- {
- for (unsigned int wj = 0; wj < KernelCols; wj++)
- {
- const uint8x8_t w = weights[wi][wj];
- const uint8x8_t x = inputs[oi*StrideRows + wi][oj*StrideCols + wj];
-
- const uint16x8_t muls = vmull_u8(w, x);
- const uint8x8_t woffset = vdup_n_u8(_weights_quant.offset);
- const uint16x8_t sum_elems = vmull_u8(x, woffset);
-
- const uint32x4_t tmps[2] = {
- vsubl_u16(vget_low_u16(muls), vget_low_u16(sum_elems)),
- vsubl_u16(vget_high_u16(muls), vget_high_u16(sum_elems)),
- };
- for (unsigned int i = 0; i < 2; i++)
- {
- accs[i] = vaddq_u32(accs[i], tmps[i]);
- }
- }
- }
-
- uint32x4_t final_accs[2];
- for (unsigned int i = 0; i < 2; i++)
- {
-#ifdef FIXED_POINT_REQUANTISATION
- const int32x4_t y = rounding_divide_by_exp2(
- saturating_doubling_high_mul(
- reinterpret_cast<int32x4_t>(accs[i]), rescale_parameters.multiplier
- ),
- rescale_parameters.shift
- );
- const int32x4_t offset = reinterpret_cast<int32x4_t>(vdupq_n_u32(_output_quant.offset));
- final_accs[i] = reinterpret_cast<uint32x4_t>(vmaxq_s32(vaddq_s32(y, offset), vdupq_n_s32(0)));
-#else // floating point requantisation
- float32x4_t fp_acc = vcvtq_f32_s32(reinterpret_cast<int32x4_t>(accs[i]));
- fp_acc = vmulq_f32(fp_acc, vdupq_n_f32(rescale_parameters.rescale));
- fp_acc = vaddq_f32(fp_acc, vdupq_n_f32(static_cast<float>(_output_quant.offset)));
- fp_acc = vmaxq_f32(fp_acc, vdupq_n_f32(0.0f));
- final_accs[i] = vcvtq_u32_f32(fp_acc);
-#endif
- }
-
- uint8x8_t output = vqmovn_u16(vcombine_u16(vqmovn_u32(final_accs[0]), vqmovn_u32(final_accs[1])));
-
- // Apply the activation function
- if (Activation == ActivationFunction::ReLU ||
- Activation == ActivationFunction::ReLU6)
- {
- output = vmax_u8(output, vdup_n_u8(aqmin));
- }
- if (Activation == ActivationFunction::ReLU6)
- {
- output = vmin_u8(output, vdup_n_u8(aqmax));
- }
-
- vst1_u8(outptrs[oi][oj] + n, output);
- }
- }
+ if (actfn == nck::ActivationFunction::ReLU6) {
+ const int32_t top_rail = output_quant.quantize(6.0f);
+ clamp_max = std::min(clamp_max, top_rail);
}
- for (; n_channels; n_channels--, n++)
- {
- // Load bias
- const int32_t bias = *reinterpret_cast<const int32_t *>(wbptr);
- wbptr += sizeof(int32_t);
-
- // Load weights
- uint8_t weights[KernelRows][KernelCols];
- for (unsigned int i = 0; i < KernelRows; i++)
- {
- for (unsigned int j = 0; j < KernelCols; j++)
- {
- weights[i][j] = *(wbptr++);
- }
- }
- // Load the input activations
- uint8_t inputs[Base::inner_tile_rows][Base::inner_tile_cols];
- for (unsigned int i = 0; i < Base::inner_tile_rows; i++)
- {
- for (unsigned int j = 0; j < Base::inner_tile_cols; j++)
- {
- inputs[i][j] = *(inptrs[i][j] + n);
- }
- }
-
- // Perform the convolution
- for (unsigned int oi = 0; oi < OutputTileRows; oi++)
- {
- for (unsigned int oj = 0; oj < OutputTileCols; oj++)
- {
- int32_t acc = bias;
- uint32_t element_sum = 0;
-
- for (unsigned int wi = 0; wi < KernelRows; wi++)
- {
- for (unsigned int wj = 0; wj < KernelCols; wj++)
- {
- const auto w = weights[wi][wj], x = inputs[oi*StrideRows + wi][oj*StrideCols + wj];
- acc += static_cast<int32_t>(static_cast<uint32_t>(w) * static_cast<uint32_t>(x));
- element_sum += static_cast<uint32_t>(x);
- }
- }
-
- acc -= static_cast<int32_t>(element_sum) * static_cast<int32_t>(_weights_quant.offset);
+ // Call the tile execution method
+ tilefn<OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows,
+ StrideCols>(n_channels, packed_params, get_input_ptr, get_output_ptr,
+ clamp_max, clamp_min, input_quant.offset,
+ weight_quant.offset, output_quant.offset,
+ requant.multiplier, requant.shift);
+}
- // Requantize
-#ifdef FIXED_POINT_REQUANTISATION
- acc = rounding_divide_by_exp2(
- saturating_doubling_high_mul(acc, rescale_parameters.multiplier),
- rescale_parameters.shift
- );
- acc += _output_quant.offset;
- uint8_t output = clamp_to_limits<uint8_t>::clamp_and_cast<int32_t>(acc);
-#else // floating point requantization
- float fp_acc = static_cast<float>(acc);
- fp_acc *= rescale_parameters.rescale;
- fp_acc += static_cast<float>(_output_quant.offset);
- fp_acc = std::max<float>(fp_acc, 0.0f);
- uint8_t output = static_cast<uint8_t>(std::min<int32_t>(static_cast<int32_t>(fp_acc), 255));
-#endif
-
- // Apply the activation function
- if (Activation == ActivationFunction::ReLU ||
- Activation == ActivationFunction::ReLU6)
- {
- output = std::max(output, aqmin);
- }
- if (Activation == ActivationFunction::ReLU6)
- {
- output = std::min(output, aqmax);
- }
+template <
+ unsigned int OutputTileRows, unsigned int OutputTileCols,
+ unsigned int KernelRows, unsigned int KernelCols,
+ unsigned int StrideRows, unsigned int StrideCols
+>
+template <nck::ActivationFunction Activation>
+void QAsymm8DepthwiseConvolution<
+ OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols
+>::execute_tile(
+ int n_channels,
+ const void* packed_params,
+ const uint8_t* inptr,
+ unsigned int in_row_stride,
+ unsigned int in_col_stride,
+ uint8_t* outptr,
+ unsigned int out_row_stride,
+ unsigned int out_col_stride
+) {
+ // Construct methods to get pointers
+ const auto get_input_ptr = [inptr, in_row_stride, in_col_stride](
+ const int i, const int j, const int channel) {
+ return inptr + i * in_row_stride + j * in_col_stride + channel;
+ };
+
+ const auto get_output_ptr = [outptr, out_row_stride, out_col_stride](
+ const int i, const int j, const int channel) {
+ return outptr + i * out_row_stride + j * out_col_stride + channel;
+ };
+
+ execute_tilefn<OutputTileRows, OutputTileCols, KernelRows, KernelCols,
+ StrideRows, StrideCols>(
+ n_channels, packed_params, Activation, get_input_ptr, get_output_ptr,
+ _inputs_quant, _weights_quant, _output_quant, rescale_parameters);
+}
- *(outptrs[oi][oj] + n) = output;
- }
- }
- }
+template <
+ unsigned int OutputTileRows, unsigned int OutputTileCols,
+ unsigned int KernelRows, unsigned int KernelCols,
+ unsigned int StrideRows, unsigned int StrideCols
+>
+template <nck::ActivationFunction Activation>
+void QAsymm8DepthwiseConvolution<
+ OutputTileRows, OutputTileCols, KernelRows, KernelCols, StrideRows, StrideCols
+>::execute_tile(
+ int n_channels,
+ const void* packed_params,
+ const uint8_t* inptrs[Base::inner_tile_rows][Base::inner_tile_cols],
+ uint8_t* outptrs[Base::output_tile_rows][Base::output_tile_cols]
+) {
+ // Construct methods to get pointers
+ const auto get_input_ptr = [inptrs](const int i, const int j,
+ const int channel) {
+ return inptrs[i][j] + channel;
+ };
+
+ const auto get_output_ptr = [outptrs](const int i, const int j,
+ const int channel) {
+ return outptrs[i][j] + channel;
+ };
+
+ // Call the tile execution method
+ execute_tilefn<OutputTileRows, OutputTileCols, KernelRows, KernelCols,
+ StrideRows, StrideCols>(
+ n_channels, packed_params, Activation, get_input_ptr, get_output_ptr,
+ _inputs_quant, _weights_quant, _output_quant, rescale_parameters);
}
} // namespace depthwise