From 60232140a2927865d1f6f9bc48871df3b2bb135b Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Fri, 22 Jan 2021 14:11:15 +0100 Subject: MLBEDSW-3832: mlw_codec: improve C API - Removed unnecessary casts - Added more error handling Signed-off-by: Louis Verhaard Change-Id: I30cc37a2fb1e855b9f67599c280c1f383f0b059e --- ethosu/mlw_codec/mlw_codecmodule.c | 26 +++++++++++++++++--------- ethosu/mlw_codec/test/test_mlw_codec.py | 14 ++++++++++++++ 2 files changed, 31 insertions(+), 9 deletions(-) (limited to 'ethosu') diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c index 6dde12dc..2c2fba2c 100644 --- a/ethosu/mlw_codec/mlw_codecmodule.c +++ b/ethosu/mlw_codec/mlw_codecmodule.c @@ -53,9 +53,10 @@ method_encode (PyObject *self, PyObject *args) return NULL; /* Unpack the length of the input integer list. */ - int input_length = (int)PyObject_Length (input_list_object); - if (input_length < 0) - input_length = 0; + Py_ssize_t input_length = PyObject_Length (input_list_object); + if (input_length < 0) { + return NULL; + } /* We need to marshall the integer list into an input buffer * suitable for mlw_encode, use a temporary heap allocated buffer @@ -71,15 +72,22 @@ method_encode (PyObject *self, PyObject *args) { PyObject *item; item = PyList_GetItem(input_list_object, i); - if (!PyLong_Check(item)) - input_buffer[i] = 0; - input_buffer[i] = (int16_t)PyLong_AsLong(item); + long value = PyLong_AsLong(item); + if (value < -255 || value > 255) { + PyErr_SetString(PyExc_ValueError, "Input value out of bounds"); + return NULL; + } + input_buffer[i] = value; } + if (PyErr_Occurred() != NULL) { + PyErr_SetString(PyExc_ValueError, "Invalid input"); + return NULL; + } /* We don't know the output length required, we guess worst case, * the mlw_encode call will do a resize (downwards) anyway. */ - uint8_t *output_buffer = malloc(input_length); + uint8_t *output_buffer = (uint8_t *) malloc(input_length); if (output_buffer == NULL) return PyErr_NoMemory(); @@ -126,13 +134,13 @@ method_decode(PyObject *self, PyObject *args) /* Unpack the input buffer and length from the bytearray object. */ uint8_t *input_buffer = (uint8_t *) PyByteArray_AsString(input_bytearray_object); - int input_length = (int)PyByteArray_Size(input_bytearray_object); + Py_ssize_t input_length = PyByteArray_Size(input_bytearray_object); /* We don't know the output length required, we guess, but the guess * will be too small, the mlw_decode call will do a resize (upwards) * anyway. */ - int16_t *output_buffer = malloc (input_length); + int16_t *output_buffer = (int16_t *) malloc (input_length); if (output_buffer == NULL) return PyErr_NoMemory(); diff --git a/ethosu/mlw_codec/test/test_mlw_codec.py b/ethosu/mlw_codec/test/test_mlw_codec.py index d37462d1..18c828a3 100644 --- a/ethosu/mlw_codec/test/test_mlw_codec.py +++ b/ethosu/mlw_codec/test/test_mlw_codec.py @@ -60,3 +60,17 @@ class TestMLWCodec: def _call_mlw_codec_method(self, method_name, test_input, expected): output = method_name(test_input) assert output == expected + + invalid_encode_test_data = [None, 3, [4, 5, None, 7], [0, 1, "two", 3], [1, 2, 256, 4], [2, 4, 8, -256]] + + @pytest.mark.parametrize("input", invalid_encode_test_data) + def test_encode_invalid_input(self, input): + with pytest.raises(Exception): + mlw_codec.encode(input) + + invalid_decode_test_data = [None, 3, []] + + @pytest.mark.parametrize("input", invalid_decode_test_data) + def test_decode_invalid_input(self, input): + with pytest.raises(Exception): + mlw_codec.decode(input) -- cgit v1.2.1