aboutsummaryrefslogtreecommitdiff
path: root/scripts/caffe_data_extractor.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/caffe_data_extractor.py')
-rwxr-xr-xscripts/caffe_data_extractor.py27
1 files changed, 19 insertions, 8 deletions
diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py
index 09ea0b86b0..65c9938480 100755
--- a/scripts/caffe_data_extractor.py
+++ b/scripts/caffe_data_extractor.py
@@ -1,16 +1,23 @@
#!/usr/bin/env python
-import argparse
+"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
+Usage
+ python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist
+
+Saves each variable to a {variable_name}.npy binary file.
+Tested with Caffe 1.0 on Python 2.7
+"""
+import argparse
import caffe
+import os
import numpy as np
-import scipy.io
if __name__ == "__main__":
# Parse arguments
- parser = argparse.ArgumentParser('Extract CNN hyper-parameters')
- parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file')
- parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist')
+ parser = argparse.ArgumentParser('Extract Caffe net parameters')
+ parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
+ parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
args = parser.parse_args()
# Create Caffe Net
@@ -18,7 +25,7 @@ if __name__ == "__main__":
# Read and dump blobs
for name, blobs in net.params.iteritems():
- print 'Name: {0}, Blobs: {1}'.format(name, len(blobs))
+ print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
for i in range(len(blobs)):
# Weights
if i == 0:
@@ -29,6 +36,10 @@ if __name__ == "__main__":
else:
pass
- print("%s : %s" % (outname, blobs[i].data.shape))
+ varname = outname
+ if os.path.sep in varname:
+ varname = varname.replace(os.path.sep, '_')
+ print("Renaming variable {0} to {1}".format(outname, varname))
+ print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
# Dump as binary
- blobs[i].data.tofile(outname + ".dat")
+ np.save(varname, blobs[i].data)