diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm')
13 files changed, 781 insertions, 102 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp index aa206e3f42..ddb438f06c 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp @@ -120,13 +120,13 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = [](const GemmArgs &args) { return (args._Nsize < 12); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_4x8, float, float>(args); } }, -{ +GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_HYBRID, "hybrid_fp32_mla_16x4", [](const GemmArgs &args) { return (args._Ksize >= 4); }, - [](const GemmArgs &args) { return ((args._Ksize <= 256) && (args._Nsize <= 256)) || (args._Msize < 16) || (args._nmulti > 1); }, + [](const GemmArgs &args) { return GemmHybrid<hybrid_fp32_mla_16x4, float, float>::estimate_cycles(args, hybrid_fp32_mla_16x4::get_performance_parameters(args._ci)); }, [](const GemmArgs &args) { return new GemmHybrid<hybrid_fp32_mla_16x4, float, float>(args); } -}, +), #ifdef __ARM_FEATURE_SVE { @@ -138,21 +138,21 @@ static const GemmImplementation<float, float> gemm_fp32_methods[] = }, #endif // __ARM_FEATURE_SVE // Pretranposed, 2D split -{ +GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED_2D, "sgemm_12x8_2d", nullptr, - [](const GemmArgs &args) { return args._maxthreads >= 8; }, + [](const GemmArgs &args) { return GemmInterleavedPretransposed2d<sgemm_12x8, float, float>::estimate_cycles(args, sgemm_12x8::get_performance_parameters(args._ci)); }, [](const GemmArgs &args) { return new GemmInterleavedPretransposed2d<sgemm_12x8, float, float>(args); } -}, +), // 1D split (with pretransposed or not) -{ +GemmImplementation<float, float>::with_estimate( GemmMethod::GEMM_INTERLEAVED, "sgemm_12x8_1d", nullptr, - nullptr, + [](const GemmArgs &args) { return GemmInterleaved<sgemm_12x8, float, float>::estimate_cycles(args, sgemm_12x8::get_performance_parameters(args._ci)); }, [](const GemmArgs &args) { return new GemmInterleaved<sgemm_12x8, float, float>(args); } -}, +), #endif // __aarch64__ #ifdef __arm__ diff --git a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp index 353d681fe2..7a983ed6ac 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_hybrid.hpp @@ -23,17 +23,15 @@ */ #pragma once -#include <assert.h> - #include <algorithm> +#include <cassert> #include "arm_gemm.hpp" #include "bias_adder.hpp" #include "ndrange.hpp" -#include "utils.hpp" - -#include "mergeresults.hpp" +#include "performance_parameters.hpp" #include "transform.hpp" +#include "utils.hpp" #ifdef CYCLE_PROFILING #include "profiler.hpp" @@ -252,6 +250,28 @@ public: void set_pretransposed_B_data(void *in_buffer) override { _B_transposed = reinterpret_cast<Toi *>(in_buffer); } + + // Estimate cycles for given problem given provided parameters + static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms) { + // Note: Current hybrid kernels don't actually round up height (they + // have paths for each possible height). Might need to make this + // configurable in future. + uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * args._Msize * roundup(args._Nsize, strategy::out_width()) * roundup(args._Ksize, strategy::k_unroll()); + + float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle; + + // TODO: A bit of a kludge here: current hybrid kernels incur extra + // overhead where the width is not a multiple of kernel width. It's + // most noticable where the overall width is quite low, so add 15% + // penalty for such widths. + if ((args._Nsize < strategy::out_width()) || (args._Nsize > strategy::out_width() && args._Nsize < 2*strategy::out_width())) { + mac_cycles *= 1.15f; + } + + uint64_t total_cycles = mac_cycles; + + return total_cycles; + } }; } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp index c726d7b0aa..261e7d2d9c 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_implementation.hpp @@ -24,6 +24,7 @@ #include "arm_gemm.hpp" +#include <cstdint> #include <functional> namespace arm_gemm { @@ -37,7 +38,7 @@ struct GemmImplementation { const GemmMethod method; const char * name; std::function<bool(const GemmArgs &, const OutputStage &)> is_supported; - std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended; + std::function<uint64_t(const GemmArgs &, const OutputStage &)> cycle_estimate; std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate; bool do_is_supported(const GemmArgs &args, const OutputStage &os) const { @@ -48,17 +49,27 @@ struct GemmImplementation { } } - bool do_is_recommended(const GemmArgs &args, const OutputStage &os) const { - if (is_recommended != nullptr) { - return is_recommended(args, os); + uint64_t do_cycle_estimate(const GemmArgs &args, const OutputStage &os) const { + if (cycle_estimate != nullptr) { + return cycle_estimate(args, os); } else { - return true; + return 0; } } + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation &operator= (const GemmImplementation &) = default; + GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const OutputStage &os) const { return instantiate(args, os); } + + GemmImplementation(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &, const OutputStage &)> is_supported, std::function<bool(const GemmArgs &, const OutputStage &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &, const OutputStage &)> instantiate) : + method(m), name(n), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args, const OutputStage &os) { return (is_recommended == nullptr) ? 0 : (is_recommended(args, os) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } }; /* Slightly different version of above for straightforward GEMMs with no @@ -69,7 +80,7 @@ struct GemmImplementation<Top, Tret, Nothing> { const GemmMethod method; const char * name; std::function<bool(const GemmArgs &)> is_supported; - std::function<bool(const GemmArgs &)> is_recommended; + std::function<uint64_t(const GemmArgs &)> cycle_estimate; std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate; bool do_is_supported(const GemmArgs &args, const Nothing &) const { @@ -80,17 +91,42 @@ struct GemmImplementation<Top, Tret, Nothing> { } } - bool do_is_recommended(const GemmArgs &args, const Nothing &) const { - if (is_recommended != nullptr) { - return is_recommended(args); + uint64_t do_cycle_estimate(const GemmArgs &args, const Nothing &) const { + if (cycle_estimate != nullptr) { + return cycle_estimate(args); } else { - return true; + return 0; } } GemmCommon<Top, Tret> *do_instantiate(const GemmArgs &args, const Nothing &) const { return instantiate(args); } + + + static GemmImplementation with_estimate(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &)> is_supported, std::function<uint64_t(const GemmArgs &)> cycle_estimate, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) { + GemmImplementation impl(m,n); + + impl.is_supported=is_supported; + impl.cycle_estimate=cycle_estimate; + impl.instantiate=instantiate; + + return impl; + } + + GemmImplementation(GemmMethod m, const char * n) : method(m), name(n), is_supported(nullptr), cycle_estimate(nullptr), instantiate(nullptr) {} + + GemmImplementation(GemmMethod m, const char *n, + std::function<bool(const GemmArgs &)> is_supported, std::function<bool(const GemmArgs &)> is_recommended, + std::function<GemmCommon<Top, Tret> *(const GemmArgs &)> instantiate) : + method(m), name(n), is_supported(is_supported), + cycle_estimate( [is_recommended](const GemmArgs &args) -> uint64_t { return (is_recommended == nullptr) ? 0 : (is_recommended(args) ? 0 : UINT64_MAX); } ), + instantiate(instantiate) { } + + GemmImplementation(const GemmImplementation &) = default; + GemmImplementation &operator=(const GemmImplementation &) = default; }; /* "Master" function implemented for each valid combination of types. @@ -103,13 +139,11 @@ const GemmImplementation<Top, Tret, OutputStage> *gemm_implementation_list(); /* * Select a GEMM implementation for the given arguments. * - * The logic here returns the first method on the list which supports the + * The logic here returns the method on the list which supports the * requested problem parameters, matches the provided filters (method and/or - * name string match) and recommends itself. - * - * If there is no such method, it will return the first method which - * supports the requested parameters and passes the filters, regardless of - * recommendation. + * name string match) and offers the lowest cycle estimate. A cycle + * estimate of '0' is treated as a special value, causing the corresponding + * method to be selected immediately. * * If no method supports the requested parameters and passes the filters, * this function returns false and doesn't touch the provided pointer @@ -121,6 +155,7 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm const GemmConfig *cfg = args._cfg; const GemmImplementation<Top, Tret, OutputStage> *saved_impl = nullptr; + uint64_t best_estimate = 0; for (const GemmImplementation<Top, Tret, OutputStage> *i = gemms; i->method != GemmMethod::DEFAULT; i++) { /* Skip if this implementation doesn't support these args. */ @@ -138,27 +173,24 @@ bool find_implementation(const GemmArgs &args, const OutputStage &os, const Gemm continue; } - /* At this point, if we don't have a saved implementation, save this - * one. This is so that we always return something if a filter - * matches, even if it doesn't recommend itself. - */ - if (saved_impl == nullptr) { - saved_impl=i; - } + /* Test the cycle estimate */ + uint64_t estimate = i->do_cycle_estimate(args, os); - /* Check that this method recommends itself. */ - if (!i->do_is_recommended(args, os)) { - continue; + /* Short circuit - if the estimate is zero, return this one immediately. */ + if (estimate==0) { + impl=i; + return true; } - impl=i; - - return true; + /* Otherwise, remember this is our best so far if we don't yet have + * a valid candidate, or we beat the estimate. */ + if ((saved_impl == nullptr) || (estimate < best_estimate)) { + saved_impl = i; + best_estimate = estimate; + } } - /* We didn't find an option matching the filters that recommended - * itself. But if we found something earlier that matched the filters - * but wasn't recommended, return it here. */ + /* Return whichever method gave the best estimate. */ if (saved_impl != nullptr) { impl = saved_impl; return true; @@ -183,7 +215,7 @@ std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, cons continue; } - res.push_back(KernelDescription(i->method, i->name, i==default_impl)); + res.push_back(KernelDescription(i->method, i->name, i==default_impl, i->do_cycle_estimate(args, os))); } return res; diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp index 3ee47492db..bddcc8dab1 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp @@ -122,6 +122,13 @@ static const GemmImplementation<int8_t, int32_t> gemm_s8_methods[] = { [](const GemmArgs &args) { return new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(args); } }, { + GemmMethod::GEMM_INTERLEAVED, + "gemm_s16_12x8", + nullptr, + [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53; }, + [](const GemmArgs &args) { return new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(args); }, +}, +{ GemmMethod::GEMM_INTERLEAVED_2D, "gemm_s8_4x4_2d", nullptr, diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp index 3b829491ca..c4dceef922 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp @@ -23,15 +23,14 @@ */ #pragma once -#include <stdio.h> -#include <assert.h> - #include <algorithm> +#include <cassert> #include "arm_gemm.hpp" #include "utils.hpp" #include "mergeresults.hpp" +#include "performance_parameters.hpp" #include "transform.hpp" #ifdef CYCLE_PROFILING @@ -149,6 +148,33 @@ class GemmInterleaved : public GemmCommon<To, Tr> { return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height()); } + static unsigned int get_k_block_size(const GemmArgs &args) { + if (args._cfg && args._cfg->inner_block_size) { + return args._cfg->inner_block_size; + } + + const unsigned int L1_size = args._ci->get_L1_cache_size(); + unsigned int k_block; + + // k_block: Find out how much of the larger array can be loaded into half the cache. + // This should account for associative caches. + k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); + + // Needs to be (at least a single) multiple of the K unroll level. + k_block /= strategy::k_unroll(); + k_block = std::max(k_block, 1U) * strategy::k_unroll(); + + // Now tune to presented problem size; this is how many blocks we need. + unsigned int num_k_blocks = iceildiv(args._Ksize, k_block); + + // So divide the space equally into that many blocks. + k_block = iceildiv(args._Ksize, num_k_blocks); + + // And round UP to the K unroll level required. + k_block = roundup(k_block, strategy::k_unroll()); + + return k_block; + } public: GemmInterleaved(GemmInterleaved &) = delete; @@ -158,35 +184,14 @@ public: GemmInterleaved(const GemmArgs &args) : _ci(args._ci), _Msize(args._Msize), _Nsize(args._Nsize), _Ksize(args._Ksize), _nbatches(args._nbatches), _nmulti(args._nmulti), - _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads) { - const unsigned int L1_size = _ci->get_L1_cache_size(); + _act(args._act), _maxthreads(args._maxthreads), _nthreads(args._maxthreads), + _k_block(get_k_block_size(args)) { const unsigned int L2_size = _ci->get_L2_cache_size(); assert(_maxthreads > 0); // Work out blocking parameters, or override from provided GemmConfig - if (args._cfg && args._cfg->inner_block_size) { - _k_block = args._cfg->inner_block_size; - } else { - // k_block: Find out how much of the larger array can be loaded into half the cache. - // This should account for associative caches. - _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); - - // Needs to be (at least a single) multiple of the K unroll level. - _k_block /= strategy::k_unroll(); - _k_block = std::max(_k_block, 1U) * strategy::k_unroll(); - - // Now tune to presented problem size; this is how many blocks we need. - unsigned int num_k_blocks = iceildiv(_Ksize, _k_block); - - // So divide the space equally into that many blocks. - _k_block = iceildiv(_Ksize, num_k_blocks); - - // And round UP to the K unroll level required. - _k_block = iceildiv(_k_block, strategy::k_unroll()); - _k_block *= strategy::k_unroll(); - } - + // TODO: Move outer block into a static function too. if (args._cfg && args._cfg->outer_block_size) { _x_block = args._cfg->outer_block_size; } else { @@ -422,6 +427,31 @@ public: void set_pretransposed_B_data(void *in_buffer) override { _B_transposed = reinterpret_cast<Toi *>(in_buffer); } + + // Estimate cycles for given problem given provided parameters + static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms) { + unsigned int k_blocks = iceildiv(args._Ksize, get_k_block_size(args)); + + uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * roundup(args._Ksize, strategy::k_unroll()); + uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Ksize, strategy::k_unroll()) * sizeof(Toi); + uint64_t merge_bytes = static_cast<uint16_t>(args._nbatches) * args._nmulti * k_blocks * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr); + + float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle; + float prepare_cycles = static_cast<float>(prepare_bytes) / params.prepare_bytes_cycle; + float merge_cycles = static_cast<float>(merge_bytes) / params.merge_bytes_cycle; + + float total_cycles = mac_cycles + prepare_cycles + merge_cycles; + + // We can't thread over multis or width, which makes this a poor + // choice in many threaded cases. Penalize that here. + float parallelism_available = static_cast<float>(iceildiv(args._Msize, strategy::out_height()) * args._nbatches) * 0.9f; + + if (parallelism_available < args._maxthreads) { + total_cycles *= (static_cast<float>(args._maxthreads) / parallelism_available); + } + + return static_cast<uint64_t>(total_cycles); + } }; } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved_pretransposed_2d.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved_pretransposed_2d.hpp index ebe33ab271..bdccd05326 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved_pretransposed_2d.hpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved_pretransposed_2d.hpp @@ -35,6 +35,7 @@ #include <algorithm> #include <cassert> +#include <cmath> // Some macros used to decide how much working space to allocate. // Round allocations up to the next cache line. @@ -301,6 +302,36 @@ class GemmInterleavedPretransposed2d : public GemmCommon<To, Tr> { } } + static unsigned int get_k_block_size(const GemmArgs &args) { + // Work out blocking parameters, or override from provided GemmConfig + if (args._cfg && args._cfg->inner_block_size) { + return args._cfg->inner_block_size; + } + + const unsigned int L1_size = args._ci->get_L1_cache_size(); + unsigned int k_block; + + // k_block: Find out how much of the larger array can be loaded into half the cache. + // This should account for associative caches. + k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); + + // Needs to be (at least a single) multiple of the K unroll level. + k_block /= strategy::k_unroll(); + k_block = std::max(k_block, 1U) * strategy::k_unroll(); + + // Now tune to presented problem size; this is how many blocks we need. + unsigned int numk_blocks = iceildiv(args._Ksize, k_block); + + // So divide the space equally into that many blocks. + k_block = iceildiv(args._Ksize, numk_blocks); + + // And round UP to the K unroll level required. + k_block = iceildiv(k_block, strategy::k_unroll()); + k_block *= strategy::k_unroll(); + + return k_block; + } + public: GemmInterleavedPretransposed2d(GemmInterleavedPretransposed2d &) = delete; GemmInterleavedPretransposed2d & operator= (GemmInterleavedPretransposed2d &) = delete; @@ -315,8 +346,8 @@ public: , _nmulti(args._nmulti) , _act(args._act) , _maxthreads(args._maxthreads) - , _nthreads(args._maxthreads) - + , _nthreads(args._maxthreads) + , _k_block(get_k_block_size(args)) // Work out the rounded size of M - needed for some buffers. , _Mround_div ( iceildiv(_Msize, strategy::out_height()) ) , _Mround ( _Mround_div * strategy::out_height() ) @@ -326,32 +357,8 @@ public: { assert(_maxthreads > 0); - const unsigned int L1_size = _ci->get_L1_cache_size(); const unsigned int L2_size = _ci->get_L2_cache_size(); - // Work out blocking parameters, or override from provided GemmConfig - if (args._cfg && args._cfg->inner_block_size) { - _k_block = args._cfg->inner_block_size; - } else { - // k_block: Find out how much of the larger array can be loaded into half the cache. - // This should account for associative caches. - _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height()))); - - // Needs to be (at least a single) multiple of the K unroll level. - _k_block /= strategy::k_unroll(); - _k_block = std::max(_k_block, 1U) * strategy::k_unroll(); - - // Now tune to presented problem size; this is how many blocks we need. - unsigned int num_k_blocks = iceildiv(_Ksize, _k_block); - - // So divide the space equally into that many blocks. - _k_block = iceildiv(_Ksize, num_k_blocks); - - // And round UP to the K unroll level required. - _k_block = iceildiv(_k_block, strategy::k_unroll()); - _k_block *= strategy::k_unroll(); - } - if (args._cfg && args._cfg->outer_block_size) { _x_block = args._cfg->outer_block_size; } else { @@ -381,6 +388,10 @@ public: return { m, n }; } + bool supports_dynamic_scheduling() const override { + return true; + } + // set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads. void set_nthreads(int nthreads) override { _nthreads = std::min(nthreads, _maxthreads); @@ -495,7 +506,60 @@ public: _B_transposed = reinterpret_cast<Toi *>(in_buffer); } - ~GemmInterleavedPretransposed2d() override { } + // Estimate cycles for given problem given provided parameters + static uint64_t estimate_cycles(const GemmArgs &args, const PerformanceParameters ¶ms) { + unsigned int k_blocks = iceildiv(args._Ksize, get_k_block_size(args)); + unsigned int m_blocks = iceildiv(args._Msize, strategy::out_height()) * args._nbatches; + unsigned int n_blocks = iceildiv(args._Nsize, strategy::out_width()); + + uint64_t total_macs = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * roundup(args._Ksize, strategy::k_unroll()); + uint64_t prepare_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * roundup(args._Msize, strategy::out_height()) * roundup(args._Ksize, strategy::k_unroll()) * sizeof(Toi); + uint64_t merge_bytes = static_cast<uint64_t>(args._nbatches) * args._nmulti * k_blocks * roundup(args._Msize, strategy::out_height()) * roundup(args._Nsize, strategy::out_width()) * sizeof(Tr); + + // Wide problems incur extra preparation cost, as it is done per thread. + // Duplicate the logic the scheduler will later use to figure out how much that will affect us + float ratio = m_blocks / static_cast<float>(n_blocks); + + unsigned int ideal_height = static_cast<unsigned int>(std::sqrt(args._maxthreads * ratio) + 0.5); + unsigned int height = 1; + + if (ideal_height == 0) { + height = 1; + } else { + for (unsigned int adj=0; adj<ideal_height; adj++) { + const unsigned int round_down = ideal_height - adj; + if (args._maxthreads % round_down == 0) { + height = round_down; + break; + } + + const unsigned int round_up = ideal_height + adj; + if (args._maxthreads % round_up == 0) { + height = round_up; + break; + } + } + } + + // We've computed the height here - we need to multiply the amount of preparation effort by the width (which is total threads / height) + prepare_bytes *= (args._maxthreads / height); + + float mac_cycles = static_cast<float>(total_macs) / params.kernel_macs_cycle; + float prepare_cycles = static_cast<float>(prepare_bytes) / params.prepare_bytes_cycle; + float merge_cycles = static_cast<float>(merge_bytes) / params.merge_bytes_cycle; + + float total_cycles = mac_cycles + prepare_cycles + merge_cycles; + + // We can't thread over multis, which might be a problem in some + // threaded cases. Penalize that here. + float parallelism_available = static_cast<float>(iceildiv(args._Msize, strategy::out_height()) * args._nbatches * iceildiv(args._Nsize, strategy::out_width())) * 0.9; + + if (parallelism_available < args._maxthreads) { + total_cycles *= (static_cast<float>(args._maxthreads) / parallelism_available); + } + + return static_cast<uint64_t>(total_cycles); + } }; } // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp index caab2e2cc2..88726b1448 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp @@ -101,6 +101,13 @@ static const GemmImplementation<uint8_t, uint32_t> gemm_u8_methods[] = { [](const GemmArgs &args) { return new GemmHybrid<smallK_hybrid_u8u32_dot_4x6, uint8_t, uint32_t>(args); } }, { + GemmMethod::GEMM_INTERLEAVED, + "gemm_u16_12x8", + nullptr, + [](const GemmArgs &args) { return args._ci->get_cpu_model() == CPUModel::A53; }, + [](const GemmArgs &args) { return new GemmInterleaved<gemm_u16_12x8, uint8_t, uint32_t>(args); }, +}, +{ GemmMethod::GEMM_HYBRID, "hybrid_u8u32_dot_16x4", [](const GemmArgs &args) { return args._ci->has_dotprod() && args._Ksize>=16; }, diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4.hpp index 8d8ede8137..4147ab60dc 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hybrid_fp32_mla_16x4.hpp @@ -25,7 +25,7 @@ #ifdef __aarch64__ - +#include "../performance_parameters.hpp" #include "../std_transforms_fixed.hpp" namespace arm_gemm @@ -75,6 +75,22 @@ public: return true; } + static PerformanceParameters get_performance_parameters(const CPUInfo *ci) { + switch (ci->get_cpu_model()) { + case CPUModel::A55r1: + return { 2.866 }; + + case CPUModel::A53: + return { 1.419 }; + + case CPUModel::A73: + return { 2.551 }; + + default: + return { 6.25 }; + } + } + StdTransformsFixed<operand_type, result_type, 4, 16, 1> transforms = {}; // Default to the generic kernel diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp index 5c3d6409b9..981ce34b49 100644 --- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp @@ -67,6 +67,22 @@ public: // Use the standard fixed size transforms. StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {}; + static PerformanceParameters get_performance_parameters(const CPUInfo *ci) { + switch (ci->get_cpu_model()) { + case CPUModel::A55r1: + return { 3.724, 1.416, 1.113 }; + + case CPUModel::A53: + return { 2.777, 0.987, 0.898 }; + + case CPUModel::A73: + return { 2.885, 1.429, 1.163 }; + + default: + return { 6.949, 4.149, 2.826 }; + } + } + kern_type kernel=a64_sgemm_asimd_12x8; sgemm_12x8(const CPUInfo *ci) { diff --git a/src/core/NEON/kernels/arm_gemm/performance_parameters.hpp b/src/core/NEON/kernels/arm_gemm/performance_parameters.hpp new file mode 100644 index 0000000000..059ab5f7df --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/performance_parameters.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2020 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +namespace arm_gemm { + +struct PerformanceParameters { + float kernel_macs_cycle; + float prepare_bytes_cycle = 0.0f; + float merge_bytes_cycle = 0.0f; + + PerformanceParameters(float k) : kernel_macs_cycle(k) { } + PerformanceParameters(float k, float p, float m) : kernel_macs_cycle(k), prepare_bytes_cycle(p), merge_bytes_cycle(m) { } +}; + +} // namespace arm_gemm diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_s8_to_s16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_s8_to_s16.hpp new file mode 100644 index 0000000000..37344a82a9 --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_s8_to_s16.hpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2017-2020 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) + +#include <arm_neon.h> +#include <cstdint> + +#include "../asmlib.hpp" + +template<> +template<> +inline void TransformImpl<8, 1, false, 2, 1, false>::Transform(int16_t *out, const int8_t *in, int ldin, int y0, int ymax, int k0, int kmax) { + int16_t *outptr = out; + const int8_t *inptr = in; + bool first = true; + + int8_t zerobuff[32] = { 0 }; // 16 for asm loop plus up to 15 for overflow loop + + for (int y=y0; y<ymax; y+=8) { + const int8_t *inptr0 = inptr + y * ldin + k0; + const int8_t *inptr1 = inptr0 + ldin; + const int8_t *inptr2 = inptr1 + ldin; + const int8_t *inptr3 = inptr2 + ldin; + const int8_t *inptr4 = inptr3 + ldin; + const int8_t *inptr5 = inptr4 + ldin; + const int8_t *inptr6 = inptr5 + ldin; + const int8_t *inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x=(kmax-k0); + for (;(x>15) || first;x-=16) { + /* Cope with ragged cases by copying from a buffer of zeroes instead */ + /* 'first' forces this to always run at least once, needed if the total size is <=7. */ + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + inptr1 = zerobuff; + // fall through + case 5: + inptr2 = zerobuff; + // fall through + case 4: + inptr3 = zerobuff; + // fall through + case 3: + inptr4 = zerobuff; + // fall through + case 2: + inptr5 = zerobuff; + // fall through + case 1: + inptr6 = zerobuff; + // fall through + case 0: + inptr7 = zerobuff; + break; + + default: + UNREACHABLE("Impossible."); + } + } + + if (first) { + if (x<=15) { + break; + } + + first = false; + } + + __asm __volatile ( + // Load up 16 elements (1 source vector, 2 destination vectors) from each of 8 sources. + "LDR q0, [%[inptr0]], #16\n" + "LDR q2, [%[inptr1]], #16\n" + "SSHLL2 v1.8h, v0.16b, #0\n" + "SSHLL v0.8h, v0.8b, #0\n" + "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3 + "SSHLL2 v3.8h, v2.16b, #0\n" + "SSHLL v2.8h, v2.8b, #0\n" + "SSHLL2 v5.8h, v4.16b, #0\n" + "SSHLL v4.8h, v4.8b, #0\n" + "ZIP1 v16.8h, v0.8h, v4.8h\n" // q16=A0C0A1C1 + ASM_PREFETCH("[%[inptr0], #128]") + "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3 + "SSHLL2 v7.8h, v6.16b, #0\n" + "SSHLL v6.8h, v6.8b, #0\n" + "ZIP1 v17.8h, v2.8h, v6.8h\n" // q17=B0D0B1D1 + "LDR q8, [%[inptr4]], #16\n" + "LDR q10, [%[inptr5]], #16\n" + "SSHLL2 v9.8h, v8.16b, #0\n" + "SSHLL v8.8h, v8.8b, #0\n" + ASM_PREFETCH("[%[inptr1], #128]") + "LDR q12, [%[inptr6]], #16\n" + "SSHLL2 v11.8h, v10.16b, #0\n" + "SSHLL v10.8h, v10.8b, #0\n" + "SSHLL2 v13.8h, v12.16b, #0\n" + "SSHLL v12.8h, v12.8b, #0\n" + "ZIP1 v18.8h, v8.8h, v12.8h\n" + "LDR q14, [%[inptr7]], #16\n" + "SSHLL2 v15.8h, v14.16b, #0\n" + "SSHLL v14.8h, v14.8b, #0\n" + "ZIP1 v19.8h, v10.8h, v14.8h\n" + + ASM_PREFETCH("[%[inptr2], #128]") + "ZIP1 v20.8h, v16.8h, v17.8h\n" // q20=A0B0C0D0A1B1C1D1 + "ZIP1 v21.8h, v18.8h, v19.8h\n" // q21=E0F0G0H0E1F1G1H1 + "ZIP2 v22.8h, v16.8h, v17.8h\n" // q22=A2B2C2D2A3B3C3D3 + "ZIP2 v23.8h, v18.8h, v19.8h\n" // q23=E2F2G2H1E3F3G3H3 + ASM_PREFETCH("[%[inptr3], #128]") + + "ZIP2 v16.8h, v0.8h, v4.8h\n" + "ZIP2 v17.8h, v2.8h, v6.8h\n" + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP2 v18.8h, v8.8h, v12.8h\n" + ASM_PREFETCH("[%[inptr4], #128]") + "ZIP2 v19.8h, v10.8h, v14.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Write back the first element of each source + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + + "ZIP1 v20.8h, v16.8h, v17.8h\n" + "ZIP1 v21.8h, v18.8h, v19.8h\n" + ASM_PREFETCH("[%[inptr5], #128]") + "ZIP2 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v23.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Write back the second element of each source + + "ZIP1 v16.8h, v1.8h, v5.8h\n" + "ZIP1 v17.8h, v3.8h, v7.8h\n" + ASM_PREFETCH("[%[inptr6], #128]") + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP1 v18.8h, v9.8h, v13.8h\n" + "ZIP1 v19.8h, v11.8h, v15.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Third element + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + ASM_PREFETCH("[%[inptr7], #128]") + + "ZIP1 v20.8h, v16.8h, v17.8h\n" + "ZIP1 v21.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Fourth element + "ZIP2 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v23.8h, v18.8h, v19.8h\n" + + "ZIP2 v16.8h, v1.8h, v5.8h\n" + "ZIP2 v17.8h, v3.8h, v7.8h\n" + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP2 v18.8h, v9.8h, v13.8h\n" + "ZIP2 v19.8h, v11.8h, v15.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Fifth element + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + + "ZIP1 v20.8h, v16.8h, v17.8h\n" + "ZIP1 v21.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Sixth element + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP2 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v23.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Seventh element + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + "STP q24, q25, [%[outptr]], #32\n" // Eighth element + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "memory" + ); + } + + for (;x>0;x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +#endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_u8_to_u16.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_u8_to_u16.hpp new file mode 100644 index 0000000000..a3a269c9cd --- /dev/null +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_u8_to_u16.hpp @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2017-2020 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#pragma once + +#if defined(__aarch64__) && defined(__ARM_FP16_ARGS) + +#include <arm_neon.h> +#include <cstdint> + +#include "../asmlib.hpp" + +template<> +template<> +inline void TransformImpl<8, 1, false, 2, 1, false>::Transform(uint16_t *out, const uint8_t *in, int ldin, int y0, int ymax, int k0, int kmax) { + uint16_t *outptr = out; + const uint8_t *inptr = in; + bool first = true; + + uint8_t zerobuff[32] = { 0 }; // 16 for asm loop plus up to 15 for overflow loop + + for (int y=y0; y<ymax; y+=8) { + const uint8_t *inptr0 = inptr + y * ldin + k0; + const uint8_t *inptr1 = inptr0 + ldin; + const uint8_t *inptr2 = inptr1 + ldin; + const uint8_t *inptr3 = inptr2 + ldin; + const uint8_t *inptr4 = inptr3 + ldin; + const uint8_t *inptr5 = inptr4 + ldin; + const uint8_t *inptr6 = inptr5 + ldin; + const uint8_t *inptr7 = inptr6 + ldin; + + prefetch_2x(inptr0); + prefetch_2x(inptr1); + prefetch_2x(inptr2); + prefetch_2x(inptr3); + prefetch_2x(inptr4); + prefetch_2x(inptr5); + prefetch_2x(inptr6); + prefetch_2x(inptr7); + + int x=(kmax-k0); + for (;(x>15) || first;x-=16) { + /* Cope with ragged cases by copying from a buffer of zeroes instead */ + /* 'first' forces this to always run at least once, needed if the total size is <=7. */ + if ((y + 7) >= ymax) { + switch ((y + 7) - ymax) { + case 6: + inptr1 = zerobuff; + // fall through + case 5: + inptr2 = zerobuff; + // fall through + case 4: + inptr3 = zerobuff; + // fall through + case 3: + inptr4 = zerobuff; + // fall through + case 2: + inptr5 = zerobuff; + // fall through + case 1: + inptr6 = zerobuff; + // fall through + case 0: + inptr7 = zerobuff; + break; + + default: + UNREACHABLE("Impossible."); + } + } + + if (first) { + if (x<=15) { + break; + } + + first = false; + } + + __asm __volatile ( + // Load up 16 elements (1 source vector, 2 destination vectors) from each of 8 sources. + "LDR q0, [%[inptr0]], #16\n" + "LDR q2, [%[inptr1]], #16\n" + "USHLL2 v1.8h, v0.16b, #0\n" + "USHLL v0.8h, v0.8b, #0\n" + "LDR q4, [%[inptr2]], #16\n" // q4=C0C1C2C3 + "USHLL2 v3.8h, v2.16b, #0\n" + "USHLL v2.8h, v2.8b, #0\n" + "USHLL2 v5.8h, v4.16b, #0\n" + "USHLL v4.8h, v4.8b, #0\n" + "ZIP1 v16.8h, v0.8h, v4.8h\n" // q16=A0C0A1C1 + ASM_PREFETCH("[%[inptr0], #128]") + "LDR q6, [%[inptr3]], #16\n" // q6=D0D1D2D3 + "USHLL2 v7.8h, v6.16b, #0\n" + "USHLL v6.8h, v6.8b, #0\n" + "ZIP1 v17.8h, v2.8h, v6.8h\n" // q17=B0D0B1D1 + "LDR q8, [%[inptr4]], #16\n" + "LDR q10, [%[inptr5]], #16\n" + "USHLL2 v9.8h, v8.16b, #0\n" + "USHLL v8.8h, v8.8b, #0\n" + ASM_PREFETCH("[%[inptr1], #128]") + "LDR q12, [%[inptr6]], #16\n" + "USHLL2 v11.8h, v10.16b, #0\n" + "USHLL v10.8h, v10.8b, #0\n" + "USHLL2 v13.8h, v12.16b, #0\n" + "USHLL v12.8h, v12.8b, #0\n" + "ZIP1 v18.8h, v8.8h, v12.8h\n" + "LDR q14, [%[inptr7]], #16\n" + "USHLL2 v15.8h, v14.16b, #0\n" + "USHLL v14.8h, v14.8b, #0\n" + "ZIP1 v19.8h, v10.8h, v14.8h\n" + + ASM_PREFETCH("[%[inptr2], #128]") + "ZIP1 v20.8h, v16.8h, v17.8h\n" // q20=A0B0C0D0A1B1C1D1 + "ZIP1 v21.8h, v18.8h, v19.8h\n" // q21=E0F0G0H0E1F1G1H1 + "ZIP2 v22.8h, v16.8h, v17.8h\n" // q22=A2B2C2D2A3B3C3D3 + "ZIP2 v23.8h, v18.8h, v19.8h\n" // q23=E2F2G2H1E3F3G3H3 + ASM_PREFETCH("[%[inptr3], #128]") + + "ZIP2 v16.8h, v0.8h, v4.8h\n" + "ZIP2 v17.8h, v2.8h, v6.8h\n" + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP2 v18.8h, v8.8h, v12.8h\n" + ASM_PREFETCH("[%[inptr4], #128]") + "ZIP2 v19.8h, v10.8h, v14.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Write back the first element of each source + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + + "ZIP1 v20.8h, v16.8h, v17.8h\n" + "ZIP1 v21.8h, v18.8h, v19.8h\n" + ASM_PREFETCH("[%[inptr5], #128]") + "ZIP2 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v23.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Write back the second element of each source + + "ZIP1 v16.8h, v1.8h, v5.8h\n" + "ZIP1 v17.8h, v3.8h, v7.8h\n" + ASM_PREFETCH("[%[inptr6], #128]") + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP1 v18.8h, v9.8h, v13.8h\n" + "ZIP1 v19.8h, v11.8h, v15.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Third element + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + ASM_PREFETCH("[%[inptr7], #128]") + + "ZIP1 v20.8h, v16.8h, v17.8h\n" + "ZIP1 v21.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Fourth element + "ZIP2 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v23.8h, v18.8h, v19.8h\n" + + "ZIP2 v16.8h, v1.8h, v5.8h\n" + "ZIP2 v17.8h, v3.8h, v7.8h\n" + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP2 v18.8h, v9.8h, v13.8h\n" + "ZIP2 v19.8h, v11.8h, v15.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Fifth element + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + + "ZIP1 v20.8h, v16.8h, v17.8h\n" + "ZIP1 v21.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Sixth element + "TRN1 v24.2d, v20.2d, v21.2d\n" + "TRN2 v25.2d, v20.2d, v21.2d\n" + + "ZIP2 v22.8h, v16.8h, v17.8h\n" + "ZIP2 v23.8h, v18.8h, v19.8h\n" + "STP q24, q25, [%[outptr]], #32\n" // Seventh element + "TRN1 v24.2d, v22.2d, v23.2d\n" + "TRN2 v25.2d, v22.2d, v23.2d\n" + "STP q24, q25, [%[outptr]], #32\n" // Eighth element + : [inptr0] "+r" (inptr0), [inptr1] "+r" (inptr1), [inptr2] "+r" (inptr2), [inptr3] "+r" (inptr3), + [inptr4] "+r" (inptr4), [inptr5] "+r" (inptr5), [inptr6] "+r" (inptr6), [inptr7] "+r" (inptr7), [outptr] "+r" (outptr) + : + : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "memory" + ); + } + + for (;x>0;x--) { + *outptr++ = *inptr0++; + *outptr++ = *inptr1++; + *outptr++ = *inptr2++; + *outptr++ = *inptr3++; + *outptr++ = *inptr4++; + *outptr++ = *inptr5++; + *outptr++ = *inptr6++; + *outptr++ = *inptr7++; + } + } +} + +#endif // __aarch64__ && __ARM_FP16_ARGS diff --git a/src/core/NEON/kernels/arm_gemm/transforms/list.hpp b/src/core/NEON/kernels/arm_gemm/transforms/list.hpp index 2c698b2576..b825e1c358 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/list.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/list.hpp @@ -28,6 +28,8 @@ #include "a64_interleave_8way_32bit.hpp" #include "a64_interleave_8way_block4_8bit.hpp" #include "a64_interleave_8way_half_to_float.hpp" +#include "a64_interleave_8way_s8_to_s16.hpp" +#include "a64_interleave_8way_u8_to_u16.hpp" #include "a64_transpose_interleave_12way_16bit.hpp" #include "a64_transpose_interleave_12way_half_to_float.hpp" #include "a64_transpose_interleave_24way_16bit.hpp" |