aboutsummaryrefslogtreecommitdiff
path: root/ethosu/mlw_codec/mlw_codecmodule.c
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/mlw_codec/mlw_codecmodule.c')
-rw-r--r--ethosu/mlw_codec/mlw_codecmodule.c31
1 files changed, 31 insertions, 0 deletions
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c
index 1f172ee..5d37302 100644
--- a/ethosu/mlw_codec/mlw_codecmodule.c
+++ b/ethosu/mlw_codec/mlw_codecmodule.c
@@ -91,6 +91,7 @@ method_reorder_encode (PyObject *self, PyObject *args)
if ((int)PyArray_NDIM(input_ndarray_object) < 4)
{
PyErr_SetString(PyExc_ValueError, "Invalid input shape");
+ Py_DECREF(input_ndarray_object);
return NULL;
}
@@ -99,6 +100,34 @@ method_reorder_encode (PyObject *self, PyObject *args)
int kernel_width = (int)PyArray_DIM(input_ndarray_object, 2);
int ifm_depth = (int)PyArray_DIM(input_ndarray_object, 3);
+ if (ofm_depth < 1)
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid output depth");
+ Py_DECREF(input_ndarray_object);
+ return NULL;
+ }
+
+ if (ifm_depth < 1)
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid input depth");
+ Py_DECREF(input_ndarray_object);
+ return NULL;
+ }
+
+ if (kernel_height < 1)
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid kernel height");
+ Py_DECREF(input_ndarray_object);
+ return NULL;
+ }
+
+ if (kernel_width < 1)
+ {
+ PyErr_SetString(PyExc_ValueError, "Invalid kernel width");
+ Py_DECREF(input_ndarray_object);
+ return NULL;
+ }
+
int16_t* brick_weights = (int16_t*)PyArray_DATA(input_ndarray_object);
int brick_strides[4];
for (int i = 0; i < 4; i++)
@@ -107,6 +136,7 @@ method_reorder_encode (PyObject *self, PyObject *args)
if (stride % sizeof(int16_t))
{
PyErr_SetString(PyExc_ValueError, "Invalid stride");
+ Py_DECREF(input_ndarray_object);
return NULL;
}
brick_strides[i] = stride / sizeof(int16_t);
@@ -114,6 +144,7 @@ method_reorder_encode (PyObject *self, PyObject *args)
if ((unsigned)PyArray_ITEMSIZE(input_ndarray_object) != sizeof(int16_t))
{
PyErr_SetString(PyExc_ValueError, "Invalid input type");
+ Py_DECREF(input_ndarray_object);
return NULL;
}
uint8_t* output_buffer = NULL;