aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
diff options
context:
space:
mode:
authorMichael Tyler <michael.tyler@arm.com>2024-06-04 15:47:37 +0100
committerMichael Tyler <michael.tyler@arm.com>2024-06-25 09:10:13 +0000
commitfc94f4d23abd4bc427b701f54ad85282e9ec7872 (patch)
tree5e2980599256e2b2f4374e5beb61596fc95c9d5a /src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
parentc2237ec4094c7824f8f7e61bc89504d01c5b59ff (diff)
downloadComputeLibrary-fc94f4d23abd4bc427b701f54ad85282e9ec7872.tar.gz
Update CPU kernels and add mixed sign GEMM support
- Add support for mixed sign quantized convolution. - Add support for mixed sign dequantized GEMM. - Add SME FP16 GEMV kernel. - Change SME vector length function to use RDSVL instead of static variable. - Add GEMM dilation support internally (not exposed yet). - Remove unused "get_default_activation_values" functions. - Add SVE fixed format interleaved BF16 DOT kernel. - Updates and optimizations to assembly kernels. Resolves COMPMID-6926 Change-Id: I227f502502611d4cc4111c89e30c53ce94079544 Signed-off-by: Michael Tyler <michael.tyler@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11570 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp12
1 files changed, 6 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
index 0cc4d4f3d9..8bbb877c1b 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid_indirect.hpp
@@ -260,8 +260,8 @@ struct kernel_weight_format<strategy, false> {
} // anonymous namespace
// Implementation of the GemmCommon abstract class.
-template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing, bool SeparateQuantize=false, bool FixedFormat=false>
-class GemmHybridIndirect : public GemmCommon<To, Tr> {
+template<typename strategy, typename To, typename Tw, typename Tr, typename OutputStage=Nothing, bool SeparateQuantize=false, bool FixedFormat=false>
+class GemmHybridIndirect : public GemmCommon<To, Tw, Tr> {
typedef typename strategy::lhs_operand_type Tloi;
typedef typename strategy::rhs_operand_type Troi;
typedef typename strategy::result_type Tri;
@@ -618,7 +618,7 @@ public:
return _args._nmulti * iceildiv(_args._Nsize, strategy::out_width());
}
- void requantize_bias(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override {
+ void requantize_bias(void *in_buffer, const Tw *B, const int ldb, const int B_multi_stride) override {
if (std::is_same<OutputStage, Requantize32>::value) {
_col_bias = reinterpret_cast<int32_t *>(in_buffer);
@@ -636,11 +636,11 @@ public:
return strat.transforms.PrepareB_supports_transpose();
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed) override {
+ void pretranspose_B_array(void *in_buffer, const Tw *B, const int ldb, const int B_multi_stride, bool transposed) override {
pretranspose_B_array_part(in_buffer, B, ldb, B_multi_stride, transposed, 0, get_B_pretranspose_window_size());
}
- void pretranspose_B_array_part(void *in_buffer, const To *B, const int ldb, const int B_multi_stride, bool transposed, size_t start, size_t end) override {
+ void pretranspose_B_array_part(void *in_buffer, const Tw *B, const int ldb, const int B_multi_stride, bool transposed, size_t start, size_t end) override {
if (end >= get_B_pretranspose_window_size()) {
requantize_bias(in_buffer, B, ldb, B_multi_stride);
}
@@ -835,7 +835,7 @@ public:
};
template<typename strategy, typename To, typename Tr, typename OutputStage=Nothing>
-using GemmHybridIndirectFixedFormat = GemmHybridIndirect<strategy, To, Tr, OutputStage, false, true>;
+using GemmHybridIndirectFixedFormat = GemmHybridIndirect<strategy, To, To, Tr, OutputStage, false, true>;
} // namespace arm_gemm