diff options
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/misc.cpp')
-rw-r--r-- | src/core/NEON/kernels/arm_gemm/misc.cpp | 44 |
1 files changed, 42 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/misc.cpp b/src/core/NEON/kernels/arm_gemm/misc.cpp index 229e6b56f9..cf99bbdb46 100644 --- a/src/core/NEON/kernels/arm_gemm/misc.cpp +++ b/src/core/NEON/kernels/arm_gemm/misc.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 Arm Limited. + * Copyright (c) 2017-2018, 2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -25,6 +25,11 @@ #ifndef NO_MULTI_THREADING #include <mutex> #endif +#include <cstdint> + +#include "arm_gemm.hpp" +#include "kernel_weight_format.hpp" +#include "utils.hpp" namespace arm_gemm { @@ -32,4 +37,39 @@ namespace arm_gemm { std::mutex report_mutex; #endif -} // namespace arm_gemm
\ No newline at end of file +WeightFormat get_weight_format(const KernelWeightFormat kwf, size_t element_size) { + if (kwf==KernelWeightFormat::NON_FIXED) { + return WeightFormat::UNSPECIFIED; + } + + uint32_t kwf_i = static_cast<uint32_t>(kwf); + uint32_t wf_i = 0; + + const auto block_bytes = (kwf_i >> 8) & 0xf; + const auto vector_count = (kwf_i >> 12) & 0xf; + + uint32_t vector_bytes; + + // For fast mode BF16 kernels set the appropriate bit and override element size to 2. + if (kwf_i & 0x10) { + element_size = 2; + wf_i |= 0x10; + } + + // Get total bytes in vector output + if (kwf_i & 0x1) { + vector_bytes = vector_count * get_vector_length<uint8_t>(); + } else { + vector_bytes = vector_count * 16; + } + + auto input_blocking = block_bytes / element_size; + auto output_blocking = vector_bytes / block_bytes; + + wf_i |= (input_blocking << 20); + wf_i |= (output_blocking << 8); + + return static_cast<WeightFormat>(wf_i); +} + +} // namespace arm_gemm |