aboutsummaryrefslogtreecommitdiff
path: root/ethosu/mlw_codec/mlw_codecmodule.c
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/mlw_codec/mlw_codecmodule.c')
-rw-r--r--ethosu/mlw_codec/mlw_codecmodule.c145
1 files changed, 135 insertions, 10 deletions
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c
index 2c2fba2c..1e13dd22 100644
--- a/ethosu/mlw_codec/mlw_codecmodule.c
+++ b/ethosu/mlw_codec/mlw_codecmodule.c
@@ -18,10 +18,137 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>
+#include <numpy/ndarrayobject.h>
#include "mlw_decode.h"
#include "mlw_encode.h"
+/* C extension wrapper for mlw_reorder_encode
+ *
+ * This method is exposed directly in python with the arguments with a
+ * prototype of the form:
+ *
+ * output = mlw_codec.reorder_encode(
+ * ifm_ublock_depth,
+ * ofm_ublock_depth,
+ * input,
+ * ofm_block_depth,
+ * is_depthwise,
+ * is_partkernel,
+ * ifm_bitdepth,
+ * decomp_h,
+ * decomp_w,
+ * verbose=0)
+ *
+ * output: bytearray
+ */
+
+static PyObject *
+method_reorder_encode (PyObject *self, PyObject *args)
+{
+ /* Object to hold the input integer list. */
+ int ifm_ublock_depth;
+ int ofm_ublock_depth;
+ PyObject *input_object;
+ int ofm_block_depth;
+ int is_depthwise;
+ int is_partkernel;
+ int ifm_bitdepth;
+ int decomp_h;
+ int decomp_w;
+
+ /* Object to hold the input verbosity integer, the verbose argument
+ * is optional so defaulted to 0.
+ */
+ int verbose = 0;
+
+ /* Arguments to the method are delivered as a tuple, unpack the
+ * tuple to get the individual arguments, note the second is
+ * optional.
+ */
+ if (!PyArg_ParseTuple(args, "iiOiiiiii|i",
+ &ifm_ublock_depth,
+ &ofm_ublock_depth,
+ &input_object,
+ &ofm_block_depth,
+ &is_depthwise,
+ &is_partkernel,
+ &ifm_bitdepth,
+ &decomp_h,
+ &decomp_w,
+ &verbose))
+ return NULL;
+
+ PyArrayObject* input_ndarray_object = PyArray_FROM_OTF(
+ input_object,
+ NPY_INT64,
+ NPY_ARRAY_ALIGNED);
+ if (input_ndarray_object == NULL)
+ {
+ return NULL;
+ }
+
+ if ((int)PyArray_NDIM(input_ndarray_object) < 4)
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid input shape");
+ return NULL;
+ }
+
+ int ofm_depth = (int)PyArray_DIM(input_ndarray_object, 0);
+ int kernel_height = (int)PyArray_DIM(input_ndarray_object, 1);
+ int kernel_width = (int)PyArray_DIM(input_ndarray_object, 2);
+ int ifm_depth = (int)PyArray_DIM(input_ndarray_object, 3);
+
+ int64_t* brick_weights = (int64_t*)PyArray_DATA(input_ndarray_object);
+ int brick_strides[4];
+ for (int i = 0; i < 4; i++)
+ {
+ brick_strides[i] = (int)PyArray_STRIDE(input_ndarray_object, i);
+ }
+ if ((unsigned)PyArray_ITEMSIZE(input_ndarray_object) != sizeof(int64_t))
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid input type");
+ return NULL;
+ }
+ uint8_t* output_buffer = NULL;
+ int padded_length;
+
+ int output_length = mlw_reorder_encode(
+ ifm_ublock_depth,
+ ofm_ublock_depth,
+ ofm_depth,
+ kernel_height,
+ kernel_width,
+ ifm_depth,
+ brick_strides,
+ brick_weights,
+ ofm_block_depth,
+ is_depthwise,
+ is_partkernel,
+ ifm_bitdepth,
+ decomp_h,
+ decomp_w,
+ &output_buffer,
+ &padded_length,
+ verbose);
+
+ if (output_buffer == NULL)
+ {
+ return PyErr_NoMemory();
+ }
+
+ PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
+ PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
+
+ /* Discard the output buffer */
+ mlw_free_outbuf(output_buffer);
+
+ PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
+ Py_DECREF(output_byte_array);
+ Py_DECREF(padded_length_obj);
+ return ret;
+}
+
/* C extension wrapper for mlw_encode
*
* This method is exposed directly in python with the arguments with a
@@ -63,6 +190,7 @@ method_encode (PyObject *self, PyObject *args)
* for that purpose.
*/
int16_t *input_buffer = (int16_t *) malloc(sizeof(int16_t *) * input_length);
+ uint8_t *output_buffer = NULL;
if (input_buffer == NULL)
return PyErr_NoMemory();
@@ -84,20 +212,13 @@ method_encode (PyObject *self, PyObject *args)
return NULL;
}
- /* We don't know the output length required, we guess worst case,
- * the mlw_encode call will do a resize (downwards) anyway.
- */
- uint8_t *output_buffer = (uint8_t *) malloc(input_length);
- if (output_buffer == NULL)
- return PyErr_NoMemory();
-
int output_length = mlw_encode(input_buffer, input_length, &output_buffer, verbose);
PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length);
/* Discard the temporary input and output buffers. */
free (input_buffer);
- free (output_buffer);
+ mlw_free_outbuf(output_buffer);
return output_byte_array;
}
@@ -163,6 +284,7 @@ method_decode(PyObject *self, PyObject *args)
static PyMethodDef mlw_methods[] = {
{"decode", method_decode, METH_VARARGS, "Python interface for decode"},
{"encode", method_encode, METH_VARARGS, "Python interface for encode"},
+ {"reorder_encode", method_reorder_encode, METH_VARARGS, "Python interface for reorder and encode"},
{NULL, NULL, 0, NULL}
};
@@ -177,6 +299,9 @@ static struct PyModuleDef mlw_codecmodule = {
mlw_methods
};
-PyMODINIT_FUNC PyInit_mlw_codec(void) {
- return PyModule_Create(&mlw_codecmodule);
+PyMODINIT_FUNC PyInit_mlw_codec(void)
+{
+ PyObject* ret = PyModule_Create(&mlw_codecmodule);
+ import_array();
+ return ret;
}