aboutsummaryrefslogtreecommitdiff
path: root/scripts/tensorflow_data_extractor.py
blob: 1dbf0e127edd197465cc1d9afa3f9b04318db83c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python
"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays.
Usage
    python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file

Saves each variable to a {variable_name}.npy binary file.

Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
    {model_name}.data-{step}-of-{max_step}
instead of:
    {model_name}.ckpt
When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.

Also note that this script relies on the parameters to be extracted being in the
'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.

Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
"""
import argparse
import numpy as np
import os
import tensorflow as tf


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser('Extract Tensorflow net parameters')
    parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\
            file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\
            model name with ".ckpt" extension')
    parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file')
    args = parser.parse_args()

    # Load Tensorflow Net
    saver = tf.train.import_meta_graph(args.netFile)
    with tf.Session() as sess:
        # Restore session
        saver.restore(sess, args.modelFile)
        print('Model restored.')
        # Save trainable variables to numpy arrays
        for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
            varname = t.name
            if os.path.sep in t.name:
                varname = varname.replace(os.path.sep, '_')
                print("Renaming variable {0} to {1}".format(t.name, varname))
            print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
            # Dump as binary
            np.save(varname, sess.run(t))