aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2024-01-12 15:32:53 +0100
committerFredrik Svedberg <fredrik.svedberg@arm.com>2024-01-24 08:40:47 +0000
commitc222f8cb5af13d890f0851558873ef6679531550 (patch)
tree2af04eec3c0ca18188ebfd0a24fc929fd30d973f
parent56e5f0c22ebc995dae13c6b72b08b28934a7871a (diff)
downloadethos-u-vela-c222f8cb5af13d890f0851558873ef6679531550.tar.gz
MLBEDSW-8568 Fix mlw_codec memory handling
Added missing memory allocation checks to mlw_codec. Change-Id: I20c04d5d9c934b9c715a2b2049705f853d90825a Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
-rw-r--r--ethosu/mlw_codec/mlw_codecmodule.c24
-rw-r--r--ethosu/mlw_codec/mlw_encode.c62
2 files changed, 57 insertions, 29 deletions
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c
index 8c540d6..1f172ee 100644
--- a/ethosu/mlw_codec/mlw_codecmodule.c
+++ b/ethosu/mlw_codec/mlw_codecmodule.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -84,6 +84,7 @@ method_reorder_encode (PyObject *self, PyObject *args)
NPY_ARRAY_ALIGNED);
if (input_ndarray_object == NULL)
{
+ PyErr_SetString(PyExc_ValueError, "Invalid input array");
return NULL;
}
@@ -137,17 +138,23 @@ method_reorder_encode (PyObject *self, PyObject *args)
&padded_length,
verbose);
- PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
- PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
+ PyObject* ret = NULL;
+ if ( output_length < 0 ) {
+ ret = PyErr_NoMemory();
+ } else {
+ PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length);
+ PyObject *padded_length_obj = Py_BuildValue("i", padded_length);
+ if ( output_byte_array && padded_length_obj ) {
+ ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
+ }
+ Py_XDECREF(output_byte_array);
+ Py_XDECREF(padded_length_obj);
+ }
/* Discard the output buffer */
mlw_free_outbuf(output_buffer);
- PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj);
-
Py_DECREF(input_ndarray_object);
- Py_DECREF(output_byte_array);
- Py_DECREF(padded_length_obj);
return ret;
}
@@ -216,7 +223,8 @@ method_encode (PyObject *self, PyObject *args)
int output_length = mlw_encode(input_buffer, (int)input_length, &output_buffer, verbose);
- PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length);
+ PyObject *output_byte_array = output_length < 0 ? PyErr_NoMemory() :
+ PyByteArray_FromStringAndSize ((char *) output_buffer, output_length);
/* Discard the temporary input and output buffers. */
free (input_buffer);
diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c
index e8e1a8c..3ec2490 100644
--- a/ethosu/mlw_codec/mlw_encode.c
+++ b/ethosu/mlw_codec/mlw_encode.c
@@ -1,5 +1,5 @@
/*
- * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
*
* SPDX-License-Identifier: Apache-2.0
*
@@ -87,14 +87,14 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar
// Preliminary allocation of sufficient size
restart_pos = (int*)malloc( max_palettes*sizeof(int) );
if (!restart_pos) {
- return 0;
+ return -1;
}
last_restart_idx=0;
got_palette=0;
restart_i=1;
restart_pos[0] = 0;
zero_cnt=0;
- memset( prev_idx, -1, sizeof(prev_idx));
+ memset(prev_idx, -1, sizeof(prev_idx));
for(i=0; i<size; i++) {
// Guess if zeros should be excluded from the palette
int exclude_zero = zero_cnt > (i-last_restart_idx)/4;
@@ -113,7 +113,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar
max_palettes = max_palettes*2;
restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
if (!restart_pos) {
- return 0;
+ return -1;
}
}
DPRINTF("restart %d pos %d\n", restart_i, i);
@@ -184,7 +184,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar
max_palettes = max_palettes*2;
restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) );
if (!restart_pos) {
- return 0;
+ return -1;
}
}
restart_pos[restart_i++] = last_restart_idx;
@@ -199,7 +199,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar
}
// Reallocate to actual size
*palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) );
- return *palette_restart_positions ? restart_i : 0;
+ return *palette_restart_positions ? restart_i : -1;
}
// Calculate frequency table
@@ -417,11 +417,18 @@ static int search_grc_params( const int *inval_buf,
search_state_t *state[MAX_ZWCFG];
for(i=0; i<n_cfg; i++) {
- state[i] = malloc( sizeof(search_state_t) * (n_inval+1) );
+ CHECKED_MALLOC(state[i], sizeof(search_state_t) * (n_inval + 1));
state[i][0].bitcnt=0;
state[i][0].prev_cfg=i;
}
+ if ( i < n_cfg ) { // Memory allocation failed - clean up and exit
+ while ( i ) {
+ free(state[--i]);
+ }
+ return -1;
+ }
+
// Loop over inval_buf
int existing_idx=0;
for(i=0; i<n_inval; i++) {
@@ -784,6 +791,10 @@ static int encode_section( const int16_t *inbuf,
CHECKED_MALLOC( w_slice_cfg, size );
CHECKED_MALLOC( w_slice_pos, size*sizeof(int) );
n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt);
+ if ( n_w_slice < 0 ) { // Memory allocation failed
+ bitpos = -1;
+ break;
+ }
if (n_weights==0)
n_w_slice = 0;
@@ -793,6 +804,10 @@ static int encode_section( const int16_t *inbuf,
CHECKED_MALLOC( z_slice_cfg, size );
CHECKED_MALLOC( z_slice_pos, size*sizeof(int) );
n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt);
+ if ( n_z_slice < 0 ) { // Memory allocation failed
+ bitpos = -1;
+ break;
+ }
}
// Encode bitstream slice
@@ -875,13 +890,12 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
}
// Analyse input data to find palette re-programming points
- int n_restarts;
int *palette_restart_pos = NULL;
- n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos);
+ int n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos);
// Compress each section (using a single palette) separately
- int bitpos=0;
- for(i=0; i<n_restarts; i++) {
+ int bitpos = 0;
+ for ( i = 0; i < n_restarts && bitpos >= 0; i++ ) {
palette_t palette;
int pos, size;
pos = palette_restart_pos[i];
@@ -892,9 +906,9 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
*outbuf, bitbuf_size, bitpos, verbose );
}
-
- // Add end of stream marker and align to 128bit
- {
+ int ret = -1;
+ if ( bitpos >= 0 && n_restarts >= 0 ) { // If allocation fails bitpos or n_restarts < 0
+ // Add end of stream marker and align to 128bit
bitbuf_t bitbuf_s, *bb=&bitbuf_s;
bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 );
bb->pos = bitpos;
@@ -906,14 +920,18 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) {
bitbuf_put( bb, "PAD", 8, 0xff );
}
bitpos = bb->pos;
+
+ assert((bitpos&127)==0);
+ int outbuf_size = bitpos/8;
+ *outbuf = realloc(*outbuf, outbuf_size);
+ if ( *outbuf ) {
+ ret = outbuf_size;
+ }
}
- assert((bitpos&127)==0);
- int outbuf_size = bitpos/8;
- *outbuf = realloc( *outbuf, outbuf_size);
free(palette_restart_pos);
- return *outbuf ? outbuf_size : -1;
+ return ret;
}
void mlw_free_outbuf( uint8_t *outbuf ) {
@@ -965,7 +983,7 @@ static int16_t* reorder(
int decomp_w,
int64_t* padded_length)
{
- *padded_length = 0;
+ *padded_length = -1;
/* Size unknown. Start with one page at least */
int64_t length = round_up(max(1, sizeof(int16_t)*
ofm_depth*
@@ -1090,7 +1108,9 @@ static int16_t* reorder(
weights = (int16_t*)realloc(weights, weight_cnt * sizeof(int16_t));
- *padded_length = weights ? weight_cnt : 0;
+ if ( weights ) {
+ *padded_length = weight_cnt;
+ }
return weights;
}
@@ -1136,7 +1156,7 @@ int mlw_reorder_encode(
padded_length);
/* Then encode */
- int output_length = 0;
+ int output_length = -1;
if (*padded_length > 0 && *padded_length <= INT32_MAX)
{
output_length = mlw_encode(weights, (int)*padded_length, outbuf, verbose);