aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h')
-rw-r--r--arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h19
1 files changed, 15 insertions, 4 deletions
diff --git a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
index 7fa44b798f..8c0aae13c9 100644
--- a/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
+++ b/arm_compute/graph/nodes/DepthwiseConvolutionLayerNode.h
@@ -36,10 +36,13 @@ class DepthwiseConvolutionLayerNode final : public INode
public:
/** Constructor
*
- * @param[in] info Convolution layer attributes
- * @param[in] method Depthwise convolution method to use
+ * @param[in] info Convolution layer attributes
+ * @param[in] depth_multiplier (Optional) Depth multiplier parameter.
+ * @param[in] method (Optional) Depthwise convolution method to use
*/
- DepthwiseConvolutionLayerNode(PadStrideInfo info, DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::Default);
+ DepthwiseConvolutionLayerNode(PadStrideInfo info,
+ int depth_multiplier = 1,
+ DepthwiseConvolutionMethod method = DepthwiseConvolutionMethod::Default);
/** Sets the depthwise convolution method to use
*
* @param[in] method Depthwise convolution method to use
@@ -53,6 +56,11 @@ public:
* @return Depthwise convolution layer method do be used by the node
*/
DepthwiseConvolutionMethod depthwise_convolution_method() const;
+ /** Depth multiplier accessor
+ *
+ * @return Depth multiplier
+ */
+ int depth_multiplier() const;
/** Convolution metadata accessor
*
* @return Convolution information
@@ -73,12 +81,14 @@ public:
* @param[in] input_descriptor Input descriptor
* @param[in] weights_descriptor Weights descriptor
* @param[in] info Convolution operation attributes
+ * @param[in] depth_multiplier (Optional) Depth multiplier parameter.
*
* @return Output descriptor
*/
static TensorDescriptor compute_output_descriptor(const TensorDescriptor &input_descriptor,
const TensorDescriptor &weights_descriptor,
- const PadStrideInfo &info);
+ const PadStrideInfo &info,
+ int depth_multiplier = 1);
// Inherited overridden methods:
NodeType type() const override;
@@ -91,6 +101,7 @@ public:
private:
PadStrideInfo _info;
+ int _depth_multiplier;
DepthwiseConvolutionMethod _method;
ActivationLayerInfo _fused_activation;
};