aboutsummaryrefslogtreecommitdiff
path: root/ethosu/mlw_codec
diff options
context:
space:
mode:
authorMauricio Briceno <mauricio.briceno@arm.com>2021-05-05 12:47:28 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2021-05-07 13:16:11 +0000
commit67e11f7bce40d72e0dda97cf658a3c3ee600c1eb (patch)
treeb9b281a07b352fde25161002034770cfde39f115 /ethosu/mlw_codec
parentc875aa6fdd8740f759305ff0fec9917977d019f0 (diff)
downloadethos-u-vela-67e11f7bce40d72e0dda97cf658a3c3ee600c1eb.tar.gz
weight_compressor: added mlw_reorder_encode3.0.0.rc1
- Moves reordering to C - Runtime is greatly minimized for encoding weights Change-Id: Ifff01e7b1ea6d5cec68310a155c3b80aa1a38545 Signed-off-by: Mauricio Briceno <mauricio.briceno@arm.com>
Diffstat (limited to 'ethosu/mlw_codec')
-rw-r--r--ethosu/mlw_codec/mlw_codecmodule.c145
-rw-r--r--ethosu/mlw_codec/mlw_encode.c267
2 files changed, 401 insertions, 11 deletions
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c
index 2c2fba2c..1e13dd22 100644
--- a/ethosu/mlw_codec/mlw_codecmodule.c
+++ b/ethosu/mlw_codec/mlw_codecmodule.c
@@ -18,10 +18,137 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
+#include <numpy/ndarrayobject.h>
#include "mlw_decode.h"
#include "mlw_encode.h"
+/* C extension wrapper for mlw_reorder_encode
+ *
+ * This method is exposed directly in python with the arguments with a
+ * prototype of the form:
+ *
+ * output = mlw_codec.reorder_encode(
+ * ifm_ublock_depth,
+ * ofm_ublock_depth,
+ * input,
+ * ofm_block_depth,
+ * is_depthwise,
+ * is_partkernel,
+ * ifm_bitdepth,
+ * decomp_h,
+ * decomp_w,
+ * verbose=0)
+ *
+ * output: bytearray
+ */
+
+static PyObject *
+method_reorder_encode (PyObject *self, PyObject *args)
+{
+ /* Object to hold the input integer list. */
+ int ifm_ublock_depth;
+ int ofm_ublock_depth;
+ PyObject *input_object;
+ int ofm_block_depth;
+ int is_depthwise;
+ int is_partkernel;
+ int ifm_bitdepth;
+ int decomp_h;
+ int decomp_w;
+
+ /* Object to hold the input verbosity integer, the verbose argument
+ * is optional so defaulted to 0.
+ */
+ int verbose = 0;
+
+ /* Arguments to the method are delivered as a tuple, unpack the
+ * tuple to get the individual arguments, note the second is
+ * optional.
+ */
+ if (!PyArg_ParseTuple(args, "iiOiiiiii|i",
+ &ifm_ublock_depth,
+ &ofm_ublock_depth,
+ &input_object,
+ &ofm_block_depth,
+ &is_depthwise,
+ &is_partkernel,
+ &ifm_bitdepth,
+ &decomp_h,
+ &decomp_w,
+ &verbose))
+ return NULL;
+
+ PyArrayObject* input_ndarray_object = PyArray_FROM_OTF(
+ input_object,
+ NPY_INT64,
+ NPY_ARRAY_ALIGNED);
+ if (input_ndarray_object == NULL)
+ {
+ return NULL;
+ }
+
+ if ((int)PyArray_NDIM(input_ndarray_object) < 4)
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid input shape");
+ return NULL;
+ }
+
+ int ofm_depth = (int)PyArray_DIM(input_ndarray_object, 0);
+ int kernel_height = (int)PyArray_DIM(input_ndarray_object, 1);
+ int kernel_width = (int)PyArray_DIM(input_ndarray_object, 2);
+ int ifm_depth = (int)PyArray_DIM(input_ndarray_object, 3);
+
+ int64_t* brick_weights = (int64_t*)PyArray_DATA(input_ndarray_object);
+ int brick_strides[4];
+ for (int i = 0; i < 4; i++)
+ {
+ brick_strides[i] = (int)PyArray_STRIDE(input_ndarray_object, i);
+ }
+ if ((unsigned)PyArray_ITEMSIZE(input_ndarray_object) != sizeof(int64_t))
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid input type");
+ return NULL;
+ }
+ uint8_t* output_buffer = NULL;
+ int padded_length;
+
+ int output_length = mlw_reorder_encode(
+ ifm_ublock_depth,
+ ofm_ublock_depth,
+ ofm_depth,
+ kernel_height,
+ kernel_width,
+ ifm_depth,
+ brick_strides,
+ brick_weights,
+ ofm_block_depth,
+ is_depthwise,
+ is_partkernel,
+ ifm_bitdepth,
+ decomp_h,
+ decomp_w,
+ &output_buffer,
+ &padded_length,
+ verbose);
+
+ if (output_buffer == NULL)
+ {
+ return PyErr_NoMemory();
+ }
+
+ PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
+ PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
+
+ /* Discard the output buffer */
+ mlw_free_outbuf(output_buffer);
+
+ PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
+ Py_DECREF(output_byte_array);
+ Py_DECREF(padded_length_obj);
+ return ret;
+}
+
/* C extension wrapper for mlw_encode
*
* This method is exposed directly in python with the arguments with a
@@ -63,6 +190,7 @@ method_encode (PyObject *self, PyObject *args)
* for that purpose.
*/
int16_t *input_buffer = (int16_t *) malloc(sizeof(int16_t *) * input_length);
+ uint8_t *output_buffer = NULL;
if (input_buffer == NULL)
return PyErr_NoMemory();
@@ -84,20 +212,13 @@ method_encode (PyObject *self, PyObject *args)
return NULL;
}
- /* We don't know the output length required, we guess worst case,
- * the mlw_encode call will do a resize (downwards) anyway.
- */
- uint8_t *output_buffer = (uint8_t *) malloc(input_length);
- if (output_buffer == NULL)
- return PyErr_NoMemory();
-
int output_length = mlw_encode(input_buffer, input_length, &output_buffer, verbose);
PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length);
/* Discard the temporary input and output buffers. */
free (input_buffer);
- free (output_buffer);
+ mlw_free_outbuf(output_buffer);
return output_byte_array;
}
@@ -163,6 +284,7 @@ method_decode(PyObject *self, PyObject *args)
static PyMethodDef mlw_methods[] = {
{"decode", method_decode, METH_VARARGS, "Python interface for decode"},
{"encode", method_encode, METH_VARARGS, "Python interface for encode"},
+ {"reorder_encode", method_reorder_encode, METH_VARARGS, "Python interface for reorder and encode"},
{NULL, NULL, 0, NULL}
};
@@ -177,6 +299,9 @@ static struct PyModuleDef mlw_codecmodule = {
mlw_methods
};
-PyMODINIT_FUNC PyInit_mlw_codec(void) {
- return PyModule_Create(&mlw_codecmodule);
+PyMODINIT_FUNC PyInit_mlw_codec(void)
+{
+ PyObject* ret = PyModule_Create(&mlw_codecmodule);
+ import_array();
+ return ret;
}
diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c
index 04afa3ee..62e8360e 100644
--- a/ethosu/mlw_codec/mlw_encode.c
+++ b/ethosu/mlw_codec/mlw_encode.c
@@ -819,12 +819,13 @@ static int encode_section( const int16_t *inbuf,
// Encode the given weight stream
// inbuf uncompressed 9bit signed weights
// inbuf_size number of weights
-// outbuf compressed bitstream, buffer is malloced
+// outbuf compressed bitstream, buffer is malloced within this function
// verbose if non-zero, printf log
// Return value is the size in bytes of the compressed output
// Return -1 if error
int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
int i;
+#ifndef NDEBUG
// Range check
for(i=0; i<inbuf_size; i++) {
if (inbuf[i]<-255 || inbuf[i]>255) {
@@ -832,8 +833,10 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
return -1;
}
}
+#endif
int bitbuf_size = inbuf_size*2+1024;
+ assert(*outbuf == NULL);
*outbuf = malloc( bitbuf_size );
// Analyse input data to find palette re-programming points
@@ -882,3 +885,265 @@ void mlw_free_outbuf( uint8_t *outbuf ) {
if (outbuf)
free(outbuf);
}
+
+static int round_up_divide(int num, int den)
+{
+ return (num + den - 1) / den;
+}
+
+static int round_up(int num, int den)
+{
+ return round_up_divide(num, den) * den;
+}
+
+static int get_weight_cnt(
+ int ifm_ublock_depth,
+ int ofm_ublock_depth,
+ int ofm_depth,
+ int kernel_height,
+ int kernel_width,
+ int ifm_depth,
+ int ofm_block_depth,
+ int is_depthwise,
+ int is_partkernel,
+ int ifm_bitdepth,
+ int decomp_h,
+ int decomp_w)
+{
+ int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
+ int subkernel_elements = decomp_w * decomp_h;
+ if (is_partkernel)
+ {
+ if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
+ {
+ subkernel_elements = round_up(subkernel_elements, 2);
+ }
+ else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
+ {
+ subkernel_elements = round_up(subkernel_elements, 4);
+ }
+ }
+ else if (is_depthwise)
+ {
+ subkernel_elements = round_up(subkernel_elements, 4);
+ }
+ int clipped_ifm_block_depth = is_depthwise ? ifm_ublock_depth : ifm_block_depth;
+ int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
+ int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
+
+ int input_length = 1;
+ input_length *= is_depthwise ? 1 : ifm_ublock_depth;
+ input_length *= ofm_ublock_depth;
+ input_length *= round_up_divide(ifm_block_depth_inner, ifm_ublock_depth);
+ input_length *= subkernel_elements;
+ input_length *= round_up_divide(ofm_block_depth, ofm_ublock_depth);
+ input_length *= round_up_divide(ifm_block_depth_outer, ifm_ublock_depth);
+ input_length *= round_up_divide(kernel_width, decomp_w);
+ input_length *= round_up_divide(kernel_height, decomp_h);
+ input_length *= round_up_divide(is_depthwise ? 1 : ifm_depth, ifm_block_depth);
+ input_length *= round_up_divide(ofm_depth, ofm_block_depth);
+
+ return input_length;
+}
+
+struct brick_buf_s
+{
+ uint8_t* buf;
+ int* strides;
+};
+typedef struct brick_buf_s brick_buf_t;
+
+static int16_t get_brick_weight(brick_buf_t* buf, int ofm_z, int wy, int wx, int ifm_z)
+{
+ uint8_t* p = buf->buf;
+
+ p += ofm_z * buf->strides[0];
+ p += wy * buf->strides[1];
+ p += wx * buf->strides[2];
+ p += ifm_z * buf->strides[3];
+
+ return *(int16_t*)p;
+}
+
+static int reorder(
+ int ifm_ublock_depth,
+ int ofm_ublock_depth,
+ int ofm_depth,
+ int kernel_height,
+ int kernel_width,
+ int ifm_depth,
+ int* strides,
+ void* inbuf,
+ int ofm_block_depth,
+ int is_depthwise,
+ int is_partkernel,
+ int ifm_bitdepth,
+ int decomp_h,
+ int decomp_w,
+ int16_t* weights)
+{
+ brick_buf_t brick_buf;
+ brick_buf.buf = inbuf;
+ brick_buf.strides = strides;
+
+ int ifm_block_depth = is_partkernel || ifm_bitdepth == 16 ? 16 : 32;
+ int weight_cnt = 0;
+ for (int ofm_block_z = 0; ofm_block_z < ofm_depth; ofm_block_z += ofm_block_depth)
+ {
+ int clipped_ofm_block_depth = min(ofm_block_depth, ofm_depth - ofm_block_z);
+ // IFM blocks required for the brick
+ for (int ifm_block_z = 0; ifm_block_z < (is_depthwise ? 1 : ifm_depth); ifm_block_z += ifm_block_depth)
+ {
+ int clipped_ifm_block_depth;
+ if (is_depthwise)
+ {
+ clipped_ifm_block_depth = ifm_ublock_depth;
+ }
+ else
+ {
+ clipped_ifm_block_depth = is_partkernel ?
+ min(ifm_block_depth, ifm_depth - ifm_block_z) : ifm_block_depth;
+ }
+ // Weight decomposition
+ // Subkernel Splitting (H)
+ for (int subkernel_y = 0; subkernel_y < kernel_height; subkernel_y += decomp_h)
+ {
+ int sub_height = min(kernel_height - subkernel_y, decomp_h);
+ // Subkernel splitting (W)
+ for (int subkernel_x = 0; subkernel_x < kernel_width; subkernel_x += decomp_w)
+ {
+ int sub_width = min(kernel_width - subkernel_x, decomp_w);
+ int subkernel_elements = sub_width * sub_height;
+ // Part kernel first works across the kernel H/W and needs padding
+ if (is_partkernel)
+ {
+ if (ifm_bitdepth == 16 && subkernel_elements % 2 != 0)
+ {
+ subkernel_elements = round_up(subkernel_elements, 2);
+ }
+ else if (ifm_bitdepth == 8 && subkernel_elements % 4 != 0)
+ {
+ subkernel_elements = round_up(subkernel_elements, 4);
+ }
+ }
+ else if (is_depthwise)
+ {
+ subkernel_elements = round_up(subkernel_elements, 4);
+ }
+ int ifm_block_depth_outer = is_partkernel ? clipped_ifm_block_depth : 1;
+ int ifm_block_depth_inner = is_partkernel ? 1 : clipped_ifm_block_depth;
+ for (int ifm_ublk_outer = 0; ifm_ublk_outer < ifm_block_depth_outer; ifm_ublk_outer += ifm_ublock_depth)
+ {
+ // OFM Ublocks in OFM-block over depth
+ for (int ofm_ublk = 0; ofm_ublk < clipped_ofm_block_depth; ofm_ublk += ofm_ublock_depth)
+ {
+ // HW Kernel element traversal - cannot be a H/W loop due to element
+ // padding requirement on depthwise/part-kernel configurations
+ for (int element = 0; element < subkernel_elements; element++)
+ {
+ int kx = element % sub_width;
+ int ky = element / sub_width;
+ // IFM Ublocks in IFM-block over depth (only 1 ublock if depthwise)
+ // In case of part-kernel-first IFM Ublock traversal have already been handled
+ // and this loop is ignored.
+ for (int ifm_ublk_inner = 0; ifm_ublk_inner < ifm_block_depth_inner; ifm_ublk_inner += ifm_ublock_depth)
+ {
+ // Feed OFM ublock elements
+ for (int ofm_ublock_z = 0; ofm_ublock_z < ofm_ublock_depth; ofm_ublock_z++)
+ {
+ // Source IFM ublock elements (only 1 element deep if depthwise)
+ for (int ifm_ublock_z = 0; ifm_ublock_z < (is_depthwise ? 1 : ifm_ublock_depth); ifm_ublock_z++)
+ {
+ // Source position within the current subkernel
+ int wx = subkernel_x + kx;
+ int wy = subkernel_y + ky;
+ // Source IFM/OFM slices
+ int ifm_ublk = ifm_ublk_inner + ifm_ublk_outer;
+ int ifm_z = ifm_block_z + ifm_ublk + ifm_ublock_z;
+ int ofm_z = ofm_block_z + ofm_ublk + ofm_ublock_z;
+ if ((ifm_z < ifm_depth) && (ofm_z < ofm_depth) && (ky < sub_height))
+ {
+ weights[weight_cnt] = get_brick_weight(&brick_buf, ofm_z, wy, wx, ifm_z);
+ }
+ weight_cnt++;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return weight_cnt;
+}
+
+// Reorder and encode the given weight stream
+// Return value is the size in bytes of the compressed output
+// Return -1 if error
+int mlw_reorder_encode(
+ int ifm_ublock_depth,
+ int ofm_ublock_depth,
+ int ofm_depth,
+ int kernel_height,
+ int kernel_width,
+ int ifm_depth,
+ int* brick_strides,
+ void* inbuf,
+ int ofm_block_depth,
+ int is_depthwise,
+ int is_partkernel,
+ int ifm_bitdepth,
+ int decomp_h,
+ int decomp_w,
+ uint8_t **outbuf, // *outbuf must be freed by caller
+ int* padded_length,
+ int verbose)
+{
+ /* Get an upper bound of the weight count */
+ int input_length = get_weight_cnt(
+ ifm_ublock_depth,
+ ofm_ublock_depth,
+ ofm_depth,
+ kernel_height,
+ kernel_width,
+ ifm_depth,
+ ofm_block_depth,
+ is_depthwise,
+ is_partkernel,
+ ifm_bitdepth,
+ decomp_h,
+ decomp_w);
+
+ int16_t* weights = (int16_t*)calloc(input_length, sizeof(int16_t));
+ if (weights == NULL)
+ {
+ return 0;
+ }
+
+ /* Reorder weights and update input_length */
+ input_length = reorder(
+ ifm_ublock_depth,
+ ofm_ublock_depth,
+ ofm_depth,
+ kernel_height,
+ kernel_width,
+ ifm_depth,
+ brick_strides,
+ inbuf,
+ ofm_block_depth,
+ is_depthwise,
+ is_partkernel,
+ ifm_bitdepth,
+ decomp_h,
+ decomp_w,
+ weights);
+
+ int output_length = mlw_encode(weights, input_length, outbuf, verbose);
+ free(weights);
+ *padded_length = input_length;
+
+ return output_length;
+}