aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/frontend/Layers.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/frontend/Layers.h')
-rw-r--r--arm_compute/graph/frontend/Layers.h20
1 files changed, 12 insertions, 8 deletions
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 78a3f20f1f..d0703317cd 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -414,24 +414,27 @@ class DepthwiseConvolutionLayer final : public ILayer
public:
/** Construct a depthwise convolution layer.
*
- * @param[in] conv_width Convolution width.
- * @param[in] conv_height Convolution height.
- * @param[in] weights Accessor to get kernel weights from.
- * @param[in] bias Accessor to get kernel bias from.
- * @param[in] conv_info Padding and stride information.
- * @param[in] quant_info (Optional) Quantization info used for weights
+ * @param[in] conv_width Convolution width.
+ * @param[in] conv_height Convolution height.
+ * @param[in] weights Accessor to get kernel weights from.
+ * @param[in] bias Accessor to get kernel bias from.
+ * @param[in] conv_info Padding and stride information.
+ * @param[in] depth_multiplier (Optional) Depth multiplier parameter.
+ * @param[in] quant_info (Optional) Quantization info used for weights
*/
DepthwiseConvolutionLayer(unsigned int conv_width,
unsigned int conv_height,
ITensorAccessorUPtr weights,
ITensorAccessorUPtr bias,
PadStrideInfo conv_info,
- const QuantizationInfo quant_info = QuantizationInfo())
+ int depth_multiplier = 1,
+ const QuantizationInfo quant_info = QuantizationInfo())
: _conv_width(conv_width),
_conv_height(conv_height),
_conv_info(std::move(conv_info)),
_weights(std::move(weights)),
_bias(std::move(bias)),
+ _depth_multiplier(depth_multiplier),
_quant_info(std::move(quant_info))
{
}
@@ -441,7 +444,7 @@ public:
NodeIdxPair input = { s.tail_node(), 0 };
NodeParams common_params = { name(), s.hints().target_hint };
return GraphBuilder::add_depthwise_convolution_node(s.graph(), common_params,
- input, Size2D(_conv_width, _conv_height), _conv_info,
+ input, Size2D(_conv_width, _conv_height), _conv_info, _depth_multiplier,
s.hints().depthwise_convolution_method_hint,
std::move(_weights), std::move(_bias), std::move(_quant_info));
}
@@ -452,6 +455,7 @@ private:
const PadStrideInfo _conv_info;
ITensorAccessorUPtr _weights;
ITensorAccessorUPtr _bias;
+ int _depth_multiplier;
const QuantizationInfo _quant_info;
};