diff options
-rw-r--r-- | ethosu/mlw_codec/mlw_codecmodule.c | 24 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_encode.c | 62 |
2 files changed, 57 insertions, 29 deletions
diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c index 8c540d61..1f172eed 100644 --- a/ethosu/mlw_codec/mlw_codecmodule.c +++ b/ethosu/mlw_codec/mlw_codecmodule.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2020-2021, 2023 Arm Limited and/or its affiliates <open-source-office@arm.com> + * SPDX-FileCopyrightText: Copyright 2020-2021, 2023-2024 Arm Limited and/or its affiliates <open-source-office@arm.com> * * SPDX-License-Identifier: Apache-2.0 * @@ -84,6 +84,7 @@ method_reorder_encode (PyObject *self, PyObject *args) NPY_ARRAY_ALIGNED); if (input_ndarray_object == NULL) { + PyErr_SetString(PyExc_ValueError, "Invalid input array"); return NULL; } @@ -137,17 +138,23 @@ method_reorder_encode (PyObject *self, PyObject *args) &padded_length, verbose); - PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length); - PyObject *padded_length_obj = Py_BuildValue("i", padded_length); + PyObject* ret = NULL; + if ( output_length < 0 ) { + ret = PyErr_NoMemory(); + } else { + PyObject *output_byte_array = PyByteArray_FromStringAndSize((char*)output_buffer, output_length); + PyObject *padded_length_obj = Py_BuildValue("i", padded_length); + if ( output_byte_array && padded_length_obj ) { + ret = PyTuple_Pack(2, output_byte_array, padded_length_obj); + } + Py_XDECREF(output_byte_array); + Py_XDECREF(padded_length_obj); + } /* Discard the output buffer */ mlw_free_outbuf(output_buffer); - PyObject* ret = PyTuple_Pack(2, output_byte_array, padded_length_obj); - Py_DECREF(input_ndarray_object); - Py_DECREF(output_byte_array); - Py_DECREF(padded_length_obj); return ret; } @@ -216,7 +223,8 @@ method_encode (PyObject *self, PyObject *args) int output_length = mlw_encode(input_buffer, (int)input_length, &output_buffer, verbose); - PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length); + PyObject *output_byte_array = output_length < 0 ? PyErr_NoMemory() : + PyByteArray_FromStringAndSize ((char *) output_buffer, output_length); /* Discard the temporary input and output buffers. */ free (input_buffer); diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c index e8e1a8ca..3ec24908 100644 --- a/ethosu/mlw_codec/mlw_encode.c +++ b/ethosu/mlw_codec/mlw_encode.c @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright 2020-2022 Arm Limited and/or its affiliates <open-source-office@arm.com> + * SPDX-FileCopyrightText: Copyright 2020-2022, 2024 Arm Limited and/or its affiliates <open-source-office@arm.com> * * SPDX-License-Identifier: Apache-2.0 * @@ -87,14 +87,14 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar // Preliminary allocation of sufficient size restart_pos = (int*)malloc( max_palettes*sizeof(int) ); if (!restart_pos) { - return 0; + return -1; } last_restart_idx=0; got_palette=0; restart_i=1; restart_pos[0] = 0; zero_cnt=0; - memset( prev_idx, -1, sizeof(prev_idx)); + memset(prev_idx, -1, sizeof(prev_idx)); for(i=0; i<size; i++) { // Guess if zeros should be excluded from the palette int exclude_zero = zero_cnt > (i-last_restart_idx)/4; @@ -113,7 +113,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar max_palettes = max_palettes*2; restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); if (!restart_pos) { - return 0; + return -1; } } DPRINTF("restart %d pos %d\n", restart_i, i); @@ -184,7 +184,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar max_palettes = max_palettes*2; restart_pos = (int*)realloc( restart_pos, max_palettes*sizeof(int) ); if (!restart_pos) { - return 0; + return -1; } } restart_pos[restart_i++] = last_restart_idx; @@ -199,7 +199,7 @@ static int search_palette_sections( int16_t *buf, int size, int **palette_restar } // Reallocate to actual size *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) ); - return *palette_restart_positions ? restart_i : 0; + return *palette_restart_positions ? restart_i : -1; } // Calculate frequency table @@ -417,11 +417,18 @@ static int search_grc_params( const int *inval_buf, search_state_t *state[MAX_ZWCFG]; for(i=0; i<n_cfg; i++) { - state[i] = malloc( sizeof(search_state_t) * (n_inval+1) ); + CHECKED_MALLOC(state[i], sizeof(search_state_t) * (n_inval + 1)); state[i][0].bitcnt=0; state[i][0].prev_cfg=i; } + if ( i < n_cfg ) { // Memory allocation failed - clean up and exit + while ( i ) { + free(state[--i]); + } + return -1; + } + // Loop over inval_buf int existing_idx=0; for(i=0; i<n_inval; i++) { @@ -784,6 +791,10 @@ static int encode_section( const int16_t *inbuf, CHECKED_MALLOC( w_slice_cfg, size ); CHECKED_MALLOC( w_slice_pos, size*sizeof(int) ); n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt); + if ( n_w_slice < 0 ) { // Memory allocation failed + bitpos = -1; + break; + } if (n_weights==0) n_w_slice = 0; @@ -793,6 +804,10 @@ static int encode_section( const int16_t *inbuf, CHECKED_MALLOC( z_slice_cfg, size ); CHECKED_MALLOC( z_slice_pos, size*sizeof(int) ); n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt); + if ( n_z_slice < 0 ) { // Memory allocation failed + bitpos = -1; + break; + } } // Encode bitstream slice @@ -875,13 +890,12 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { } // Analyse input data to find palette re-programming points - int n_restarts; int *palette_restart_pos = NULL; - n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos); + int n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos); // Compress each section (using a single palette) separately - int bitpos=0; - for(i=0; i<n_restarts; i++) { + int bitpos = 0; + for ( i = 0; i < n_restarts && bitpos >= 0; i++ ) { palette_t palette; int pos, size; pos = palette_restart_pos[i]; @@ -892,9 +906,9 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { *outbuf, bitbuf_size, bitpos, verbose ); } - - // Add end of stream marker and align to 128bit - { + int ret = -1; + if ( bitpos >= 0 && n_restarts >= 0 ) { // If allocation fails bitpos or n_restarts < 0 + // Add end of stream marker and align to 128bit bitbuf_t bitbuf_s, *bb=&bitbuf_s; bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 ); bb->pos = bitpos; @@ -906,14 +920,18 @@ int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { bitbuf_put( bb, "PAD", 8, 0xff ); } bitpos = bb->pos; + + assert((bitpos&127)==0); + int outbuf_size = bitpos/8; + *outbuf = realloc(*outbuf, outbuf_size); + if ( *outbuf ) { + ret = outbuf_size; + } } - assert((bitpos&127)==0); - int outbuf_size = bitpos/8; - *outbuf = realloc( *outbuf, outbuf_size); free(palette_restart_pos); - return *outbuf ? outbuf_size : -1; + return ret; } void mlw_free_outbuf( uint8_t *outbuf ) { @@ -965,7 +983,7 @@ static int16_t* reorder( int decomp_w, int64_t* padded_length) { - *padded_length = 0; + *padded_length = -1; /* Size unknown. Start with one page at least */ int64_t length = round_up(max(1, sizeof(int16_t)* ofm_depth* @@ -1090,7 +1108,9 @@ static int16_t* reorder( weights = (int16_t*)realloc(weights, weight_cnt * sizeof(int16_t)); - *padded_length = weights ? weight_cnt : 0; + if ( weights ) { + *padded_length = weight_cnt; + } return weights; } @@ -1136,7 +1156,7 @@ int mlw_reorder_encode( padded_length); /* Then encode */ - int output_length = 0; + int output_length = -1; if (*padded_length > 0 && *padded_length <= INT32_MAX) { output_length = mlw_encode(weights, (int)*padded_length, outbuf, verbose); |