aboutsummaryrefslogtreecommitdiff
path: root/scripts/caffe_mnist_image_extractor.py
blob: 2c9478a61131233fb3a2175850a33be2ef631db4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python
"""Extracts mnist image data from the Caffe data files and stores them in numpy arrays
Usage
    python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path

Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the
corresponding labels to labels100.txt.

Tested with Caffe 1.0 on Python 2.7
"""
import argparse
import os
import struct
import numpy as np
from array import array


if __name__ == "__main__":
    # Parse arguments
    parser = argparse.ArgumentParser('Extract Caffe mnist image data')
    parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory')
    parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)')
    args = parser.parse_args()

    images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte')
    labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte')

    images_file = open(images_filename, 'rb')
    labels_file = open(labels_filename, 'rb')
    images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16))
    labels_magic, labels_size = struct.unpack('>II', labels_file.read(8))
    images = array('B', images_file.read())
    labels = array('b', labels_file.read())

    input10_path   = os.path.join(args.outDir, 'input10.npy')
    input100_path  = os.path.join(args.outDir, 'input100.npy')
    labels100_path = os.path.join(args.outDir, 'labels100.npy')

    outputs_10  = np.zeros(( 10, 28, 28, 1), dtype=np.float32)
    outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32)
    labels_output = open(labels100_path, 'w')
    for i in xrange(100):
        image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0
        outputs_100[i, :, :, 0] = image

        if i < 10:
            outputs_10[i, :, :, 0] = image

        if i == 10:
            np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2)))
            print "Wrote", input10_path

        labels_output.write(str(labels[i]) + '\n')

    labels_output.close()
    print "Wrote", labels100_path

    np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2)))
    print "Wrote", input100_path