aboutsummaryrefslogtreecommitdiff
path: root/scripts/caffe_data_extractor.py
blob: 47d24b265f71cb29c8e19ab01db246b46f7c5023 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/env python
"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
Usage
    python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist

Saves each variable to a {variable_name}.npy binary file.

Tested with Caffe 1.0 on Python 2.7
"""
import argparse
import caffe
import os
import numpy as np


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser('Extract Caffe net parameters')
    parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
    parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
    args = parser.parse_args()

    # Create Caffe Net
    net = caffe.Net(args.netFile, 1, weights=args.modelFile)

    # Read and dump blobs
    for name, blobs in net.params.iteritems():
        print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
        for i in range(len(blobs)):
            # Weights
            if i == 0:
                outname = name + "_w"
            # Bias
            elif i == 1:
                outname = name + "_b"
            else:
                continue

            varname = outname
            if os.path.sep in varname:
                varname = varname.replace(os.path.sep, '_')
                print("Renaming variable {0} to {1}".format(outname, varname))
            print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
            # Dump as binary
            np.save(varname, blobs[i].data)