From be058f4b5fb2383e0fa4666a92c43ca994486c86 Mon Sep 17 00:00:00 2001 From: thecha01 Date: Wed, 2 Sep 2020 16:18:25 +0100 Subject: Use shape broadcast for Mult inputs in EDSR graph We no longer have to explicitly create a tensor with the correct dimensions for the Const nodes, instead we use the graph API shape propogation logic in EltwiseLayerNode to broadcast the shapes Signed-off-by: thecha01 Change-Id: Ifb62b572d6391850d3357cd3307cef7cd9645ee3 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3898 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- examples/graph_edsr.h | 36 ++++++++++++++++-------------------- 1 file changed, 16 insertions(+), 20 deletions(-) (limited to 'examples/graph_edsr.h') diff --git a/examples/graph_edsr.h b/examples/graph_edsr.h index 42a2789861..72012afdcb 100644 --- a/examples/graph_edsr.h +++ b/examples/graph_edsr.h @@ -105,14 +105,10 @@ public: node_post_residual_FakeQuantWithMinMaxVars->output(0)->set_accessor(get_weights_accessor(data_path, "/cnn_data/edsr_model/post_residual_FakeQuantWithMinMaxVars.npy", DataLayout::NHWC)); - TensorShape scalar_4d_shape{}; - - scalar_4d_shape.set(0, 1, false).set(1, 1, false).set(2, 1, false).set(3, 1, false); - NodeID id_mul_15_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -146,7 +142,7 @@ public: NodeID id_mul_14_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -180,7 +176,7 @@ public: NodeID id_mul_13_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -214,7 +210,7 @@ public: NodeID id_mul_12_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -248,7 +244,7 @@ public: NodeID id_mul_11_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -282,7 +278,7 @@ public: NodeID id_mul_10_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -316,7 +312,7 @@ public: NodeID id_mul_9_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -350,7 +346,7 @@ public: NodeID id_mul_8_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -384,7 +380,7 @@ public: NodeID id_mul_7_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -418,7 +414,7 @@ public: NodeID id_mul_6_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -452,7 +448,7 @@ public: NodeID id_mul_5_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -486,7 +482,7 @@ public: NodeID id_mul_4_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -520,7 +516,7 @@ public: NodeID id_mul_3_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -554,7 +550,7 @@ public: NodeID id_mul_2_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -588,7 +584,7 @@ public: NodeID id_mul_1_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); @@ -622,7 +618,7 @@ public: NodeID id_mul_y = _graph.add_node( TensorDescriptor { - scalar_4d_shape, + TensorShape{ 1 }, DataType::QASYMM8, QuantizationInfo(0.0003921568568330258), DataLayout::NHWC }); -- cgit v1.2.1