From e2220551b7a64b929650ba9a60529c31e70c13c5 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 20 Jul 2018 13:23:44 +0100 Subject: COMPMID-1367: Enable NHWC in graph examples Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena Reviewed-by: Anthony Barbier Tested-by: Jenkins --- src/graph/backends/GLES/GCFunctionsFactory.cpp | 50 ++++++++++++++++++++++---- 1 file changed, 43 insertions(+), 7 deletions(-) (limited to 'src/graph/backends/GLES/GCFunctionsFactory.cpp') diff --git a/src/graph/backends/GLES/GCFunctionsFactory.cpp b/src/graph/backends/GLES/GCFunctionsFactory.cpp index e6bd5a5f02..f72513c87c 100644 --- a/src/graph/backends/GLES/GCFunctionsFactory.cpp +++ b/src/graph/backends/GLES/GCFunctionsFactory.cpp @@ -68,6 +68,42 @@ struct GCEltwiseFunctions namespace detail { +// Specialize functions +template <> +std::unique_ptr create_concatenate_layer(ConcatenateLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating Concatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); + + // Return nullptr if depth concatenate is switched off + if(!node.is_enabled()) + { + return nullptr; + } + + // Extract IO and info + std::vector inputs; + for(unsigned int i = 0; i < node.num_inputs(); ++i) + { + inputs.push_back(get_backing_tensor(node.input(i))); + } + typename GCTargetInfo::TensorType *output = get_backing_tensor(node.output(0)); + + // Create and configure function + auto func = support::cpp14::make_unique(); + func->configure(inputs, output); + + // Log info + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() + << " Target " << GCTargetInfo::TargetType + << " Data Type: " << output->info()->data_type() + << " Shape: " << output->info()->tensor_shape() + << " Num Inputs: " << inputs.size() + << std::endl); + + return std::move(func); +} + template <> std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx) { @@ -92,7 +128,7 @@ std::unique_ptr create_convolution_layer func; std::string func_name; - if(conv_algorithm == ConvolutionMethod::DIRECT) + if(conv_algorithm == ConvolutionMethod::Direct) { std::tie(func, func_name) = create_named_function( std::string("DirectConvolutionLayer"), @@ -139,7 +175,7 @@ std::unique_ptr create_depthwise_convolution_layer func; std::string func_name; - if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) + if(dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3) { std::tie(func, func_name) = create_named_function( std::string("DepthwiseConvolutionLayer3x3"), @@ -183,17 +219,17 @@ std::unique_ptr create_eltwise_layer func = nullptr; std::string func_name; - if(eltwise_op == EltwiseOperation::ADD) + if(eltwise_op == EltwiseOperation::Add) { std::tie(func, func_name) = create_named_function( std::string("GCArithmeticAddition"), input1, input2, output, convert_policy); } - else if(eltwise_op == EltwiseOperation::SUB) + else if(eltwise_op == EltwiseOperation::Sub) { ARM_COMPUTE_ERROR("Arithmetic subtraction is not supported in GLES backend"); } - else if(eltwise_op == EltwiseOperation::MUL) + else if(eltwise_op == EltwiseOperation::Mul) { std::tie(func, func_name) = create_named_function( std::string("PixelWiseMultiplication"), @@ -232,8 +268,8 @@ std::unique_ptr GCFunctionFactory::create(INode *node, GraphContext & return detail::create_batch_normalization_layer(*polymorphic_downcast(node)); case NodeType::ConvolutionLayer: return detail::create_convolution_layer(*polymorphic_downcast(node), ctx); - case NodeType::DepthConcatenateLayer: - return detail::create_depth_concatenate_layer(*polymorphic_downcast(node)); + case NodeType::ConcatenateLayer: + return detail::create_concatenate_layer(*polymorphic_downcast(node)); case NodeType::DepthwiseConvolutionLayer: return detail::create_depthwise_convolution_layer(*polymorphic_downcast(node)); case NodeType::EltwiseLayer: -- cgit v1.2.1