aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2020-03-13 11:31:53 +0000
committerSang-Hoon Park <sang-hoon.park@arm.com>2020-03-13 16:11:38 +0000
commit104fbd7b533c40f19465c85e884f10ae500e639e (patch)
tree10b6fb262fe72e5afeb13ba5b8f491fb3c1a825e
parent797b76b1aef38ea3be6f68ae2bf323048e9beff8 (diff)
downloadComputeLibrary-104fbd7b533c40f19465c85e884f10ae500e639e.tar.gz
COMPMID-3221: Add DeconvolutionLayerDescriptor
A new struct for DeconvolutionLayerNode is added for better extendability. Change-Id: I935277e8073a8295de7b0059b946cb637085f1ff Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2883 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/graph/LayerDescriptors.h16
-rw-r--r--arm_compute/graph/nodes/DeconvolutionLayerNode.h9
-rw-r--r--examples/graph_edsr.h15
-rw-r--r--src/graph/GraphBuilder.cpp2
-rw-r--r--src/graph/nodes/DeconvolutionLayerNode.cpp12
5 files changed, 35 insertions, 19 deletions
diff --git a/arm_compute/graph/LayerDescriptors.h b/arm_compute/graph/LayerDescriptors.h
index af69682fa3..0cf203174e 100644
--- a/arm_compute/graph/LayerDescriptors.h
+++ b/arm_compute/graph/LayerDescriptors.h
@@ -86,6 +86,22 @@ struct EltwiseLayerDescriptor
RoundingPolicy r_policy; /**< Rounding policy */
};
+/** Deconvolution layer descriptor */
+struct DeconvolutionLayerDescriptor
+{
+ /** Constructor
+ *
+ * @param[in] info Dedonvolution layer attributes
+ * @param[in] out_quant_info (Optional) Output quantization infomation
+ */
+ DeconvolutionLayerDescriptor(PadStrideInfo info, QuantizationInfo out_quant_info = QuantizationInfo())
+ : info(info), out_quant_info(out_quant_info)
+ {
+ }
+
+ PadStrideInfo info; /**< Padding and stride information */
+ QuantizationInfo out_quant_info; /**< Output quantization information */
+};
} // namespace descriptor
} // namespace graph
} // namespace arm_compute
diff --git a/arm_compute/graph/nodes/DeconvolutionLayerNode.h b/arm_compute/graph/nodes/DeconvolutionLayerNode.h
index 5633898366..a5efdfb3bc 100644
--- a/arm_compute/graph/nodes/DeconvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/DeconvolutionLayerNode.h
@@ -25,6 +25,7 @@
#define ARM_COMPUTE_GRAPH_DECONVOLUTION_LAYER_NODE_H
#include "arm_compute/graph/INode.h"
+#include "arm_compute/graph/LayerDescriptors.h"
namespace arm_compute
{
@@ -36,10 +37,9 @@ class DeconvolutionLayerNode final : public INode
public:
/** Constructor
*
- * @param[in] info DeConvolution layer attributes
- * @param[in] out_quant_info (Optional) Output quantization infomation
+ * @param[in] descriptor Contains information used by this layer described in @ref descriptors::DeconvolutionLayerDescriptor
*/
- DeconvolutionLayerNode(PadStrideInfo info, QuantizationInfo out_quant_info = QuantizationInfo());
+ DeconvolutionLayerNode(const descriptors::DeconvolutionLayerDescriptor &descriptor);
/** Deconvolution metadata accessor
*
* @return Deconvolution information
@@ -64,8 +64,7 @@ public:
void accept(INodeVisitor &v) override;
private:
- PadStrideInfo _info;
- QuantizationInfo _out_quant_info;
+ descriptors::DeconvolutionLayerDescriptor descriptor;
};
} // namespace graph
} // namespace arm_compute
diff --git a/examples/graph_edsr.h b/examples/graph_edsr.h
index e31cc8940a..cb467d0377 100644
--- a/examples/graph_edsr.h
+++ b/examples/graph_edsr.h
@@ -1245,13 +1245,14 @@ public:
_graph.add_connection(id_pre_upscale_Conv2D_bias, 0, id_pre_upscale_BiasAdd, 2);
NodeID id_upscale_net_FakeQuantWithMinMaxVars_1 = _graph.add_node<DeconvolutionLayerNode>(
- PadStrideInfo
- {
- 2, 2,
- 0, 0,
- 0, 0,
- DimensionRoundingType::FLOOR },
- QuantizationInfo{ 0.004990961868315935, 26 });
+ descriptors::DeconvolutionLayerDescriptor
+ {
+ PadStrideInfo{
+ 2, 2,
+ 0, 0,
+ 0, 0,
+ DimensionRoundingType::FLOOR },
+ QuantizationInfo{ 0.004990961868315935, 26 } });
INode *node_upscale_net_FakeQuantWithMinMaxVars_1 = _graph.node(id_upscale_net_FakeQuantWithMinMaxVars_1);
node_upscale_net_FakeQuantWithMinMaxVars_1->set_common_node_parameters(NodeParams{ "upscale_net_FakeQuantWithMinMaxVars_1", target });
_graph.add_connection(id_pre_upscale_BiasAdd, 0, id_upscale_net_FakeQuantWithMinMaxVars_1, 0);
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index e429817d50..218e6ce62d 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -306,7 +306,7 @@ NodeID GraphBuilder::add_deconvolution_node(Graph &g, NodeParams params, NodeIdx
}
// Create convolution node and connect
- NodeID deconv_nid = g.add_node<DeconvolutionLayerNode>(deconv_info);
+ NodeID deconv_nid = g.add_node<DeconvolutionLayerNode>(descriptors::DeconvolutionLayerDescriptor{ deconv_info });
g.add_connection(input.node_id, input.index, deconv_nid, 0);
g.add_connection(w_nid, 0, deconv_nid, 1);
if(has_bias)
diff --git a/src/graph/nodes/DeconvolutionLayerNode.cpp b/src/graph/nodes/DeconvolutionLayerNode.cpp
index a2e4e2b056..2daeaaccf7 100644
--- a/src/graph/nodes/DeconvolutionLayerNode.cpp
+++ b/src/graph/nodes/DeconvolutionLayerNode.cpp
@@ -32,8 +32,8 @@ namespace arm_compute
{
namespace graph
{
-DeconvolutionLayerNode::DeconvolutionLayerNode(PadStrideInfo info, QuantizationInfo out_quant_info)
- : _info(std::move(info)), _out_quant_info(std::move(out_quant_info))
+DeconvolutionLayerNode::DeconvolutionLayerNode(const descriptors::DeconvolutionLayerDescriptor &descriptor)
+ : descriptor(std::move(descriptor))
{
_input_edges.resize(3, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
@@ -41,7 +41,7 @@ DeconvolutionLayerNode::DeconvolutionLayerNode(PadStrideInfo info, QuantizationI
PadStrideInfo DeconvolutionLayerNode::deconvolution_info() const
{
- return _info;
+ return descriptor.info;
}
TensorDescriptor DeconvolutionLayerNode::compute_output_descriptor(const TensorDescriptor &input_descriptor,
@@ -87,11 +87,11 @@ TensorDescriptor DeconvolutionLayerNode::configure_output(size_t idx) const
ARM_COMPUTE_ERROR_ON(src == nullptr || weights == nullptr);
- TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), _info);
+ TensorDescriptor output_info = compute_output_descriptor(src->desc(), weights->desc(), descriptor.info);
- if(!_out_quant_info.empty())
+ if(!descriptor.out_quant_info.empty())
{
- output_info.set_quantization_info(_out_quant_info);
+ output_info.set_quantization_info(descriptor.out_quant_info);
}
return output_info;