aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_shufflenet.cpp
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-01-15 14:44:04 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-01-28 16:05:27 +0000
commit11fedda86532cf632b9a3ae4b0f57e85f2a7c4f4 (patch)
tree6fd8003a38fe9baa262696754bdd5cb1d1595947 /examples/graph_shufflenet.cpp
parent6c89ffac750010cb9335794defe8a366c04db937 (diff)
downloadComputeLibrary-11fedda86532cf632b9a3ae4b0f57e85f2a7c4f4.tar.gz
COMPMID-2985 add data_layout to PoolingLayerInfo
- use data layout from PoolingLayerInfo if it's available - deprecate constructors without data_layout - (3RDPARTY_UPDATE) modify examples and test suites to give data layout Change-Id: Ie9ae8cc4837c339ff69a16a816110be704863c2d Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2603 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'examples/graph_shufflenet.cpp')
-rw-r--r--examples/graph_shufflenet.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/examples/graph_shufflenet.cpp b/examples/graph_shufflenet.cpp
index 0a67f5873a..0b977982b5 100644
--- a/examples/graph_shufflenet.cpp
+++ b/examples/graph_shufflenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,8 +81,9 @@ public:
}
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
- TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
+ const auto operation_layout = common_params.data_layout;
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
const DataLayout weights_layout = DataLayout::NCHW;
@@ -107,7 +108,7 @@ public:
1e-5f)
.set_name("Conv1/BatchNorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv1/Relu")
- << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 1, 1))).set_name("pool1/MaxPool");
+ << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, operation_layout, PadStrideInfo(2, 2, 1, 1))).set_name("pool1/MaxPool");
// Stage 2
add_residual_block(data_path, DataLayout::NCHW, 0U /* unit */, 112U /* depth */, 2U /* stride */);
@@ -131,7 +132,7 @@ public:
add_residual_block(data_path, DataLayout::NCHW, 14U /* unit */, 544U /* depth */, 1U /* stride */);
add_residual_block(data_path, DataLayout::NCHW, 15U /* unit */, 544U /* depth */, 1U /* stride */);
- graph << PoolingLayer(PoolingLayerInfo(PoolingType::AVG)).set_name("predictions/AvgPool")
+ graph << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, operation_layout)).set_name("predictions/AvgPool")
<< FlattenLayer().set_name("predictions/Reshape")
<< FullyConnectedLayer(
1000U,
@@ -181,7 +182,7 @@ private:
if(stride == 2)
{
- right_ss << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, 3, PadStrideInfo(2, 2, 1, 1))).set_name(unit_name + "/pool_1/AveragePool");
+ right_ss << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, 3, common_params.data_layout, PadStrideInfo(2, 2, 1, 1))).set_name(unit_name + "/pool_1/AveragePool");
dwc_info = PadStrideInfo(2, 2, 1, 1);
}