aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2017-08-23 11:02:43 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit86b53339679e12c952a24a8845a5409ac3d52de6 (patch)
tree807c897ca1001f22b1906d285488877a287b482b /scripts
parent70e9bc21682f4eaedaceb632f594f588cb2c91fc (diff)
downloadComputeLibrary-86b53339679e12c952a24a8845a5409ac3d52de6.tar.gz
COMPMID-514 (3RDPARTY_UPDATE)(DATA_UPDATE) Add support to load .npy data
* Add tensorflow_data_extractor script. * Incorporate 3rdparty npy reader libnpy. * Port AlexNet system test to validation_new. * Port LeNet5 system test to validation_new. * Update 3rdparty/ and data/ submodules. Change-Id: I156d060fe9185cd8db810b34bf524cbf5cb34f61 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/84914 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/caffe_data_extractor.py27
-rw-r--r--scripts/tensorflow_data_extractor.py51
2 files changed, 70 insertions, 8 deletions
diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py
index 09ea0b86b0..65c9938480 100755
--- a/scripts/caffe_data_extractor.py
+++ b/scripts/caffe_data_extractor.py
@@ -1,16 +1,23 @@
#!/usr/bin/env python
-import argparse
+"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
+Usage
+ python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist
+
+Saves each variable to a {variable_name}.npy binary file.
+Tested with Caffe 1.0 on Python 2.7
+"""
+import argparse
import caffe
+import os
import numpy as np
-import scipy.io
if __name__ == "__main__":
# Parse arguments
- parser = argparse.ArgumentParser('Extract CNN hyper-parameters')
- parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file')
- parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist')
+ parser = argparse.ArgumentParser('Extract Caffe net parameters')
+ parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
+ parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
args = parser.parse_args()
# Create Caffe Net
@@ -18,7 +25,7 @@ if __name__ == "__main__":
# Read and dump blobs
for name, blobs in net.params.iteritems():
- print 'Name: {0}, Blobs: {1}'.format(name, len(blobs))
+ print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
for i in range(len(blobs)):
# Weights
if i == 0:
@@ -29,6 +36,10 @@ if __name__ == "__main__":
else:
pass
- print("%s : %s" % (outname, blobs[i].data.shape))
+ varname = outname
+ if os.path.sep in varname:
+ varname = varname.replace(os.path.sep, '_')
+ print("Renaming variable {0} to {1}".format(outname, varname))
+ print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
# Dump as binary
- blobs[i].data.tofile(outname + ".dat")
+ np.save(varname, blobs[i].data)
diff --git a/scripts/tensorflow_data_extractor.py b/scripts/tensorflow_data_extractor.py
new file mode 100644
index 0000000000..1dbf0e127e
--- /dev/null
+++ b/scripts/tensorflow_data_extractor.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python
+"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays.
+Usage
+ python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file
+
+Saves each variable to a {variable_name}.npy binary file.
+
+Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
+ {model_name}.data-{step}-of-{max_step}
+instead of:
+ {model_name}.ckpt
+When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
+when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.
+
+Also note that this script relies on the parameters to be extracted being in the
+'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
+specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
+collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.
+
+Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
+"""
+import argparse
+import numpy as np
+import os
+import tensorflow as tf
+
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser('Extract Tensorflow net parameters')
+ parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\
+ file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\
+ model name with ".ckpt" extension')
+ parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file')
+ args = parser.parse_args()
+
+ # Load Tensorflow Net
+ saver = tf.train.import_meta_graph(args.netFile)
+ with tf.Session() as sess:
+ # Restore session
+ saver.restore(sess, args.modelFile)
+ print('Model restored.')
+ # Save trainable variables to numpy arrays
+ for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
+ varname = t.name
+ if os.path.sep in t.name:
+ varname = varname.replace(os.path.sep, '_')
+ print("Renaming variable {0} to {1}".format(t.name, varname))
+ print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
+ # Dump as binary
+ np.save(varname, sess.run(t))