diff options
Diffstat (limited to 'ethosu/mlw_codec/mlw_main.c')
-rw-r--r-- | ethosu/mlw_codec/mlw_main.c | 177 |
1 files changed, 177 insertions, 0 deletions
diff --git a/ethosu/mlw_codec/mlw_main.c b/ethosu/mlw_codec/mlw_main.c new file mode 100644 index 00000000..9f720495 --- /dev/null +++ b/ethosu/mlw_codec/mlw_main.c @@ -0,0 +1,177 @@ +/* + * Copyright (c) 2020 Arm Limited. All rights reserved. + * + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the License); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <stdio.h> +#include <stdlib.h> +#include <stdint.h> +#include <stdbool.h> +#include <string.h> +#include <assert.h> +#include <math.h> +#include <getopt.h> +#include <stdarg.h> +#include "mlw_encode.h" +#include "mlw_decode.h" + +static void fatal_error(const char *format, ...) { + va_list ap; + va_start (ap, format); + vfprintf(stderr, format, ap); + va_end(ap); + exit(1); +} + +static void print_usage(void) { + printf("Usage:\n"); + printf(" Encode: ./mlw_codec [<options>] [-o <outfile.mlw>] infiles.bin\n"); + printf(" Decode: ./mlw_codec [<options>] -d [-o <outfile.bin>] infiles.mlw\n"); + printf("\n"); + printf("Options:\n"); + printf(" -w The uncompressed weight file is an int16_t (word) stream.\n"); + printf(" This is to support 9bit signed weights. Little endian is assuemd.\n"); + printf(" The default format is int8_t (byte) stream (if -w is not specified)\n"); + printf("\n"); +} + +// Read file into allocated buffer. Return length in bytes. +static int read_file( FILE *f, uint8_t **buf) { + + fseek(f, 0, SEEK_END); + int size = ftell(f); + fseek(f, 0, SEEK_SET); + *buf = malloc(size); + assert(*buf); + int rsize = fread(*buf, 1, size, f); + assert(rsize==size); + fclose(f); + return size; +} + + +#define MAX_INFILES 1000 + +int main(int argc, char *argv[]) +{ + int c, decode=0, inbuf_size, outbuf_size; + char *infile_name[MAX_INFILES], *outfile_name=0; + uint8_t *inbuf=0, *outbuf=0; + FILE *infile, *outfile=0; + int verbose=0, infile_idx=0; + int int16_format=0; + + if (argc==1) { + print_usage(); + exit(1); + } + + // Parse command line options + while( optind < argc) { + // Parse options + while ((c = getopt (argc, argv, "di:o:v:w?")) != -1) { + switch (c) { + case 'd': + decode=1; + break; + case 'i': + assert(infile_idx<MAX_INFILES); + infile_name[infile_idx++]=optarg; + break; + case 'o': + outfile_name=optarg; + break; + case 'v': + verbose=atoi(optarg); + break; + case 'w': + int16_format=1; + break; + case '?': + print_usage(); + exit(0); + } + } + + if (optind<argc) { + assert(infile_idx<MAX_INFILES); + infile_name[infile_idx++]=argv[optind]; + optind++; + + } + } + + if (outfile_name) { + outfile=fopen(outfile_name, "wb"); + if (!outfile) + fatal_error("ERROR: cannot open outfile %s\n", outfile_name); + } + + // Loop over input files + int nbr_of_infiles=infile_idx; + for(infile_idx=0; infile_idx<nbr_of_infiles; infile_idx++) { + infile=fopen(infile_name[infile_idx], "rb"); + if (!infile) + fatal_error("ERROR: cannot open infile %s\n", infile_name[infile_idx]); + + // Read infile to buffer + inbuf_size = read_file(infile, &inbuf); + + if (!decode) { + // Encode + int i, n = int16_format ? inbuf_size/sizeof(int16_t) : inbuf_size; + int16_t *weights = malloc( n * sizeof(int16_t) ); + for(i=0; i<n; i++) { + weights[i] = int16_format ? ((int16_t*)inbuf)[i] : ((int8_t*)inbuf)[i]; + } + outbuf_size = mlw_encode( weights, n, &outbuf, verbose); + free(weights); + printf("Input size %d output size %d bpw %4.2f\n", n, outbuf_size, outbuf_size*8.0/n); + } else { + // Decode + int i, n; + int16_t *weights; + n = mlw_decode( inbuf, inbuf_size, &weights, verbose); + outbuf_size = int16_format ? n*sizeof(int16_t) : n; + outbuf = malloc( outbuf_size ); + assert(outbuf); + for(i=0; i<n; i++) { + if (int16_format) + ((int16_t*)outbuf)[i] = weights[i]; + else + outbuf[i] = weights[i]; + } + free(weights); + printf("Input size %d output size %d bpw %4.2f\n", inbuf_size, n, inbuf_size*8.0/n); + + } + + if (outfile) { + fwrite(outbuf, 1, outbuf_size, outfile); + } + + if (inbuf) + free(inbuf); + if (outbuf) + free(outbuf); + } + + if (outfile) { + fclose(outfile); + } + + return 0; +} |