diff options
Diffstat (limited to 'src/core/utils/AssemblyUtils.cpp')
-rw-r--r-- | src/core/utils/AssemblyUtils.cpp | 242 |
1 files changed, 241 insertions, 1 deletions
diff --git a/src/core/utils/AssemblyUtils.cpp b/src/core/utils/AssemblyUtils.cpp index 1e8a2a54c9..45e7ff78be 100644 --- a/src/core/utils/AssemblyUtils.cpp +++ b/src/core/utils/AssemblyUtils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -66,5 +66,245 @@ arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_ pad_stride_info.pad_right(), pad_stride_info.pad_bottom() }; } + +arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format) +{ + arm_gemm::WeightFormat gemm_weight_fromat; + + switch(weight_format) + { + case arm_compute::WeightFormat::UNSPECIFIED: + gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED; + break; + case arm_compute::WeightFormat::ANY: + gemm_weight_fromat = arm_gemm::WeightFormat::ANY; + break; + case arm_compute::WeightFormat::OHWI: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWI; + break; + case arm_compute::WeightFormat::OHWIo2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2; + break; + case arm_compute::WeightFormat::OHWIo4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4; + break; + case arm_compute::WeightFormat::OHWIo8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8; + break; + case arm_compute::WeightFormat::OHWIo16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16; + break; + case arm_compute::WeightFormat::OHWIo32: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32; + break; + case arm_compute::WeightFormat::OHWIo64: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64; + break; + case arm_compute::WeightFormat::OHWIo128: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128; + break; + case arm_compute::WeightFormat::OHWIo4i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2; + break; + case arm_compute::WeightFormat::OHWIo4i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo8i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2; + break; + case arm_compute::WeightFormat::OHWIo8i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo16i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2; + break; + case arm_compute::WeightFormat::OHWIo16i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo32i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2; + break; + case arm_compute::WeightFormat::OHWIo32i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo64i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2; + break; + case arm_compute::WeightFormat::OHWIo64i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo4i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4; + break; + case arm_compute::WeightFormat::OHWIo4i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo8i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4; + break; + case arm_compute::WeightFormat::OHWIo8i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo16i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4; + break; + case arm_compute::WeightFormat::OHWIo16i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo32i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4; + break; + case arm_compute::WeightFormat::OHWIo32i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo64i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4; + break; + case arm_compute::WeightFormat::OHWIo64i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo2i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8; + break; + case arm_compute::WeightFormat::OHWIo4i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8; + break; + case arm_compute::WeightFormat::OHWIo8i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8; + break; + case arm_compute::WeightFormat::OHWIo16i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8; + break; + case arm_compute::WeightFormat::OHWIo32i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8; + break; + case arm_compute::WeightFormat::OHWIo64i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8; + break; + default: + gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED; + } + return gemm_weight_fromat; +} + +arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format) +{ + arm_compute::WeightFormat acl_weight_fromat; + + switch(weight_format) + { + case arm_gemm::WeightFormat::UNSPECIFIED: + acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED; + break; + case arm_gemm::WeightFormat::ANY: + acl_weight_fromat = arm_compute::WeightFormat::ANY; + break; + case arm_gemm::WeightFormat::OHWI: + acl_weight_fromat = arm_compute::WeightFormat::OHWI; + break; + case arm_gemm::WeightFormat::OHWIo2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo2; + break; + case arm_gemm::WeightFormat::OHWIo4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4; + break; + case arm_gemm::WeightFormat::OHWIo8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8; + break; + case arm_gemm::WeightFormat::OHWIo16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16; + break; + case arm_gemm::WeightFormat::OHWIo32: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32; + break; + case arm_gemm::WeightFormat::OHWIo64: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64; + break; + case arm_gemm::WeightFormat::OHWIo128: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo128; + break; + case arm_gemm::WeightFormat::OHWIo4i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2; + break; + case arm_gemm::WeightFormat::OHWIo4i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo8i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2; + break; + case arm_gemm::WeightFormat::OHWIo8i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo16i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2; + break; + case arm_gemm::WeightFormat::OHWIo16i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo32i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2; + break; + case arm_gemm::WeightFormat::OHWIo32i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo64i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2; + break; + case arm_gemm::WeightFormat::OHWIo64i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo4i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4; + break; + case arm_gemm::WeightFormat::OHWIo4i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo8i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4; + break; + case arm_gemm::WeightFormat::OHWIo8i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo16i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4; + break; + case arm_gemm::WeightFormat::OHWIo16i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo32i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4; + break; + case arm_gemm::WeightFormat::OHWIo32i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo64i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4; + break; + case arm_gemm::WeightFormat::OHWIo64i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo2i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8; + break; + case arm_gemm::WeightFormat::OHWIo4i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8; + break; + case arm_gemm::WeightFormat::OHWIo8i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8; + break; + case arm_gemm::WeightFormat::OHWIo16i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8; + break; + case arm_gemm::WeightFormat::OHWIo32i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8; + break; + case arm_gemm::WeightFormat::OHWIo64i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8; + break; + default: + acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED; + } + return acl_weight_fromat; +} } // namespace assembly_utils } // namespace arm_compute |