From c222f8cb5af13d890f0851558873ef6679531550 Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Fri, 12 Jan 2024 15:32:53 +0100 Subject: MLBEDSW-8568 Fix mlw_codec memory handling Added missing memory allocation checks to mlw_codec. Change-Id: I20c04d5d9c934b9c715a2b2049705f853d90825a Signed-off-by: Fredrik Svedberg --- ethosu/mlw_codec/mlw_codecmodule.c | 24 ++++++++++----- ethosu/mlw_codec/mlw_encode.c | 62 +++++++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c index 8c540d6..1f172ee 100644 --- a/ethosu/mlw_codec/mlw_codecmodule.c +++ b/ethosu/mlw_codec/mlw_codecmodule.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -84,6 +84,7 @@ method_reorder_encode (PyObject *self, PyObject *args) NPY_ARRAY_ALIGNED); if (input_ndarray_object == NULL) { + PyErr_SetString(PyExc_ValueError, "Invalid input array"); return NULL; } @@ -137,17 +138,23 @@ method_reorder_encode (PyObject *self, PyObject *args) &padded_length, verbose); - PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length); - PyObject *padded_length_obj = Py_BuildValue("i", padded_length); + PyObject* ret = NULL; + if ( output_length < 0 ) { + ret = PyErr_NoMemory(); + } else { + PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length); + PyObject *padded_length_obj = Py_BuildValue("i", padded_length); + if ( output_byte_array && padded_length_obj ) { + ret = PyTuple_Pack(2, output_byte_array, padded_length_obj); + } + Py_XDECREF(output_byte_array); + Py_XDECREF(padded_length_obj); + } /* Discard the output buffer */ mlw_free_outbuf(output_buffer); - PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj); - Py_DECREF(input_ndarray_object); - Py_DECREF(output_byte_array); - Py_DECREF(padded_length_obj); return ret; } @@ -216,7 +223,8 @@ method_encode (PyObject *self, PyObject *args) int output_length = mlw_encode(input_buffer, (int)input_length, &output_buffer, verbose); - PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length); + PyObject *output_byte_array = output_length < 0 ? PyErr_NoMemory() : + PyByteArray_FromStringAndSize ((char *) output_buffer, output_length); /* Discard the temporary input and output buffers. */ free (input_buffer); diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c index e8e1a8c..3ec2490 100644 --- a/ethosu/mlw_codec/mlw_encode.c +++ b/ethosu/mlw_codec/mlw_encode.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates + * SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates * * SPDX-License-Identifier: Apache-2.0 * @@ -87,14 +87,14 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar // Preliminary allocation of sufficient size restart_pos = (int*)malloc( max_palettes*sizeof(int) ); if (!restart_pos) { - return 0; + return -1; } last_restart_idx=0; got_palette=0; restart_i=1; restart_pos[0] = 0; zero_cnt=0; - memset( prev_idx, -1, sizeof(prev_idx)); + memset(prev_idx, -1, sizeof(prev_idx)); for(i=0; i (i-last_restart_idx)/4; @@ -113,7 +113,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar max_palettes = max_palettes*2; restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); if (!restart_pos) { - return 0; + return -1; } } DPRINTF("restart %d pos %d\n", restart_i, i); @@ -184,7 +184,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar max_palettes = max_palettes*2; restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); if (!restart_pos) { - return 0; + return -1; } } restart_pos[restart_i++] = last_restart_idx; @@ -199,7 +199,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar } // Reallocate to actual size *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) ); - return *palette_restart_positions ? restart_i : 0; + return *palette_restart_positions ? restart_i : -1; } // Calculate frequency table @@ -417,11 +417,18 @@ static int search_grc_params( const int *inval_buf, search_state_t *state[MAX_ZWCFG]; for(i=0; i= 0; i++ ) { palette_t palette; int pos, size; pos = palette_restart_pos[i]; @@ -892,9 +906,9 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { *outbuf, bitbuf_size, bitpos, verbose ); } - - // Add end of stream marker and align to 128bit - { + int ret = -1; + if ( bitpos >= 0 && n_restarts >= 0 ) { // If allocation fails bitpos or n_restarts < 0 + // Add end of stream marker and align to 128bit bitbuf_t bitbuf_s, *bb=&bitbuf_s; bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 ); bb->pos = bitpos; @@ -906,14 +920,18 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { bitbuf_put( bb, "PAD", 8, 0xff ); } bitpos = bb->pos; + + assert((bitpos&127)==0); + int outbuf_size = bitpos/8; + *outbuf = realloc(*outbuf, outbuf_size); + if ( *outbuf ) { + ret = outbuf_size; + } } - assert((bitpos&127)==0); - int outbuf_size = bitpos/8; - *outbuf = realloc( *outbuf, outbuf_size); free(palette_restart_pos); - return *outbuf ? outbuf_size : -1; + return ret; } void mlw_free_outbuf( uint8_t *outbuf ) { @@ -965,7 +983,7 @@ static int16_t* reorder( int decomp_w, int64_t* padded_length) { - *padded_length = 0; + *padded_length = -1; /* Size unknown. Start with one page at least */ int64_t length = round_up(max(1, sizeof(int16_t)* ofm_depth* @@ -1090,7 +1108,9 @@ static int16_t* reorder( weights = (int16_t*)realloc(weights, weight_cnt * sizeof(int16_t)); - *padded_length = weights ? weight_cnt : 0; + if ( weights ) { + *padded_length = weight_cnt; + } return weights; } @@ -1136,7 +1156,7 @@ int mlw_reorder_encode( padded_length); /* Then encode */ - int output_length = 0; + int output_length = -1; if (*padded_length > 0 && *padded_length <= INT32_MAX) { output_length = mlw_encode(weights, (int)*padded_length, outbuf, verbose); -- cgit v1.2.1