aboutsummaryrefslogtreecommitdiff
path: root/scripts/caffe_data_extractor.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/caffe_data_extractor.py')
-rwxr-xr-xscripts/caffe_data_extractor.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py
new file mode 100755
index 0000000000..09ea0b86b0
--- /dev/null
+++ b/scripts/caffe_data_extractor.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python
+import argparse
+
+import caffe
+import numpy as np
+import scipy.io
+
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser('Extract CNN hyper-parameters')
+ parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file')
+ parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist')
+ args = parser.parse_args()
+
+ # Create Caffe Net
+ net = caffe.Net(args.netFile, 1, weights=args.modelFile)
+
+ # Read and dump blobs
+ for name, blobs in net.params.iteritems():
+ print 'Name: {0}, Blobs: {1}'.format(name, len(blobs))
+ for i in range(len(blobs)):
+ # Weights
+ if i == 0:
+ outname = name + "_w"
+ # Bias
+ elif i == 1:
+ outname = name + "_b"
+ else:
+ pass
+
+ print("%s : %s" % (outname, blobs[i].data.shape))
+ # Dump as binary
+ blobs[i].data.tofile(outname + ".dat")