aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorAlex Gilday <alexander.gilday@arm.com>2018-03-08 11:28:29 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:16 +0000
commitb34b9d4db3c6c75d1a167a4fc25d40214b351f80 (patch)
treea42febb59b8891207d175db8ddd9d0bf2758c5fa /scripts
parent287051663030ccd945accdcd90905fb48bf30948 (diff)
downloadComputeLibrary-b34b9d4db3c6c75d1a167a4fc25d40214b351f80.tar.gz
COMPMID-977: Create script to generate npy image inputs to examples
Change-Id: Ia951b038133054e6f54c41b006cb711b17536ec1 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/123708 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'scripts')
-rw-r--r--scripts/caffe_mnist_image_extractor.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/scripts/caffe_mnist_image_extractor.py b/scripts/caffe_mnist_image_extractor.py
new file mode 100644
index 0000000000..2c9478a611
--- /dev/null
+++ b/scripts/caffe_mnist_image_extractor.py
@@ -0,0 +1,59 @@
+#!/usr/bin/env python
+"""Extracts mnist image data from the Caffe data files and stores them in numpy arrays
+Usage
+ python caffe_mnist_image_extractor.py -d path_to_caffe_data_directory -o desired_output_path
+
+Saves the first 10 images extracted as input10.npy, the first 100 images as input100.npy, and the
+corresponding labels to labels100.txt.
+
+Tested with Caffe 1.0 on Python 2.7
+"""
+import argparse
+import os
+import struct
+import numpy as np
+from array import array
+
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser('Extract Caffe mnist image data')
+ parser.add_argument('-d', dest='dataDir', type=str, required=True, help='Path to Caffe data directory')
+ parser.add_argument('-o', dest='outDir', type=str, default='.', help='Output directory (default = current directory)')
+ args = parser.parse_args()
+
+ images_filename = os.path.join(args.dataDir, 'mnist/t10k-images-idx3-ubyte')
+ labels_filename = os.path.join(args.dataDir, 'mnist/t10k-labels-idx1-ubyte')
+
+ images_file = open(images_filename, 'rb')
+ labels_file = open(labels_filename, 'rb')
+ images_magic, images_size, rows, cols = struct.unpack('>IIII', images_file.read(16))
+ labels_magic, labels_size = struct.unpack('>II', labels_file.read(8))
+ images = array('B', images_file.read())
+ labels = array('b', labels_file.read())
+
+ input10_path = os.path.join(args.outDir, 'input10.npy')
+ input100_path = os.path.join(args.outDir, 'input100.npy')
+ labels100_path = os.path.join(args.outDir, 'labels100.npy')
+
+ outputs_10 = np.zeros(( 10, 28, 28, 1), dtype=np.float32)
+ outputs_100 = np.zeros((100, 28, 28, 1), dtype=np.float32)
+ labels_output = open(labels100_path, 'w')
+ for i in xrange(100):
+ image = np.array(images[i * rows * cols : (i + 1) * rows * cols]).reshape((rows, cols)) / 256.0
+ outputs_100[i, :, :, 0] = image
+
+ if i < 10:
+ outputs_10[i, :, :, 0] = image
+
+ if i == 10:
+ np.save(input10_path, np.transpose(outputs_10, (0, 3, 1, 2)))
+ print "Wrote", input10_path
+
+ labels_output.write(str(labels[i]) + '\n')
+
+ labels_output.close()
+ print "Wrote", labels100_path
+
+ np.save(input100_path, np.transpose(outputs_100, (0, 3, 1, 2)))
+ print "Wrote", input100_path