aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/backends
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2019-09-26 17:18:26 +0100
committerManuel Bottini <manuel.bottini@arm.com>2019-10-23 16:56:45 +0000
commit05069f07bcf95676597698a79926327555276362 (patch)
treea4a861127660aa439c9468da7479d92cecc85138 /arm_compute/graph/backends
parente36b5266e4c6593932432bc0289e431d007b8710 (diff)
downloadComputeLibrary-05069f07bcf95676597698a79926327555276362.tar.gz
COMPMID-2515: Merge optimized depthwise convolution to the generic depthwise convolution function
3RDPARTY_UPDATE Change-Id: Iff9e915c5329c617527b6f5042979f4e21a8b2b8 Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/2022 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'arm_compute/graph/backends')
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h25
-rw-r--r--arm_compute/graph/backends/ValidateHelpers.h8
2 files changed, 10 insertions, 23 deletions
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index 94b385e81e..ee257e3abf 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -538,7 +538,7 @@ std::unique_ptr<IFunction> create_deconvolution_layer(DeconvolutionLayerNode &no
*
* @return Backend depth-wise convolution layer function
*/
-template <typename DepthwiseConvolutionLayerFunctions, typename TargetInfo>
+template <typename DepthwiseConvolutionLayer, typename TargetInfo>
std::unique_ptr<IFunction> create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
{
validate_node<TargetInfo>(node, 3 /* expected inputs */, 1 /* expected outputs */);
@@ -556,26 +556,17 @@ std::unique_ptr<IFunction> create_depthwise_convolution_layer(DepthwiseConvoluti
biases->info()->set_data_type(DataType::S32);
}
- const PadStrideInfo conv_info = node.convolution_info();
- const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
- const unsigned int depth_multiplier = node.depth_multiplier();
- const ActivationLayerInfo fused_act = node.fused_activation();
+ const PadStrideInfo conv_info = node.convolution_info();
+ const unsigned int depth_multiplier = node.depth_multiplier();
+ const ActivationLayerInfo fused_act = node.fused_activation();
// Create and configure function (we assume that functions have been validated before creation)
std::unique_ptr<IFunction> func;
std::string func_name;
- if(dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3)
- {
- std::tie(func, func_name) = create_named_function<typename DepthwiseConvolutionLayerFunctions::OptimizedDepthwiseConvolutionLayer>(
- std::string("DepthwiseConvolutionLayer3x3"),
- input, weights, biases, output, conv_info, depth_multiplier, fused_act);
- }
- else
- {
- std::tie(func, func_name) = create_named_function<typename DepthwiseConvolutionLayerFunctions::GenericDepthwiseConvolutionLayer>(
- std::string("DepthwiseConvolutionLayer"),
- input, weights, biases, output, conv_info, depth_multiplier, fused_act);
- }
+
+ std::tie(func, func_name) = create_named_function<DepthwiseConvolutionLayer>(
+ std::string("DepthwiseConvolutionLayer"),
+ input, weights, biases, output, conv_info, depth_multiplier, fused_act);
// Log info
std::ostringstream qss;
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h
index 13de273bdf..9170006d9c 100644
--- a/arm_compute/graph/backends/ValidateHelpers.h
+++ b/arm_compute/graph/backends/ValidateHelpers.h
@@ -163,13 +163,12 @@ Status validate_convolution_layer(ConvolutionLayerNode &node)
/** Validates a Depthwise Convolution layer node
*
* @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type
- * @tparam DepthwiseConvolutionLayer3x3 Optimized 3x3 Depthwise Convolution layer type
*
* @param[in] node Node to validate
*
* @return Status
*/
-template <typename DepthwiseConvolutionLayer, typename DepthwiseConvolutionLayer3x3>
+template <typename DepthwiseConvolutionLayer>
Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
@@ -191,11 +190,8 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
switch(dwc_algorithm)
{
case DepthwiseConvolutionMethod::Default:
- case DepthwiseConvolutionMethod::GEMV:
- status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info, depth_multiplier);
- break;
case DepthwiseConvolutionMethod::Optimized3x3:
- status = DepthwiseConvolutionLayer3x3::validate(input, weights, biases, output, conv_info, depth_multiplier);
+ status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info, depth_multiplier);
break;
default:
ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported depthwise convolution method");