diff options
Diffstat (limited to 'ethosu/mlw_codec')
-rw-r--r-- | ethosu/mlw_codec/makefile | 49 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_codecmodule.c | 174 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_common.h | 29 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_decode.c | 300 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_decode.h | 42 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_encode.c | 874 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_encode.h | 45 | ||||
-rw-r--r-- | ethosu/mlw_codec/mlw_main.c | 177 | ||||
-rw-r--r-- | ethosu/mlw_codec/test_mlw_codec.py | 43 |
9 files changed, 1733 insertions, 0 deletions
diff --git a/ethosu/mlw_codec/makefile b/ethosu/mlw_codec/makefile new file mode 100644 index 00000000..6eb418dd --- /dev/null +++ b/ethosu/mlw_codec/makefile @@ -0,0 +1,49 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Description: +# Makefile to build mlw_codec + +UNAME=$(shell uname -o) + +CFLAGS=-Wall -Wno-unused-function -Wno-unused-variable + +ifeq ($(DEBUG),1) + CFLAGS+=-g -O0 -DDEBUG +else + CFLAGS+=-O3 +endif + +LIBSRCS=mlw_encode.c mlw_decode.c +LIBHDRS=mlw_encode.h mlw_decode.h mlw_common.h + +ifeq ($(UNAME),Cygwin) + MLWEXE=mlw_codec.exe +else + MLWEXE=mlw_codec +endif + +all: mlwexe + +.PHONY: mlwexe +mlwexe: $(MLWEXE) + +clean: + rm -f $(MLWEXE) + +$(MLWEXE): mlw_main.c $(LIBSRCS) $(LIBHDRS) makefile + gcc $(CFLAGS) mlw_main.c $(LIBSRCS) -o $(MLWEXE) -lm diff --git a/ethosu/mlw_codec/mlw_codecmodule.c b/ethosu/mlw_codec/mlw_codecmodule.c new file mode 100644 index 00000000..de945ab3 --- /dev/null +++ b/ethosu/mlw_codec/mlw_codecmodule.c @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define PY_SSIZE_T_CLEAN +#include <Python.h> + +#include "mlw_decode.h" +#include "mlw_encode.h" + +/* C extension wrapper for mlw_encode + * + * This method is exposed directly in python with the arguments with a + * prototype of the form: + * + * output = mlw_codec.encode(input, verbose=0) + * + * input: [int] + * verbose: int + * output: bytearray + */ + +static PyObject * +method_encode (PyObject *self, PyObject *args) +{ + /* Object to hold the input integer list. */ + PyObject *input_list_object; + + /* Object to hold the input verbosity integer, the verbose argument + * is optional so defaulted to 0. + */ + int verbose = 0; + + /* Arguments to the method are delivered as a tuple, unpack the + * tuple to get the individual arguments, note the second is + * optional. + */ + if (!PyArg_ParseTuple(args, "O|i", &input_list_object, &verbose)) + return NULL; + + /* Unpack the length of the input integer list. */ + int input_length = PyObject_Length (input_list_object); + if (input_length < 0) + input_length = 0; + + /* We need to marshall the integer list into an input buffer + * suitable for mlw_encode, use a temporary heap allocated buffer + * for that purpose. + */ + int16_t *input_buffer = (int16_t *) malloc(sizeof(int16_t *) * input_length); + if (input_buffer == NULL) + return PyErr_NoMemory(); + + /* Unpack the input integer list into the temporary buffer. + */ + for (int i = 0; i < input_length; i++) + { + PyObject *item; + item = PyList_GetItem(input_list_object, i); + if (!PyLong_Check(item)) + input_buffer[i] = 0; + input_buffer[i] = PyLong_AsLong(item); + } + + /* We don't know the output length required, we guess worst case, + * the mlw_encode call will do a resize (downwards) anyway. + */ + uint8_t *output_buffer = malloc(input_length); + if (output_buffer == NULL) + return PyErr_NoMemory(); + + int output_length = mlw_encode(input_buffer, input_length, &output_buffer, verbose); + + PyObject *output_byte_array = PyByteArray_FromStringAndSize ((char *) output_buffer, output_length); + + /* Discard the temporary input and output buffers. */ + free (input_buffer); + free (output_buffer); + + return output_byte_array; +} + +/* C extension wrapper for mlw_decode + * + * This method is exposed directly in python with the arguments with a + * prototype of the form: + * + * output = mlw_codec.decode(input, verbose=0) + * + * input: bytearray + * verbose: int + * output: [int] + */ + +static PyObject * +method_decode(PyObject *self, PyObject *args) +{ + /* Object to hold the input bytearray. */ + PyObject *input_bytearray_object; + + /* Object to hold the input verbosity integer, the verbose argument + * is optional so defaulted to 0. + */ + int verbose = 0; + + /* Arguments to the method are delivered as a tuple, unpack the + * tuple to get the individual arguments, note the second is + * optional. + */ + if (!PyArg_ParseTuple(args, "Y|i", &input_bytearray_object, &verbose)) + return NULL; + + /* Unpack the input buffer and length from the bytearray object. */ + uint8_t *input_buffer = (uint8_t *) PyByteArray_AsString(input_bytearray_object); + int input_length = PyByteArray_Size(input_bytearray_object); + + /* We don't know the output length required, we guess, but the guess + * will be too small, the mlw_decode call will do a resize (upwards) + * anyway. + */ + int16_t *output_buffer = malloc (input_length); + if (output_buffer == NULL) + return PyErr_NoMemory(); + + int output_length = mlw_decode (input_buffer, input_length, &output_buffer, verbose); + + /* Construct a new integer list and marshall the output buffer + * contents into the list. */ + PyObject *output_list = PyList_New(output_length); + for (int i = 0; i <output_length; i++) + PyList_SetItem (output_list, i, PyLong_FromLong (output_buffer[i])); + + free (output_buffer); + + return output_list; +} + +/* mlw_codec method descriptors. + */ + +static PyMethodDef mlw_methods[] = { + {"decode", method_decode, METH_VARARGS, "Python interface for decode"}, + {"encode", method_encode, METH_VARARGS, "Python interface for encode"}, + {NULL, NULL, 0, NULL} +}; + +/* mlw_codec module descriptor. + */ + +static struct PyModuleDef mlw_codecmodule = { + PyModuleDef_HEAD_INIT, + "mlw_codec", + "Python interface for the mlw encoder", + -1, + mlw_methods +}; + +PyMODINIT_FUNC PyInit_mlw_codec(void) { + return PyModule_Create(&mlw_codecmodule); +} diff --git a/ethosu/mlw_codec/mlw_common.h b/ethosu/mlw_codec/mlw_common.h new file mode 100644 index 00000000..008473a5 --- /dev/null +++ b/ethosu/mlw_codec/mlw_common.h @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdint.h> + +#ifndef __MLW_COMMON_H__ +#define __MLW_COMMON_H__ + +#define ZDIV_DISABLE 6 // not alternating mode +#define ZDIV_EOS 7 // indicates end of stream + +#define WDIV_UNCOMPRESSED 7 // indicates uncompressed weights + +#endif diff --git a/ethosu/mlw_codec/mlw_decode.c b/ethosu/mlw_codec/mlw_decode.c new file mode 100644 index 00000000..92aaea67 --- /dev/null +++ b/ethosu/mlw_codec/mlw_decode.c @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include <assert.h> +#include <math.h> +#include <stdarg.h> +#include <math.h> +#include "mlw_common.h" +#include "mlw_decode.h" + + +/////////////////////////////// Read from bitstream + +typedef struct bitbuf { + uint8_t *buf; + int buf_size; // in bytes + int pos; // bit pos of next bit + int log_symbols; +} bitbuf_t; + + +// size in byte +static void bitbuf_init( bitbuf_t *bb, uint8_t *buf, int size, int log_symbols) { + bb->buf = buf; + bb->pos = 0; + bb->buf_size = size; + bb->log_symbols = log_symbols; +} + +static int bitbuf_getbit( bitbuf_t *bb) { + int byte_pos = bb->pos>>3; + int bit_pos = bb->pos&7; + if ( byte_pos < 0 || byte_pos >= bb->buf_size ) { + printf("bitbuf_getbit: underrun, bit_pos %3d byte_pos %3d buf_size %3d\n", bit_pos, byte_pos, bb->buf_size); + exit(1); + } + int bit = bb->buf[ byte_pos ] & (1<<bit_pos) ? 1 : 0; + bb->pos++; + return bit; +} + +static int bitbuf_get( bitbuf_t *bb, const char *name, int len) { + int i, data=0, save_pos=bb->pos; + if (len>0) { + for(i=0; i<len; i++) { + data |= bitbuf_getbit(bb)<<i; + } + if (bb->log_symbols) + printf("bitbuf: pos %3d %7s len %d data %x\n", save_pos, name, len, data); + } + return data; +} + +// Decode the given weight stream +// inbuf compressed bitstream +// inbuf_size size of compressed bitstream in bytes +// outbuf uncompressed 9bit signed weights, buffer malloced +// verbose if non-zero, printf log +// Return value is the number of uncompressed weights +int mlw_decode( uint8_t *inbuf, int inbuf_size, int16_t **outbuf, int verbose) { + int nvalues; + int w_grc_div; + int w_grc_trunc; + int w_uncompressed; + int z_grc_div, z_prev_grc_div=0; + int new_palette; + int palsize=0, palbits=0; + int direct_offset=0; + int16_t palette[512]; + int first=1; + int use_zero_run, i, j; + int outbuf_size=0; + int nchunks=0; + + *outbuf=0; + + bitbuf_t bitbuf_s, *bb=&bitbuf_s; + bitbuf_init( bb, inbuf, inbuf_size, (verbose&2)?1:0 ); + + // Loop over all slices + while(1) { + // Decode slice header + z_grc_div = bitbuf_get( bb, "ZDIV", 3 ); + while(z_grc_div==ZDIV_EOS) { // TODO: change to ZDIV_PAD + // End of stream + // Byte align + bitbuf_get( bb, "BYTEALIGN", (8-(bb->pos&7))&7 ); + first=1; + if ( (bb->pos/8) == inbuf_size) { + // Quit if we actually reached end of input stream + break; + } + z_grc_div = bitbuf_get( bb, "ZDIV", 3 ); + } + if ( (bb->pos/8) == inbuf_size) { + break; // reached end of input stream + } + assert(z_grc_div<4 || z_grc_div==ZDIV_DISABLE); + use_zero_run = z_grc_div!=ZDIV_DISABLE; // alternating grc + nvalues = bitbuf_get( bb, "SLICELEN", 15 )+1; + w_grc_div = bitbuf_get( bb, "WDIV", 3 ); + w_grc_trunc = bitbuf_get( bb, "WTRUNC", 1 ); + new_palette = bitbuf_get( bb, "NEWPAL", 1 ); + if (first) { + // the first slice must have a palette/direct mode setup + assert(new_palette); + first=0; + } + if (!new_palette) { + // At the moment it is not supported to change between alternating + // and non-alternating without redefining the palette (this is because + // the zero is not included in the palette in case of alternating) + int prev_use_zero_run = z_prev_grc_div!=ZDIV_DISABLE; + (void)(prev_use_zero_run); + assert( use_zero_run == prev_use_zero_run); + } + z_prev_grc_div = z_grc_div; + if (new_palette) { + direct_offset = bitbuf_get( bb, "DIROFS", 5 ); + palsize = bitbuf_get( bb, "PALSIZE", 5 ); + if (palsize>0) + palsize++; + palbits = bitbuf_get( bb, "PALBITS", 3 )+2; + for(i=0; i<palsize; i++) { + palette[i] = bitbuf_get( bb, "PALETTE", palbits ); + } + } + + if (w_grc_div==WDIV_UNCOMPRESSED) { + // Uncompressed mode + w_uncompressed = 1; + int uncompressed_bits; + if (palsize>0) { + // Uncompressed bits is given by palette size. + uncompressed_bits=0; + while( (1<<uncompressed_bits) < palsize ) + uncompressed_bits++; + } else { + // No palette. PALBITS is used to specify uncompressed bits. + uncompressed_bits=palbits; + } + // In uncompressed mode there's only a remainder part (no unary) + // This is achieved by setting w_grc_div to index bit width + w_grc_div = uncompressed_bits; + } else { + w_uncompressed = 0; + assert(w_grc_div<6); + } + + // Decode the slice + int z_nvalues = nvalues + (new_palette?1:0); + int *w_value = malloc( nvalues*sizeof(int) ); + int *z_value = malloc( z_nvalues*sizeof(int) ); + int w_pos=0, z_pos=0; + int w_prev_pos=0, z_prev_pos=0; + int w_unary0=0, w_unary1=0, w_unary1_len=0, w_q[12]={0}, w_carry=0; + int z_unary=0, z_q[12]={0}, z_carry=0; + int w_nsymbols=0; + int w_prev_enable=0, w_prev_nsymbols=0, w_prev_q[12]={0}; + int z_nsymbols=0; + int z_prev_enable=0, z_prev_nsymbols=0, z_prev_q[12]={0}; + int total_zcnt=0; + int z_unary_len = z_grc_div<3 ? 12 : 8; + + // Loop over all chunks in the slice + do { + // Flow control to possibly throttle either the weights or zero-runs + int balance = use_zero_run ? w_pos - z_pos : 0; + int w_enable = (balance<8 || !use_zero_run) && w_pos<nvalues; + int z_enable = balance>=0 && use_zero_run && z_pos<z_nvalues; + if (w_enable) { + if (!w_uncompressed) + w_unary0 = bitbuf_get( bb, "WUNARY0", 12 ); + else + w_unary0 = 0; + } + if (z_enable) { + z_unary = bitbuf_get( bb, "ZUNARY", z_unary_len ); + z_nsymbols=0; + int cnt = z_carry; + for(i=0; i<z_unary_len; i++) { + if (z_unary & (1<<i)) { + cnt++; + } else { + z_q[z_nsymbols++] = cnt; + cnt=0; + } + } + z_carry = cnt; + z_pos += z_nsymbols; + } + if (w_enable) { + w_unary1_len=0; + int max_symbols = w_uncompressed && w_grc_div>5 ? 8 : 12; + for(i=0; i<max_symbols; i++) { + if (w_unary0&(1<<i)) + w_unary1_len++; + } + w_unary1 = bitbuf_get( bb, "WUNARY1", w_unary1_len ); + w_nsymbols=0; + int cnt = w_carry; + for(i=0; i<max_symbols; i++) { + int code=0; + if (w_unary0 & (1<<i)) { + code++; + if (w_unary1&1) { + code++; + } + w_unary1 = w_unary1>>1; + } + cnt += code; + if (code<2 || w_grc_trunc) { + w_q[w_nsymbols++] = cnt; + cnt=0; + } + } + w_carry = cnt; + w_pos += w_nsymbols; + } + if (w_prev_enable) { + for(i=0; i<w_prev_nsymbols && w_prev_pos<nvalues; i++, w_prev_pos++) { + int remain = bitbuf_get( bb, "WREMAIN", w_grc_div ); + w_value[w_prev_pos] = (w_prev_q[i]<<w_grc_div) + remain; + } + } + if (z_prev_enable) { + for(i=0; i<z_prev_nsymbols && z_prev_pos<z_nvalues; i++, z_prev_pos++) { + int remain = bitbuf_get( bb, "ZREMAIN", z_grc_div ); + z_value[z_prev_pos] = (z_prev_q[i]<<z_grc_div) + remain; + total_zcnt += z_value[z_prev_pos]; + } + } + w_prev_enable = w_enable; + w_prev_nsymbols = w_nsymbols; + memcpy( w_prev_q, w_q, sizeof(w_prev_q)); + z_prev_enable = z_enable; + z_prev_nsymbols = z_nsymbols; + memcpy( z_prev_q, z_q, sizeof(z_prev_q)); + nchunks++; + } while( w_prev_enable || z_prev_enable ); + + // Interleave non-zero and zeros into the outbut buffer + // Increase the outbuffer to fit the new slice + *outbuf = realloc( *outbuf, (outbuf_size + nvalues + total_zcnt)*sizeof(int16_t)); + + int k=outbuf_size; + + // Insert initial zeros + // (slices redefining the palette may start with zeros) + if (new_palette && use_zero_run) { + for(j=0; j<z_value[0]; j++) { + (*outbuf)[k++] = 0; + } + } + + // Loop over all weights and insert zeros in-between + for(i=0; i<nvalues; i++) { + int val; + assert(w_value[i]<512); // HW supports 9bit + if (w_value[i]<palsize) { + val = palette[w_value[i]]; + } else { + val = w_value[i]-palsize+direct_offset; + } + int sign = val&1; + int mag = val>>1; + (*outbuf)[k++] = sign ? -mag : mag; + if (use_zero_run) { + for(j=0; j<z_value[i+(new_palette?1:0)]; j++) { + (*outbuf)[k++] = 0; + } + } + } + + outbuf_size = k; + free(w_value); + free(z_value); + } + return outbuf_size; +} diff --git a/ethosu/mlw_codec/mlw_decode.h b/ethosu/mlw_codec/mlw_decode.h new file mode 100644 index 00000000..a15261ad --- /dev/null +++ b/ethosu/mlw_codec/mlw_decode.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdint.h> + +#ifndef __MLW_DECODE_H__ +#define __MLW_DECODE_H__ + +#ifdef _MSC_VER + #define EXPORTED __declspec(dllexport) +#else + #define EXPORTED __attribute__((visibility("default"))) +#endif + +#if __cplusplus +extern "C" +{ +#endif + +EXPORTED +int mlw_decode(uint8_t *inbuf, int inbuf_size, int16_t **outbuf, int verbose); + +#if __cplusplus +} +#endif + +#endif diff --git a/ethosu/mlw_codec/mlw_encode.c b/ethosu/mlw_codec/mlw_encode.c new file mode 100644 index 00000000..ac25fc52 --- /dev/null +++ b/ethosu/mlw_codec/mlw_encode.c @@ -0,0 +1,874 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include <assert.h> +#include <math.h> +#include <stdarg.h> +#include <math.h> +#include "mlw_common.h" +#include "mlw_encode.h" + +#define DPRINTF(...) +//#define DPRINTF(...) printf(__VA_ARGS__) + +#define ZERO_RUN_THRES 4 + +#define min(a,b) ((a)<(b)?(a):(b)) +#define max(a,b) ((a)>(b)?(a):(b)) + +typedef struct palette { + int16_t lut[32]; + int16_t inv_lut[512]; + int palsize; // number of palette entries + int palbits; // bit width of palette entries + int use_zero_runs; // zeros are coded separately + int only_palette; // no values outside the palette + int direct_offset; // added to the decoded weight index before direct conversion to sign/mag + int only_zeros; // special case that the section is all zeros +} palette_t; + +static int is_power_of_two( int x ) { + return ((x-1) & x)==0; +} + +static int get_palette_index_bits( int size ) { + int i; + for(i=7; i>=0; i--) + if (size > (1<<i) ) + return i+1; + return 0; +} + +// Search the stream for suitable palette restart positions +// Return the number of restarts +static int search_palette_sections( int16_t *buf, int size, int **palette_restart_positions ) { + int i,j,got_palette,restart_i,palette_size=0, last_restart_idx, zero_cnt; + int prev_idx[512]; // For each value, keep track of the index of the previous occurence + int *restart_pos; + int max_palettes = size/64; + + // Preliminary allocation of sufficient size + restart_pos = (int*)malloc( max_palettes*sizeof(int) ); + last_restart_idx=0; + got_palette=0; + restart_i=1; + restart_pos[0] = 0; + zero_cnt=0; + memset( prev_idx, -1, sizeof(prev_idx)); + for(i=0; i<size; i++) { + // Guess if zeros should be excluded from the palette + int exclude_zero = zero_cnt > (i-last_restart_idx)/4; + + if (got_palette) { + // Check if the next value is not covered by the current palette + if ( prev_idx[ buf[i]+256 ] < last_restart_idx ) { + // New value: increase the palette size + palette_size++; + DPRINTF("Note: at pos %d extend palette to size %d\n", i, palette_size); + if ( is_power_of_two(palette_size-1-exclude_zero) ) { + if ( (i - last_restart_idx - zero_cnt) > 512 || (palette_size-exclude_zero)>32 ) { + // create a new palette because we extend a long lasting palette to require one more index bit + DPRINTF("Note: at pos %d create new palette because previous has to increase one more index bit. last_restart_idx %d n %d zero_cnt %d\n", i, last_restart_idx, i - last_restart_idx, zero_cnt ); + assert( restart_i < max_palettes ); + DPRINTF("restart %d pos %d\n", restart_i, i); + restart_pos[restart_i++] = i; + last_restart_idx = i; + got_palette=0; + zero_cnt=0; + } + } + } + } + + prev_idx[ buf[i]+256 ] = i; + if (buf[i]==0) + zero_cnt++; + + static const int window_sizes[5][2] = {{32,1}, {64,1}, {128,1}, {256,1}, {512,1}}; + int k; + // loop over window sizes + for(k=0; k<5; k++) { + // Every Nth non-zero value, count what would be the size of a palette covering the last N NZ. + int N = window_sizes[k][0] * (got_palette?2:1); + if ( (i - last_restart_idx - zero_cnt) > 0 && ((i - last_restart_idx - zero_cnt) % N)==0 ) { + // Search backward to the position N nonzero values earlier + int nzcnt=0; + for( j=i; j>last_restart_idx; j--) { + if ( buf[j]!=0 ) { + if (nzcnt==N+1) + break; + nzcnt++; + } + } + int restart_idx = j; + + // Calculate the size of a new palette (starting at restart_idx) + int new_palette_size=0; + for(j=0; j<512; j++) { + if ( prev_idx[j] >= restart_idx ) { + new_palette_size++; + } + } + + int create_new_palette=0; + if (got_palette) { + int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero ); + int old_size_bits = get_palette_index_bits( palette_size - exclude_zero ); + int savings = N*(old_size_bits*15-new_size_bits*15)/16 - new_palette_size*8 - 20; + if ( savings>0 ) { + // Create new palette because it can be smaller than the existing palette + create_new_palette=1; + DPRINTF("Note: at pos %d restart smaller palette\n", restart_idx); + } + } else { + if ( (new_palette_size-exclude_zero) <= 32) { + int new_size_bits = get_palette_index_bits( new_palette_size - exclude_zero ); + // estimate if we will make savings by using palette mode + int savings = N*(90-new_size_bits*15)/16 - new_palette_size*8 - 20; + create_new_palette = savings>0; + } + } + if (create_new_palette) { + palette_size=new_palette_size; + got_palette=1; + last_restart_idx = restart_idx; + DPRINTF("Note: at pos %d create palette of size %d\n", last_restart_idx, new_palette_size); + if ( restart_pos[restart_i-1] != last_restart_idx) { + assert( restart_i < max_palettes ); + restart_pos[restart_i++] = last_restart_idx; + } + zero_cnt=0; + for( j=last_restart_idx; j<=i; j++) + if (buf[j]==0) + zero_cnt++; + } + } + } + } + // Reallocate to actual size + *palette_restart_positions = (int*)realloc( restart_pos, restart_i*sizeof(int) ); + return restart_i; +} + +// Calculate frequency table +static void calc_freq( const int16_t *buf, int size, int freq[512] ) { + int i; + memset(freq, 0, 512*sizeof(int)); + for(i=0; i<size; i++) { + freq[buf[i]+256]++; + } +} + +static int cmp_uint64(const void * a, const void * b) { + uint64_t aa = *(uint64_t*)a; + uint64_t bb = *(uint64_t*)b; + return aa>bb ? -1 : aa<bb ? 1 : 0; +} + +// Create palette from the given frequencies +// Freq index 0-511 correspond to weights -256..255 +static void create_palette( int freq[512], + int use_zero_runs, + palette_t *p ) { + uint64_t freq64[512]; + int i,all_cnt,all_max_val; + + // Pair the frequency with the value so that + // the array can be sorted on frequency while keeping + // track of the corresponding palette value + memset(freq64, 0, sizeof(freq64)); + all_cnt=0; + all_max_val=0; + for(i=-255; i<256; i++) { + if (i==0 && use_zero_runs) + continue; + int sign = i<0; + int mag = abs(i); + int palval = (mag<<1) | sign; + + // Store palette value in 16 LSB bits, which will not affect the sorting + freq64[palval] = (((uint64_t)freq[i+256])<<16) | palval; + all_cnt+=freq[i+256]; + + if (freq[i+256]>0) { + all_max_val = max(all_max_val, palval); + } + } + + // Count number of non-used weight values around zero (0, -1, +1, -2, +2 etc) + for(i=0; i<31; i++) { + if ((freq64[i]>>16)!=0) + break; + } + p->direct_offset = i; + + // Sort in descending frequency order + qsort(freq64, 512, sizeof(uint64_t), cmp_uint64); + + // Identify special case that there are no weights to code + // in the weight index stream (i.e. all weights are zeros) + p->only_zeros = (freq64[0]>>16)==0; + if (p->only_zeros) { + p->direct_offset=0; + } + + // Check if all weights fit into the palette (and the palette is not empty) + p->only_palette = (freq64[0]>>16)>0 && (freq64[32]>>16)==0; + + int max_palette_size; + if (p->only_palette) { + max_palette_size = 32; + } else { + // For direct-lut we must make sure that the encoded weight + // index is not > 511. We do that by limiting the palette size + // such that the greatest value can be reached after subtracting + // the palette size. + max_palette_size = min(32, 511-all_max_val); + if (max_palette_size==1) { + max_palette_size=0; // because palette of size 1 is not supported + } + } + + // Setup the 32 entry palette + int palette_max_val = 0, val, cnt, pal_cnt=0; + for(i=0; i<max_palette_size; i++) { + cnt = freq64[i]>>16; + val = freq64[i]&0xffff; + if ( cnt==0 ) + break; + p->lut[i] = val; + palette_max_val = max(palette_max_val, val); + pal_cnt+=cnt; + } + if (i==1) + i++; // palette size of 1 is not supported, make it 2 + + // Heuristic for when to use the palette. If more than half of the + // weights are in the palette then we use it. This ensures we don't + // use palette for e.g. rectangular distributions. + int palbits_val; + if (pal_cnt > all_cnt/2) { + p->palsize = i; + palbits_val = palette_max_val; + } else { + // No palette + p->palsize = 0; + // If no palette, then palbits is used to specify the + // number of bits required for uncompressed mode, i.e. + // the number of bits for the greatest weight value + palbits_val = all_max_val; + } + + // the palette entry bit width + // minimum 2bits (because PALBITS is in range 2..9) + int palbits=2; + while( (1<<palbits) <= palbits_val ) + palbits++; + assert(palbits<=9); + p->palbits = palbits; + p->use_zero_runs = use_zero_runs; +} + +// Return 1 if zero runs should be used +// If palette_size is 512, then palette is not used (in that case the palette is setup +// with the standard alternating unsigned to signed mapping) +static int find_palette( const int16_t *inbuf, int inbuf_size, palette_t *p) { + int freq[512], i; + + // Calculate frequencies of the given weight stream + calc_freq( inbuf, inbuf_size, freq); + + // Find two most common values + int most_common_freq[2]={0}, most_common_val[2]={0}; + for(i=0; i<512; i++) { + if ( freq[i] > most_common_freq[0] ) { + most_common_freq[1] = most_common_freq[0]; + most_common_val[1] = most_common_val[0]; + most_common_freq[0] = freq[i]; + most_common_val[0] = i-256; + } else if ( freq[i] > most_common_freq[1] ) { + most_common_freq[1] = freq[i]; + most_common_val[1] = i-256; + } + } + + // Decide if zero-runs (alternating mode) should be used: + // * zero should be the most common symbol + // * zero should be sufficiently more common than the second most common symbol + int use_zero_runs = most_common_val[0]==0 && most_common_freq[0] > ZERO_RUN_THRES*most_common_freq[1]; + + // Create the palette + create_palette( freq, use_zero_runs, p); + + return use_zero_runs; +} + +static void create_inverse_palette( palette_t *p) { + int i; + memset( p->inv_lut, 0, sizeof(p->inv_lut)); + for(i=0; i<512; i++) { + int val = i; + int sign = val&1; + int mag = val>>1; + int weight = sign ? -mag : mag; + if (weight+256 < 512) + p->inv_lut[ weight+256 ] = i + p->palsize - p->direct_offset; + } + for(i=0; i<p->palsize; i++) { + int val = p->lut[i]; + int sign = val&1; + int mag = val>>1; + int weight = sign ? -mag : mag; + if (weight+256 < 512) + p->inv_lut[ weight+256 ] = i; + } +} + +#define NWCFG 13 +#define NZCFG 4 // restrict search to ZDIV=0..3 +#define MAX_ZWCFG (max(NWCFG,NZCFG)) + +// search state +typedef struct search_state { + int bitcnt; // number of bits to reach this state + uint8_t prev_cfg; // previous grc parameter config +} search_state_t; + +// (trunc<<4) | div, 0x20 means uncompressed +static const char w_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x20 }; +static const char z_grc_params[] = { 0x00, 0x01, 0x02, 0x03, 0x04 }; + + + +// An algorithm similar to the Viterbi algorithm is used to search for a +// good GRC parameter sequence for the given input value sequence. +// The inval buffer can contain weights, weight indices or runs. +// The return value is the resulting number of bitstream sections. +static int search_grc_params( const int *inval_buf, + int n_inval, + int zrun_mode, + int uncompressed_bits, + uint8_t *grc_param_cfg, + int *grc_param_pos, + int max_grc_param_cfg, + int *existing_grc_param_pos, + int n_existing_grc_param_pos, + int *bitcnt ) +{ + int n_cfg = zrun_mode ? NZCFG : NWCFG; + const char *grc_params = zrun_mode ? z_grc_params : w_grc_params; + int i,j; + + search_state_t *state[MAX_ZWCFG]; + for(i=0; i<n_cfg; i++) { + state[i] = malloc( sizeof(search_state_t) * (n_inval+1) ); + state[i][0].bitcnt=0; + state[i][0].prev_cfg=i; + } + + // Loop over inval_buf + int existing_idx=0; + for(i=0; i<n_inval; i++) { + int value = inval_buf[i]; + + // Best GRC parameter so far + int best_bitcnt=0x7fffffff, best_cfg=0; + for(j=0; j<n_cfg; j++) { + if (state[j][i].bitcnt < best_bitcnt) { + best_bitcnt = state[j][i].bitcnt; + best_cfg = j; + } + } + + int cmd_cost = 40; + if (existing_idx < n_existing_grc_param_pos && existing_grc_param_pos[existing_idx] == (i+1)) { + // free transition, because the weight stream already inserted a command at this position + cmd_cost = 0; + existing_idx++; + } + + // Loop over GRC parameters, calculate bits to code value, and then update the search state + for(j=0; j<n_cfg; j++) { + int div = grc_params[j]&15; + int trunc = grc_params[j]>>4; + int q = value>>div; + int bits = trunc ? min(q+1,2) + div : q+1+div; + if (!zrun_mode && ((trunc && q>2) || q>31)) + bits=10000; // it's not possible to code the current value; give it a high cost + if (trunc==2) + bits=uncompressed_bits; + + if ( best_bitcnt + cmd_cost < state[j][i].bitcnt ) { + // Change GRC parameters + state[j][i+1].prev_cfg = best_cfg; + state[j][i+1].bitcnt = best_bitcnt + cmd_cost + bits; + } else { + // Keep same GRC parameters + state[j][i+1].prev_cfg = j; + state[j][i+1].bitcnt = state[j][i].bitcnt + bits; + } + } + } + + + // Best GRC parameter + int best_bitcnt=0x7fffffff, best_cfg=0; + for(j=0; j<n_cfg; j++) { + if (state[j][n_inval].bitcnt < best_bitcnt) { + best_bitcnt = state[j][n_inval].bitcnt; + best_cfg = j; + } + } + + int cfg = best_cfg; + int n_cmds=0; + for(i=n_inval; i>=0; i--) { + if (state[cfg][i].prev_cfg != cfg || i==0) { + n_cmds++; + cfg = state[cfg][i].prev_cfg; + } + } + + (void)(max_grc_param_cfg); + assert(n_cmds<=max_grc_param_cfg); + + cfg = best_cfg; + j=n_cmds-1; + int endpos=n_inval; + for(i=n_inval; i>=0; i--) { + if (state[cfg][i].prev_cfg != cfg || i==0) { + grc_param_cfg[j] = cfg; + grc_param_pos[j] = endpos; + j--; + cfg = state[cfg][i].prev_cfg; + endpos = i-1; + } + } + assert(j==-1); + + for(i=0; i<n_cfg; i++) { + free(state[i]); + } + + *bitcnt = best_bitcnt; + + return n_cmds; +} + + +/////////////////////////////// Write to bitstream + +typedef struct bitbuf { + uint8_t *buf; + int buf_size; // in bytes + int pos; // bit pos of next bit + int log_symbols; +} bitbuf_t; + +// size in byte +static void bitbuf_init( bitbuf_t *bb, uint8_t *buf, int size, int log_symbols ) { + bb->buf = buf; + bb->pos = 0; + bb->buf_size = size; + bb->log_symbols = log_symbols; +} + +static void bitbuf_putbit( bitbuf_t *bb, int bit) { + int byte_pos = bb->pos>>3; + int bit_pos = bb->pos&7; + assert( byte_pos >= 0 ); + assert( byte_pos < bb->buf_size ); + bb->buf[ byte_pos ] = (bb->buf[ byte_pos ] & ~(1<<bit_pos)) | (bit<<bit_pos); + bb->pos += 1; +} + +static void bitbuf_put( bitbuf_t *bb, const char *name, int len, int data) { + int i; + if (len>0) { + if (bb->log_symbols) + printf("bitbuf: pos %3d %7s len %d data %x\n", bb->pos, name, len, data); + for(i=0; i<len; i++) { + bitbuf_putbit(bb, (data>>i)&1); + } + } +} + +// Return new bitpos +static int encode_slice( const int *w_value, + const int *z_value, + int nvalues, + palette_t *p, + int new_palette, + int uncompressed_bits, + int w_cfg, + int z_cfg, + uint8_t *bitbuf, + int bitbuf_size, + int bitpos, + int verbose ) +{ + int i,j; + bitbuf_t bitbuf_s, *bb=&bitbuf_s; + bitbuf_init( bb, bitbuf, bitbuf_size, verbose&2?1:0 ); + bb->pos = bitpos; + + assert(nvalues<32768); + // GRC parameters for this slice + int w_grc_div = w_grc_params[w_cfg] & 15; + int w_grc_trunc = (w_grc_params[w_cfg] >> 4)==1; + int w_uncompressed = (w_grc_params[w_cfg] >> 4)==2; + int z_grc_div = z_grc_params[z_cfg] & 15; + + if (w_uncompressed) { + w_grc_div = uncompressed_bits; + } + + int zdiv = p->use_zero_runs ? z_grc_div : ZDIV_DISABLE; + int wdiv = !w_uncompressed ? w_grc_div : WDIV_UNCOMPRESSED; + + if (verbose&1) { + printf("slice: bitoffset %7d slicelen %5d zdiv %d wdiv %d wtrunc %d newpal %d palbits %d palsize %2d\n", + bb->pos, nvalues, zdiv, wdiv, w_grc_trunc, new_palette, p->palbits, p->palsize); + } + + // Write slice header + bitbuf_put( bb, "ZDIV", 3, zdiv); + bitbuf_put( bb, "SLICELEN", 15, nvalues-1 ); + bitbuf_put( bb, "WDIV", 3, wdiv); + bitbuf_put( bb, "WTRUNC", 1, w_grc_trunc ); + bitbuf_put( bb, "NEWPAL", 1, new_palette ); + if (new_palette) { + bitbuf_put( bb, "DIROFS", 5, p->direct_offset ); + bitbuf_put( bb, "PALSIZE", 5, max(0, p->palsize-1)); + bitbuf_put( bb, "PALBITS", 3, p->palbits-2 ); + for(i=0; i<p->palsize; i++) { + bitbuf_put( bb, "PALETTE", p->palbits, p->lut[i] ); + } + } + + int z_nvalues = nvalues + (new_palette?1:0); + int w_pos=0, z_pos=0; + int w_unary0=0, w_unary1=0, w_unary1_len=0, w_q=-1, w_r=0; + int z_unary=0, z_q=-1, z_r=0; + int w_nsymbols=0, w_remain[12]={0}; + int w_prev_enable=0, w_prev_nsymbols=0, w_prev_remain[12]={0}; + int z_nsymbols=0, z_remain[12]={0}; + int z_prev_enable=0, z_prev_nsymbols=0, z_prev_remain[12]={0}; + int z_unary_len = z_grc_div<3 ? 12 : 8; + do { + int balance = p->use_zero_runs ? w_pos - z_pos : 0; + int w_enable = balance<8 && w_pos<nvalues; + int z_enable = balance>=0 && p->use_zero_runs && z_pos<z_nvalues; + if (w_enable) { + // Encode chunk (weights) + j=0; + w_nsymbols=0; + w_unary0=0; + w_unary1=0; + w_unary1_len=0; + int max_symbols = w_uncompressed && w_grc_div>5 ? 8 : 12; + while(j<max_symbols) { + if (w_q<0) { + if (w_pos<nvalues) { + int value = w_value[w_pos]; + assert(value<512); + w_q = value>>w_grc_div; + w_r = value&((1<<w_grc_div)-1); + assert( w_q<=31 && (!w_grc_trunc || w_q<=2)); + } else { + w_q = 0; + w_r = -1; // don't send remainder + } + } + while( w_q>=0 && j<max_symbols) { + w_unary0 |= w_q>0 ? (1<<j) : 0; + if (w_q>0) { + w_unary1 |= w_q>1 ? (1<<w_unary1_len) : 0; + w_unary1_len++; + } + j++; + w_q-=2; + if (w_grc_trunc) + w_q--; + } + if (w_q<0 && w_r>=0) { + w_remain[w_nsymbols] = w_r; + w_nsymbols++; + w_pos++; + } + } + } + + if (z_enable) { + // Encode chunk (zrun) + j=0; + z_nsymbols=0; + z_unary=0; + while(j<z_unary_len) { + if (z_q<0) { + if (z_pos<z_nvalues) { + int value = z_value[z_pos]; + z_q = value>>z_grc_div; + z_r = value&((1<<z_grc_div)-1); + } else { + z_q = 0; + z_r = -1; + } + } + while( z_q>=0 && j<z_unary_len) { + z_unary |= z_q>0 ? (1<<j) : 0; + j++; + z_q--; + } + if (z_q<0 && z_r>=0) { + z_remain[z_nsymbols] = z_r; + z_nsymbols++; + z_pos++; + } + } + } + + // Write chunk to bitstream + if (w_enable && !w_uncompressed) { + bitbuf_put( bb, "WUNARY0", 12, w_unary0); + } + if (z_enable) { + bitbuf_put( bb, "ZUNARY", z_unary_len, z_unary); + } + if (w_enable && !w_uncompressed) { + bitbuf_put( bb, "WUNARY1", w_unary1_len, w_unary1); + } + if (w_prev_enable) { + for(i=0; i<w_prev_nsymbols; i++) { + bitbuf_put( bb, "WREMAIN", w_grc_div, w_prev_remain[i]); + } + } + if (z_prev_enable) { + for(i=0; i<z_prev_nsymbols; i++) { + bitbuf_put( bb, "ZREMAIN", z_grc_div, z_prev_remain[i]); + } + } + w_prev_enable = w_enable; + w_prev_nsymbols = w_nsymbols; + memcpy( w_prev_remain, w_remain, sizeof(w_prev_remain)); + z_prev_enable = z_enable; + z_prev_nsymbols = z_nsymbols; + memcpy( z_prev_remain, z_remain, sizeof(z_prev_remain)); + } while( w_prev_enable || z_prev_enable ); + + return bb->pos; +} + + +// return new bitpos +static int encode_section( const int16_t *inbuf, + int size, + palette_t *p, + uint8_t *bitbuf, + int bitbuf_size, + int bitpos, + int verbose ) +{ + int uncompressed_bits; + + // Uncompressed mode can only be used if either all weights + // are in the palette OR if the palette is not used. + if (p->only_palette) { + // Uncompressed bits derived from palette size + uncompressed_bits=0; + while( (1<<uncompressed_bits) < p->palsize ) + uncompressed_bits++; + } else if (p->palsize==0) { + // Uncompressed bits is palbits (which is the bitdepth of the greatest weight) + uncompressed_bits = p->palbits; + } else { + // Don't use uncompressed + uncompressed_bits = 100; + } + + int *weight_values = malloc( size*sizeof(int) ); + int *zrun_values = malloc( size*sizeof(int) ); + + // Get weights (or weight indicies) AND zero-runs from the input weight stream. + int i=0, n_weights = 0, zcnt; + while(1) { + if (p->use_zero_runs) { + zcnt=0; + // Count zero run + // Special case: if all weights in the section are zero, we must + // still ensure we have one coded weight so the the slice length + // doesn't become 0. Therefore we skip the first zero run and code + // the zero explicitly as a weight value instead + if (!p->only_zeros || i>0) { + while( i<size && inbuf[i]==0) { + zcnt++; + i++; + } + } + zrun_values[n_weights] = zcnt; + } + if (i==size) + break; + int value = p->inv_lut[inbuf[i]+256]; + weight_values[n_weights] = value; + n_weights++; + i++; + } + + // Search for good GRC parameters for the weight stream + int n_w_slice, w_bitcnt; + uint8_t *w_slice_cfg; + int *w_slice_pos; + w_slice_cfg = malloc( size ); + w_slice_pos = malloc( size*sizeof(int) ); + n_w_slice = search_grc_params( weight_values, n_weights, 0, uncompressed_bits, w_slice_cfg, w_slice_pos, size, 0, 0, &w_bitcnt); + if (n_weights==0) + n_w_slice = 0; + + // Search for good GRC parameters for the zrun stream + int n_z_slice=0, z_bitcnt=0; + uint8_t *z_slice_cfg=0; + int *z_slice_pos=0; + if (p->use_zero_runs) { + z_slice_cfg = malloc( size ); + z_slice_pos = malloc( size*sizeof(int) ); + n_z_slice = search_grc_params( zrun_values, n_weights+1, 1, 0, z_slice_cfg, z_slice_pos, size, w_slice_pos, n_w_slice, &z_bitcnt); + } + + // Encode bitstream slice + int pos=0, i_w_slice=0, i_z_slice=0, new_palette=1; + while(pos<n_weights || new_palette) { + int endpos=pos+32767; // max slice length + + if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]<endpos) { + endpos = w_slice_pos[i_w_slice]; + } + + if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]<endpos) { + endpos = z_slice_pos[i_z_slice]; + } + + if (n_weights < endpos) { + endpos = n_weights; + } + + // The first slice (when new_palette is 1) encodes zero runs both at the + // beginning and end (i.e. number of zero runs are len+1). + // The following slices only encode zero runs at the end (there cannot be + // any zeros in the beginning since they are encoded by the previous slice) + int len = endpos - pos; + int *zrun_buf = p->use_zero_runs ? zrun_values+pos+(!new_palette) : 0; + bitpos = encode_slice( weight_values+pos, zrun_buf, len, + p, new_palette, uncompressed_bits, + w_slice_cfg[i_w_slice], p->use_zero_runs ? z_slice_cfg[i_z_slice] : 0, + bitbuf, bitbuf_size, bitpos, verbose ); + new_palette = 0; + + if (i_w_slice<n_w_slice && w_slice_pos[i_w_slice]==endpos) { + i_w_slice++; + } + if (i_z_slice<n_z_slice && z_slice_pos[i_z_slice]==endpos) { + i_z_slice++; + } + pos = endpos; + } + + // Free temporary buffers + free(w_slice_cfg); + free(w_slice_pos); + if (p->use_zero_runs) { + free(z_slice_cfg); + free(z_slice_pos); + } + free(weight_values); + free(zrun_values); + + return bitpos; +} + +// Encode the given weight stream +// inbuf uncompressed 9bit signed weights +// inbuf_size number of weights +// outbuf compressed bitstream, buffer is malloced +// verbose if non-zero, printf log +// Return value is the size in bytes of the compressed output +// Return -1 if error +int mlw_encode( int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose) { + int i; + // Range check + for(i=0; i<inbuf_size; i++) { + if (inbuf[i]<-255 || inbuf[i]>255) { + printf("ERROR: weight out of range at index %d, weight value is %d (valid range is -255..255)\n", i, inbuf[i]); + return -1; + } + } + + int bitbuf_size = inbuf_size*2+1024; + *outbuf = malloc( bitbuf_size ); + + // Analyse input data to find palette re-programming points + int n_restarts; + int *palette_restart_pos; + n_restarts = search_palette_sections( inbuf, inbuf_size, &palette_restart_pos); + + // Compress each section (using a single palette) separately + int bitpos=0; + for(i=0; i<n_restarts; i++) { + palette_t palette; + int pos, size; + pos = palette_restart_pos[i]; + size = (i<n_restarts-1 ? palette_restart_pos[i+1] : inbuf_size) - pos; + find_palette( inbuf+pos, size, &palette); + create_inverse_palette( &palette); + bitpos = encode_section( inbuf+pos, size, &palette, + *outbuf, bitbuf_size, bitpos, verbose ); + } + + + // Add end of stream marker and align to 128bit + { + bitbuf_t bitbuf_s, *bb=&bitbuf_s; + bitbuf_init( bb, *outbuf, bitbuf_size, verbose&2?1:0 ); + bb->pos = bitpos; + bitbuf_put( bb, "ZDIV", 3, ZDIV_EOS); + bitbuf_put( bb, "BYTEALIGN", (8-(bb->pos&7))&7, 0xff ); + + // Pad with 0xff until 64bit aligned + while( bb->pos & 127 ) { + bitbuf_put( bb, "PAD", 8, 0xff ); + } + bitpos = bb->pos; + } + assert((bitpos&127)==0); + int outbuf_size = bitpos/8; + *outbuf = realloc( *outbuf, outbuf_size); + + free(palette_restart_pos); + + return outbuf_size; +} + +void mlw_free_outbuf( uint8_t *outbuf ) { + if (outbuf) + free(outbuf); +} diff --git a/ethosu/mlw_codec/mlw_encode.h b/ethosu/mlw_codec/mlw_encode.h new file mode 100644 index 00000000..a995ac6e --- /dev/null +++ b/ethosu/mlw_codec/mlw_encode.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdint.h> + +#ifndef __MLW_ENCODE_H__ +#define __MLW_ENCODE_H__ + +#ifdef _MSC_VER + #define EXPORTED __declspec(dllexport) +#else + #define EXPORTED __attribute__((visibility("default"))) +#endif + +#if __cplusplus +extern "C" +{ +#endif + +EXPORTED +int mlw_encode(int16_t *inbuf, int inbuf_size, uint8_t **outbuf, int verbose); + +EXPORTED +void mlw_free_outbuf(uint8_t *outbuf); + +#if __cplusplus +} +#endif + +#endif diff --git a/ethosu/mlw_codec/mlw_main.c b/ethosu/mlw_codec/mlw_main.c new file mode 100644 index 00000000..9f720495 --- /dev/null +++ b/ethosu/mlw_codec/mlw_main.c @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include <assert.h> +#include <math.h> +#include <getopt.h> +#include <stdarg.h> +#include "mlw_encode.h" +#include "mlw_decode.h" + +static void fatal_error(const char *format, ...) { + va_list ap; + va_start (ap, format); + vfprintf(stderr, format, ap); + va_end(ap); + exit(1); +} + +static void print_usage(void) { + printf("Usage:\n"); + printf(" Encode: ./mlw_codec [<options>] [-o <outfile.mlw>] infiles.bin\n"); + printf(" Decode: ./mlw_codec [<options>] -d [-o <outfile.bin>] infiles.mlw\n"); + printf("\n"); + printf("Options:\n"); + printf(" -w The uncompressed weight file is an int16_t (word) stream.\n"); + printf(" This is to support 9bit signed weights. Little endian is assuemd.\n"); + printf(" The default format is int8_t (byte) stream (if -w is not specified)\n"); + printf("\n"); +} + +// Read file into allocated buffer. Return length in bytes. +static int read_file( FILE *f, uint8_t **buf) { + + fseek(f, 0, SEEK_END); + int size = ftell(f); + fseek(f, 0, SEEK_SET); + *buf = malloc(size); + assert(*buf); + int rsize = fread(*buf, 1, size, f); + assert(rsize==size); + fclose(f); + return size; +} + + +#define MAX_INFILES 1000 + +int main(int argc, char *argv[]) +{ + int c, decode=0, inbuf_size, outbuf_size; + char *infile_name[MAX_INFILES], *outfile_name=0; + uint8_t *inbuf=0, *outbuf=0; + FILE *infile, *outfile=0; + int verbose=0, infile_idx=0; + int int16_format=0; + + if (argc==1) { + print_usage(); + exit(1); + } + + // Parse command line options + while( optind < argc) { + // Parse options + while ((c = getopt (argc, argv, "di:o:v:w?")) != -1) { + switch (c) { + case 'd': + decode=1; + break; + case 'i': + assert(infile_idx<MAX_INFILES); + infile_name[infile_idx++]=optarg; + break; + case 'o': + outfile_name=optarg; + break; + case 'v': + verbose=atoi(optarg); + break; + case 'w': + int16_format=1; + break; + case '?': + print_usage(); + exit(0); + } + } + + if (optind<argc) { + assert(infile_idx<MAX_INFILES); + infile_name[infile_idx++]=argv[optind]; + optind++; + + } + } + + if (outfile_name) { + outfile=fopen(outfile_name, "wb"); + if (!outfile) + fatal_error("ERROR: cannot open outfile %s\n", outfile_name); + } + + // Loop over input files + int nbr_of_infiles=infile_idx; + for(infile_idx=0; infile_idx<nbr_of_infiles; infile_idx++) { + infile=fopen(infile_name[infile_idx], "rb"); + if (!infile) + fatal_error("ERROR: cannot open infile %s\n", infile_name[infile_idx]); + + // Read infile to buffer + inbuf_size = read_file(infile, &inbuf); + + if (!decode) { + // Encode + int i, n = int16_format ? inbuf_size/sizeof(int16_t) : inbuf_size; + int16_t *weights = malloc( n * sizeof(int16_t) ); + for(i=0; i<n; i++) { + weights[i] = int16_format ? ((int16_t*)inbuf)[i] : ((int8_t*)inbuf)[i]; + } + outbuf_size = mlw_encode( weights, n, &outbuf, verbose); + free(weights); + printf("Input size %d output size %d bpw %4.2f\n", n, outbuf_size, outbuf_size*8.0/n); + } else { + // Decode + int i, n; + int16_t *weights; + n = mlw_decode( inbuf, inbuf_size, &weights, verbose); + outbuf_size = int16_format ? n*sizeof(int16_t) : n; + outbuf = malloc( outbuf_size ); + assert(outbuf); + for(i=0; i<n; i++) { + if (int16_format) + ((int16_t*)outbuf)[i] = weights[i]; + else + outbuf[i] = weights[i]; + } + free(weights); + printf("Input size %d output size %d bpw %4.2f\n", inbuf_size, n, inbuf_size*8.0/n); + + } + + if (outfile) { + fwrite(outbuf, 1, outbuf_size, outfile); + } + + if (inbuf) + free(inbuf); + if (outbuf) + free(outbuf); + } + + if (outfile) { + fclose(outfile); + } + + return 0; +} diff --git a/ethosu/mlw_codec/test_mlw_codec.py b/ethosu/mlw_codec/test_mlw_codec.py new file mode 100644 index 00000000..b8687210 --- /dev/null +++ b/ethosu/mlw_codec/test_mlw_codec.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Simple example of the usage of mlw_codec. + +import sys + +from ethosu import mlw_codec + + +# Simple example +if __name__ == "__main__": + weights = [0, 2, 3, 0, -1, -2, -3, 0, 0, 0, 1, -250, 240] * 3 + print("Original weights :", weights) + + compressed_weights = mlw_codec.encode(weights) + print("Compressed weights :", len(compressed_weights), compressed_weights) + + uncompressed_weights = mlw_codec.decode(compressed_weights) + print("Uncompressed weights:", uncompressed_weights) + + if weights != uncompressed_weights: + print("TEST FAILED") + sys.exit(1) + else: + print("TEST PASSED") + sys.exit(0) |