From 67e11f7bce40d72e0dda97cf658a3c3ee600c1eb Mon Sep 17 00:00:00 2001 From: Mauricio Briceno Date: Wed, 5 May 2021 12:47:28 +0200 Subject: weight_compressor: added mlw_reorder_encode - Moves reordering to C - Runtime is greatly minimized for encoding weights Change-Id: Ifff01e7b1ea6d5cec68310a155c3b80aa1a38545 Signed-off-by: Mauricio Briceno --- ethosu/mlw_codec/mlw_encode.c | 267 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 266 insertions(+), 1 deletion(-) (limited to 'ethosu/mlw_codec/mlw_encode.c') diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c index 04afa3ee..62e8360e 100644 --- a/ethosu/mlw_codec/mlw_encode.c +++ b/ethosu/mlw_codec/mlw_encode.c @@ -819,12 +819,13 @@ static int encode_section( const int16_t *inbuf, // Encode the given weight stream // inbuf uncompressed 9bit signed weights // inbuf_size number of weights -// outbuf compressed bitstream, buffer is malloced +// outbuf compressed bitstream, buffer is malloced within this function // verbose if non-zero, printf log // Return value is the size in bytes of the compressed output // Return -1 if error int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { int i; +#ifndef NDEBUG // Range check for(i=0; i255) { @@ -832,8 +833,10 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { return -1; } } +#endif int bitbuf_size = inbuf_size*2+1024; + assert(*outbuf == NULL); *outbuf = malloc( bitbuf_size ); // Analyse input data to find palette re-programming points @@ -882,3 +885,265 @@ void mlw_free_outbuf( uint8_t *outbuf ) { if (outbuf) free(outbuf); } + +static int round_up_divide(int num, int den) +{ + return (num + den - 1) / den; +} + +static int round_up(int num, int den) +{ + return round_up_divide(num, den) * den; +} + +static int get_weight_cnt( + int ifm_ublock_depth, + int ofm_ublock_depth, + int ofm_depth, + int kernel_height, + int kernel_width, + int ifm_depth, + int ofm_block_depth, + int is_depthwise, + int is_partkernel, + int ifm_bitdepth, + int decomp_h, + int decomp_w) +{ + int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32; + int subkernel_elements = decomp_w * decomp_h; + if (is_partkernel) + { + if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0) + { + subkernel_elements = round_up(subkernel_elements, 2); + } + else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0) + { + subkernel_elements = round_up(subkernel_elements, 4); + } + } + else if (is_depthwise) + { + subkernel_elements = round_up(subkernel_elements, 4); + } + int clipped_ifm_block_depth = is_depthwise ? ifm_ublock_depth : ifm_block_depth; + int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1; + int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth; + + int input_length = 1; + input_length *= is_depthwise ? 1 : ifm_ublock_depth; + input_length *= ofm_ublock_depth; + input_length *= round_up_divide(ifm_block_depth_inner, ifm_ublock_depth); + input_length *= subkernel_elements; + input_length *= round_up_divide(ofm_block_depth, ofm_ublock_depth); + input_length *= round_up_divide(ifm_block_depth_outer, ifm_ublock_depth); + input_length *= round_up_divide(kernel_width, decomp_w); + input_length *= round_up_divide(kernel_height, decomp_h); + input_length *= round_up_divide(is_depthwise ? 1 : ifm_depth, ifm_block_depth); + input_length *= round_up_divide(ofm_depth, ofm_block_depth); + + return input_length; +} + +struct brick_buf_s +{ + uint8_t* buf; + int* strides; +}; +typedef struct brick_buf_s brick_buf_t; + +static int16_t get_brick_weight(brick_buf_t* buf, int ofm_z, int wy, int wx, int ifm_z) +{ + uint8_t* p = buf->buf; + + p += ofm_z * buf->strides[0]; + p += wy * buf->strides[1]; + p += wx * buf->strides[2]; + p += ifm_z * buf->strides[3]; + + return *(int16_t*)p; +} + +static int reorder( + int ifm_ublock_depth, + int ofm_ublock_depth, + int ofm_depth, + int kernel_height, + int kernel_width, + int ifm_depth, + int* strides, + void* inbuf, + int ofm_block_depth, + int is_depthwise, + int is_partkernel, + int ifm_bitdepth, + int decomp_h, + int decomp_w, + int16_t* weights) +{ + brick_buf_t brick_buf; + brick_buf.buf = inbuf; + brick_buf.strides = strides; + + int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32; + int weight_cnt = 0; + for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth) + { + int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z); + // IFM blocks required for the brick + for (int ifm_block_z = 0; ifm_block_z < (is_depthwise ? 1 : ifm_depth); ifm_block_z += ifm_block_depth) + { + int clipped_ifm_block_depth; + if (is_depthwise) + { + clipped_ifm_block_depth = ifm_ublock_depth; + } + else + { + clipped_ifm_block_depth = is_partkernel ? + min(ifm_block_depth, ifm_depth - ifm_block_z) : ifm_block_depth; + } + // Weight decomposition + // Subkernel Splitting (H) + for (int subkernel_y = 0; subkernel_y < kernel_height; subkernel_y += decomp_h) + { + int sub_height = min(kernel_height - subkernel_y, decomp_h); + // Subkernel splitting (W) + for (int subkernel_x = 0; subkernel_x < kernel_width; subkernel_x += decomp_w) + { + int sub_width = min(kernel_width - subkernel_x, decomp_w); + int subkernel_elements = sub_width * sub_height; + // Part kernel first works across the kernel H/W and needs padding + if (is_partkernel) + { + if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0) + { + subkernel_elements = round_up(subkernel_elements, 2); + } + else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0) + { + subkernel_elements = round_up(subkernel_elements, 4); + } + } + else if (is_depthwise) + { + subkernel_elements = round_up(subkernel_elements, 4); + } + int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1; + int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth; + for (int ifm_ublk_outer = 0; ifm_ublk_outer < ifm_block_depth_outer; ifm_ublk_outer += ifm_ublock_depth) + { + // OFM Ublocks in OFM-block over depth + for (int ofm_ublk = 0; ofm_ublk < clipped_ofm_block_depth; ofm_ublk += ofm_ublock_depth) + { + // HW Kernel element traversal - cannot be a H/W loop due to element + // padding requirement on depthwise/part-kernel configurations + for (int element = 0; element < subkernel_elements; element++) + { + int kx = element % sub_width; + int ky = element / sub_width; + // IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise) + // In case of part-kernel-first IFM Ublock traversal have already been handled + // and this loop is ignored. + for (int ifm_ublk_inner = 0; ifm_ublk_inner < ifm_block_depth_inner; ifm_ublk_inner += ifm_ublock_depth) + { + // Feed OFM ublock elements + for (int ofm_ublock_z = 0; ofm_ublock_z < ofm_ublock_depth; ofm_ublock_z++) + { + // Source IFM ublock elements (only 1 element deep if depthwise) + for (int ifm_ublock_z = 0; ifm_ublock_z < (is_depthwise ? 1 : ifm_ublock_depth); ifm_ublock_z++) + { + // Source position within the current subkernel + int wx = subkernel_x + kx; + int wy = subkernel_y + ky; + // Source IFM/OFM slices + int ifm_ublk = ifm_ublk_inner + ifm_ublk_outer; + int ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z; + int ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z; + if ((ifm_z < ifm_depth) && (ofm_z < ofm_depth) && (ky < sub_height)) + { + weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z); + } + weight_cnt++; + } + } + } + } + } + } + } + } + } + } + + return weight_cnt; +} + +// Reorder and encode the given weight stream +// Return value is the size in bytes of the compressed output +// Return -1 if error +int mlw_reorder_encode( + int ifm_ublock_depth, + int ofm_ublock_depth, + int ofm_depth, + int kernel_height, + int kernel_width, + int ifm_depth, + int* brick_strides, + void* inbuf, + int ofm_block_depth, + int is_depthwise, + int is_partkernel, + int ifm_bitdepth, + int decomp_h, + int decomp_w, + uint8_t **outbuf, // *outbuf must be freed by caller + int* padded_length, + int verbose) +{ + /* Get an upper bound of the weight count */ + int input_length = get_weight_cnt( + ifm_ublock_depth, + ofm_ublock_depth, + ofm_depth, + kernel_height, + kernel_width, + ifm_depth, + ofm_block_depth, + is_depthwise, + is_partkernel, + ifm_bitdepth, + decomp_h, + decomp_w); + + int16_t* weights = (int16_t*)calloc(input_length, sizeof(int16_t)); + if (weights == NULL) + { + return 0; + } + + /* Reorder weights and update input_length */ + input_length = reorder( + ifm_ublock_depth, + ofm_ublock_depth, + ofm_depth, + kernel_height, + kernel_width, + ifm_depth, + brick_strides, + inbuf, + ofm_block_depth, + is_depthwise, + is_partkernel, + ifm_bitdepth, + decomp_h, + decomp_w, + weights); + + int output_length = mlw_encode(weights, input_length, outbuf, verbose); + free(weights); + *padded_length = input_length; + + return output_length; +} -- cgit v1.2.1