aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_vgg16.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-07-17 12:28:42 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit7d66a8e3f603f2cd363f04a750847e3f9eabdfd4 (patch)
tree0d7e1ad5bf0ecd32cd919074f756d27c351d7638 /examples/graph_vgg16.cpp
parentae54e026c86aec7d6819ee3ef76372c1a3c92467 (diff)
downloadComputeLibrary-7d66a8e3f603f2cd363f04a750847e3f9eabdfd4.tar.gz
COMPMID-1386: Add support for converting weights for CL.
Change-Id: I62e3ead903366baeeb1488f233a9b8b0c388c9de Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140403 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'examples/graph_vgg16.cpp')
-rw-r--r--examples/graph_vgg16.cpp45
1 files changed, 26 insertions, 19 deletions
diff --git a/examples/graph_vgg16.cpp b/examples/graph_vgg16.cpp
index e677650d04..e23ea65dd7 100644
--- a/examples/graph_vgg16.cpp
+++ b/examples/graph_vgg16.cpp
@@ -60,7 +60,7 @@ public:
// Checks
ARM_COMPUTE_EXIT_ON_MSG(arm_compute::is_data_type_quantized_asymmetric(common_params.data_type), "Unsupported data type!");
- ARM_COMPUTE_EXIT_ON_MSG(common_params.data_layout == DataLayout::NHWC, "Unsupported data layout!");
+ ARM_COMPUTE_EXIT_ON_MSG(common_params.data_layout == DataLayout::NHWC && common_params.target != Target::CL, "Unsupported data layout!");
// Print parameter values
std::cout << common_params << std::endl;
@@ -72,14 +72,21 @@ public:
const std::array<float, 3> mean_rgb{ { 123.68f, 116.779f, 103.939f } };
std::unique_ptr<IPreprocessor> preprocessor = arm_compute::support::cpp14::make_unique<CaffePreproccessor>(mean_rgb);
+ // Create input descriptor
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
+ TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
+
+ // Set weights trained layout
+ const DataLayout weights_layout = DataLayout::NCHW;
+
+ // Create graph
graph << common_params.target
<< common_params.fast_math_hint
- << InputLayer(TensorDescriptor(TensorShape(224U, 224U, 3U, 1U), common_params.data_type),
- get_input_accessor(common_params, std::move(preprocessor)))
+ << InputLayer(input_descriptor, get_input_accessor(common_params, std::move(preprocessor)))
// Layer 1
<< ConvolutionLayer(
3U, 3U, 64U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv1_1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv1_1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv1_1_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv1_1")
@@ -87,7 +94,7 @@ public:
// Layer 2
<< ConvolutionLayer(
3U, 3U, 64U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv1_2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv1_2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv1_2_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv1_2")
@@ -96,7 +103,7 @@ public:
// Layer 3
<< ConvolutionLayer(
3U, 3U, 128U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv2_1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv2_1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv2_1_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv2_1")
@@ -104,7 +111,7 @@ public:
// Layer 4
<< ConvolutionLayer(
3U, 3U, 128U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv2_2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv2_2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv2_2_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv2_2")
@@ -113,7 +120,7 @@ public:
// Layer 5
<< ConvolutionLayer(
3U, 3U, 256U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_1_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv3_1")
@@ -121,7 +128,7 @@ public:
// Layer 6
<< ConvolutionLayer(
3U, 3U, 256U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_2_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv3_2")
@@ -129,7 +136,7 @@ public:
// Layer 7
<< ConvolutionLayer(
3U, 3U, 256U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_3_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_3_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv3_3_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv3_3")
@@ -138,7 +145,7 @@ public:
// Layer 8
<< ConvolutionLayer(
3U, 3U, 512U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_1_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv4_1")
@@ -146,7 +153,7 @@ public:
// Layer 9
<< ConvolutionLayer(
3U, 3U, 512U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_2_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv4_2")
@@ -154,7 +161,7 @@ public:
// Layer 10
<< ConvolutionLayer(
3U, 3U, 512U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_3_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_3_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv4_3_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv4_3")
@@ -163,7 +170,7 @@ public:
// Layer 11
<< ConvolutionLayer(
3U, 3U, 512U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_1_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_1_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_1_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv5_1")
@@ -171,7 +178,7 @@ public:
// Layer 12
<< ConvolutionLayer(
3U, 3U, 512U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_2_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_2_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_2_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv5_2")
@@ -179,7 +186,7 @@ public:
// Layer 13
<< ConvolutionLayer(
3U, 3U, 512U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_3_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_3_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/conv5_3_b.npy"),
PadStrideInfo(1, 1, 1, 1))
.set_name("conv5_3")
@@ -188,21 +195,21 @@ public:
// Layer 14
<< FullyConnectedLayer(
4096U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc6_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc6_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc6_b.npy"))
.set_name("fc6")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Relu")
// Layer 15
<< FullyConnectedLayer(
4096U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc7_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc7_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc7_b.npy"))
.set_name("fc7")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Relu_1")
// Layer 16
<< FullyConnectedLayer(
1000U,
- get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc8_w.npy"),
+ get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc8_w.npy", weights_layout),
get_weights_accessor(data_path, "/cnn_data/vgg16_model/fc8_b.npy"))
.set_name("fc8")
// Softmax