aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp30
1 files changed, 22 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index b9729d4c5c..484892dc81 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -21,7 +21,9 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
+
+// This can only be built if the target/compiler supports FP16 arguments.
+#ifdef __ARM_FP16_ARGS
#include "arm_gemm.hpp"
@@ -40,26 +42,38 @@ UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, c
const int maxthreads, const bool pretransposed_hint)
{
#ifdef __aarch64__
- /* If FP16 is supported, use it */
- if(ci.has_fp16())
+
+ // Only consider the native FP16 kernel if it will get built.
+#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS)
+#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
+ // If the compiler is configured to enable this feature always, then assume it is available at runtime too.
+ const bool use_fp16 = true;
+#else
+ // Otherwise, detect at runtime via CPUInfo.
+ const bool use_fp16 = ci.has_fp16();
+#endif
+
+ // If FP16 is supported, use it.
+ if(use_fp16)
{
return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
+#endif
- /* Fallback to using the blocked SGEMM kernel. */
+ // Fallback to using the blocked SGEMM kernel.
return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#else
- /* For AArch32, only support the SGEMM route. */
+ // For AArch32, only support the SGEMM route for now.
return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#endif
}
-// Instantiate static class members
-#ifdef __aarch64__
+// Instantiate static class members if necessary.
+#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
const int hgemm_24x8::out_width;
const int hgemm_24x8::out_height;
#endif
} // namespace arm_gemm
-#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC
+#endif // __ARM_FP16_ARGS