aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/nodes/ConvolutionLayerNode.h
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-08-15 12:14:46 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commit2a2db590fd179dcb8e1a575293cd2b887e2dc246 (patch)
tree5e10da7cb6777f3020b84a2389b279ceef2be5ee /arm_compute/graph/nodes/ConvolutionLayerNode.h
parentc1961b51df2e15a01a5950139e81bbd47fbfa627 (diff)
downloadComputeLibrary-2a2db590fd179dcb8e1a575293cd2b887e2dc246.tar.gz
COMPMID-1505: Add native grouping support at graph level
Change-Id: Iedc91b0aee743b59af5140c8acb8124548da3163 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/144362 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-by: Michele DiGiorgio <michele.digiorgio@arm.com>
Diffstat (limited to 'arm_compute/graph/nodes/ConvolutionLayerNode.h')
-rw-r--r--arm_compute/graph/nodes/ConvolutionLayerNode.h8
1 files changed, 8 insertions, 0 deletions
diff --git a/arm_compute/graph/nodes/ConvolutionLayerNode.h b/arm_compute/graph/nodes/ConvolutionLayerNode.h
index 4299be6bb5..0698ac1360 100644
--- a/arm_compute/graph/nodes/ConvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/ConvolutionLayerNode.h
@@ -37,11 +37,13 @@ public:
/** Constructor
*
* @param[in] info Convolution layer attributes
+ * @param[in] num_groups (Optional) Number of groups (Defaults to 1)
* @param[in] method (Optional) Convolution method to use
* @param[in] fast_math_hint (Optional) Fast math hint
* @param[in] out_quant_info (Optional) Output quantization info
*/
ConvolutionLayerNode(PadStrideInfo info,
+ unsigned int num_groups = 1,
ConvolutionMethod method = ConvolutionMethod::Default,
FastMathHint fast_math_hint = FastMathHint::Disabled,
QuantizationInfo out_quant_info = QuantizationInfo());
@@ -73,6 +75,11 @@ public:
* @return Convolution information
*/
PadStrideInfo convolution_info() const;
+ /** Number of groups in convolution accessor
+ *
+ * @return Number of groups in convolution
+ */
+ unsigned int num_groups() const;
/** Computes convolution output descriptor
*
* @param[in] input_descriptor Input descriptor
@@ -93,6 +100,7 @@ public:
private:
PadStrideInfo _info;
+ unsigned int _num_groups;
ConvolutionMethod _method;
FastMathHint _fast_math_hint;
QuantizationInfo _out_quant_info;