aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCong Li <sicong.li@arm.com>2017-08-23 11:02:43 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit86b53339679e12c952a24a8845a5409ac3d52de6 (patch)
tree807c897ca1001f22b1906d285488877a287b482b
parent70e9bc21682f4eaedaceb632f594f588cb2c91fc (diff)
downloadComputeLibrary-86b53339679e12c952a24a8845a5409ac3d52de6.tar.gz
COMPMID-514 (3RDPARTY_UPDATE)(DATA_UPDATE) Add support to load .npy data
* Add tensorflow_data_extractor script. * Incorporate 3rdparty npy reader libnpy. * Port AlexNet system test to validation_new. * Port LeNet5 system test to validation_new. * Update 3rdparty/ and data/ submodules. Change-Id: I156d060fe9185cd8db810b34bf524cbf5cb34f61 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/84914 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
m---------3rdparty0
-rw-r--r--arm_compute/core/Window.h4
-rw-r--r--arm_compute/core/Window.inl6
m---------data0
-rw-r--r--docs/03_scripts.dox68
-rw-r--r--examples/neon_copy_objects.cpp4
-rwxr-xr-xscripts/caffe_data_extractor.py27
-rw-r--r--scripts/tensorflow_data_extractor.py51
-rw-r--r--src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLFillBorderKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLIm2ColKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLWeightsReshapeKernel.cpp4
-rw-r--r--src/core/ITensor.cpp6
-rw-r--r--src/core/NEON/kernels/NEFillBorderKernel.cpp2
-rw-r--r--src/core/NEON/kernels/NEFillInnerBorderKernel.cpp2
-rw-r--r--src/core/NEON/kernels/NEIm2ColKernel.cpp2
-rw-r--r--tests/AssetsLibrary.h91
-rw-r--r--tests/CL/CLAccessor.h12
-rw-r--r--tests/NEON/Accessor.h12
-rw-r--r--tests/Utils.h51
-rw-r--r--tests/validation/CL/SYSTEM/AlexNet.cpp (renamed from tests/validation_old/system_tests/CL/AlexNet.cpp)105
-rw-r--r--tests/validation/CL/SYSTEM/LeNet5.cpp (renamed from tests/validation_old/system_tests/CL/LeNet5.cpp)83
-rw-r--r--tests/validation/NEON/SYSTEM/AlexNet.cpp (renamed from tests/validation_old/system_tests/NEON/AlexNet.cpp)105
-rw-r--r--tests/validation/NEON/SYSTEM/LeNet5.cpp (renamed from tests/validation_old/system_tests/NEON/LeNet5.cpp)83
-rw-r--r--tests/validation_old/model_objects/AlexNet.h585
-rw-r--r--tests/validation_old/model_objects/LeNet5.h278
27 files changed, 456 insertions, 1133 deletions
diff --git a/3rdparty b/3rdparty
-Subproject 473b15cd5e41fc530b8619510ce45894b34739d
+Subproject 47648a7dde5b65cdce2ca4b53642b501f194930
diff --git a/arm_compute/core/Window.h b/arm_compute/core/Window.h
index baf767e7d5..654f5ed4f8 100644
--- a/arm_compute/core/Window.h
+++ b/arm_compute/core/Window.h
@@ -157,10 +157,10 @@ public:
/** Use the tensor's dimensions to fill the window dimensions.
*
- * @param[in] info Tensor information to copy the dimensions from.
+ * @param[in] shape @ref TensorShape to copy the dimensions from.
* @param[in] first_dimension Only copy dimensions which are greater or equal to this value.
*/
- void use_tensor_dimensions(const ITensorInfo *info, size_t first_dimension = Window::DimX);
+ void use_tensor_dimensions(const TensorShape &shape, size_t first_dimension = Window::DimX);
/** Shift the values of a given dimension by the given shift_value
*
diff --git a/arm_compute/core/Window.inl b/arm_compute/core/Window.inl
index 01cd988ea0..6b02128797 100644
--- a/arm_compute/core/Window.inl
+++ b/arm_compute/core/Window.inl
@@ -201,11 +201,11 @@ inline Window Window::first_slice_window() const
return slice;
}
-inline void Window::use_tensor_dimensions(const ITensorInfo *info, size_t first_dimension)
+inline void Window::use_tensor_dimensions(const TensorShape &shape, size_t first_dimension)
{
- for(unsigned int n = first_dimension; n < info->num_dimensions(); ++n)
+ for(unsigned int n = first_dimension; n < shape.num_dimensions(); ++n)
{
- set(n, Window::Dimension(0, std::max(info->dimension(n), static_cast<size_t>(1))));
+ set(n, Window::Dimension(0, std::max(shape[n], static_cast<size_t>(1))));
}
}
}
diff --git a/data b/data
-Subproject 1f4578a90cde937d510198fc0926adf42a81440
+Subproject 39104801bfc8c3885a044a723b5de98c44cc788
diff --git a/docs/03_scripts.dox b/docs/03_scripts.dox
index a91a93166b..2fd3907978 100644
--- a/docs/03_scripts.dox
+++ b/docs/03_scripts.dox
@@ -9,9 +9,9 @@ One can find caffe <a href="https://github.com/BVLC/caffe/wiki/Model-Zoo">pre-tr
caffe's official github repository.
The caffe_data_extractor.py provided in the @ref scripts folder is an example script that shows how to
-extract the hyperparameter values from a trained model.
+extract the parameter values from a trained model.
-@note complex networks might require alter the script to properly work.
+@note complex networks might require altering the script to properly work.
@subsection how_to How to use the script
@@ -22,19 +22,71 @@ Download the pre-trained caffe model.
Run the caffe_data_extractor.py script by
- ./caffe_data_extractor.py -m <caffe model> -n <caffe netlist>
+ python caffe_data_extractor.py -m <caffe model> -n <caffe netlist>
For example, to extract the data from pre-trained caffe Alex model to binary file:
- ./caffe_data_extractor.py -m /path/to/bvlc_alexnet.caffemodel -n /path/to/caffe/models/bvlc_alexnet/deploy.prototxt
+ python caffe_data_extractor.py -m /path/to/bvlc_alexnet.caffemodel -n /path/to/caffe/models/bvlc_alexnet/deploy.prototxt
The script has been tested under Python2.7.
-@subsection result What is the expected ouput from the script
+@subsection result What is the expected output from the script
-If the script run succesfully, it prints the shapes of each layer onto the standard
-output and generates *.dat files containing the weights and biases of each layer.
+If the script runs successfully, it prints the names and shapes of each layer onto the standard
+output and generates *.npy files containing the weights and biases of each layer.
The @ref arm_compute::utils::load_trained_data shows how one could load
-the weights and biases into tensor from the .dat file by the help of Accessor.
+the weights and biases into tensor from the .npy file by the help of Accessor.
+
+@section tensorflow_data_extractor Extract data from pre-trained tensorflow model
+
+The script tensorflow_data_extractor.py extracts trainable parameters (e.g. values of weights and biases) from a
+trained tensorflow model. A tensorflow model consists of the following two files:
+
+{model_name}.data-{step}-{global_step}: A binary file containing values of each variable.
+
+{model_name}.meta: A binary file containing a MetaGraph struct which defines the graph structure of the neural
+network.
+
+@note Since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
+ {model_name}.data-{step}-of-{max_step}
+instead of:
+ {model_name}.ckpt
+When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
+when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.
+
+@note This script relies on the parameters to be extracted being in the
+'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
+specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
+collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.
+
+@subsection how_to How to use the script
+
+Install tensorflow and numpy.
+
+Download the pre-trained tensorflow model.
+
+Run tensorflow_data_extractor.py with
+
+ python tensorflow_data_extractor -m <path_to_binary_checkpoint_file> -n <path_to_metagraph_file>
+
+For example, to extract the data from pre-trained tensorflow Alex model to binary files:
+
+ python tensorflow_data_extractor -m /path/to/bvlc_alexnet -n /path/to/bvlc_alexnet.meta
+
+Or for binary checkpoint files before Tensorflow 0.11:
+
+ python tensorflow_data_extractor -m /path/to/bvlc_alexnet.ckpt -n /path/to/bvlc_alexnet.meta
+
+@note with versions >= Tensorflow 0.11 only model name is passed to the -m option
+
+The script has been tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
+
+@subsection result What is the expected output from the script
+
+If the script runs successfully, it prints the names and shapes of each parameter onto the standard output and generates
+ *.npy files containing the weights and biases of each layer.
+
+The @ref arm_compute::utils::load_trained_data shows how one could load
+the weights and biases into tensor from the .npy file by the help of Accessor.
*/
diff --git a/examples/neon_copy_objects.cpp b/examples/neon_copy_objects.cpp
index 191f455557..04024530d5 100644
--- a/examples/neon_copy_objects.cpp
+++ b/examples/neon_copy_objects.cpp
@@ -75,7 +75,7 @@ void main_neon_copy_objects(int argc, const char **argv)
// Fill the input tensor:
// Simplest way: create an iterator to iterate through each element of the input tensor:
Window input_window;
- input_window.use_tensor_dimensions(input.info());
+ input_window.use_tensor_dimensions(input.info()->tensor_shape());
std::cout << " Dimensions of the input's iterator:\n";
std::cout << " X = [start=" << input_window.x().start() << ", end=" << input_window.x().end() << ", step=" << input_window.x().step() << "]\n";
std::cout << " Y = [start=" << input_window.y().start() << ", end=" << input_window.y().end() << ", step=" << input_window.y().step() << "]\n";
@@ -109,7 +109,7 @@ void main_neon_copy_objects(int argc, const char **argv)
// More efficient way: create an iterator to iterate through each row (instead of each element) of the output tensor:
Window output_window;
- output_window.use_tensor_dimensions(output.info(), /* first_dimension =*/Window::DimY); // Iterate through the rows (not each element)
+ output_window.use_tensor_dimensions(output.info()->tensor_shape(), /* first_dimension =*/Window::DimY); // Iterate through the rows (not each element)
std::cout << " Dimensions of the output's iterator:\n";
std::cout << " X = [start=" << output_window.x().start() << ", end=" << output_window.x().end() << ", step=" << output_window.x().step() << "]\n";
std::cout << " Y = [start=" << output_window.y().start() << ", end=" << output_window.y().end() << ", step=" << output_window.y().step() << "]\n";
diff --git a/scripts/caffe_data_extractor.py b/scripts/caffe_data_extractor.py
index 09ea0b86b0..65c9938480 100755
--- a/scripts/caffe_data_extractor.py
+++ b/scripts/caffe_data_extractor.py
@@ -1,16 +1,23 @@
#!/usr/bin/env python
-import argparse
+"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
+Usage
+ python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist
+
+Saves each variable to a {variable_name}.npy binary file.
+Tested with Caffe 1.0 on Python 2.7
+"""
+import argparse
import caffe
+import os
import numpy as np
-import scipy.io
if __name__ == "__main__":
# Parse arguments
- parser = argparse.ArgumentParser('Extract CNN hyper-parameters')
- parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Caffe model file')
- parser.add_argument('-n', dest='netFile', type=str, required=True, help='Caffe netlist')
+ parser = argparse.ArgumentParser('Extract Caffe net parameters')
+ parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
+ parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
args = parser.parse_args()
# Create Caffe Net
@@ -18,7 +25,7 @@ if __name__ == "__main__":
# Read and dump blobs
for name, blobs in net.params.iteritems():
- print 'Name: {0}, Blobs: {1}'.format(name, len(blobs))
+ print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
for i in range(len(blobs)):
# Weights
if i == 0:
@@ -29,6 +36,10 @@ if __name__ == "__main__":
else:
pass
- print("%s : %s" % (outname, blobs[i].data.shape))
+ varname = outname
+ if os.path.sep in varname:
+ varname = varname.replace(os.path.sep, '_')
+ print("Renaming variable {0} to {1}".format(outname, varname))
+ print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
# Dump as binary
- blobs[i].data.tofile(outname + ".dat")
+ np.save(varname, blobs[i].data)
diff --git a/scripts/tensorflow_data_extractor.py b/scripts/tensorflow_data_extractor.py
new file mode 100644
index 0000000000..1dbf0e127e
--- /dev/null
+++ b/scripts/tensorflow_data_extractor.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python
+"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays.
+Usage
+ python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file
+
+Saves each variable to a {variable_name}.npy binary file.
+
+Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of:
+ {model_name}.data-{step}-of-{max_step}
+instead of:
+ {model_name}.ckpt
+When dealing with binary files with version >= 0.11, only pass {model_name} to -m option;
+when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option.
+
+Also note that this script relies on the parameters to be extracted being in the
+'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless
+specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other
+collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly.
+
+Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3.
+"""
+import argparse
+import numpy as np
+import os
+import tensorflow as tf
+
+
+if __name__ == "__main__":
+ # Parse arguments
+ parser = argparse.ArgumentParser('Extract Tensorflow net parameters')
+ parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\
+ file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\
+ model name with ".ckpt" extension')
+ parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file')
+ args = parser.parse_args()
+
+ # Load Tensorflow Net
+ saver = tf.train.import_meta_graph(args.netFile)
+ with tf.Session() as sess:
+ # Restore session
+ saver.restore(sess, args.modelFile)
+ print('Model restored.')
+ # Save trainable variables to numpy arrays
+ for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
+ varname = t.name
+ if os.path.sep in t.name:
+ varname = varname.replace(os.path.sep, '_')
+ print("Renaming variable {0} to {1}".format(t.name, varname))
+ print("Saving variable {0} with shape {1} ...".format(varname, t.shape))
+ # Dump as binary
+ np.save(varname, sess.run(t))
diff --git a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
index 9d1c87d9e1..265c5074c5 100644
--- a/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
+++ b/src/core/CL/kernels/CLDirectConvolutionLayerKernel.cpp
@@ -246,7 +246,7 @@ void CLDirectConvolutionLayerKernel::run(const Window &window, cl::CommandQueue
if(_biases != nullptr)
{
Window slice_biases;
- slice_biases.use_tensor_dimensions(_biases->info());
+ slice_biases.use_tensor_dimensions(_biases->info()->tensor_shape());
add_1D_tensor_argument(idx1, _biases, slice_biases);
}
diff --git a/src/core/CL/kernels/CLFillBorderKernel.cpp b/src/core/CL/kernels/CLFillBorderKernel.cpp
index 6ff152113b..d2610539d1 100644
--- a/src/core/CL/kernels/CLFillBorderKernel.cpp
+++ b/src/core/CL/kernels/CLFillBorderKernel.cpp
@@ -157,7 +157,7 @@ void CLFillBorderKernel::configure(ICLTensor *tensor, BorderSize border_size, Bo
Window win;
win.set(Window::DimX, Window::Dimension(0, total_valid_width + valid_height));
win.set(Window::DimY, Window::Dimension(0, 1, 1));
- win.use_tensor_dimensions(tensor->info(), Window::DimZ);
+ win.use_tensor_dimensions(tensor->info()->tensor_shape(), Window::DimZ);
ICLKernel::configure(win);
}
diff --git a/src/core/CL/kernels/CLIm2ColKernel.cpp b/src/core/CL/kernels/CLIm2ColKernel.cpp
index 5147ea0609..3d21a9e3c0 100644
--- a/src/core/CL/kernels/CLIm2ColKernel.cpp
+++ b/src/core/CL/kernels/CLIm2ColKernel.cpp
@@ -182,7 +182,7 @@ void CLIm2ColKernel::run_reduced(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
Window out_window;
- out_window.use_tensor_dimensions(_output->info());
+ out_window.use_tensor_dimensions(_output->info()->tensor_shape());
Window out_slice = out_window.first_slice_window_1D();
Window in_slice = window.first_slice_window_3D();
diff --git a/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp
index 794a1bc56e..508fb899f1 100644
--- a/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLLocallyConnectedMatrixMultiplyKernel.cpp
@@ -101,7 +101,7 @@ void CLLocallyConnectedMatrixMultiplyKernel::run(const Window &window, cl::Comma
Window slice = window.first_slice_window_2D();
Window matrix_b_window;
- matrix_b_window.use_tensor_dimensions(_input1->info());
+ matrix_b_window.use_tensor_dimensions(_input1->info()->tensor_shape());
Window slice_matrix_b = matrix_b_window.first_slice_window_3D();
do
diff --git a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
index 7b80f3ff5a..bc27477971 100644
--- a/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
+++ b/src/core/CL/kernels/CLWeightsReshapeKernel.cpp
@@ -107,7 +107,7 @@ void CLWeightsReshapeKernel::run(const Window &window, cl::CommandQueue &queue)
ARM_COMPUTE_ERROR_ON_MISMATCHING_WINDOWS(ICLKernel::window(), window);
Window out_window;
- out_window.use_tensor_dimensions(_output->info());
+ out_window.use_tensor_dimensions(_output->info()->tensor_shape());
Window in_slice = window.first_slice_window_3D();
Window out_slice = out_window.first_slice_window_2D();
@@ -117,7 +117,7 @@ void CLWeightsReshapeKernel::run(const Window &window, cl::CommandQueue &queue)
if(_biases != nullptr)
{
- biases_window.use_tensor_dimensions(_biases->info());
+ biases_window.use_tensor_dimensions(_biases->info()->tensor_shape());
biases_slice = biases_window.first_slice_window_1D();
}
diff --git a/src/core/ITensor.cpp b/src/core/ITensor.cpp
index 0b29eca57b..4a54675561 100644
--- a/src/core/ITensor.cpp
+++ b/src/core/ITensor.cpp
@@ -55,9 +55,9 @@ void ITensor::copy_from(const ITensor &src)
dst_info->set_valid_region(src_info->valid_region());
Window win_src;
- win_src.use_tensor_dimensions(src_info, Window::DimY);
+ win_src.use_tensor_dimensions(src_info->tensor_shape(), Window::DimY);
Window win_dst;
- win_dst.use_tensor_dimensions(dst_info, Window::DimY);
+ win_dst.use_tensor_dimensions(dst_info->tensor_shape(), Window::DimY);
Iterator src_it(&src, win_src);
Iterator dst_it(this, win_dst);
@@ -147,4 +147,4 @@ void ITensor::print(std::ostream &s, IOFormatInfo io_fmt) const
s << io_fmt.row_delim;
}
}
-} \ No newline at end of file
+}
diff --git a/src/core/NEON/kernels/NEFillBorderKernel.cpp b/src/core/NEON/kernels/NEFillBorderKernel.cpp
index 65d5388c4b..3f1f678a7e 100644
--- a/src/core/NEON/kernels/NEFillBorderKernel.cpp
+++ b/src/core/NEON/kernels/NEFillBorderKernel.cpp
@@ -120,7 +120,7 @@ void NEFillBorderKernel::configure(ITensor *tensor, BorderSize border_size, Bord
Window win;
win.set(Window::DimX, Window::Dimension(0, 1, 1));
win.set(Window::DimY, Window::Dimension(0, 1, 1));
- win.use_tensor_dimensions(_tensor->info(), Window::DimZ);
+ win.use_tensor_dimensions(_tensor->info()->tensor_shape(), Window::DimZ);
INEKernel::configure(win);
}
diff --git a/src/core/NEON/kernels/NEFillInnerBorderKernel.cpp b/src/core/NEON/kernels/NEFillInnerBorderKernel.cpp
index 5323733fd3..017e259ca4 100644
--- a/src/core/NEON/kernels/NEFillInnerBorderKernel.cpp
+++ b/src/core/NEON/kernels/NEFillInnerBorderKernel.cpp
@@ -57,7 +57,7 @@ void NEFillInnerBorderKernel::configure(ITensor *input, BorderSize border_size,
Window win;
win.set(Window::DimX, Window::Dimension(0, 1, 1));
win.set(Window::DimY, Window::Dimension(0, 1, 1));
- win.use_tensor_dimensions(_tensor->info(), Window::DimZ);
+ win.use_tensor_dimensions(_tensor->info()->tensor_shape(), Window::DimZ);
INEKernel::configure(win);
}
diff --git a/src/core/NEON/kernels/NEIm2ColKernel.cpp b/src/core/NEON/kernels/NEIm2ColKernel.cpp
index 3e50277cdf..71910e3a69 100644
--- a/src/core/NEON/kernels/NEIm2ColKernel.cpp
+++ b/src/core/NEON/kernels/NEIm2ColKernel.cpp
@@ -231,7 +231,7 @@ void NEIm2ColKernel::run_reduced(const Window &window)
in_window.set(Window::DimX, Window::Dimension(0, 1, 1));
Window out_window;
- out_window.use_tensor_dimensions(_output->info());
+ out_window.use_tensor_dimensions(_output->info()->tensor_shape());
out_window.set(Window::DimX, Window::Dimension(out_window.x().start(), out_window.x().end(), in_width));
Window in_slice = in_window.first_slice_window_3D();
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index 6945aa6fe1..58406c5dca 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -31,6 +31,7 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Window.h"
+#include "libnpy/npy.hpp"
#include "tests/RawTensor.h"
#include "tests/TensorCache.h"
#include "tests/Utils.h"
@@ -42,6 +43,7 @@
#include <random>
#include <string>
#include <type_traits>
+#include <vector>
namespace arm_compute
{
@@ -290,10 +292,22 @@ public:
template <typename T, typename D>
void fill_tensor_uniform(T &&tensor, std::random_device::result_type seed_offset, D low, D high) const;
- /** Fills the specified @p tensor with data loaded from binary in specified path.
+ /** Fills the specified @p tensor with data loaded from .npy (numpy binary) in specified path.
*
* @param[in, out] tensor To be filled tensor.
* @param[in] name Data file.
+ *
+ * @note The numpy array stored in the binary .npy file must be row-major in the sense that it
+ * must store elements within a row consecutively in the memory, then rows within a 2D slice,
+ * then 2D slices within a 3D slice and so on. Note that it imposes no restrictions on what
+ * indexing convention is used in the numpy array. That is, the numpy array can be either fortran
+ * style or C style as long as it adheres to the rule above.
+ *
+ * More concretely, the orders of dimensions for each style are as follows:
+ * C-style (numpy default):
+ * array[HigherDims..., Z, Y, X]
+ * Fortran style:
+ * array[X, Y, Z, HigherDims...]
*/
template <typename T>
void fill_layer_data(T &&tensor, std::string name) const;
@@ -644,30 +658,77 @@ void AssetsLibrary::fill_layer_data(T &&tensor, std::string name) const
#else /* _WIN32 */
const std::string path_separator("/");
#endif /* _WIN32 */
-
const std::string path = _library_path + path_separator + name;
+ std::vector<unsigned long> shape;
+
// Open file
- std::ifstream file(path, std::ios::in | std::ios::binary);
- if(!file.good())
+ std::ifstream stream(path, std::ios::in | std::ios::binary);
+ ARM_COMPUTE_ERROR_ON_MSG(!stream.good(), "Failed to load binary data");
+ // Check magic bytes and version number
+ unsigned char v_major = 0;
+ unsigned char v_minor = 0;
+ npy::read_magic(stream, &v_major, &v_minor);
+
+ // Read header
+ std::string header;
+ if(v_major == 1 && v_minor == 0)
+ {
+ header = npy::read_header_1_0(stream);
+ }
+ else if(v_major == 2 && v_minor == 0)
+ {
+ header = npy::read_header_2_0(stream);
+ }
+ else
{
- throw std::runtime_error("Could not load binary data: " + path);
+ ARM_COMPUTE_ERROR("Unsupported file format version");
}
- Window window;
- for(unsigned int d = 0; d < tensor.shape().num_dimensions(); ++d)
+ // Parse header
+ bool fortran_order = false;
+ std::string typestr;
+ npy::ParseHeader(header, typestr, &fortran_order, shape);
+
+ // Check if the typestring matches the given one
+ std::string expect_typestr = get_typestring(tensor.data_type());
+ ARM_COMPUTE_ERROR_ON_MSG(typestr != expect_typestr, "Typestrings mismatch");
+
+ // Validate tensor shape
+ ARM_COMPUTE_ERROR_ON_MSG(shape.size() != tensor.shape().num_dimensions(), "Tensor ranks mismatch");
+ if(fortran_order)
{
- window.set(d, Window::Dimension(0, tensor.shape()[d], 1));
+ for(size_t i = 0; i < shape.size(); ++i)
+ {
+ ARM_COMPUTE_ERROR_ON_MSG(tensor.shape()[i] != shape[i], "Tensor dimensions mismatch");
+ }
+ }
+ else
+ {
+ for(size_t i = 0; i < shape.size(); ++i)
+ {
+ ARM_COMPUTE_ERROR_ON_MSG(tensor.shape()[i] != shape[shape.size() - i - 1], "Tensor dimensions mismatch");
+ }
}
- //FIXME : Replace with normal loop
- execute_window_loop(window, [&](const Coordinates & id)
+ // Read data
+ if(tensor.padding().empty())
{
- float val;
- file.read(reinterpret_cast<char *>(&val), sizeof(float));
- void *const out_ptr = tensor(id);
- store_value_with_data_type(out_ptr, val, tensor.data_type());
- });
+ // If tensor has no padding read directly from stream.
+ stream.read(reinterpret_cast<char *>(tensor.data()), tensor.size());
+ }
+ else
+ {
+ // If tensor has padding accessing tensor elements through execution window.
+ Window window;
+ window.use_tensor_dimensions(tensor.shape());
+
+ //FIXME : Replace with normal loop
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ stream.read(reinterpret_cast<char *>(tensor(id)), tensor.element_size());
+ });
+ }
}
} // namespace test
} // namespace arm_compute
diff --git a/tests/CL/CLAccessor.h b/tests/CL/CLAccessor.h
index b1d7a078b1..2f955653c8 100644
--- a/tests/CL/CLAccessor.h
+++ b/tests/CL/CLAccessor.h
@@ -63,6 +63,8 @@ public:
int fixed_point_position() const override;
const void *operator()(const Coordinates &coord) const override;
void *operator()(const Coordinates &coord) override;
+ const void *data() const;
+ void *data();
private:
CLTensor &_tensor;
@@ -124,6 +126,16 @@ inline int CLAccessor::fixed_point_position() const
return _tensor.info()->fixed_point_position();
}
+inline const void *CLAccessor::data() const
+{
+ return _tensor.buffer();
+}
+
+inline void *CLAccessor::data()
+{
+ return _tensor.buffer();
+}
+
inline const void *CLAccessor::operator()(const Coordinates &coord) const
{
return _tensor.ptr_to_element(coord);
diff --git a/tests/NEON/Accessor.h b/tests/NEON/Accessor.h
index c379018d39..e0ff35231c 100644
--- a/tests/NEON/Accessor.h
+++ b/tests/NEON/Accessor.h
@@ -57,6 +57,8 @@ public:
int fixed_point_position() const override;
const void *operator()(const Coordinates &coord) const override;
void *operator()(const Coordinates &coord) override;
+ const void *data() const;
+ void *data();
private:
Tensor &_tensor;
@@ -112,6 +114,16 @@ inline int Accessor::fixed_point_position() const
return _tensor.info()->fixed_point_position();
}
+inline const void *Accessor::data() const
+{
+ return _tensor.buffer();
+}
+
+inline void *Accessor::data()
+{
+ return _tensor.buffer();
+}
+
inline const void *Accessor::operator()(const Coordinates &coord) const
{
return _tensor.ptr_to_element(coord);
diff --git a/tests/Utils.h b/tests/Utils.h
index 7af38e5c93..f325bb3f37 100644
--- a/tests/Utils.h
+++ b/tests/Utils.h
@@ -476,6 +476,57 @@ inline void fill_array(ArrayAccessor_T &&array, const std::vector<T> &v)
array.resize(v.size());
std::memcpy(array.buffer(), v.data(), v.size() * sizeof(T));
}
+
+/** Obtain numpy type string from DataType.
+ *
+ * @param[in] data_type Data type.
+ *
+ * @return numpy type string.
+ */
+inline std::string get_typestring(DataType data_type)
+{
+ // Check endianness
+ const unsigned int i = 1;
+ const char *c = reinterpret_cast<const char *>(&i);
+ std::string endianness;
+ if(*c == 1)
+ {
+ endianness = std::string("<");
+ }
+ else
+ {
+ endianness = std::string(">");
+ }
+ const std::string no_endianness("|");
+
+ switch(data_type)
+ {
+ case DataType::U8:
+ return no_endianness + "u" + support::cpp11::to_string(sizeof(uint8_t));
+ case DataType::S8:
+ return no_endianness + "i" + support::cpp11::to_string(sizeof(int8_t));
+ case DataType::U16:
+ return endianness + "u" + support::cpp11::to_string(sizeof(uint16_t));
+ case DataType::S16:
+ return endianness + "i" + support::cpp11::to_string(sizeof(int16_t));
+ case DataType::U32:
+ return endianness + "u" + support::cpp11::to_string(sizeof(uint32_t));
+ case DataType::S32:
+ return endianness + "i" + support::cpp11::to_string(sizeof(int32_t));
+ case DataType::U64:
+ return endianness + "u" + support::cpp11::to_string(sizeof(uint64_t));
+ case DataType::S64:
+ return endianness + "i" + support::cpp11::to_string(sizeof(int64_t));
+ case DataType::F32:
+ return endianness + "f" + support::cpp11::to_string(sizeof(float));
+ case DataType::F64:
+ return endianness + "f" + support::cpp11::to_string(sizeof(double));
+ case DataType::SIZET:
+ return endianness + "u" + support::cpp11::to_string(sizeof(size_t));
+ default:
+ ARM_COMPUTE_ERROR("NOT SUPPORTED!");
+ }
+}
} // namespace test
} // namespace arm_compute
#endif /* __ARM_COMPUTE_TEST_UTILS_H__ */
diff --git a/tests/validation_old/system_tests/CL/AlexNet.cpp b/tests/validation/CL/SYSTEM/AlexNet.cpp
index b403b6e93c..d7dd62d4f2 100644
--- a/tests/validation_old/system_tests/CL/AlexNet.cpp
+++ b/tests/validation/CL/SYSTEM/AlexNet.cpp
@@ -22,83 +22,66 @@
* SOFTWARE.
*/
#ifdef INTERNAL_ONLY //FIXME Delete this file before the release
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "CL/CLAccessor.h"
-#include "tests/validation_old/Validation.h"
#include "arm_compute/runtime/CL/CLSubTensor.h"
#include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
#include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h"
+#include "arm_compute/runtime/CL/functions/CLDirectConvolutionLayer.h"
#include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
#include "arm_compute/runtime/CL/functions/CLNormalizationLayer.h"
#include "arm_compute/runtime/CL/functions/CLPoolingLayer.h"
#include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/networks/AlexNetNetwork.h"
+#include "tests/validation/Validation.h"
-#include "tests/validation_old/model_objects/AlexNet.h"
-
-#include <array>
-
-using namespace arm_compute;
-using namespace arm_compute::test;
-using namespace arm_compute::test::validation;
+#include <string>
+#include <vector>
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
namespace
{
-using CLAlexNetModel = model_objects::AlexNet<ICLTensor,
+using CLAlexNetModel = networks::AlexNetNetwork<ICLTensor,
CLTensor,
CLSubTensor,
CLAccessor,
CLActivationLayer,
CLConvolutionLayer,
+ CLDirectConvolutionLayer,
CLFullyConnectedLayer,
CLNormalizationLayer,
CLPoolingLayer,
CLSoftmaxLayer>;
-std::vector<unsigned int> compute_alexnet(unsigned int batches, std::string input_file)
+std::vector<unsigned int> compute_alexnet(DataType dt, unsigned int batches, std::string input_file)
{
- std::vector<std::string> weight_files = { "cnn_data/alexnet_model/conv1_w.dat",
- "cnn_data/alexnet_model/conv2_w.dat",
- "cnn_data/alexnet_model/conv3_w.dat",
- "cnn_data/alexnet_model/conv4_w.dat",
- "cnn_data/alexnet_model/conv5_w.dat",
- "cnn_data/alexnet_model/fc6_w.dat",
- "cnn_data/alexnet_model/fc7_w.dat",
- "cnn_data/alexnet_model/fc8_w.dat"
+ std::vector<std::string> weight_files = { "cnn_data/alexnet_model/conv1_w.npy",
+ "cnn_data/alexnet_model/conv2_w.npy",
+ "cnn_data/alexnet_model/conv3_w.npy",
+ "cnn_data/alexnet_model/conv4_w.npy",
+ "cnn_data/alexnet_model/conv5_w.npy",
+ "cnn_data/alexnet_model/fc6_w.npy",
+ "cnn_data/alexnet_model/fc7_w.npy",
+ "cnn_data/alexnet_model/fc8_w.npy"
};
- std::vector<std::string> bias_files = { "cnn_data/alexnet_model/conv1_b.dat",
- "cnn_data/alexnet_model/conv2_b.dat",
- "cnn_data/alexnet_model/conv3_b.dat",
- "cnn_data/alexnet_model/conv4_b.dat",
- "cnn_data/alexnet_model/conv5_b.dat",
- "cnn_data/alexnet_model/fc6_b.dat",
- "cnn_data/alexnet_model/fc7_b.dat",
- "cnn_data/alexnet_model/fc8_b.dat"
+ std::vector<std::string> bias_files = { "cnn_data/alexnet_model/conv1_b.npy",
+ "cnn_data/alexnet_model/conv2_b.npy",
+ "cnn_data/alexnet_model/conv3_b.npy",
+ "cnn_data/alexnet_model/conv4_b.npy",
+ "cnn_data/alexnet_model/conv5_b.npy",
+ "cnn_data/alexnet_model/fc6_b.npy",
+ "cnn_data/alexnet_model/fc7_b.npy",
+ "cnn_data/alexnet_model/fc8_b.npy"
};
CLAlexNetModel network{};
- network.init_weights(batches);
+ network.init(dt, 4, batches);
network.build();
network.allocate();
network.fill(weight_files, bias_files);
@@ -109,24 +92,24 @@ std::vector<unsigned int> compute_alexnet(unsigned int batches, std::string inpu
}
} // namespace
-#ifndef DOXYGEN_SKIP_THIS
-BOOST_AUTO_TEST_SUITE(SYSTEM_TESTS)
-BOOST_AUTO_TEST_SUITE(CL)
+TEST_SUITE(CL)
+TEST_SUITE(SYSTEM_TESTS)
-BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
-BOOST_AUTO_TEST_CASE(AlexNet)
+TEST_CASE(AlexNet, framework::DatasetMode::PRECOMMIT)
{
// Compute alexnet
- std::vector<unsigned int> classified_labels = compute_alexnet(1, "cnn_data/imagenet_data/shark.dat");
+ std::vector<unsigned int> classified_labels = compute_alexnet(DataType::F32, 1, "cnn_data/imagenet_data/cat.npy");
// Expected labels
- std::vector<unsigned int> expected_labels = { 2 };
+ std::vector<unsigned int> expected_labels = { 281 };
// Validate labels
validate(classified_labels, expected_labels);
}
-BOOST_AUTO_TEST_SUITE_END()
-BOOST_AUTO_TEST_SUITE_END()
-#endif /* DOXYGEN_SKIP_THIS */
+TEST_SUITE_END()
+TEST_SUITE_END()
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
#endif /* INTERNAL_ONLY */
diff --git a/tests/validation_old/system_tests/CL/LeNet5.cpp b/tests/validation/CL/SYSTEM/LeNet5.cpp
index 0f34dd1ae7..6f68fa11a0 100644
--- a/tests/validation_old/system_tests/CL/LeNet5.cpp
+++ b/tests/validation/CL/SYSTEM/LeNet5.cpp
@@ -22,47 +22,30 @@
* SOFTWARE.
*/
#ifdef INTERNAL_ONLY //FIXME Delete this file before the release
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "CL/CLAccessor.h"
-#include "tests/validation_old/Validation.h"
#include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
#include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h"
#include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
#include "arm_compute/runtime/CL/functions/CLPoolingLayer.h"
#include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h"
+#include "tests/CL/CLAccessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/networks/LeNet5Network.h"
+#include "tests/validation/Validation.h"
-#include "tests/validation_old/model_objects/LeNet5.h"
-
-using namespace arm_compute;
-using namespace arm_compute::test;
-using namespace arm_compute::test::validation;
+#include <string>
+#include <vector>
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
namespace
{
-using CLLeNet5Model = model_objects::LeNet5<CLTensor,
+using CLLeNet5Model = networks::LeNet5Network<CLTensor,
CLAccessor,
CLActivationLayer,
CLConvolutionLayer,
@@ -71,19 +54,21 @@ using CLLeNet5Model = model_objects::LeNet5<CLTensor,
CLSoftmaxLayer>;
std::vector<unsigned int> compute_lenet5(unsigned int batches, std::string input_file)
{
- std::vector<std::string> weight_files = { "cnn_data/lenet_model/conv1_w.dat",
- "cnn_data/lenet_model/conv2_w.dat",
- "cnn_data/lenet_model/ip1_w.dat",
- "cnn_data/lenet_model/ip2_w.dat"
+ std::vector<std::string> weight_files = { "cnn_data/lenet_model/conv1_w.npy",
+ "cnn_data/lenet_model/conv2_w.npy",
+ "cnn_data/lenet_model/ip1_w.npy",
+ "cnn_data/lenet_model/ip2_w.npy"
};
- std::vector<std::string> bias_files = { "cnn_data/lenet_model/conv1_b.dat",
- "cnn_data/lenet_model/conv2_b.dat",
- "cnn_data/lenet_model/ip1_b.dat",
- "cnn_data/lenet_model/ip2_b.dat"
+ std::vector<std::string> bias_files = { "cnn_data/lenet_model/conv1_b.npy",
+ "cnn_data/lenet_model/conv2_b.npy",
+ "cnn_data/lenet_model/ip1_b.npy",
+ "cnn_data/lenet_model/ip2_b.npy"
};
CLLeNet5Model network{};
- network.build(batches);
+ network.init(batches);
+ network.build();
+ network.allocate();
network.fill(weight_files, bias_files);
network.feed(std::move(input_file));
network.run();
@@ -92,15 +77,13 @@ std::vector<unsigned int> compute_lenet5(unsigned int batches, std::string input
}
} // namespace
-#ifndef DOXYGEN_SKIP_THIS
-BOOST_AUTO_TEST_SUITE(SYSTEM_TESTS)
-BOOST_AUTO_TEST_SUITE(CL)
+TEST_SUITE(CL)
+TEST_SUITE(SYSTEM_TESTS)
-BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
-BOOST_AUTO_TEST_CASE(LeNet5)
+TEST_CASE(LeNet5, framework::DatasetMode::PRECOMMIT)
{
// Compute alexnet
- std::vector<unsigned int> classified_labels = compute_lenet5(10, "cnn_data/mnist_data/input100.dat");
+ std::vector<unsigned int> classified_labels = compute_lenet5(10, "cnn_data/mnist_data/input10.npy");
// Expected labels
std::vector<unsigned int> expected_labels = { 7, 2, 1, 0, 4, 1, 4, 9, 5, 9 };
@@ -109,7 +92,9 @@ BOOST_AUTO_TEST_CASE(LeNet5)
validate(classified_labels, expected_labels);
}
-BOOST_AUTO_TEST_SUITE_END()
-BOOST_AUTO_TEST_SUITE_END()
-#endif /* DOXYGEN_SKIP_THIS */
+TEST_SUITE_END()
+TEST_SUITE_END()
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
#endif /* INTERNAL_ONLY */
diff --git a/tests/validation_old/system_tests/NEON/AlexNet.cpp b/tests/validation/NEON/SYSTEM/AlexNet.cpp
index 9697cf35e8..7a2b0d22b0 100644
--- a/tests/validation_old/system_tests/NEON/AlexNet.cpp
+++ b/tests/validation/NEON/SYSTEM/AlexNet.cpp
@@ -22,84 +22,67 @@
* SOFTWARE.
*/
#ifdef INTERNAL_ONLY //FIXME Delete this file before the release
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "NEON/Accessor.h"
-#include "tests/validation_old/Validation.h"
#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h"
#include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h"
#include "arm_compute/runtime/NEON/functions/NENormalizationLayer.h"
#include "arm_compute/runtime/NEON/functions/NEPoolingLayer.h"
#include "arm_compute/runtime/NEON/functions/NESoftmaxLayer.h"
#include "arm_compute/runtime/SubTensor.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/networks/AlexNetNetwork.h"
+#include "tests/validation/Validation.h"
-#include "tests/validation_old/model_objects/AlexNet.h"
-
-#include <array>
-
-using namespace arm_compute;
-using namespace arm_compute::test;
-using namespace arm_compute::test::validation;
+#include <string>
+#include <vector>
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
namespace
{
-using NEAlexNetModel = model_objects::AlexNet<ITensor,
+using NEAlexNetModel = networks::AlexNetNetwork<ITensor,
Tensor,
SubTensor,
Accessor,
NEActivationLayer,
NEConvolutionLayer,
+ NEDirectConvolutionLayer,
NEFullyConnectedLayer,
NENormalizationLayer,
NEPoolingLayer,
NESoftmaxLayer>;
-std::vector<unsigned int> compute_alexnet(unsigned int batches, std::string input_file)
+std::vector<unsigned int> compute_alexnet(DataType dt, unsigned int batches, std::string input_file)
{
- std::vector<std::string> weight_files = { "cnn_data/alexnet_model/conv1_w.dat",
- "cnn_data/alexnet_model/conv2_w.dat",
- "cnn_data/alexnet_model/conv3_w.dat",
- "cnn_data/alexnet_model/conv4_w.dat",
- "cnn_data/alexnet_model/conv5_w.dat",
- "cnn_data/alexnet_model/fc6_w.dat",
- "cnn_data/alexnet_model/fc7_w.dat",
- "cnn_data/alexnet_model/fc8_w.dat"
+ std::vector<std::string> weight_files = { "cnn_data/alexnet_model/conv1_w.npy",
+ "cnn_data/alexnet_model/conv2_w.npy",
+ "cnn_data/alexnet_model/conv3_w.npy",
+ "cnn_data/alexnet_model/conv4_w.npy",
+ "cnn_data/alexnet_model/conv5_w.npy",
+ "cnn_data/alexnet_model/fc6_w.npy",
+ "cnn_data/alexnet_model/fc7_w.npy",
+ "cnn_data/alexnet_model/fc8_w.npy"
};
- std::vector<std::string> bias_files = { "cnn_data/alexnet_model/conv1_b.dat",
- "cnn_data/alexnet_model/conv2_b.dat",
- "cnn_data/alexnet_model/conv3_b.dat",
- "cnn_data/alexnet_model/conv4_b.dat",
- "cnn_data/alexnet_model/conv5_b.dat",
- "cnn_data/alexnet_model/fc6_b.dat",
- "cnn_data/alexnet_model/fc7_b.dat",
- "cnn_data/alexnet_model/fc8_b.dat"
+ std::vector<std::string> bias_files = { "cnn_data/alexnet_model/conv1_b.npy",
+ "cnn_data/alexnet_model/conv2_b.npy",
+ "cnn_data/alexnet_model/conv3_b.npy",
+ "cnn_data/alexnet_model/conv4_b.npy",
+ "cnn_data/alexnet_model/conv5_b.npy",
+ "cnn_data/alexnet_model/fc6_b.npy",
+ "cnn_data/alexnet_model/fc7_b.npy",
+ "cnn_data/alexnet_model/fc8_b.npy"
};
NEAlexNetModel network{};
- network.init_weights(batches);
+ network.init(dt, 4, batches);
network.build();
network.allocate();
network.fill(weight_files, bias_files);
@@ -110,24 +93,24 @@ std::vector<unsigned int> compute_alexnet(unsigned int batches, std::string inpu
}
} // namespace
-#ifndef DOXYGEN_SKIP_THIS
-BOOST_AUTO_TEST_SUITE(SYSTEM_TESTS)
-BOOST_AUTO_TEST_SUITE(NEON)
+TEST_SUITE(NEON)
+TEST_SUITE(SYSTEM_TESTS)
-BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
-BOOST_AUTO_TEST_CASE(AlexNet)
+TEST_CASE(AlexNet, framework::DatasetMode::PRECOMMIT)
{
// Compute alexnet
- std::vector<unsigned int> classified_labels = compute_alexnet(1, "cnn_data/imagenet_data/shark.dat");
+ std::vector<unsigned int> classified_labels = compute_alexnet(DataType::F32, 1, "cnn_data/imagenet_data/cat.npy");
// Expected labels
- std::vector<unsigned int> expected_labels = { 2 };
+ std::vector<unsigned int> expected_labels = { 281 };
// Validate labels
validate(classified_labels, expected_labels);
}
-BOOST_AUTO_TEST_SUITE_END()
-BOOST_AUTO_TEST_SUITE_END()
-#endif /* DOXYGEN_SKIP_THIS */
+TEST_SUITE_END()
+TEST_SUITE_END()
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
#endif /* INTERNAL_ONLY */
diff --git a/tests/validation_old/system_tests/NEON/LeNet5.cpp b/tests/validation/NEON/SYSTEM/LeNet5.cpp
index 916df98eb6..1642de8a97 100644
--- a/tests/validation_old/system_tests/NEON/LeNet5.cpp
+++ b/tests/validation/NEON/SYSTEM/LeNet5.cpp
@@ -22,47 +22,30 @@
* SOFTWARE.
*/
#ifdef INTERNAL_ONLY //FIXME Delete this file before the release
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "NEON/Accessor.h"
-#include "tests/validation_old/Validation.h"
#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h"
#include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h"
#include "arm_compute/runtime/NEON/functions/NEPoolingLayer.h"
#include "arm_compute/runtime/NEON/functions/NESoftmaxLayer.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/framework/Asserts.h"
+#include "tests/framework/Macros.h"
+#include "tests/networks/LeNet5Network.h"
+#include "tests/validation/Validation.h"
-#include "tests/validation_old/model_objects/LeNet5.h"
-
-using namespace arm_compute;
-using namespace arm_compute::test;
-using namespace arm_compute::test::validation;
+#include <string>
+#include <vector>
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
namespace
{
-using NELeNet5Model = model_objects::LeNet5<Tensor,
+using NELeNet5Model = networks::LeNet5Network<Tensor,
Accessor,
NEActivationLayer,
NEConvolutionLayer,
@@ -71,19 +54,21 @@ using NELeNet5Model = model_objects::LeNet5<Tensor,
NESoftmaxLayer>;
std::vector<unsigned int> compute_lenet5(unsigned int batches, std::string input_file)
{
- std::vector<std::string> weight_files = { "cnn_data/lenet_model/conv1_w.dat",
- "cnn_data/lenet_model/conv2_w.dat",
- "cnn_data/lenet_model/ip1_w.dat",
- "cnn_data/lenet_model/ip2_w.dat"
+ std::vector<std::string> weight_files = { "cnn_data/lenet_model/conv1_w.npy",
+ "cnn_data/lenet_model/conv2_w.npy",
+ "cnn_data/lenet_model/ip1_w.npy",
+ "cnn_data/lenet_model/ip2_w.npy"
};
- std::vector<std::string> bias_files = { "cnn_data/lenet_model/conv1_b.dat",
- "cnn_data/lenet_model/conv2_b.dat",
- "cnn_data/lenet_model/ip1_b.dat",
- "cnn_data/lenet_model/ip2_b.dat"
+ std::vector<std::string> bias_files = { "cnn_data/lenet_model/conv1_b.npy",
+ "cnn_data/lenet_model/conv2_b.npy",
+ "cnn_data/lenet_model/ip1_b.npy",
+ "cnn_data/lenet_model/ip2_b.npy"
};
NELeNet5Model network{};
- network.build(batches);
+ network.init(batches);
+ network.build();
+ network.allocate();
network.fill(weight_files, bias_files);
network.feed(std::move(input_file));
network.run();
@@ -92,15 +77,13 @@ std::vector<unsigned int> compute_lenet5(unsigned int batches, std::string input
}
} // namespace
-#ifndef DOXYGEN_SKIP_THIS
-BOOST_AUTO_TEST_SUITE(SYSTEM_TESTS)
-BOOST_AUTO_TEST_SUITE(NEON)
+TEST_SUITE(NEON)
+TEST_SUITE(SYSTEM_TESTS)
-BOOST_TEST_DECORATOR(*boost::unit_test::label("precommit"))
-BOOST_AUTO_TEST_CASE(LeNet5)
+TEST_CASE(LeNet5, framework::DatasetMode::PRECOMMIT)
{
// Compute alexnet
- std::vector<unsigned int> classified_labels = compute_lenet5(10, "cnn_data/mnist_data/input100.dat");
+ std::vector<unsigned int> classified_labels = compute_lenet5(10, "cnn_data/mnist_data/input10.npy");
// Expected labels
std::vector<unsigned int> expected_labels = { 7, 2, 1, 0, 4, 1, 4, 9, 5, 9 };
@@ -109,7 +92,9 @@ BOOST_AUTO_TEST_CASE(LeNet5)
validate(classified_labels, expected_labels);
}
-BOOST_AUTO_TEST_SUITE_END()
-BOOST_AUTO_TEST_SUITE_END()
-#endif /* DOXYGEN_SKIP_THIS */
+TEST_SUITE_END()
+TEST_SUITE_END()
+} // namespace validation
+} // namespace test
+} // namespace arm_compute
#endif /* INTERNAL_ONLY */
diff --git a/tests/validation_old/model_objects/AlexNet.h b/tests/validation_old/model_objects/AlexNet.h
deleted file mode 100644
index 45622e2118..0000000000
--- a/tests/validation_old/model_objects/AlexNet.h
+++ /dev/null
@@ -1,585 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef __ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__
-#define __ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__
-
-#include "arm_compute/runtime/Tensor.h"
-
-#include "tests/AssetsLibrary.h"
-#include "tests/Globals.h"
-#include "tests/Utils.h"
-
-#include <memory>
-
-using namespace arm_compute;
-using namespace arm_compute::test;
-
-namespace arm_compute
-{
-namespace test
-{
-namespace model_objects
-{
-/** AlexNet model object */
-template <typename ITensorType,
- typename TensorType,
- typename SubTensorType,
- typename Accessor,
- typename ActivationLayerFunction,
- typename ConvolutionLayerFunction,
- typename FullyConnectedLayerFunction,
- typename NormalizationLayerFunction,
- typename PoolingLayerFunction,
- typename SoftmaxLayerFunction,
- DataType dt = DataType::F32,
- int fixed_point_position = 4>
-class AlexNet
-{
-public:
- AlexNet()
- : _batches(1), _reshaped_weights(false)
- {
- }
-
- void init_weights(unsigned int batches, bool reshaped_weights = false)
- {
- _batches = batches;
- _reshaped_weights = reshaped_weights;
-
- // Initialize weights and biases
- if(!_reshaped_weights)
- {
- for(auto &wi : w)
- {
- wi = std::unique_ptr<TensorType>(new TensorType());
- }
- for(auto &bi : b)
- {
- bi = std::unique_ptr<TensorType>(new TensorType());
- }
- w[0]->allocator()->init(TensorInfo(TensorShape(11U, 11U, 3U, 96U), 1, dt, fixed_point_position));
- b[0]->allocator()->init(TensorInfo(TensorShape(96U), 1, dt, fixed_point_position));
- w[1]->allocator()->init(TensorInfo(TensorShape(5U, 5U, 48U, 256U), 1, dt, fixed_point_position));
- b[1]->allocator()->init(TensorInfo(TensorShape(256U), 1, dt, fixed_point_position));
- w[2]->allocator()->init(TensorInfo(TensorShape(3U, 3U, 256U, 384U), 1, dt, fixed_point_position));
- b[2]->allocator()->init(TensorInfo(TensorShape(384U), 1, dt, fixed_point_position));
- w[3]->allocator()->init(TensorInfo(TensorShape(3U, 3U, 192U, 384U), 1, dt, fixed_point_position));
- b[3]->allocator()->init(TensorInfo(TensorShape(384U), 1, dt, fixed_point_position));
- w[4]->allocator()->init(TensorInfo(TensorShape(3U, 3U, 192U, 256U), 1, dt, fixed_point_position));
- b[4]->allocator()->init(TensorInfo(TensorShape(256U), 1, dt, fixed_point_position));
- w[5]->allocator()->init(TensorInfo(TensorShape(9216U, 4096U), 1, dt, fixed_point_position));
- b[5]->allocator()->init(TensorInfo(TensorShape(4096U), 1, dt, fixed_point_position));
- w[6]->allocator()->init(TensorInfo(TensorShape(4096U, 4096U), 1, dt, fixed_point_position));
- b[6]->allocator()->init(TensorInfo(TensorShape(4096U), 1, dt, fixed_point_position));
- w[7]->allocator()->init(TensorInfo(TensorShape(4096U, 1000U), 1, dt, fixed_point_position));
- b[7]->allocator()->init(TensorInfo(TensorShape(1000U), 1, dt, fixed_point_position));
-
- w21 = std::unique_ptr<SubTensorType>(new SubTensorType(w[1].get(), TensorShape(5U, 5U, 48U, 128U), Coordinates()));
- w22 = std::unique_ptr<SubTensorType>(new SubTensorType(w[1].get(), TensorShape(5U, 5U, 48U, 128U), Coordinates(0, 0, 0, 128)));
- b21 = std::unique_ptr<SubTensorType>(new SubTensorType(b[1].get(), TensorShape(128U), Coordinates()));
- b22 = std::unique_ptr<SubTensorType>(new SubTensorType(b[1].get(), TensorShape(128U), Coordinates(128)));
-
- w41 = std::unique_ptr<SubTensorType>(new SubTensorType(w[3].get(), TensorShape(3U, 3U, 192U, 192U), Coordinates()));
- w42 = std::unique_ptr<SubTensorType>(new SubTensorType(w[3].get(), TensorShape(3U, 3U, 192U, 192U), Coordinates(0, 0, 0, 192)));
- b41 = std::unique_ptr<SubTensorType>(new SubTensorType(b[3].get(), TensorShape(192U), Coordinates()));
- b42 = std::unique_ptr<SubTensorType>(new SubTensorType(b[3].get(), TensorShape(192U), Coordinates(192)));
-
- w51 = std::unique_ptr<SubTensorType>(new SubTensorType(w[4].get(), TensorShape(3U, 3U, 192U, 128U), Coordinates()));
- w52 = std::unique_ptr<SubTensorType>(new SubTensorType(w[4].get(), TensorShape(3U, 3U, 192U, 128U), Coordinates(0, 0, 0, 128)));
- b51 = std::unique_ptr<SubTensorType>(new SubTensorType(b[4].get(), TensorShape(128U), Coordinates()));
- b52 = std::unique_ptr<SubTensorType>(new SubTensorType(b[4].get(), TensorShape(128U), Coordinates(128)));
- }
- else
- {
- const unsigned int dt_size = 16 / arm_compute::data_size_from_type(dt);
-
- // Create tensor for the reshaped weights
- w[0] = std::unique_ptr<TensorType>(new TensorType());
- auto w21_tensor = std::unique_ptr<TensorType>(new TensorType());
- auto w22_tensor = std::unique_ptr<TensorType>(new TensorType());
- w[2] = std::unique_ptr<TensorType>(new TensorType());
- auto w41_tensor = std::unique_ptr<TensorType>(new TensorType());
- auto w42_tensor = std::unique_ptr<TensorType>(new TensorType());
- auto w51_tensor = std::unique_ptr<TensorType>(new TensorType());
- auto w52_tensor = std::unique_ptr<TensorType>(new TensorType());
-
- w[0]->allocator()->init(TensorInfo(TensorShape(366U * dt_size, 96U / dt_size), 1, dt, fixed_point_position));
- w21_tensor->allocator()->init(TensorInfo(TensorShape(1248U * dt_size, 128U / dt_size), 1, dt, fixed_point_position));
- w22_tensor->allocator()->init(TensorInfo(TensorShape(1248U * dt_size, 128U / dt_size), 1, dt, fixed_point_position));
- w[2]->allocator()->init(TensorInfo(TensorShape(2560U * dt_size, 384U / dt_size), 1, dt, fixed_point_position));
- w41_tensor->allocator()->init(TensorInfo(TensorShape(1920U * dt_size, 192U / dt_size), 1, dt, fixed_point_position));
- w42_tensor->allocator()->init(TensorInfo(TensorShape(1920U * dt_size, 192U / dt_size), 1, dt, fixed_point_position));
- w51_tensor->allocator()->init(TensorInfo(TensorShape(1920U * dt_size, 128U / dt_size), 1, dt, fixed_point_position));
- w52_tensor->allocator()->init(TensorInfo(TensorShape(1920U * dt_size, 128U / dt_size), 1, dt, fixed_point_position));
-
- w21 = std::move(w21_tensor);
- w22 = std::move(w22_tensor);
- w41 = std::move(w41_tensor);
- w42 = std::move(w42_tensor);
- w51 = std::move(w51_tensor);
- w52 = std::move(w52_tensor);
-
- w[5] = std::unique_ptr<TensorType>(new TensorType());
- w[6] = std::unique_ptr<TensorType>(new TensorType());
- w[7] = std::unique_ptr<TensorType>(new TensorType());
- b[5] = std::unique_ptr<TensorType>(new TensorType());
- b[6] = std::unique_ptr<TensorType>(new TensorType());
- b[7] = std::unique_ptr<TensorType>(new TensorType());
-
- b[5]->allocator()->init(TensorInfo(TensorShape(4096U), 1, dt, fixed_point_position));
- b[6]->allocator()->init(TensorInfo(TensorShape(4096U), 1, dt, fixed_point_position));
- b[7]->allocator()->init(TensorInfo(TensorShape(1000U), 1, dt, fixed_point_position));
-
- if(_batches > 1 && std::is_same<TensorType, Tensor>::value)
- {
- w[5]->allocator()->init(TensorInfo(TensorShape(9216U * dt_size, 4096U / dt_size), 1, dt, fixed_point_position));
- w[6]->allocator()->init(TensorInfo(TensorShape(4096U * dt_size, 4096U / dt_size), 1, dt, fixed_point_position));
- w[7]->allocator()->init(TensorInfo(TensorShape(4096U * dt_size, 1000U / dt_size), 1, dt, fixed_point_position));
- }
- else
- {
- w[5]->allocator()->init(TensorInfo(TensorShape(4096U, 9216U), 1, dt, fixed_point_position));
- w[6]->allocator()->init(TensorInfo(TensorShape(4096U, 4096U), 1, dt, fixed_point_position));
- w[7]->allocator()->init(TensorInfo(TensorShape(1000U, 4096U), 1, dt, fixed_point_position));
- }
- }
- }
-
- void build()
- {
- input.allocator()->init(TensorInfo(TensorShape(227U, 227U, 3U, _batches), 1, dt, fixed_point_position));
- output.allocator()->init(TensorInfo(TensorShape(1000U, _batches), 1, dt, fixed_point_position));
-
- // Initialize intermediate tensors
- // Layer 1
- conv1_out.allocator()->init(TensorInfo(TensorShape(55U, 55U, 96U, _batches), 1, dt, fixed_point_position));
- act1_out.allocator()->init(TensorInfo(TensorShape(55U, 55U, 96U, _batches), 1, dt, fixed_point_position));
- norm1_out.allocator()->init(TensorInfo(TensorShape(55U, 55U, 96U, _batches), 1, dt, fixed_point_position));
- pool1_out.allocator()->init(TensorInfo(TensorShape(27U, 27U, 96U, _batches), 1, dt, fixed_point_position));
- pool11_out = std::unique_ptr<SubTensorType>(new SubTensorType(&pool1_out, TensorShape(27U, 27U, 48U, _batches), Coordinates()));
- pool12_out = std::unique_ptr<SubTensorType>(new SubTensorType(&pool1_out, TensorShape(27U, 27U, 48U, _batches), Coordinates(0, 0, 48)));
- // Layer 2
- conv2_out.allocator()->init(TensorInfo(TensorShape(27U, 27U, 256U, _batches), 1, dt, fixed_point_position));
- conv21_out = std::unique_ptr<SubTensorType>(new SubTensorType(&conv2_out, TensorShape(27U, 27U, 128U, _batches), Coordinates()));
- conv22_out = std::unique_ptr<SubTensorType>(new SubTensorType(&conv2_out, TensorShape(27U, 27U, 128U, _batches), Coordinates(0, 0, 128)));
- act2_out.allocator()->init(TensorInfo(TensorShape(27U, 27U, 256U, _batches), 1, dt, fixed_point_position));
- norm2_out.allocator()->init(TensorInfo(TensorShape(27U, 27U, 256U, _batches), 1, dt, fixed_point_position));
- pool2_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 256U, _batches), 1, dt, fixed_point_position));
- // Layer 3
- conv3_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 384U, _batches), 1, dt, fixed_point_position));
- act3_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 384U, _batches), 1, dt, fixed_point_position));
- act31_out = std::unique_ptr<SubTensorType>(new SubTensorType(&act3_out, TensorShape(13U, 13U, 192U, _batches), Coordinates()));
- act32_out = std::unique_ptr<SubTensorType>(new SubTensorType(&act3_out, TensorShape(13U, 13U, 192U, _batches), Coordinates(0, 0, 192)));
- // Layer 4
- conv4_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 384U, _batches), 1, dt, fixed_point_position));
- conv41_out = std::unique_ptr<SubTensorType>(new SubTensorType(&conv4_out, TensorShape(13U, 13U, 192U, _batches), Coordinates()));
- conv42_out = std::unique_ptr<SubTensorType>(new SubTensorType(&conv4_out, TensorShape(13U, 13U, 192U, _batches), Coordinates(0, 0, 192)));
- act4_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 384U, _batches), 1, dt, fixed_point_position));
- act41_out = std::unique_ptr<SubTensorType>(new SubTensorType(&act4_out, TensorShape(13U, 13U, 192U, _batches), Coordinates()));
- act42_out = std::unique_ptr<SubTensorType>(new SubTensorType(&act4_out, TensorShape(13U, 13U, 192U, _batches), Coordinates(0, 0, 192)));
- // Layer 5
- conv5_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 256U, _batches), 1, dt, fixed_point_position));
- conv51_out = std::unique_ptr<SubTensorType>(new SubTensorType(&conv5_out, TensorShape(13U, 13U, 128U, _batches), Coordinates()));
- conv52_out = std::unique_ptr<SubTensorType>(new SubTensorType(&conv5_out, TensorShape(13U, 13U, 128U, _batches), Coordinates(0, 0, 128)));
- act5_out.allocator()->init(TensorInfo(TensorShape(13U, 13U, 256U, _batches), 1, dt, fixed_point_position));
- pool5_out.allocator()->init(TensorInfo(TensorShape(6U, 6U, 256U, _batches), 1, dt, fixed_point_position));
- // Layer 6
- fc6_out.allocator()->init(TensorInfo(TensorShape(4096U, _batches), 1, dt, fixed_point_position));
- act6_out.allocator()->init(TensorInfo(TensorShape(4096U, _batches), 1, dt, fixed_point_position));
- // Layer 7
- fc7_out.allocator()->init(TensorInfo(TensorShape(4096U, _batches), 1, dt, fixed_point_position));
- act7_out.allocator()->init(TensorInfo(TensorShape(4096U, _batches), 1, dt, fixed_point_position));
- // Layer 8
- fc8_out.allocator()->init(TensorInfo(TensorShape(1000U, _batches), 1, dt, fixed_point_position));
-
- // Allocate layers
- {
- // Layer 1
- conv1 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- act1 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- norm1 = std::unique_ptr<NormalizationLayerFunction>(new NormalizationLayerFunction());
- pool1 = std::unique_ptr<PoolingLayerFunction>(new PoolingLayerFunction());
- // Layer 2
- conv21 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- conv22 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- act2 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- norm2 = std::unique_ptr<NormalizationLayerFunction>(new NormalizationLayerFunction());
- pool2 = std::unique_ptr<PoolingLayerFunction>(new PoolingLayerFunction());
- // Layer 3
- conv3 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- act3 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- // Layer 4
- conv41 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- conv42 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- act4 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- // Layer 5
- conv51 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- conv52 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- act5 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- pool5 = std::unique_ptr<PoolingLayerFunction>(new PoolingLayerFunction());
- // Layer 6
- fc6 = std::unique_ptr<FullyConnectedLayerFunction>(new FullyConnectedLayerFunction());
- act6 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- // Layer 7
- fc7 = std::unique_ptr<FullyConnectedLayerFunction>(new FullyConnectedLayerFunction());
- act7 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- // Layer 8
- fc8 = std::unique_ptr<FullyConnectedLayerFunction>(new FullyConnectedLayerFunction());
- // Softmax
- smx = std::unique_ptr<SoftmaxLayerFunction>(new SoftmaxLayerFunction());
- }
-
- // Configure Layers
- {
- // Layer 1
- conv1->configure(&input, w[0].get(), b[0].get(), &conv1_out, PadStrideInfo(4, 4, 0, 0), WeightsInfo(_reshaped_weights, 11U, 11U, 96U));
- act1->configure(&conv1_out, &act1_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- norm1->configure(&act1_out, &norm1_out, NormalizationLayerInfo(NormType::CROSS_MAP, 5, 0.0001f, 0.75f));
- pool1->configure(&norm1_out, &pool1_out, PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 0, 0)));
- // Layer 2
- conv21->configure(pool11_out.get(), w21.get(), b21.get(), conv21_out.get(), PadStrideInfo(1, 1, 2, 2), WeightsInfo(_reshaped_weights, 5U, 5U, 128U));
- conv22->configure(pool12_out.get(), w22.get(), b22.get(), conv22_out.get(), PadStrideInfo(1, 1, 2, 2), WeightsInfo(_reshaped_weights, 5U, 5U, 128U));
- act2->configure(&conv2_out, &act2_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- norm2->configure(&act2_out, &norm2_out, NormalizationLayerInfo(NormType::CROSS_MAP, 5, 0.0001f, 0.75f));
- pool2->configure(&norm2_out, &pool2_out, PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 0, 0)));
- // Layer 3
- conv3->configure(&pool2_out, w[2].get(), b[2].get(), &conv3_out, PadStrideInfo(1, 1, 1, 1), WeightsInfo(_reshaped_weights, 3U, 3U, 384U));
- act3->configure(&conv3_out, &act3_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- // Layer 4
- conv41->configure(act31_out.get(), w41.get(), b41.get(), conv41_out.get(), PadStrideInfo(1, 1, 1, 1), WeightsInfo(_reshaped_weights, 3U, 3U, 192U));
- conv42->configure(act32_out.get(), w42.get(), b42.get(), conv42_out.get(), PadStrideInfo(1, 1, 1, 1), WeightsInfo(_reshaped_weights, 3U, 3U, 192U));
- act4->configure(&conv4_out, &act4_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- // Layer 5
- conv51->configure(act41_out.get(), w51.get(), b51.get(), conv51_out.get(), PadStrideInfo(1, 1, 1, 1), WeightsInfo(_reshaped_weights, 3U, 3U, 128U));
- conv52->configure(act42_out.get(), w52.get(), b52.get(), conv52_out.get(), PadStrideInfo(1, 1, 1, 1), WeightsInfo(_reshaped_weights, 3U, 3U, 128U));
- act5->configure(&conv5_out, &act5_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- pool5->configure(&act5_out, &pool5_out, PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 0, 0)));
- // Layer 6
- fc6->configure(&pool5_out, w[5].get(), b[5].get(), &fc6_out, true, _reshaped_weights);
- act6->configure(&fc6_out, &act6_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- // Layer 7
- fc7->configure(&act6_out, w[6].get(), b[6].get(), &fc7_out, true, _reshaped_weights);
- act7->configure(&fc7_out, &act7_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- // Layer 8
- fc8->configure(&act7_out, w[7].get(), b[7].get(), &fc8_out, true, _reshaped_weights);
- // Softmax
- smx->configure(&fc8_out, &output);
- }
- }
-
- void allocate()
- {
- input.allocator()->allocate();
- output.allocator()->allocate();
- for(auto &wi : w)
- {
- if(wi.get())
- {
- wi->allocator()->allocate();
- }
- }
- for(auto &bi : b)
- {
- if(bi.get())
- {
- bi->allocator()->allocate();
- }
- }
- if(_reshaped_weights)
- {
- dynamic_cast<TensorType *>(w21.get())->allocator()->allocate();
- dynamic_cast<TensorType *>(w22.get())->allocator()->allocate();
- dynamic_cast<TensorType *>(w41.get())->allocator()->allocate();
- dynamic_cast<TensorType *>(w42.get())->allocator()->allocate();
- dynamic_cast<TensorType *>(w51.get())->allocator()->allocate();
- dynamic_cast<TensorType *>(w52.get())->allocator()->allocate();
- }
- conv1_out.allocator()->allocate();
- act1_out.allocator()->allocate();
- norm1_out.allocator()->allocate();
- pool1_out.allocator()->allocate();
- conv2_out.allocator()->allocate();
- act2_out.allocator()->allocate();
- norm2_out.allocator()->allocate();
- pool2_out.allocator()->allocate();
- conv3_out.allocator()->allocate();
- act3_out.allocator()->allocate();
- conv4_out.allocator()->allocate();
- act4_out.allocator()->allocate();
- conv5_out.allocator()->allocate();
- act5_out.allocator()->allocate();
- pool5_out.allocator()->allocate();
- fc6_out.allocator()->allocate();
- act6_out.allocator()->allocate();
- fc7_out.allocator()->allocate();
- act7_out.allocator()->allocate();
- fc8_out.allocator()->allocate();
- }
-
- /** Fills the trainable parameters and input with random data. */
- void fill_random()
- {
- library->fill_tensor_uniform(Accessor(input), 0);
- if(!_reshaped_weights)
- {
- for(unsigned int i = 0; i < w.size(); ++i)
- {
- library->fill_tensor_uniform(Accessor(*w[i]), i + 1);
- library->fill_tensor_uniform(Accessor(*b[i]), i + 10);
- }
- }
- else
- {
- library->fill_tensor_uniform(Accessor(*w[0]), 1);
- library->fill_tensor_uniform(Accessor(*w[2]), 2);
-
- library->fill_tensor_uniform(Accessor(*w[5]), 3);
- library->fill_tensor_uniform(Accessor(*b[5]), 4);
- library->fill_tensor_uniform(Accessor(*w[6]), 5);
- library->fill_tensor_uniform(Accessor(*b[6]), 6);
- library->fill_tensor_uniform(Accessor(*w[7]), 7);
- library->fill_tensor_uniform(Accessor(*b[7]), 8);
-
- library->fill_tensor_uniform(Accessor(*dynamic_cast<TensorType *>(w21.get())), 9);
- library->fill_tensor_uniform(Accessor(*dynamic_cast<TensorType *>(w22.get())), 10);
- library->fill_tensor_uniform(Accessor(*dynamic_cast<TensorType *>(w41.get())), 11);
- library->fill_tensor_uniform(Accessor(*dynamic_cast<TensorType *>(w42.get())), 12);
- library->fill_tensor_uniform(Accessor(*dynamic_cast<TensorType *>(w51.get())), 13);
- library->fill_tensor_uniform(Accessor(*dynamic_cast<TensorType *>(w52.get())), 14);
- }
- }
-
-#ifdef INTERNAL_ONLY
- /** Fills the trainable parameters from binary files
- *
- * @param weights Files names containing the weights data
- * @param biases Files names containing the bias data
- */
- void fill(std::vector<std::string> weights, std::vector<std::string> biases)
- {
- ARM_COMPUTE_ERROR_ON(weights.size() != w.size());
- ARM_COMPUTE_ERROR_ON(biases.size() != b.size());
- ARM_COMPUTE_ERROR_ON(_reshaped_weights);
-
- for(unsigned int i = 0; i < weights.size(); ++i)
- {
- library->fill_layer_data(Accessor(*w[i]), weights[i]);
- library->fill_layer_data(Accessor(*b[i]), biases[i]);
- }
- }
-
- /** Feed input to network from file.
- *
- * @param name File name of containing the input data.
- */
- void feed(std::string name)
- {
- library->fill_layer_data(Accessor(input), name);
- }
-#endif /* INTERNAL_ONLY */
-
- /** Get the classification results.
- *
- * @return Vector containing the classified labels
- */
- std::vector<unsigned int> get_classifications()
- {
- std::vector<unsigned int> classified_labels;
- Accessor output_accessor(output);
-
- Window window;
- window.set(Window::DimX, Window::Dimension(0, 1, 1));
- for(unsigned int d = 1; d < output_accessor.shape().num_dimensions(); ++d)
- {
- window.set(d, Window::Dimension(0, output_accessor.shape()[d], 1));
- }
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- int max_idx = 0;
- float val = 0;
- const void *const out_ptr = output_accessor(id);
- for(unsigned int l = 0; l < output_accessor.shape().x(); ++l)
- {
- float curr_val = reinterpret_cast<const float *>(out_ptr)[l];
- if(curr_val > val)
- {
- max_idx = l;
- val = curr_val;
- }
- }
- classified_labels.push_back(max_idx);
- });
- return classified_labels;
- }
-
- /** Clear all allocated memory from the tensor objects */
- void clear()
- {
- conv1.reset();
- act1.reset();
- norm1.reset();
- pool1.reset();
- conv21.reset();
- conv22.reset();
- act2.reset();
- norm2.reset();
- pool2.reset();
- conv3.reset();
- act3.reset();
- conv41.reset();
- conv42.reset();
- act4.reset();
- conv51.reset();
- conv52.reset();
- act5.reset();
- pool5.reset();
- fc6.reset();
- act6.reset();
- fc7.reset();
- act7.reset();
- fc8.reset();
- smx.reset();
-
- // Free allocations
- input.allocator()->free();
- output.allocator()->free();
- for(auto &wi : w)
- {
- wi.reset();
- }
- for(auto &bi : b)
- {
- bi.reset();
- }
-
- w21.reset();
- w22.reset();
- b21.reset();
- b21.reset();
- w41.reset();
- w42.reset();
- b41.reset();
- b42.reset();
- w51.reset();
- w52.reset();
- b51.reset();
- b52.reset();
-
- conv1_out.allocator()->free();
- act1_out.allocator()->free();
- norm1_out.allocator()->free();
- pool1_out.allocator()->free();
- conv2_out.allocator()->free();
- act2_out.allocator()->free();
- norm2_out.allocator()->free();
- pool2_out.allocator()->free();
- conv3_out.allocator()->free();
- act3_out.allocator()->free();
- conv4_out.allocator()->free();
- act4_out.allocator()->free();
- conv5_out.allocator()->free();
- act5_out.allocator()->free();
- pool5_out.allocator()->free();
- fc6_out.allocator()->free();
- act6_out.allocator()->free();
- fc7_out.allocator()->free();
- act7_out.allocator()->free();
- fc8_out.allocator()->free();
- }
-
- /** Runs the model */
- void run()
- {
- // Layer 1
- conv1->run();
- act1->run();
- norm1->run();
- pool1->run();
- // Layer 2
- conv21->run();
- conv22->run();
- act2->run();
- norm2->run();
- pool2->run();
- // Layer 3
- conv3->run();
- act3->run();
- // Layer 4
- conv41->run();
- conv42->run();
- act4->run();
- // Layer 5
- conv51->run();
- conv52->run();
- act5->run();
- pool5->run();
- // Layer 6
- fc6->run();
- act6->run();
- // Layer 7
- fc7->run();
- act7->run();
- // Layer 8
- fc8->run();
- // Softmax
- smx->run();
- }
-
-private:
- unsigned int _batches;
- bool _reshaped_weights;
-
- std::unique_ptr<ActivationLayerFunction> act1{ nullptr }, act2{ nullptr }, act3{ nullptr }, act4{ nullptr }, act5{ nullptr }, act6{ nullptr }, act7{ nullptr };
- std::unique_ptr<ConvolutionLayerFunction> conv1{ nullptr }, conv21{ nullptr }, conv22{ nullptr }, conv3{ nullptr }, conv41{ nullptr }, conv42{ nullptr }, conv51{ nullptr }, conv52{ nullptr };
- std::unique_ptr<FullyConnectedLayerFunction> fc6{ nullptr }, fc7{ nullptr }, fc8{};
- std::unique_ptr<NormalizationLayerFunction> norm1{ nullptr }, norm2{ nullptr };
- std::unique_ptr<PoolingLayerFunction> pool1{ nullptr }, pool2{ nullptr }, pool5{ nullptr };
- std::unique_ptr<SoftmaxLayerFunction> smx{ nullptr };
-
- TensorType input{}, output{};
- std::array<std::unique_ptr<TensorType>, 8> w{}, b{};
- std::unique_ptr<ITensorType> w21{ nullptr }, w22{ nullptr }, b21{ nullptr }, b22{ nullptr };
- std::unique_ptr<ITensorType> w41{ nullptr }, w42{ nullptr }, b41{ nullptr }, b42{ nullptr };
- std::unique_ptr<ITensorType> w51{ nullptr }, w52{ nullptr }, b51{ nullptr }, b52{ nullptr };
-
- TensorType conv1_out{}, act1_out{}, norm1_out{}, pool1_out{};
- TensorType conv2_out{}, act2_out{}, pool2_out{}, norm2_out{};
- TensorType conv3_out{}, act3_out{};
- TensorType conv4_out{}, act4_out{};
- TensorType conv5_out{}, act5_out{}, pool5_out{};
- TensorType fc6_out{}, act6_out{};
- TensorType fc7_out{}, act7_out{};
- TensorType fc8_out{};
-
- std::unique_ptr<SubTensorType> pool11_out{ nullptr }, pool12_out{ nullptr };
- std::unique_ptr<SubTensorType> conv21_out{ nullptr }, conv22_out{ nullptr };
- std::unique_ptr<SubTensorType> act31_out{ nullptr }, act32_out{ nullptr };
- std::unique_ptr<SubTensorType> conv41_out{ nullptr }, conv42_out{ nullptr }, act41_out{ nullptr }, act42_out{ nullptr };
- std::unique_ptr<SubTensorType> conv51_out{ nullptr }, conv52_out{ nullptr };
-};
-} // namespace model_objects
-} // namespace test
-} // namespace arm_compute
-#endif //__ARM_COMPUTE_TEST_MODEL_OBJECTS_ALEXNET_H__
diff --git a/tests/validation_old/model_objects/LeNet5.h b/tests/validation_old/model_objects/LeNet5.h
deleted file mode 100644
index d3e72b0010..0000000000
--- a/tests/validation_old/model_objects/LeNet5.h
+++ /dev/null
@@ -1,278 +0,0 @@
-/*
- * Copyright (c) 2017 ARM Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#ifndef __ARM_COMPUTE_TEST_MODEL_OBJECTS_LENET5_H__
-#define __ARM_COMPUTE_TEST_MODEL_OBJECTS_LENET5_H__
-
-#include "tests/AssetsLibrary.h"
-#include "tests/Globals.h"
-#include "tests/Utils.h"
-
-#include <memory>
-
-using namespace arm_compute;
-using namespace arm_compute::test;
-
-namespace arm_compute
-{
-namespace test
-{
-namespace model_objects
-{
-/** Lenet5 model object */
-template <typename TensorType,
- typename Accessor,
- typename ActivationLayerFunction,
- typename ConvolutionLayerFunction,
- typename FullyConnectedLayerFunction,
- typename PoolingLayerFunction,
- typename SoftmaxLayerFunction>
-class LeNet5
-{
-public:
- /** Initialize and build the model.
- *
- * @param batches Number of batches should handle
- */
- void build(unsigned int batches)
- {
- // Initialize input, output, weights and biases
- input.allocator()->init(TensorInfo(TensorShape(28U, 28U, 1U, batches), 1, DataType::F32));
- output.allocator()->init(TensorInfo(TensorShape(10U, batches), 1, DataType::F32));
- w[0].allocator()->init(TensorInfo(TensorShape(5U, 5U, 1U, 20U), 1, DataType::F32));
- b[0].allocator()->init(TensorInfo(TensorShape(20U), 1, DataType::F32));
- w[1].allocator()->init(TensorInfo(TensorShape(5U, 5U, 20U, 50U), 1, DataType::F32));
- b[1].allocator()->init(TensorInfo(TensorShape(50U), 1, DataType::F32));
- w[2].allocator()->init(TensorInfo(TensorShape(800U, 500U), 1, DataType::F32));
- b[2].allocator()->init(TensorInfo(TensorShape(500U), 1, DataType::F32));
- w[3].allocator()->init(TensorInfo(TensorShape(500U, 10U), 1, DataType::F32));
- b[3].allocator()->init(TensorInfo(TensorShape(10U), 1, DataType::F32));
-
- // Initialize intermediate tensors
- // Layer 1
- conv1_out.allocator()->init(TensorInfo(TensorShape(24U, 24U, 20U, batches), 1, DataType::F32));
- pool1_out.allocator()->init(TensorInfo(TensorShape(12U, 12U, 20U, batches), 1, DataType::F32));
- // Layer 2
- conv2_out.allocator()->init(TensorInfo(TensorShape(8U, 8U, 50U, batches), 1, DataType::F32));
- pool2_out.allocator()->init(TensorInfo(TensorShape(4U, 4U, 50U, batches), 1, DataType::F32));
- // Layer 3
- fc1_out.allocator()->init(TensorInfo(TensorShape(500U, batches), 1, DataType::F32));
- act1_out.allocator()->init(TensorInfo(TensorShape(500U, batches), 1, DataType::F32));
- // Layer 6
- fc2_out.allocator()->init(TensorInfo(TensorShape(10U, batches), 1, DataType::F32));
-
- // Allocate layers
- {
- // Layer 1
- conv1 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- pool1 = std::unique_ptr<PoolingLayerFunction>(new PoolingLayerFunction());
- // Layer 2
- conv2 = std::unique_ptr<ConvolutionLayerFunction>(new ConvolutionLayerFunction());
- pool2 = std::unique_ptr<PoolingLayerFunction>(new PoolingLayerFunction());
- // Layer 3
- fc1 = std::unique_ptr<FullyConnectedLayerFunction>(new FullyConnectedLayerFunction());
- act1 = std::unique_ptr<ActivationLayerFunction>(new ActivationLayerFunction());
- // Layer 4
- fc2 = std::unique_ptr<FullyConnectedLayerFunction>(new FullyConnectedLayerFunction());
- // Softmax
- smx = std::unique_ptr<SoftmaxLayerFunction>(new SoftmaxLayerFunction());
- }
-
- // Configure Layers
- {
- conv1->configure(&input, &w[0], &b[0], &conv1_out, PadStrideInfo(1, 1, 0, 0));
- pool1->configure(&conv1_out, &pool1_out, PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)));
- conv2->configure(&pool1_out, &w[1], &b[1], &conv2_out, PadStrideInfo(1, 1, 0, 0));
- pool2->configure(&conv2_out, &pool2_out, PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0)));
- fc1->configure(&pool2_out, &w[2], &b[2], &fc1_out);
- act1->configure(&fc1_out, &act1_out, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU));
- fc2->configure(&act1_out, &w[3], &b[3], &fc2_out);
- smx->configure(&fc2_out, &output);
- }
-
- // Allocate tensors
- {
- input.allocator()->allocate();
- output.allocator()->allocate();
- for(auto &wi : w)
- {
- wi.allocator()->allocate();
- }
- for(auto &bi : b)
- {
- bi.allocator()->allocate();
- }
- conv1_out.allocator()->allocate();
- pool1_out.allocator()->allocate();
- conv2_out.allocator()->allocate();
- pool2_out.allocator()->allocate();
- fc1_out.allocator()->allocate();
- act1_out.allocator()->allocate();
- fc2_out.allocator()->allocate();
- }
- }
-
- /** Fills the trainable parameters and input with random data. */
- void fill_random()
- {
- std::uniform_real_distribution<> distribution(-1, 1);
- library->fill(Accessor(input), distribution, 0);
- for(unsigned int i = 0; i < w.size(); ++i)
- {
- library->fill(Accessor(w[i]), distribution, i + 1);
- library->fill(Accessor(b[i]), distribution, i + 10);
- }
- }
-
-#ifdef INTERNAL_ONLY
- /** Fills the trainable parameters from binary files
- *
- * @param weights Files names containing the weights data
- * @param biases Files names containing the bias data
- */
- void fill(std::vector<std::string> weights, std::vector<std::string> biases)
- {
- ARM_COMPUTE_ERROR_ON(weights.size() != w.size());
- ARM_COMPUTE_ERROR_ON(biases.size() != b.size());
-
- for(unsigned int i = 0; i < weights.size(); ++i)
- {
- library->fill_layer_data(Accessor(w[i]), weights[i]);
- library->fill_layer_data(Accessor(b[i]), biases[i]);
- }
- }
-
- /** Feed input to network from file.
- *
- * @param name File name of containing the input data.
- */
- void feed(std::string name)
- {
- library->fill_layer_data(Accessor(input), name);
- }
-#endif /* INTERNAL_ONLY */
-
- /** Get the classification results.
- *
- * @return Vector containing the classified labels
- */
- std::vector<unsigned int> get_classifications()
- {
- std::vector<unsigned int> classified_labels;
- Accessor output_accessor(output);
-
- Window window;
- window.set(Window::DimX, Window::Dimension(0, 1, 1));
- for(unsigned int d = 1; d < output_accessor.shape().num_dimensions(); ++d)
- {
- window.set(d, Window::Dimension(0, output_accessor.shape()[d], 1));
- }
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- int max_idx = 0;
- float val = 0;
- const void *const out_ptr = output_accessor(id);
- for(unsigned int l = 0; l < output_accessor.shape().x(); ++l)
- {
- float curr_val = reinterpret_cast<const float *>(out_ptr)[l];
- if(curr_val > val)
- {
- max_idx = l;
- val = curr_val;
- }
- }
- classified_labels.push_back(max_idx);
- });
- return classified_labels;
- }
-
- /** Clear all allocated memory from the tensor objects */
- void clear()
- {
- conv1.reset();
- pool1.reset();
- conv2.reset();
- pool2.reset();
- fc1.reset();
- act1.reset();
- fc2.reset();
- smx.reset();
-
- input.allocator()->free();
- output.allocator()->free();
- for(auto &wi : w)
- {
- wi.allocator()->free();
- }
- for(auto &bi : b)
- {
- bi.allocator()->free();
- }
-
- conv1_out.allocator()->free();
- pool1_out.allocator()->free();
- conv2_out.allocator()->free();
- pool2_out.allocator()->free();
- fc1_out.allocator()->free();
- act1_out.allocator()->free();
- fc2_out.allocator()->free();
- }
-
- /** Runs the model */
- void run()
- {
- // Layer 1
- conv1->run();
- pool1->run();
- // Layer 2
- conv2->run();
- pool2->run();
- // Layer 3
- fc1->run();
- act1->run();
- // Layer 4
- fc2->run();
- // Softmax
- smx->run();
- }
-
-private:
- std::unique_ptr<ActivationLayerFunction> act1{ nullptr };
- std::unique_ptr<ConvolutionLayerFunction> conv1{ nullptr }, conv2{ nullptr };
- std::unique_ptr<FullyConnectedLayerFunction> fc1{ nullptr }, fc2{ nullptr };
- std::unique_ptr<PoolingLayerFunction> pool1{ nullptr }, pool2{ nullptr };
- std::unique_ptr<SoftmaxLayerFunction> smx{ nullptr };
-
- TensorType input{}, output{};
- std::array<TensorType, 4> w{}, b{};
-
- TensorType conv1_out{}, pool1_out{};
- TensorType conv2_out{}, pool2_out{};
- TensorType fc1_out{}, act1_out{};
- TensorType fc2_out{};
-};
-} // namespace model_objects
-} // namespace test
-} // namespace arm_compute
-#endif //__ARM_COMPUTE_TEST_MODEL_OBJECTS_LENET5_H__