From 50fbc6f0a1e781d4f9f83d1e4ea2588390facea2 Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Wed, 15 Nov 2017 18:12:13 +0000 Subject: COMPMID-675: Updated UDOT product kernels Change-Id: I565397b58b2297fc7fd3c8a2a873c2cb762ceb5c Reviewed-on: http://mpd-gerrit.cambridge.arm.com/95940 Tested-by: Kaizen Reviewed-by: Gian Marco Iodice --- .../kernels/assembly/kernels/a64_gemm_u8_12x8.hpp | 4 + .../assembly/kernels/a64_gemm_u8_12x8/a55r1.hpp | 395 +++++++++++++++++++++ .../assembly/kernels/a64_gemm_u8_12x8/generic.hpp | 41 ++- 3 files changed, 419 insertions(+), 21 deletions(-) create mode 100644 arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8/a55r1.hpp (limited to 'arm_compute/core/NEON/kernels/assembly') diff --git a/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp b/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp index ebd1512f23..62cd747d7c 100644 --- a/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp +++ b/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8.hpp @@ -27,6 +27,7 @@ // Load the actual kernel #include "a64_gemm_u8_12x8/generic.hpp" +#include "a64_gemm_u8_12x8/a55r1.hpp" class gemm_u8_12x8 { public: @@ -54,6 +55,9 @@ public: gemm_u8_12x8(const CPUInfo *ci) { kernel = a64_gemm_u8_12x8; + if (ci->CPU == CPUTarget::A55_DOT) { + kernel = a64_gemm_u8_12x8_a55r1; + } } }; diff --git a/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8/a55r1.hpp b/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8/a55r1.hpp new file mode 100644 index 0000000000..3ede256f40 --- /dev/null +++ b/arm_compute/core/NEON/kernels/assembly/kernels/a64_gemm_u8_12x8/a55r1.hpp @@ -0,0 +1,395 @@ +/* + * Copyright (c) 2017 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#ifdef __aarch64__ + +#include +#include "dot_toolchain_support.h" +#include + +inline void a64_gemm_u8_12x8_a55r1(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) { + assert(Apanel); + assert(Bpanel); + assert(Cpanel); + const uint8_t *a_ptr = Apanel; + uint32_t *c_ptr = Cpanel; + // We divide K by 4 because the udot instruction processes 4 elements at a time. + const int W = K/4; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + const int oddk = (W & 1); + const int init_value_k = ((W+1)/2) - 1; + for (int yb=0; yb - inline void a64_gemm_u8_12x8(const uint8_t *Apanel, const uint8_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K) { assert(Apanel); assert(Bpanel); assert(Cpanel); - K/=4; - const uint32_t *a_ptr = reinterpret_cast(Apanel); - uint32_t *c_ptr = reinterpret_cast(Cpanel); + const uint8_t *a_ptr = Apanel; + uint32_t *c_ptr = Cpanel; + // We divide K by 4 because the udot instruction processes 4 elements at a time. + const int W = K/4; + // Fix up for odd lengths - set a flag if K is odd, but make + // sure we round up the iteration count. + const int oddk = (W & 1); + const int init_value_k = ((W+1)/2) - 1; for (int yb=0; yb(Bpanel); + const uint8_t *a_ptr0 = a_ptr; + const uint8_t *b_ptr = Bpanel; for (int xb=0; xb