aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_shufflenet.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/graph_shufflenet.cpp')
-rw-r--r--examples/graph_shufflenet.cpp13
1 files changed, 7 insertions, 6 deletions
diff --git a/examples/graph_shufflenet.cpp b/examples/graph_shufflenet.cpp
index 0a67f5873a..0b977982b5 100644
--- a/examples/graph_shufflenet.cpp
+++ b/examples/graph_shufflenet.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,8 +81,9 @@ public:
}
// Create input descriptor
- const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, common_params.data_layout);
- TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(common_params.data_layout);
+ const auto operation_layout = common_params.data_layout;
+ const TensorShape tensor_shape = permute_shape(TensorShape(224U, 224U, 3U, 1U), DataLayout::NCHW, operation_layout);
+ TensorDescriptor input_descriptor = TensorDescriptor(tensor_shape, common_params.data_type).set_layout(operation_layout);
// Set weights trained layout
const DataLayout weights_layout = DataLayout::NCHW;
@@ -107,7 +108,7 @@ public:
1e-5f)
.set_name("Conv1/BatchNorm")
<< ActivationLayer(ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU)).set_name("Conv1/Relu")
- << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, PadStrideInfo(2, 2, 1, 1))).set_name("pool1/MaxPool");
+ << PoolingLayer(PoolingLayerInfo(PoolingType::MAX, 3, operation_layout, PadStrideInfo(2, 2, 1, 1))).set_name("pool1/MaxPool");
// Stage 2
add_residual_block(data_path, DataLayout::NCHW, 0U /* unit */, 112U /* depth */, 2U /* stride */);
@@ -131,7 +132,7 @@ public:
add_residual_block(data_path, DataLayout::NCHW, 14U /* unit */, 544U /* depth */, 1U /* stride */);
add_residual_block(data_path, DataLayout::NCHW, 15U /* unit */, 544U /* depth */, 1U /* stride */);
- graph << PoolingLayer(PoolingLayerInfo(PoolingType::AVG)).set_name("predictions/AvgPool")
+ graph << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, operation_layout)).set_name("predictions/AvgPool")
<< FlattenLayer().set_name("predictions/Reshape")
<< FullyConnectedLayer(
1000U,
@@ -181,7 +182,7 @@ private:
if(stride == 2)
{
- right_ss << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, 3, PadStrideInfo(2, 2, 1, 1))).set_name(unit_name + "/pool_1/AveragePool");
+ right_ss << PoolingLayer(PoolingLayerInfo(PoolingType::AVG, 3, common_params.data_layout, PadStrideInfo(2, 2, 1, 1))).set_name(unit_name + "/pool_1/AveragePool");
dwc_info = PadStrideInfo(2, 2, 1, 1);
}