diff options
author | Francesco Petrogalli <francesco.petrogalli@arm.com> | 2022-06-30 10:22:01 +0000 |
---|---|---|
committer | Francesco Petrogalli <francesco.petrogalli@arm.com> | 2022-07-19 09:26:27 +0000 |
commit | 553f6953fe3bdfad53c11c25f305a16d79d83b24 (patch) | |
tree | 73642b948b79662096f593458c6138d2f7f48ec6 /src/cpu/kernels | |
parent | 99c46475daf277aa53e6747f9e41209f418fed33 (diff) | |
download | ComputeLibrary-553f6953fe3bdfad53c11c25f305a16d79d83b24.tar.gz |
[ONCPUML-951] Variable weight support for Convolution.
API changes for NEGEMMConvolutionLayer and CpuGemmConv2d
Built with:
scons neon=1 opencl=0 os=linux arch=armv8.2-a multi_isa=1 \
build=native -j32 Werror=false validation_tests=1 build_dir=opt \
standalone=1 asserts=1 experimental_fixed_format_kernels=1 .
Tested with:
./build/opt/tests/arm_compute_validation
Hardware where the test executable was run:
Neoverse N1
Test coverage:
* NEGEMMConvolutionLayer, CpuGemmConv2d
* NHWC (the only one supported by the fixed-format kernels)
* F16, F32
* Shapes: RunSmall
Change-Id: I4fd3e495a7cbf61210ea02d37440ba9652934e99
Signed-off-by: Francesco Petrogalli <francesco.petrogalli@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/7632
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gunes Bayir <gunes.bayir@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/kernels')
-rw-r--r-- | src/cpu/kernels/assembly/arm_gemm.hpp | 116 |
1 files changed, 115 insertions, 1 deletions
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp index 247cb1d470..48fd7c6b43 100644 --- a/src/cpu/kernels/assembly/arm_gemm.hpp +++ b/src/cpu/kernels/assembly/arm_gemm.hpp @@ -47,6 +47,57 @@ enum class GemmMethod GEMM_HYBRID_QUANTIZED }; +/** Memory layouts for the weights tensor. + * + * * UNSPECIFIED is used to select kernels that do not run in + * variable weights mode. + * + * * ANY is used to query the kernel database to retrieve any of the + * kernels that runs in variable weights mode. Once a kernel is + * found, the specific format expected by the kernel can be + * retrieved by the user for reordering the weights tensor + * accordingly. + * + * The other values OHWIo{interleave_by}i{block_by} describe the + * memory layout of a 4D tensor with layout OHWI that has been + * transformed into a 4D tensor with dimensions O'HWI' where: + * + * O' = first multiple of {interleave_by} s.t. O<=O' + * I' = first multiple of {block_by} s.t. I<=I' + * + * The total size of the dst tensor is O' x H x W x I' + * + * The access function of the tensor with layout + * OHWIo{interleave_by}i{block_by} and size O'HWI' is a 6-parameter + * access function, where the 6 parameters are computed as follows: + * + * x5 = floor(o/{interleave_by}) RANGE [0, O'/{interleave_by} -1] SIZE: O'/{interleave_by} + * + * x4 = h RANGE [0, H-1] SIZE: H + * x3 = w RANGE [0, W-1] SIZE: W + * x2 = floor(i/{block_by}) RANGE [0, I'/{block_by} -1] SIZE: I'/{block_by} + * x1 = o%{interleave_by} RANGE [0, {interleave_by} -1] SIZE: {interleave_by} + * x0 = i%{block_by} RANGE [0, {block_by} -1] SIZE: {block_by} + * TOTAL SIZE: O' * H * W * I' + * + * 4D 6D + * ----------------- ----------------------------------- + * value(o, h, w, i) = x5 * H * W * I' * {interleave_by} + * + x4 * W * I' * {interleave_by} + * + x3 * I' * {interleave_by} + * + x2 * {interleave_by} * {block_by} + * + x1 * {block_by} + * + x0 + * + * Notice that in arm_gemm the 4D tensor of dimension O'HWI' created + * for the OHWIo{interleave_by}i{block_by} format is in reality seen + * as a 2D tensor, where the number of rows is O'/{interleave_by} + * and the number of columns is {interleave_by} * H * W * I'. + * + * The postfix *_bf16 is for the memory layout needed for the + * fast-mode kernels, in which the weights are passed in bfloat16 + * format. + */ enum class WeightFormat { UNSPECIFIED = 0x1, @@ -87,6 +138,69 @@ enum class WeightFormat OHWIo64i8 = 0x804000 }; +// OHWIo<interleave_by>i<block_by> +inline int interleave_by(const WeightFormat wf) +{ + return ((int)wf >> 8) & 0xFFF; +} +inline int block_by(const WeightFormat wf) +{ + return ((int)wf >> 20) & 0xF; +} +inline bool is_fixed_format(const WeightFormat wf) +{ + return wf != WeightFormat::UNSPECIFIED && wf != WeightFormat::ANY; +} + +inline std::string to_string(WeightFormat wf) +{ +#define __CASE_WEIGHT_FORMAT(wf) \ +case WeightFormat::wf: \ + return #wf; + switch(wf) + { + __CASE_WEIGHT_FORMAT(UNSPECIFIED) + __CASE_WEIGHT_FORMAT(ANY) + __CASE_WEIGHT_FORMAT(OHWI) + __CASE_WEIGHT_FORMAT(OHWIo2) + __CASE_WEIGHT_FORMAT(OHWIo4) + __CASE_WEIGHT_FORMAT(OHWIo8) + __CASE_WEIGHT_FORMAT(OHWIo16) + __CASE_WEIGHT_FORMAT(OHWIo32) + __CASE_WEIGHT_FORMAT(OHWIo64) + __CASE_WEIGHT_FORMAT(OHWIo128) + __CASE_WEIGHT_FORMAT(OHWIo4i2) + __CASE_WEIGHT_FORMAT(OHWIo4i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i2) + __CASE_WEIGHT_FORMAT(OHWIo8i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i2) + __CASE_WEIGHT_FORMAT(OHWIo16i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i2) + __CASE_WEIGHT_FORMAT(OHWIo32i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i2) + __CASE_WEIGHT_FORMAT(OHWIo64i2_bf16) + __CASE_WEIGHT_FORMAT(OHWIo4i4) + __CASE_WEIGHT_FORMAT(OHWIo4i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo8i4) + __CASE_WEIGHT_FORMAT(OHWIo8i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo16i4) + __CASE_WEIGHT_FORMAT(OHWIo16i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo32i4) + __CASE_WEIGHT_FORMAT(OHWIo32i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo64i4) + __CASE_WEIGHT_FORMAT(OHWIo64i4_bf16) + __CASE_WEIGHT_FORMAT(OHWIo2i8) + __CASE_WEIGHT_FORMAT(OHWIo4i8) + __CASE_WEIGHT_FORMAT(OHWIo8i8) + __CASE_WEIGHT_FORMAT(OHWIo16i8) + __CASE_WEIGHT_FORMAT(OHWIo32i8) + __CASE_WEIGHT_FORMAT(OHWIo64i8) + default: + return "invalid value"; + } +#undef __CASE_WEIGHT_FORMAT +} + struct KernelDescription { GemmMethod method = GemmMethod::DEFAULT; @@ -230,6 +344,6 @@ template <typename Top, typename Tret, class OutputStage = Nothing> std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {}); template <typename Top, typename Tret, class OutputStage = Nothing> -bool has_opt_gemm(const GemmArgs &args, const OutputStage & = {}); +bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {}); } // namespace arm_gemm |