aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMurray Kornelsen <murray.kornelsen@mail.mcgill.ca>2022-07-13 21:40:26 -0400
committerPablo Marquez Tello <pablo.tello@arm.com>2022-09-14 06:48:39 +0000
commit6e09e1404c635d948cf20eb6b4b5747dfb6656f2 (patch)
tree006199bd21b8a1330e1f1c86be60084bfb466706
parenta4814e8394ffdd7e268614d54cc22e30648f48ff (diff)
downloadComputeLibrary-6e09e1404c635d948cf20eb6b4b5747dfb6656f2.tar.gz
INT8 Quantized MeanStdDevNorm (LayerNorm)
Implements LayerNorm for qasymm8 tensors. Uses uint8x16 loads and stores. Summation is performed in integer arithmetic (vpaddl) Normalization is performed in float32 before requantizing back to int8. Signed-off-by: Murray Kornelsen <murray.kornelsen@mail.mcgill.ca> Change-Id: I2407c8b34717fb47adab98791bd76fb8a3c62f4a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7922 Comments-Addressed: Pablo Marquez Tello <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Reviewed-by: Pablo Marquez Tello <pablo.tello@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--Android.bp1
-rw-r--r--filelist.json3
-rw-r--r--scripts/clang-tidy.h7
-rw-r--r--src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp9
-rw-r--r--src/cpu/kernels/meanstddevnorm/generic/neon/qasymm8.cpp145
-rw-r--r--src/cpu/kernels/meanstddevnorm/list.h1
-rw-r--r--tests/validation/NEON/MeanStdDevNormalizationLayer.cpp19
-rw-r--r--tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h39
-rw-r--r--tests/validation/reference/MeanStdDevNormalizationLayer.cpp11
9 files changed, 213 insertions, 22 deletions
diff --git a/Android.bp b/Android.bp
index 6f6c66cc55..8c6d700062 100644
--- a/Android.bp
+++ b/Android.bp
@@ -520,6 +520,7 @@ cc_library_static {
"src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp",
"src/cpu/kernels/meanstddevnorm/generic/neon/fp32.cpp",
"src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp",
+ "src/cpu/kernels/meanstddevnorm/generic/neon/qasymm8.cpp",
"src/cpu/kernels/pool2d/neon/fp16.cpp",
"src/cpu/kernels/pool2d/neon/fp32.cpp",
"src/cpu/kernels/pool2d/neon/nchw/all.cpp",
diff --git a/filelist.json b/filelist.json
index c218ed9129..eb39915524 100644
--- a/filelist.json
+++ b/filelist.json
@@ -1738,7 +1738,8 @@
"neon":{
"common":["src/cpu/kernels/meanstddevnorm/generic/neon/impl.cpp"],
"fp32":["src/cpu/kernels/meanstddevnorm/generic/neon/fp32.cpp"],
- "fp16":["src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp"]
+ "fp16":["src/cpu/kernels/meanstddevnorm/generic/neon/fp16.cpp"],
+ "qasymm8":["src/cpu/kernels/meanstddevnorm/generic/neon/qasymm8.cpp"]
}
}
},
diff --git a/scripts/clang-tidy.h b/scripts/clang-tidy.h
index b3705122c6..24e4b15c6f 100644
--- a/scripts/clang-tidy.h
+++ b/scripts/clang-tidy.h
@@ -1,5 +1,12 @@
#include <arm_neon.h>
+#if __arm__
+inline uint32x4_t vpaddq_u32(uint32x4_t, uint32x4_t)
+{
+ return vdupq_n_u32(0);
+}
+#endif
+
inline float16x4_t vrsqrts_f16 (float16x4_t, float16x4_t)
{
return vdup_n_f16(0);
diff --git a/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp b/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp
index 7d8fc7ec7f..37e88a8565 100644
--- a/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp
+++ b/src/core/NEON/kernels/NEMeanStdDevNormalizationKernel.cpp
@@ -55,7 +55,7 @@ struct MeanStdDevNormKernel
MeanStdDevNormUKernelPtr ukernel;
};
-static const MeanStdDevNormKernel available_kernels[] =
+static const std::vector<MeanStdDevNormKernel> available_kernels =
{
{
"fp32_neon_meanstddevnorm",
@@ -69,6 +69,11 @@ static const MeanStdDevNormKernel available_kernels[] =
REGISTER_FP16_NEON(arm_compute::cpu::neon_fp16_meanstddevnorm)
},
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ {
+ "qasymm8_neon_meanstddevnorm",
+ [](const MeanStdDevNormSelectorData & data) { return data.dt == DataType::QASYMM8; },
+ REGISTER_QASYMM8_NEON(arm_compute::cpu::neon_qasymm8_meanstddevnorm)
+ },
};
/** Micro-kernel selector
@@ -95,7 +100,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, f
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->num_dimensions() > 2, "Input tensor cannot have more than 2 dimensions");
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32, DataType::QASYMM8);
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
diff --git a/src/cpu/kernels/meanstddevnorm/generic/neon/qasymm8.cpp b/src/cpu/kernels/meanstddevnorm/generic/neon/qasymm8.cpp
new file mode 100644
index 0000000000..53af1e4b16
--- /dev/null
+++ b/src/cpu/kernels/meanstddevnorm/generic/neon/qasymm8.cpp
@@ -0,0 +1,145 @@
+/*
+ * Copyright (c) 2022 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/Window.h"
+#include "src/core/NEON/NEAsymm.h"
+#include "src/core/NEON/NEMath.h"
+#include "src/core/NEON/wrapper/wrapper.h"
+
+#include <arm_neon.h>
+namespace
+{
+inline float32x4_t clamp_v4f32(float32x4_t block, float32x4_t quant_min_vec, float32x4_t quant_max_vec)
+{
+ return vminq_f32(vmaxq_f32(block, quant_min_vec), quant_max_vec);
+}
+inline uint16x8_t fuse_words_f32(float32x4_t fb1, float32x4_t fb2)
+{
+ return vcombine_u16(vmovn_u32(vcvtq_u32_f32(fb1)), vmovn_u32(vcvtq_u32_f32(fb2)));
+}
+inline uint8x16_t fuse_shorts_u16(uint16x8_t sb1, uint16x8_t sb2)
+{
+ return vcombine_u8(vmovn_u16(sb1), vmovn_u16(sb2));
+}
+} // namespace
+
+namespace arm_compute
+{
+namespace cpu
+{
+void neon_qasymm8_meanstddevnorm(ITensor *input, ITensor *output, float epsilon, const Window &window)
+{
+ Window win = window;
+ win.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ const int window_step_x = 16;
+ const int window_start_x = static_cast<int>(window.x().start());
+ const int window_end_x = static_cast<int>(window.x().end());
+
+ const UniformQuantizationInfo qi_out = output->info()->quantization_info().uniform();
+ const float output_scale = qi_out.scale;
+ const int output_offset = qi_out.offset;
+
+ Iterator input_itr(input, win);
+ Iterator output_itr(output, win);
+
+ const float output_inv_scale = 1.0f / output_scale;
+ const float32x4_t quant_max_vec = vdupq_n_f32(255.0f);
+ const float32x4_t quant_min_vec = vdupq_n_f32(0.0f);
+
+ execute_window_loop(
+ win, [&](const Coordinates &)
+ {
+ int x = window_start_x;
+ auto in_ptr = reinterpret_cast<const uint8_t *>(input_itr.ptr());
+ auto out_ptr = reinterpret_cast<uint8_t *>(output_itr.ptr());
+
+ uint32x4_t sum_vec = vdupq_n_u32(0);
+ uint32x4_t sum_sq_vec = vdupq_n_u32(0);
+
+ for(; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint8x16_t data = vld1q_u8(in_ptr + x);
+ sum_vec = vaddq_u32(sum_vec, vpaddlq_u16(vpaddlq_u8(data)));
+ const uint16x8_t squares_low = vmull_u8(vget_low_u8(data), vget_low_u8(data));
+ const uint16x8_t squares_high = vmull_u8(vget_high_u8(data), vget_high_u8(data));
+ sum_sq_vec = vaddq_u32(sum_sq_vec, vaddq_u32(vpaddlq_u16(squares_low), vpaddlq_u16(squares_high)));
+ }
+
+#ifdef __aarch64__
+ sum_vec = vpaddq_u32(sum_vec, sum_vec);
+ sum_vec = vpaddq_u32(sum_vec, sum_vec);
+ uint32_t sum = vgetq_lane_u32(sum_vec, 0);
+ sum_sq_vec = vpaddq_u32(sum_sq_vec, sum_sq_vec);
+ sum_sq_vec = vpaddq_u32(sum_sq_vec, sum_sq_vec);
+ uint32_t sum_sq = vgetq_lane_u32(sum_sq_vec, 0);
+#elif __arm__ // #ifdef __aarch64__
+ uint32_t sum = vgetq_lane_u32(sum_vec, 0) +
+ vgetq_lane_u32(sum_vec, 1) +
+ vgetq_lane_u32(sum_vec, 2) +
+ vgetq_lane_u32(sum_vec, 3);
+
+ uint32_t sum_sq = vgetq_lane_u32(sum_sq_vec, 0) +
+ vgetq_lane_u32(sum_sq_vec, 1) +
+ vgetq_lane_u32(sum_sq_vec, 2) +
+ vgetq_lane_u32(sum_sq_vec, 3);
+#endif // #ifdef __aarch64__
+ for(; x < window_end_x; ++x)
+ {
+ auto data = static_cast<uint32_t>(*(in_ptr + x));
+ sum += data;
+ sum_sq += (data * data);
+ }
+
+ const float mean = (static_cast<float>(sum) / static_cast<float>(input->info()->dimension(0)));
+ const float var = (static_cast<float>(sum_sq) / static_cast<float>(input->info()->dimension(0))) - (mean * mean);
+ const float stdev_inv = 1.0f / sqrtf(var + epsilon);
+ const float32x4_t v_scale = vdupq_n_f32(stdev_inv * output_inv_scale);
+ const float32x4_t v_offset = vdupq_n_f32(-mean * stdev_inv * output_inv_scale + output_offset);
+ for(x = window_start_x; x <= (window_end_x - window_step_x); x += window_step_x)
+ {
+ const uint8x16_t data = vld1q_u8(in_ptr + x);
+ float32x4_t db1 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(data)))));
+ float32x4_t db2 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(data)))));
+ float32x4_t db3 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(data)))));
+ float32x4_t db4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(data)))));
+ db1 = clamp_v4f32(vaddq_f32(vmulq_f32(db1, v_scale), v_offset), quant_min_vec, quant_max_vec);
+ db2 = clamp_v4f32(vaddq_f32(vmulq_f32(db2, v_scale), v_offset), quant_min_vec, quant_max_vec);
+ db3 = clamp_v4f32(vaddq_f32(vmulq_f32(db3, v_scale), v_offset), quant_min_vec, quant_max_vec);
+ db4 = clamp_v4f32(vaddq_f32(vmulq_f32(db4, v_scale), v_offset), quant_min_vec, quant_max_vec);
+ const uint8x16_t out = fuse_shorts_u16(fuse_words_f32(db1, db2), fuse_words_f32(db3, db4));
+ vst1q_u8(out_ptr + x, out);
+ }
+
+ for(; x < window_end_x; ++x)
+ {
+ auto data = static_cast<float32_t>(*(in_ptr + x));
+ const uint8_t res = data * (stdev_inv * output_inv_scale) + (-mean * stdev_inv * output_inv_scale + output_offset);
+ *(out_ptr + x) = res;
+ }
+ },
+ input_itr, output_itr);
+}
+} // namespace cpu
+} // namespace arm_compute
diff --git a/src/cpu/kernels/meanstddevnorm/list.h b/src/cpu/kernels/meanstddevnorm/list.h
index ac9cb37d23..6277d65884 100644
--- a/src/cpu/kernels/meanstddevnorm/list.h
+++ b/src/cpu/kernels/meanstddevnorm/list.h
@@ -32,6 +32,7 @@ namespace cpu
DECLARE_MEANSTDDEVNORM_KERNEL(neon_fp32_meanstddevnorm);
DECLARE_MEANSTDDEVNORM_KERNEL(neon_fp16_meanstddevnorm);
+DECLARE_MEANSTDDEVNORM_KERNEL(neon_qasymm8_meanstddevnorm);
#undef DECLARE_MEANSTDDEVNORM_KERNEL
} // namespace cpu
diff --git a/tests/validation/NEON/MeanStdDevNormalizationLayer.cpp b/tests/validation/NEON/MeanStdDevNormalizationLayer.cpp
index dee8f78da9..085f3608a0 100644
--- a/tests/validation/NEON/MeanStdDevNormalizationLayer.cpp
+++ b/tests/validation/NEON/MeanStdDevNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,7 +47,8 @@ namespace
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
RelativeTolerance<half> tolerance_f16(half(0.2f));
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-RelativeTolerance<float> tolerance_f32(1e-4f);
+RelativeTolerance<float> tolerance_f32(1e-4f);
+RelativeTolerance<uint8_t> tolerance_qasymm8(1);
} // namespace
TEST_SUITE(NEON)
@@ -114,9 +115,23 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEMeanStdDevNormalizationLayerFixture<float>, f
// Validate output
validate(Accessor(_target), _reference, tolerance_f32);
}
+
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEMeanStdDevNormalizationLayerFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::Small2DShapes(),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("InPlace", { false, true })),
+ framework::dataset::make("Epsilon", { 1e-7 })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // Quantized
+TEST_SUITE_END() // QASYMM8
+
TEST_SUITE_END() // MeanStdNormalizationLayer
TEST_SUITE_END() // Neon
} // namespace validation
diff --git a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h
index 9868cd1abf..f3c108e6da 100644
--- a/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h
+++ b/tests/validation/fixtures/MeanStdDevNormalizationLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2021 Arm Limited.
+ * Copyright (c) 2019-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -45,29 +45,35 @@ class MeanStdDevNormalizationLayerValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8f)
+ void setup(TensorShape shape, DataType dt, bool in_place, float epsilon = 1e-8)
{
- _data_type = dt;
- _target = compute_target(shape, dt, in_place, epsilon);
- _reference = compute_reference(shape, dt, epsilon);
+ QuantizationInfo qi = QuantizationInfo(0.5f, 10);
+ _data_type = dt;
+ _target = compute_target(shape, dt, in_place, epsilon, qi);
+ _reference = compute_reference(shape, dt, epsilon, qi);
}
protected:
template <typename U>
- void fill(U &&src_tensor)
+ void fill(U &&tensor)
{
- static_assert(std::is_floating_point<T>::value || std::is_same<T, half>::value, "Only floating point data types supported.");
- using DistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<T>>::type;
-
- DistributionType distribution{ T(-1.0f), T(1.0f) };
- library->fill(src_tensor, distribution, 0);
+ if(is_data_type_float(_data_type))
+ {
+ std::uniform_real_distribution<> distribution{ -1.0f, 1.0f };
+ library->fill(tensor, distribution, 0);
+ }
+ else
+ {
+ std::uniform_int_distribution<> distribution{ 0, 255 };
+ library->fill(tensor, distribution, 0);
+ }
}
- TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon)
+ TensorType compute_target(TensorShape shape, DataType dt, bool in_place, float epsilon, QuantizationInfo qi)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, dt, 1);
- TensorType dst;
+ TensorType src = create_tensor<TensorType>(shape, dt, 1, qi);
+ TensorType dst = create_tensor<TensorType>(shape, dt, 1, qi);
TensorType *dst_ptr = in_place ? &src : &dst;
@@ -104,10 +110,10 @@ protected:
}
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, DataType dt, float epsilon, QuantizationInfo qi)
{
// Create reference
- SimpleTensor<T> ref_src{ shape, dt, 1 };
+ SimpleTensor<T> ref_src{ shape, dt, 1, qi };
// Fill reference
fill(ref_src);
@@ -119,6 +125,7 @@ protected:
SimpleTensor<T> _reference{};
DataType _data_type{};
};
+
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/reference/MeanStdDevNormalizationLayer.cpp b/tests/validation/reference/MeanStdDevNormalizationLayer.cpp
index 0a23fa19bb..a7c8a784d9 100644
--- a/tests/validation/reference/MeanStdDevNormalizationLayer.cpp
+++ b/tests/validation/reference/MeanStdDevNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019, 2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -63,6 +63,15 @@ SimpleTensor<T> mean_std_normalization_layer(const SimpleTensor<T> &src, float e
return dst;
}
+template <>
+SimpleTensor<uint8_t> mean_std_normalization_layer(const SimpleTensor<uint8_t> &src, float epsilon)
+{
+ SimpleTensor<float> src_tmp = convert_from_asymmetric(src);
+ SimpleTensor<float> dst_tmp = mean_std_normalization_layer<float>(src_tmp, epsilon);
+ SimpleTensor<uint8_t> dst = convert_to_asymmetric<uint8_t>(dst_tmp, src.quantization_info());
+ return dst;
+}
+
template SimpleTensor<float> mean_std_normalization_layer(const SimpleTensor<float> &src, float epsilon);
template SimpleTensor<half> mean_std_normalization_layer(const SimpleTensor<half> &src, float epsilon);
} // namespace reference