#!/usr/bin/env python """Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. Usage python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file Saves each variable to a {variable_name}.npy binary file. Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: {model_name}.data-{step}-of-{max_step} instead of: {model_name}.ckpt When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. Also note that this script relies on the parameters to be extracted being in the 'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. """ import argparse import numpy as np import os import tensorflow as tf if __name__ == "__main__": # Parse arguments parser = argparse.ArgumentParser('Extract Tensorflow net parameters') parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ model name with ".ckpt" extension') parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') args = parser.parse_args() # Load Tensorflow Net saver = tf.train.import_meta_graph(args.netFile) with tf.Session() as sess: # Restore session saver.restore(sess, args.modelFile) print('Model restored.') # Save trainable variables to numpy arrays for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): varname = t.name if os.path.sep in t.name: varname = varname.replace(os.path.sep, '_') print("Renaming variable {0} to {1}".format(t.name, varname)) print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) # Dump as binary np.save(varname, sess.run(t))