diff options
author | SiCong Li <sicong.li@arm.com> | 2017-08-23 11:02:43 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | 86b53339679e12c952a24a8845a5409ac3d52de6 (patch) | |
tree | 807c897ca1001f22b1906d285488877a287b482b /scripts/tensorflow_data_extractor.py | |
parent | 70e9bc21682f4eaedaceb632f594f588cb2c91fc (diff) | |
download | ComputeLibrary-86b53339679e12c952a24a8845a5409ac3d52de6.tar.gz |
COMPMID-514 (3RDPARTY_UPDATE)(DATA_UPDATE) Add support to load .npy data
* Add tensorflow_data_extractor script.
* Incorporate 3rdparty npy reader libnpy.
* Port AlexNet system test to validation_new.
* Port LeNet5 system test to validation_new.
* Update 3rdparty/ and data/ submodules.
Change-Id: I156d060fe9185cd8db810b34bf524cbf5cb34f61
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/84914
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'scripts/tensorflow_data_extractor.py')
-rw-r--r-- | scripts/tensorflow_data_extractor.py | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/scripts/tensorflow_data_extractor.py b/scripts/tensorflow_data_extractor.py new file mode 100644 index 0000000000..1dbf0e127e --- /dev/null +++ b/scripts/tensorflow_data_extractor.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. +Usage + python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file + +Saves each variable to a {variable_name}.npy binary file. + +Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: + {model_name}.data-{step}-of-{max_step} +instead of: + {model_name}.ckpt +When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; +when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. + +Also note that this script relies on the parameters to be extracted being in the +'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless +specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other +collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. + +Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. +""" +import argparse +import numpy as np +import os +import tensorflow as tf + + +if __name__ == "__main__": + # Parse arguments + parser = argparse.ArgumentParser('Extract Tensorflow net parameters') + parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ + file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ + model name with ".ckpt" extension') + parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') + args = parser.parse_args() + + # Load Tensorflow Net + saver = tf.train.import_meta_graph(args.netFile) + with tf.Session() as sess: + # Restore session + saver.restore(sess, args.modelFile) + print('Model restored.') + # Save trainable variables to numpy arrays + for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): + varname = t.name + if os.path.sep in t.name: + varname = varname.replace(os.path.sep, '_') + print("Renaming variable {0} to {1}".format(t.name, varname)) + print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) + # Dump as binary + np.save(varname, sess.run(t)) |