From bee466b5eac4ec39d4032d946c9a4aee051f2b31 Mon Sep 17 00:00:00 2001 From: steniu01 Date: Wed, 21 Jun 2017 16:45:41 +0100 Subject: COMPID-345 Add caffe_data_extractor.py script and the instructions Change-Id: Ibb84b2060c4d6362be9ce4b1757e273e013de618 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78630 Tested-by: Kaizen Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier --- scripts/caffe_data_extractor.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100755 scripts/caffe_data_extractor.py (limited to 'scripts') diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py new file mode 100755 index 0000000000..09ea0b86b0 --- /dev/null +++ b/scripts/caffe_data_extractor.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +import argparse + +import caffe +import numpy as np +import scipy.io + + +if __name__ == "__main__": + # Parse arguments + parser = argparse.ArgumentParser('Extract CNN hyper-parameters') + parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file') + parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist') + args = parser.parse_args() + + # Create Caffe Net + net = caffe.Net(args.netFile, 1, weights=args.modelFile) + + # Read and dump blobs + for name, blobs in net.params.iteritems(): + print 'Name: {0}, Blobs: {1}'.format(name, len(blobs)) + for i in range(len(blobs)): + # Weights + if i == 0: + outname = name + "_w" + # Bias + elif i == 1: + outname = name + "_b" + else: + pass + + print("%s : %s" % (outname, blobs[i].data.shape)) + # Dump as binary + blobs[i].data.tofile(outname + ".dat") -- cgit v1.2.1