diff options
Diffstat (limited to 'ethosu/mlw_codec/mlw_codecmodule.c')
-rw-r--r-- | ethosu/mlw_codec/mlw_codecmodule.c | 145 |
1 files changed, 135 insertions, 10 deletions
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c index 2c2fba2c..1e13dd22 100644 --- a/ethosu/mlw_codec/mlw_codecmodule.c +++ b/ethosu/mlw_codec/mlw_codecmodule.c @@ -18,10 +18,137 @@ #define PY_SSIZE_T_CLEAN #include <Python.h> +#include <numpy/ndarrayobject.h> #include "mlw_decode.h" #include "mlw_encode.h" +/* C extension wrapper for mlw_reorder_encode + * + * This method is exposed directly in python with the arguments with a + * prototype of the form: + * + * output = mlw_codec.reorder_encode( + * ifm_ublock_depth, + * ofm_ublock_depth, + * input, + * ofm_block_depth, + * is_depthwise, + * is_partkernel, + * ifm_bitdepth, + * decomp_h, + * decomp_w, + * verbose=0) + * + * output: bytearray + */ + +static PyObject * +method_reorder_encode (PyObject *self, PyObject *args) +{ + /* Object to hold the input integer list. */ + int ifm_ublock_depth; + int ofm_ublock_depth; + PyObject *input_object; + int ofm_block_depth; + int is_depthwise; + int is_partkernel; + int ifm_bitdepth; + int decomp_h; + int decomp_w; + + /* Object to hold the input verbosity integer, the verbose argument + * is optional so defaulted to 0. + */ + int verbose = 0; + + /* Arguments to the method are delivered as a tuple, unpack the + * tuple to get the individual arguments, note the second is + * optional. + */ + if (!PyArg_ParseTuple(args, "iiOiiiiii|i", + &ifm_ublock_depth, + &ofm_ublock_depth, + &input_object, + &ofm_block_depth, + &is_depthwise, + &is_partkernel, + &ifm_bitdepth, + &decomp_h, + &decomp_w, + &verbose)) + return NULL; + + PyArrayObject* input_ndarray_object = PyArray_FROM_OTF( + input_object, + NPY_INT64, + NPY_ARRAY_ALIGNED); + if (input_ndarray_object == NULL) + { + return NULL; + } + + if ((int)PyArray_NDIM(input_ndarray_object) < 4) + { + PyErr_SetString(PyExc_ValueError, "Invalid input shape"); + return NULL; + } + + int ofm_depth = (int)PyArray_DIM(input_ndarray_object, 0); + int kernel_height = (int)PyArray_DIM(input_ndarray_object, 1); + int kernel_width = (int)PyArray_DIM(input_ndarray_object, 2); + int ifm_depth = (int)PyArray_DIM(input_ndarray_object, 3); + + int64_t* brick_weights = (int64_t*)PyArray_DATA(input_ndarray_object); + int brick_strides[4]; + for (int i = 0; i < 4; i++) + { + brick_strides[i] = (int)PyArray_STRIDE(input_ndarray_object, i); + } + if ((unsigned)PyArray_ITEMSIZE(input_ndarray_object) != sizeof(int64_t)) + { + PyErr_SetString(PyExc_ValueError, "Invalid input type"); + return NULL; + } + uint8_t* output_buffer = NULL; + int padded_length; + + int output_length = mlw_reorder_encode( + ifm_ublock_depth, + ofm_ublock_depth, + ofm_depth, + kernel_height, + kernel_width, + ifm_depth, + brick_strides, + brick_weights, + ofm_block_depth, + is_depthwise, + is_partkernel, + ifm_bitdepth, + decomp_h, + decomp_w, + &output_buffer, + &padded_length, + verbose); + + if (output_buffer == NULL) + { + return PyErr_NoMemory(); + } + + PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length); + PyObject *padded_length_obj = Py_BuildValue("i", padded_length); + + /* Discard the output buffer */ + mlw_free_outbuf(output_buffer); + + PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj); + Py_DECREF(output_byte_array); + Py_DECREF(padded_length_obj); + return ret; +} + /* C extension wrapper for mlw_encode * * This method is exposed directly in python with the arguments with a @@ -63,6 +190,7 @@ method_encode (PyObject *self, PyObject *args) * for that purpose. */ int16_t *input_buffer = (int16_t *) malloc(sizeof(int16_t *) * input_length); + uint8_t *output_buffer = NULL; if (input_buffer == NULL) return PyErr_NoMemory(); @@ -84,20 +212,13 @@ method_encode (PyObject *self, PyObject *args) return NULL; } - /* We don't know the output length required, we guess worst case, - * the mlw_encode call will do a resize (downwards) anyway. - */ - uint8_t *output_buffer = (uint8_t *) malloc(input_length); - if (output_buffer == NULL) - return PyErr_NoMemory(); - int output_length = mlw_encode(input_buffer, input_length, &output_buffer, verbose); PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length); /* Discard the temporary input and output buffers. */ free (input_buffer); - free (output_buffer); + mlw_free_outbuf(output_buffer); return output_byte_array; } @@ -163,6 +284,7 @@ method_decode(PyObject *self, PyObject *args) static PyMethodDef mlw_methods[] = { {"decode", method_decode, METH_VARARGS, "Python interface for decode"}, {"encode", method_encode, METH_VARARGS, "Python interface for encode"}, + {"reorder_encode", method_reorder_encode, METH_VARARGS, "Python interface for reorder and encode"}, {NULL, NULL, 0, NULL} }; @@ -177,6 +299,9 @@ static struct PyModuleDef mlw_codecmodule = { mlw_methods }; -PyMODINIT_FUNC PyInit_mlw_codec(void) { - return PyModule_Create(&mlw_codecmodule); +PyMODINIT_FUNC PyInit_mlw_codec(void) +{ + PyObject* ret = PyModule_Create(&mlw_codecmodule); + import_array(); + return ret; } |