aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_lenet.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-17 12:28:42 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit7d66a8e3f603f2cd363f04a750847e3f9eabdfd4 (patch)
tree0d7e1ad5bf0ecd32cd919074f756d27c351d7638 /examples/graph_lenet.cpp
parentae54e026c86aec7d6819ee3ef76372c1a3c92467 (diff)
downloadComputeLibrary-7d66a8e3f603f2cd363f04a750847e3f9eabdfd4.tar.gz
COMPMID-1386: Add support for converting weights for CL.
Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'examples/graph_lenet.cpp')
-rw-r--r--examples/graph_lenet.cpp19
1 files changed, 13 insertions, 6 deletions
diff --git a/examples/graph_lenet.cpp b/examples/graph_lenet.cpp
index 0d8a943737..f3aa266c50 100644
--- a/examples/graph_lenet.cpp
+++ b/examples/graph_lenet.cpp
@@ -60,7 +60,7 @@ public:
// Checks
ARM_COMPUTE_EXIT_ON_MSG(arm_compute::is_data_type_quantized_asymmetric(common_params.data_type), "Unsupported data type!");
- ARM_COMPUTE_EXIT_ON_MSG(common_params.data_layout == DataLayout::NHWC, "Unsupported data layout!");
+ ARM_COMPUTE_EXIT_ON_MSG(common_params.data_layout == DataLayout::NHWC && common_params.target != Target::CL, "Unsupported data layout!");
// Print parameter values
std::cout << common_params << std::endl;
@@ -69,33 +69,40 @@ public:
std::string data_path = common_params.data_path;
unsigned int batches = 4; /** Number of batches */
+ // Create input descriptor
+ const TensorShape tensor_shape = permute_shape(TensorShape(28U, 28U, 1U, batches), DataLayout::NCHW, common_params.data_layout);
+ TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
+
+ // Set weights trained layout
+ const DataLayout weights_layout = DataLayout::NCHW;
+
//conv1 << pool1 << conv2 << pool2 << fc1 << act1 << fc2 << smx
graph << common_params.target
<< common_params.fast_math_hint
- << InputLayer(TensorDescriptor(TensorShape(28U, 28U, 1U, batches), common_params.data_type), get_input_accessor(common_params))
+ << InputLayer(input_descriptor, get_input_accessor(common_params))
<< ConvolutionLayer(
5U, 5U, 20U,
- get_weights_accessor(data_path, "/cnn_data/lenet_model/conv1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/lenet_model/conv1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/lenet_model/conv1_b.npy"),
PadStrideInfo(1, 1, 0, 0))
.set_name("conv1")
<< PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))).set_name("pool1")
<< ConvolutionLayer(
5U, 5U, 50U,
- get_weights_accessor(data_path, "/cnn_data/lenet_model/conv2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/lenet_model/conv2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/lenet_model/conv2_b.npy"),
PadStrideInfo(1, 1, 0, 0))
.set_name("conv2")
<< PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 2, PadStrideInfo(2, 2, 0, 0))).set_name("pool2")
<< FullyConnectedLayer(
500U,
- get_weights_accessor(data_path, "/cnn_data/lenet_model/ip1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/lenet_model/ip1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/lenet_model/ip1_b.npy"))
.set_name("ip1")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("relu")
<< FullyConnectedLayer(
10U,
- get_weights_accessor(data_path, "/cnn_data/lenet_model/ip2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/lenet_model/ip2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/lenet_model/ip2_b.npy"))
.set_name("ip2")
<< SoftmaxLayer().set_name("prob")