aboutsummaryrefslogtreecommitdiff
path: root/src/core/utils/AssemblyUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/utils/AssemblyUtils.cpp')
-rw-r--r--src/core/utils/AssemblyUtils.cpp242
1 files changed, 241 insertions, 1 deletions
diff --git a/src/core/utils/AssemblyUtils.cpp b/src/core/utils/AssemblyUtils.cpp
index 1e8a2a54c9..45e7ff78be 100644
--- a/src/core/utils/AssemblyUtils.cpp
+++ b/src/core/utils/AssemblyUtils.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021-2022 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,5 +66,245 @@ arm_conv::PaddingValues map_to_arm_conv_padding(const PadStrideInfo &pad_stride_
pad_stride_info.pad_right(),
pad_stride_info.pad_bottom() };
}
+
+arm_gemm::WeightFormat map_to_arm_gemm_weight_format(const arm_compute::WeightFormat &weight_format)
+{
+ arm_gemm::WeightFormat gemm_weight_fromat;
+
+ switch(weight_format)
+ {
+ case arm_compute::WeightFormat::UNSPECIFIED:
+ gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
+ break;
+ case arm_compute::WeightFormat::ANY:
+ gemm_weight_fromat = arm_gemm::WeightFormat::ANY;
+ break;
+ case arm_compute::WeightFormat::OHWI:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWI;
+ break;
+ case arm_compute::WeightFormat::OHWIo2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2;
+ break;
+ case arm_compute::WeightFormat::OHWIo4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4;
+ break;
+ case arm_compute::WeightFormat::OHWIo8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8;
+ break;
+ case arm_compute::WeightFormat::OHWIo16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32;
+ break;
+ case arm_compute::WeightFormat::OHWIo64:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64;
+ break;
+ case arm_compute::WeightFormat::OHWIo128:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo128;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i2:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i2_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i2_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i4:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i4_bf16:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i4_bf16;
+ break;
+ case arm_compute::WeightFormat::OHWIo2i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo2i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo4i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo4i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo8i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo8i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo16i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo16i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo32i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo32i8;
+ break;
+ case arm_compute::WeightFormat::OHWIo64i8:
+ gemm_weight_fromat = arm_gemm::WeightFormat::OHWIo64i8;
+ break;
+ default:
+ gemm_weight_fromat = arm_gemm::WeightFormat::UNSPECIFIED;
+ }
+ return gemm_weight_fromat;
+}
+
+arm_compute::WeightFormat map_to_arm_compute_weight_format(const arm_gemm::WeightFormat &weight_format)
+{
+ arm_compute::WeightFormat acl_weight_fromat;
+
+ switch(weight_format)
+ {
+ case arm_gemm::WeightFormat::UNSPECIFIED:
+ acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
+ break;
+ case arm_gemm::WeightFormat::ANY:
+ acl_weight_fromat = arm_compute::WeightFormat::ANY;
+ break;
+ case arm_gemm::WeightFormat::OHWI:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWI;
+ break;
+ case arm_gemm::WeightFormat::OHWIo2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64;
+ break;
+ case arm_gemm::WeightFormat::OHWIo128:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo128;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i2:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i2_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i2_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i4:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i4_bf16:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i4_bf16;
+ break;
+ case arm_gemm::WeightFormat::OHWIo2i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo2i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo4i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo4i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo8i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo8i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo16i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo16i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo32i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo32i8;
+ break;
+ case arm_gemm::WeightFormat::OHWIo64i8:
+ acl_weight_fromat = arm_compute::WeightFormat::OHWIo64i8;
+ break;
+ default:
+ acl_weight_fromat = arm_compute::WeightFormat::UNSPECIFIED;
+ }
+ return acl_weight_fromat;
+}
} // namespace assembly_utils
} // namespace arm_compute