From 86b53339679e12c952a24a8845a5409ac3d52de6 Mon Sep 17 00:00:00 2001 From: SiCong Li Date: Wed, 23 Aug 2017 11:02:43 +0100 Subject: COMPMID-514 (3RDPARTY_UPDATE)(DATA_UPDATE) Add support to load .npy data * Add tensorflow_data_extractor script. * Incorporate 3rdparty npy reader libnpy. * Port AlexNet system test to validation_new. * Port LeNet5 system test to validation_new. * Update 3rdparty/ and data/ submodules. Change-Id: I156d060fe9185cd8db810b34bf524cbf5cb34f61 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/84914 Reviewed-by: Anthony Barbier Tested-by: Kaizen --- scripts/caffe_data_extractor.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) (limited to 'scripts/caffe_data_extractor.py') diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py index 09ea0b86b0..65c9938480 100755 --- a/scripts/caffe_data_extractor.py +++ b/scripts/caffe_data_extractor.py @@ -1,16 +1,23 @@ #!/usr/bin/env python -import argparse +"""Extracts trainable parameters from Caffe models and stores them in numpy arrays. +Usage + python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist + +Saves each variable to a {variable_name}.npy binary file. +Tested with Caffe 1.0 on Python 2.7 +""" +import argparse import caffe +import os import numpy as np -import scipy.io if __name__ == "__main__": # Parse arguments - parser = argparse.ArgumentParser('Extract CNN hyper-parameters') - parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file') - parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist') + parser = argparse.ArgumentParser('Extract Caffe net parameters') + parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file') + parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist') args = parser.parse_args() # Create Caffe Net @@ -18,7 +25,7 @@ if __name__ == "__main__": # Read and dump blobs for name, blobs in net.params.iteritems(): - print 'Name: {0}, Blobs: {1}'.format(name, len(blobs)) + print('Name: {0}, Blobs: {1}'.format(name, len(blobs))) for i in range(len(blobs)): # Weights if i == 0: @@ -29,6 +36,10 @@ if __name__ == "__main__": else: pass - print("%s : %s" % (outname, blobs[i].data.shape)) + varname = outname + if os.path.sep in varname: + varname = varname.replace(os.path.sep, '_') + print("Renaming variable {0} to {1}".format(outname, varname)) + print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape)) # Dump as binary - blobs[i].data.tofile(outname + ".dat") + np.save(varname, blobs[i].data) -- cgit v1.2.1