From 91780021e25575086c6c31d014d34b6513649a9d Mon Sep 17 00:00:00 2001 From: Ramy Elgammal Date: Wed, 20 Jul 2022 14:57:37 +0100 Subject: Fix for inclusion of "arm_gemm" from src into "Types.h" from core - Added arm_compute::WeightFormat and converted to/from arm_gemm::WeightFormat when needed through two map function. - Moved to_string(WeightFormat) to TypePrinter.h Resolves: COMPMID-5415 Signed-off-by: Ramy Elgammal Change-Id: I65f7942100bcd4dbf2c5cf6c07f26c8e1e3bf86e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/c/VisualCompute/ComputeLibrary/+/438511 Tested-by: bsgcomp Reviewed-by: Pablo Tello Reviewed-by: Sicong Li Comments-Addressed: bsgcomp Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7985 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Michalis Spyrou Benchmark: Arm Jenkins --- src/core/utils/AssemblyUtils.cpp | 242 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 241 insertions(+), 1 deletion(-) (limited to 'src/core/utils/AssemblyUtils.cpp') diff --git a/src/core/utils/AssemblyUtils.cpp b/src/core/utils/AssemblyUtils.cpp index 1e8a2a54c9..45e7ff78be 100644 --- a/src/core/utils/AssemblyUtils.cpp +++ b/src/core/utils/AssemblyUtils.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021 Arm Limited. + * Copyright (c) 2021-2022 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -66,5 +66,245 @@ arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_ pad_stride_info.pad_right(), pad_stride_info.pad_bottom() }; } + +arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format) +{ + arm_gemm::WeightFormat gemm_weight_fromat; + + switch(weight_format) + { + case arm_compute::WeightFormat::UNSPECIFIED: + gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED; + break; + case arm_compute::WeightFormat::ANY: + gemm_weight_fromat = arm_gemm::WeightFormat::ANY; + break; + case arm_compute::WeightFormat::OHWI: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWI; + break; + case arm_compute::WeightFormat::OHWIo2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2; + break; + case arm_compute::WeightFormat::OHWIo4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4; + break; + case arm_compute::WeightFormat::OHWIo8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8; + break; + case arm_compute::WeightFormat::OHWIo16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16; + break; + case arm_compute::WeightFormat::OHWIo32: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32; + break; + case arm_compute::WeightFormat::OHWIo64: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64; + break; + case arm_compute::WeightFormat::OHWIo128: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128; + break; + case arm_compute::WeightFormat::OHWIo4i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2; + break; + case arm_compute::WeightFormat::OHWIo4i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo8i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2; + break; + case arm_compute::WeightFormat::OHWIo8i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo16i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2; + break; + case arm_compute::WeightFormat::OHWIo16i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo32i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2; + break; + case arm_compute::WeightFormat::OHWIo32i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo64i2: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2; + break; + case arm_compute::WeightFormat::OHWIo64i2_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16; + break; + case arm_compute::WeightFormat::OHWIo4i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4; + break; + case arm_compute::WeightFormat::OHWIo4i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo8i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4; + break; + case arm_compute::WeightFormat::OHWIo8i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo16i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4; + break; + case arm_compute::WeightFormat::OHWIo16i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo32i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4; + break; + case arm_compute::WeightFormat::OHWIo32i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo64i4: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4; + break; + case arm_compute::WeightFormat::OHWIo64i4_bf16: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16; + break; + case arm_compute::WeightFormat::OHWIo2i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8; + break; + case arm_compute::WeightFormat::OHWIo4i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8; + break; + case arm_compute::WeightFormat::OHWIo8i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8; + break; + case arm_compute::WeightFormat::OHWIo16i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8; + break; + case arm_compute::WeightFormat::OHWIo32i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8; + break; + case arm_compute::WeightFormat::OHWIo64i8: + gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8; + break; + default: + gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED; + } + return gemm_weight_fromat; +} + +arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format) +{ + arm_compute::WeightFormat acl_weight_fromat; + + switch(weight_format) + { + case arm_gemm::WeightFormat::UNSPECIFIED: + acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED; + break; + case arm_gemm::WeightFormat::ANY: + acl_weight_fromat = arm_compute::WeightFormat::ANY; + break; + case arm_gemm::WeightFormat::OHWI: + acl_weight_fromat = arm_compute::WeightFormat::OHWI; + break; + case arm_gemm::WeightFormat::OHWIo2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo2; + break; + case arm_gemm::WeightFormat::OHWIo4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4; + break; + case arm_gemm::WeightFormat::OHWIo8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8; + break; + case arm_gemm::WeightFormat::OHWIo16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16; + break; + case arm_gemm::WeightFormat::OHWIo32: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32; + break; + case arm_gemm::WeightFormat::OHWIo64: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64; + break; + case arm_gemm::WeightFormat::OHWIo128: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo128; + break; + case arm_gemm::WeightFormat::OHWIo4i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2; + break; + case arm_gemm::WeightFormat::OHWIo4i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo8i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2; + break; + case arm_gemm::WeightFormat::OHWIo8i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo16i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2; + break; + case arm_gemm::WeightFormat::OHWIo16i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo32i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2; + break; + case arm_gemm::WeightFormat::OHWIo32i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo64i2: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2; + break; + case arm_gemm::WeightFormat::OHWIo64i2_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16; + break; + case arm_gemm::WeightFormat::OHWIo4i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4; + break; + case arm_gemm::WeightFormat::OHWIo4i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo8i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4; + break; + case arm_gemm::WeightFormat::OHWIo8i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo16i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4; + break; + case arm_gemm::WeightFormat::OHWIo16i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo32i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4; + break; + case arm_gemm::WeightFormat::OHWIo32i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo64i4: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4; + break; + case arm_gemm::WeightFormat::OHWIo64i4_bf16: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16; + break; + case arm_gemm::WeightFormat::OHWIo2i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8; + break; + case arm_gemm::WeightFormat::OHWIo4i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8; + break; + case arm_gemm::WeightFormat::OHWIo8i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8; + break; + case arm_gemm::WeightFormat::OHWIo16i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8; + break; + case arm_gemm::WeightFormat::OHWIo32i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8; + break; + case arm_gemm::WeightFormat::OHWIo64i8: + acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8; + break; + default: + acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED; + } + return acl_weight_fromat; +} } // namespace assembly_utils } // namespace arm_compute -- cgit v1.2.1