diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp index b9729d4c5c..484892dc81 100644 --- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp +++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp @@ -21,7 +21,9 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +// This can only be built if the target/compiler supports FP16 arguments. +#ifdef __ARM_FP16_ARGS #include "arm_gemm.hpp" @@ -40,26 +42,38 @@ UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, c const int maxthreads, const bool pretransposed_hint) { #ifdef __aarch64__ - /* If FP16 is supported, use it */ - if(ci.has_fp16()) + + // Only consider the native FP16 kernel if it will get built. +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS) +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + // If the compiler is configured to enable this feature always, then assume it is available at runtime too. + const bool use_fp16 = true; +#else + // Otherwise, detect at runtime via CPUInfo. + const bool use_fp16 = ci.has_fp16(); +#endif + + // If FP16 is supported, use it. + if(use_fp16) { return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); } +#endif - /* Fallback to using the blocked SGEMM kernel. */ + // Fallback to using the blocked SGEMM kernel. return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); #else - /* For AArch32, only support the SGEMM route. */ + // For AArch32, only support the SGEMM route for now. return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint)); #endif } -// Instantiate static class members -#ifdef __aarch64__ +// Instantiate static class members if necessary. +#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)) const int hgemm_24x8::out_width; const int hgemm_24x8::out_height; #endif } // namespace arm_gemm -#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#endif // __ARM_FP16_ARGS |