From b34b9d4db3c6c75d1a167a4fc25d40214b351f80 Mon Sep 17 00:00:00 2001 From: Alex Gilday Date: Thu, 8 Mar 2018 11:28:29 +0000 Subject: COMPMID-977: Create script to generate npy image inputs to examples Change-Id: Ia951b038133054e6f54c41b006cb711b17536ec1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123708 Tested-by: Jenkins Reviewed-by: Anthony Barbier Reviewed-by: Giorgio Arena Reviewed-by: Georgios Pinitas --- scripts/caffe_mnist_image_extractor.py | 59 ++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 scripts/caffe_mnist_image_extractor.py (limited to 'scripts') diff --git a/scripts/caffe_mnist_image_extractor.py b/scripts/caffe_mnist_image_extractor.py new file mode 100644 index 0000000000..2c9478a611 --- /dev/null +++ b/scripts/caffe_mnist_image_extractor.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python +"""Extracts mnist image data from the Caffe data files and stores them in numpy arrays +Usage + python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path + +Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the +corresponding labels to labels100.txt. + +Tested with Caffe 1.0 on Python 2.7 +""" +import argparse +import os +import struct +import numpy as np +from array import array + + +if __name__ == "__main__": + # Parse arguments + parser = argparse.ArgumentParser('Extract Caffe mnist image data') + parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory') + parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)') + args = parser.parse_args() + + images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte') + labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte') + + images_file = open(images_filename, 'rb') + labels_file = open(labels_filename, 'rb') + images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16)) + labels_magic, labels_size = struct.unpack('>II', labels_file.read(8)) + images = array('B', images_file.read()) + labels = array('b', labels_file.read()) + + input10_path = os.path.join(args.outDir, 'input10.npy') + input100_path = os.path.join(args.outDir, 'input100.npy') + labels100_path = os.path.join(args.outDir, 'labels100.npy') + + outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32) + outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32) + labels_output = open(labels100_path, 'w') + for i in xrange(100): + image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0 + outputs_100[i, :, :, 0] = image + + if i < 10: + outputs_10[i, :, :, 0] = image + + if i == 10: + np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2))) + print "Wrote", input10_path + + labels_output.write(str(labels[i]) + '\n') + + labels_output.close() + print "Wrote", labels100_path + + np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2))) + print "Wrote", input100_path -- cgit v1.2.1