diff options
Diffstat (limited to 'scripts')
-rwxr-xr-x | scripts/caffe_data_extractor.py | 45 | ||||
-rwxr-xr-x | scripts/tensorflow_data_extractor.py | 51 |
2 files changed, 0 insertions, 96 deletions
diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py deleted file mode 100755 index 47d24b265f..0000000000 --- a/scripts/caffe_data_extractor.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python -"""Extracts trainable parameters from Caffe models and stores them in numpy arrays. -Usage - python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist - -Saves each variable to a {variable_name}.npy binary file. - -Tested with Caffe 1.0 on Python 2.7 -""" -import argparse -import caffe -import os -import numpy as np - - -if __name__ == "__main__": - # Parse arguments - parser = argparse.ArgumentParser('Extract Caffe net parameters') - parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file') - parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist') - args = parser.parse_args() - - # Create Caffe Net - net = caffe.Net(args.netFile, 1, weights=args.modelFile) - - # Read and dump blobs - for name, blobs in net.params.iteritems(): - print('Name: {0}, Blobs: {1}'.format(name, len(blobs))) - for i in range(len(blobs)): - # Weights - if i == 0: - outname = name + "_w" - # Bias - elif i == 1: - outname = name + "_b" - else: - continue - - varname = outname - if os.path.sep in varname: - varname = varname.replace(os.path.sep, '_') - print("Renaming variable {0} to {1}".format(outname, varname)) - print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape)) - # Dump as binary - np.save(varname, blobs[i].data) diff --git a/scripts/tensorflow_data_extractor.py b/scripts/tensorflow_data_extractor.py deleted file mode 100755 index 1dbf0e127e..0000000000 --- a/scripts/tensorflow_data_extractor.py +++ /dev/null @@ -1,51 +0,0 @@ -#!/usr/bin/env python -"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. -Usage - python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file - -Saves each variable to a {variable_name}.npy binary file. - -Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: - {model_name}.data-{step}-of-{max_step} -instead of: - {model_name}.ckpt -When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; -when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. - -Also note that this script relies on the parameters to be extracted being in the -'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless -specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other -collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. - -Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. -""" -import argparse -import numpy as np -import os -import tensorflow as tf - - -if __name__ == "__main__": - # Parse arguments - parser = argparse.ArgumentParser('Extract Tensorflow net parameters') - parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ - file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ - model name with ".ckpt" extension') - parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') - args = parser.parse_args() - - # Load Tensorflow Net - saver = tf.train.import_meta_graph(args.netFile) - with tf.Session() as sess: - # Restore session - saver.restore(sess, args.modelFile) - print('Model restored.') - # Save trainable variables to numpy arrays - for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): - varname = t.name - if os.path.sep in t.name: - varname = varname.replace(os.path.sep, '_') - print("Renaming variable {0} to {1}".format(t.name, varname)) - print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) - # Dump as binary - np.save(varname, sess.run(t)) |