From 104fbd7b533c40f19465c85e884f10ae500e639e Mon Sep 17 00:00:00 2001 From: Sang-Hoon Park Date: Fri, 13 Mar 2020 11:31:53 +0000 Subject: COMPMID-3221: Add DeconvolutionLayerDescriptor A new struct for DeconvolutionLayerNode is added for better extendability. Change-Id: I935277e8073a8295de7b0059b946cb637085f1ff Signed-off-by: Sang-Hoon Park Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2883 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- arm_compute/graph/LayerDescriptors.h | 16 ++++++++++++++++ arm_compute/graph/nodes/DeconvolutionLayerNode.h | 9 ++++----- examples/graph_edsr.h | 15 ++++++++------- src/graph/GraphBuilder.cpp | 2 +- src/graph/nodes/DeconvolutionLayerNode.cpp | 12 ++++++------ 5 files changed, 35 insertions(+), 19 deletions(-) diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h index af69682fa3..0cf203174e 100644 --- a/arm_compute/graph/LayerDescriptors.h +++ b/arm_compute/graph/LayerDescriptors.h @@ -86,6 +86,22 @@ struct EltwiseLayerDescriptor RoundingPolicy r_policy; /**< Rounding policy */ }; +/** Deconvolution layer descriptor */ +struct DeconvolutionLayerDescriptor +{ + /** Constructor + * + * @param[in] info Dedonvolution layer attributes + * @param[in] out_quant_info (Optional) Output quantization infomation + */ + DeconvolutionLayerDescriptor(PadStrideInfo info, QuantizationInfo out_quant_info = QuantizationInfo()) + : info(info), out_quant_info(out_quant_info) + { + } + + PadStrideInfo info; /**< Padding and stride information */ + QuantizationInfo out_quant_info; /**< Output quantization information */ +}; } // namespace descriptor } // namespace graph } // namespace arm_compute diff --git a/arm_compute/graph/nodes/DeconvolutionLayerNode.h b/arm_compute/graph/nodes/DeconvolutionLayerNode.h index 5633898366..a5efdfb3bc 100644 --- a/arm_compute/graph/nodes/DeconvolutionLayerNode.h +++ b/arm_compute/graph/nodes/DeconvolutionLayerNode.h @@ -25,6 +25,7 @@ #define ARM_COMPUTE_GRAPH_DECONVOLUTION_LAYER_NODE_H #include "arm_compute/graph/INode.h" +#include "arm_compute/graph/LayerDescriptors.h" namespace arm_compute { @@ -36,10 +37,9 @@ class DeconvolutionLayerNode final : public INode public: /** Constructor * - * @param[in] info DeConvolution layer attributes - * @param[in] out_quant_info (Optional) Output quantization infomation + * @param[in] descriptor Contains information used by this layer described in @ref descriptors::DeconvolutionLayerDescriptor */ - DeconvolutionLayerNode(PadStrideInfo info, QuantizationInfo out_quant_info = QuantizationInfo()); + DeconvolutionLayerNode(const descriptors::DeconvolutionLayerDescriptor &descriptor); /** Deconvolution metadata accessor * * @return Deconvolution information @@ -64,8 +64,7 @@ public: void accept(INodeVisitor &v) override; private: - PadStrideInfo _info; - QuantizationInfo _out_quant_info; + descriptors::DeconvolutionLayerDescriptor descriptor; }; } // namespace graph } // namespace arm_compute diff --git a/examples/graph_edsr.h b/examples/graph_edsr.h index e31cc8940a..cb467d0377 100644 --- a/examples/graph_edsr.h +++ b/examples/graph_edsr.h @@ -1245,13 +1245,14 @@ public: _graph.add_connection(id_pre_upscale_Conv2D_bias, 0, id_pre_upscale_BiasAdd, 2); NodeID id_upscale_net_FakeQuantWithMinMaxVars_1 = _graph.add_node( - PadStrideInfo - { - 2, 2, - 0, 0, - 0, 0, - DimensionRoundingType::FLOOR }, - QuantizationInfo{ 0.004990961868315935, 26 }); + descriptors::DeconvolutionLayerDescriptor + { + PadStrideInfo{ + 2, 2, + 0, 0, + 0, 0, + DimensionRoundingType::FLOOR }, + QuantizationInfo{ 0.004990961868315935, 26 } }); INode *node_upscale_net_FakeQuantWithMinMaxVars_1 = _graph.node(id_upscale_net_FakeQuantWithMinMaxVars_1); node_upscale_net_FakeQuantWithMinMaxVars_1->set_common_node_parameters(NodeParams{ "upscale_net_FakeQuantWithMinMaxVars_1", target }); _graph.add_connection(id_pre_upscale_BiasAdd, 0, id_upscale_net_FakeQuantWithMinMaxVars_1, 0); diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp index e429817d50..218e6ce62d 100644 --- a/src/graph/GraphBuilder.cpp +++ b/src/graph/GraphBuilder.cpp @@ -306,7 +306,7 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx } // Create convolution node and connect - NodeID deconv_nid = g.add_node(deconv_info); + NodeID deconv_nid = g.add_node(descriptors::DeconvolutionLayerDescriptor{ deconv_info }); g.add_connection(input.node_id, input.index, deconv_nid, 0); g.add_connection(w_nid, 0, deconv_nid, 1); if(has_bias) diff --git a/src/graph/nodes/DeconvolutionLayerNode.cpp b/src/graph/nodes/DeconvolutionLayerNode.cpp index a2e4e2b056..2daeaaccf7 100644 --- a/src/graph/nodes/DeconvolutionLayerNode.cpp +++ b/src/graph/nodes/DeconvolutionLayerNode.cpp @@ -32,8 +32,8 @@ namespace arm_compute { namespace graph { -DeconvolutionLayerNode::DeconvolutionLayerNode(PadStrideInfo info, QuantizationInfo out_quant_info) - : _info(std::move(info)), _out_quant_info(std::move(out_quant_info)) +DeconvolutionLayerNode::DeconvolutionLayerNode(const descriptors::DeconvolutionLayerDescriptor &descriptor) + : descriptor(std::move(descriptor)) { _input_edges.resize(3, EmptyEdgeID); _outputs.resize(1, NullTensorID); @@ -41,7 +41,7 @@ DeconvolutionLayerNode::DeconvolutionLayerNode(PadStrideInfo info, QuantizationI PadStrideInfo DeconvolutionLayerNode::deconvolution_info() const { - return _info; + return descriptor.info; } TensorDescriptor DeconvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor, @@ -87,11 +87,11 @@ TensorDescriptor DeconvolutionLayerNode::configure_output(size_t idx) const ARM_COMPUTE_ERROR_ON(src == nullptr || weights == nullptr); - TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), _info); + TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), descriptor.info); - if(!_out_quant_info.empty()) + if(!descriptor.out_quant_info.empty()) { - output_info.set_quantization_info(_out_quant_info); + output_info.set_quantization_info(descriptor.out_quant_info); } return output_info; -- cgit v1.2.1