aboutsummaryrefslogtreecommitdiff
path: root/examples/graph_edsr.h
diff options
context:
space:
mode:
authorthecha01 <theo.charalambous@arm.com>2020-09-02 16:18:25 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-09-04 15:54:27 +0000
commitbe058f4b5fb2383e0fa4666a92c43ca994486c86 (patch)
tree3748a2665d4254a44f3ac905e6935c2829095d2a /examples/graph_edsr.h
parentb3182b19251cd010baad8252e7607de7059ac986 (diff)
downloadComputeLibrary-be058f4b5fb2383e0fa4666a92c43ca994486c86.tar.gz
Use shape broadcast for Mult inputs in EDSR graph
We no longer have to explicitly create a tensor with the correct dimensions for the Const nodes, instead we use the graph API shape propogation logic in EltwiseLayerNode to broadcast the shapes Signed-off-by: thecha01 <theo.charalambous@arm.com> Change-Id: Ifb62b572d6391850d3357cd3307cef7cd9645ee3 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3898 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'examples/graph_edsr.h')
-rw-r--r--examples/graph_edsr.h36
1 files changed, 16 insertions, 20 deletions
diff --git a/examples/graph_edsr.h b/examples/graph_edsr.h
index 42a2789861..72012afdcb 100644
--- a/examples/graph_edsr.h
+++ b/examples/graph_edsr.h
@@ -105,14 +105,10 @@ public:
node_post_residual_FakeQuantWithMinMaxVars->output(0)->set_accessor(get_weights_accessor(data_path, "/cnn_data/edsr_model/post_residual_FakeQuantWithMinMaxVars.npy",
DataLayout::NHWC));
- TensorShape scalar_4d_shape{};
-
- scalar_4d_shape.set(0, 1, false).set(1, 1, false).set(2, 1, false).set(3, 1, false);
-
NodeID id_mul_15_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -146,7 +142,7 @@ public:
NodeID id_mul_14_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -180,7 +176,7 @@ public:
NodeID id_mul_13_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -214,7 +210,7 @@ public:
NodeID id_mul_12_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -248,7 +244,7 @@ public:
NodeID id_mul_11_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -282,7 +278,7 @@ public:
NodeID id_mul_10_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -316,7 +312,7 @@ public:
NodeID id_mul_9_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -350,7 +346,7 @@ public:
NodeID id_mul_8_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -384,7 +380,7 @@ public:
NodeID id_mul_7_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -418,7 +414,7 @@ public:
NodeID id_mul_6_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -452,7 +448,7 @@ public:
NodeID id_mul_5_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -486,7 +482,7 @@ public:
NodeID id_mul_4_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -520,7 +516,7 @@ public:
NodeID id_mul_3_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -554,7 +550,7 @@ public:
NodeID id_mul_2_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -588,7 +584,7 @@ public:
NodeID id_mul_1_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });
@@ -622,7 +618,7 @@ public:
NodeID id_mul_y = _graph.add_node<ConstNode>(
TensorDescriptor
{
- scalar_4d_shape,
+ TensorShape{ 1 },
DataType::QASYMM8,
QuantizationInfo(0.0003921568568330258),
DataLayout::NHWC });