From da2491fb6d3cefb69846f220356fff282486495c Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 1 Jun 2018 17:49:09 +0100 Subject: COMPMID-1151: Templatize FunctionFactories. Change-Id: Id1c68c3bf442c3fcff265041b260d007db7593cb Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/134027 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/graph/backends/CL/CLFunctionsFactory.cpp | 683 ++----------------------- src/graph/backends/GLES/GCFunctionsFactory.cpp | 425 +++------------ src/graph/backends/NEON/NEFunctionFactory.cpp | 587 +++------------------ 3 files changed, 189 insertions(+), 1506 deletions(-) (limited to 'src/graph') diff --git a/src/graph/backends/CL/CLFunctionsFactory.cpp b/src/graph/backends/CL/CLFunctionsFactory.cpp index 90ea81f21a..4d6734846a 100644 --- a/src/graph/backends/CL/CLFunctionsFactory.cpp +++ b/src/graph/backends/CL/CLFunctionsFactory.cpp @@ -25,16 +25,9 @@ #include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/graph/Graph.h" -#include "arm_compute/graph/GraphContext.h" -#include "arm_compute/graph/Logger.h" -#include "arm_compute/graph/TypePrinter.h" -#include "arm_compute/graph/Types.h" -#include "arm_compute/graph/backends/Utils.h" -#include "arm_compute/graph/nodes/Nodes.h" +#include "arm_compute/graph/backends/FunctionHelpers.h" #include "arm_compute/runtime/CL/CLFunctions.h" -#include "support/ToolchainSupport.h" - using namespace arm_compute::utils::cast; namespace arm_compute @@ -43,634 +36,38 @@ namespace graph { namespace backends { -namespace -{ -/** Returns backing tensor of a given tensor - * - * @param[in] tensor Tensor to extract the backing tensor from - * - * @return Backing tensor if present else nullptr - */ -arm_compute::ICLTensor *get_backing_tensor(arm_compute::graph::Tensor *tensor) -{ - arm_compute::ICLTensor *backing_tensor = nullptr; - if(tensor != nullptr) - { - ARM_COMPUTE_ERROR_ON(tensor->desc().target != arm_compute::graph::Target::CL); - // Get backing tensor handle - ITensorHandle *tensor_handle = tensor->handle(); - // Get backing tensor - backing_tensor = (tensor_handle != nullptr) ? polymorphic_cast(&tensor_handle->tensor()) : nullptr; - } - - return backing_tensor; -} - -/** Create a backend activation layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend activation layer function - */ -std::unique_ptr create_activation_layer(ActivationLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL ActivationLayerNode node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const ActivationLayerInfo act_info = node.activation_info(); - - // Create function - auto func = support::cpp14::make_unique(); - func->configure(input, output, act_info); - - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLActivationLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Activation function: " << act_info.activation() - << " a: " << act_info.a() - << " b: " << act_info.b() - << " InPlace : " << is_in_place_operation(input, output) - << std::endl); - - return std::move(func); -} - -/** Create a backend batch normalization layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend batch normalization layer function - */ -std::unique_ptr create_batch_normalization_layer(BatchNormalizationLayerNode &node) +/** Target specific information structure used to pass information to the layer templates */ +struct CLTargetInfo { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL BatchNormalization node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - - // TODO (geopin01) : Var and mean are compulsory, switch function to accept nullptr as beta and/or gamma - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 5); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *mean = get_backing_tensor(node.input(1)); - ICLTensor *var = get_backing_tensor(node.input(2)); - ICLTensor *beta = get_backing_tensor(node.input(3)); - ICLTensor *gamma = get_backing_tensor(node.input(4)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const float epsilon = node.epsilon(); - const ActivationLayerInfo fused_act = node.fused_activation(); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, mean, var, beta, gamma, epsilon, fused_act); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLBatchNormalizationLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Epsilon: " << epsilon << " " - << (fused_act.enabled() ? to_string(fused_act.activation()) : "") - << " InPlace : " << is_in_place_operation(input, output) - << std::endl); - - return std::move(func); -} - -/** Create a backend channel shuffle layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend channel shuffle layer function - */ -std::unique_ptr create_channel_shuffle_layer(ChannelShuffleLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL Channel Shuffle node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const unsigned int num_groups = node.num_groups(); - - // Create function - auto func = support::cpp14::make_unique(); - func->configure(input, output, num_groups); - - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLChannelShuffleLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Num groups: " << num_groups - << std::endl); - - return std::move(func); -} - -/** Create a backend convolution layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend convolution layer function - */ -std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *weights = get_backing_tensor(node.input(1)); - ICLTensor *biases = get_backing_tensor(node.input(2)); - ICLTensor *output = get_backing_tensor(node.output(0)); - - if(is_data_type_quantized_asymmetric(input->info()->data_type())) - { - biases->info()->set_data_type(DataType::S32); - } - - const PadStrideInfo conv_info = node.convolution_info(); - const ConvolutionMethod conv_algorithm = node.convolution_method(); - const bool fast_math = node.fast_math_hint() == FastMathHint::ENABLED; - - // Create and configure function (we assume that functions have been validated before creation) - std::shared_ptr mm = get_memory_manager(ctx, Target::CL); - std::unique_ptr func; - std::string func_name; - - if(conv_algorithm == ConvolutionMethod::WINOGRAD) - { - std::tie(func, func_name) = create_named_memory_managed_function( - std::string("CLWinogradConvolutionLayer"), mm, input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math); - } - else if(conv_algorithm == ConvolutionMethod::DIRECT) - { - std::tie(func, func_name) = create_named_function( - std::string("CLDirectConvolutionLayer"), input, weights, biases, output, conv_info); - } - else if(conv_algorithm == ConvolutionMethod::GEMM) - { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("CLGEMMConvolutionLayer"), mm, - input, weights, biases, output, conv_info); - } - else - { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("CLConvolutionLayer"), mm, - input, weights, biases, output, conv_info, WeightsInfo(), Size2D(1U, 1U), ActivationLayerInfo(), fast_math); - } + using TensorType = arm_compute::ICLTensor; + static Target TargetType; +}; - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input->info()->data_type() - << " Input QuantInfo: " << input->info()->quantization_info() - << " Weights QuantInfo: " << weights->info()->quantization_info() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - return func; -} +Target CLTargetInfo::TargetType = Target::CL; -/** Create a backend deconvolution layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend deconvolution layer function - */ -std::unique_ptr create_deconvolution_layer(DeconvolutionLayerNode &node, GraphContext &ctx) +/** Collection of CL convolution functions */ +struct CLConvolutionLayerFunctions { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL DeconvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *weights = get_backing_tensor(node.input(1)); - ICLTensor *biases = get_backing_tensor(node.input(2)); - ICLTensor *output = get_backing_tensor(node.output(0)); + using GenericConvolutionLayer = CLConvolutionLayer; + using GEMMConvolutionLayer = CLGEMMConvolutionLayer; + using DirectConvolutionLayer = CLDirectConvolutionLayer; + using WinogradConvolutionLayer = CLWinogradConvolutionLayer; +}; - const PadStrideInfo deconv_info = node.deconvolution_info(); - const Size2D inner_border = node.inner_border(); - - // Create and configure function (we assume that functions have been validated before creation) - std::shared_ptr mm = get_memory_manager(ctx, Target::CL); - std::unique_ptr func; - std::string func_name; - - std::tie(func, func_name) = create_named_memory_managed_function(std::string("CLDeconvolutionLayer"), mm, - input, weights, biases, output, - deconv_info, inner_border.x(), inner_border.y()); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - return func; -} - -/** Create a backend layer depth concatenate function - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth concatenate layer function - */ -std::unique_ptr create_depth_concatenate_layer(DepthConcatenateLayerNode &node) +/** Collection of CL depthwise convolution functions */ +struct CLDepthwiseConvolutionLayerFunctions { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating CL DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Return nullptr if depth concatenate is switched off - if(!node.is_enabled()) - { - return nullptr; - } + using GenericDepthwiseConvolutionLayer = CLDepthwiseConvolutionLayer; + using DepthwiseConvolutionLayer3x3 = CLDepthwiseConvolutionLayer3x3; +}; - // Extract IO and info - std::vector inputs; - for(unsigned int i = 0; i < node.num_inputs(); ++i) - { - inputs.push_back(get_backing_tensor(node.input(i))); - } - ICLTensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(inputs, output); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLDepthConcatenateLayer" - << " Data Type: " << output->info()->data_type() - << " Shape: " << output->info()->tensor_shape() - << " Num Inputs: " << inputs.size() - << std::endl); - - return std::move(func); -} - -/** Create a backend layer depth-wise convolution function - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth-wise convolution layer function - */ -std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) +/** Collection of CL element-wise functions */ +struct CLEltwiseFunctions { - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *weights = get_backing_tensor(node.input(1)); - ICLTensor *biases = get_backing_tensor(node.input(2)); - ICLTensor *output = get_backing_tensor(node.output(0)); - - if(is_data_type_quantized_asymmetric(input->info()->data_type())) - { - biases->info()->set_data_type(DataType::S32); - } - - const PadStrideInfo conv_info = node.convolution_info(); - const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); - - // Create and configure function (we assume that functions have been validated before creation) - std::unique_ptr func; - std::string func_name; - if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) - { - std::tie(func, func_name) = create_named_function( - std::string("CLDepthwiseConvolutionLayer3x3"), input, weights, biases, output, conv_info); - } - else - { - std::tie(func, func_name) = create_named_function( - std::string("CLDepthwiseConvolutionLayer"), input, weights, biases, output, conv_info); - } - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input->info()->data_type() - << " Input QuantInfo: " << input->info()->quantization_info() - << " Weights QuantInfo: " << weights->info()->quantization_info() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - return func; -} - -/** Create a backend element-wise operation layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend element-wise operation layer function - */ -std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 2); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input1 = get_backing_tensor(node.input(0)); - ICLTensor *input2 = get_backing_tensor(node.input(1)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const EltwiseOperation eltwise_op = node.eltwise_operation(); - const ConvertPolicy convert_policy = node.convert_policy(); - ARM_COMPUTE_ERROR_ON(input1 == nullptr); - ARM_COMPUTE_ERROR_ON(input2 == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - std::unique_ptr func = nullptr; - std::string func_name; - if(eltwise_op == EltwiseOperation::ADD) - { - std::tie(func, func_name) = create_named_function(std::string("CLArithmeticAddition"), - input1, input2, output, - convert_policy); - } - else if(eltwise_op == EltwiseOperation::SUB) - { - std::tie(func, func_name) = create_named_function( - std::string("CLArithmeticSubtraction"), input1, input2, output, convert_policy); - } - else if(eltwise_op == EltwiseOperation::MUL) - { - std::tie(func, func_name) = create_named_function( - std::string("CLPixelWiseMultiplication"), input1, input2, output, 1.f, convert_policy, - node.rounding_policy()); - } - else - { - ARM_COMPUTE_ERROR("Unsupported element-wise operation!"); - } - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input1->info()->data_type() - << " Shape : " << input1->info()->tensor_shape() - << std::endl); - - return func; -} - -/** Create a backend flatten layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend flatten layer function - */ -std::unique_ptr create_flatten_layer(FlattenLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL FlattenLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLFlattenLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend fully connected layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend fully connected layer function - */ -std::unique_ptr create_fully_connected_layer(FullyConnectedLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL FullyConnectedLayer node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *weights = get_backing_tensor(node.input(1)); - ICLTensor *biases = get_backing_tensor(node.input(2)); - ICLTensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::CL)); - func->configure(input, weights, biases, output); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(weights == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLFullyConnectedLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Biases Shape: " << biases->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend normalization layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend normalization layer function - */ -std::unique_ptr create_normalization_layer(NormalizationLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL NormalizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const NormalizationLayerInfo norm_info = node.normalization_info(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, norm_info); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLNormalizationLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Normalization info: " << norm_info.type() - << std::endl); - - return std::move(func); -} - -/** Create a backend pooling layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend pooling layer function - */ -std::unique_ptr create_pooling_layer(PoolingLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL PoolingLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const PoolingLayerInfo pool_info = node.pooling_info(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, pool_info); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLPoolingLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Pooling info: " << pool_info.pool_type() - << std::endl); - - return std::move(func); -} - -/** Create a backend reshape layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend reshape layer function - */ -std::unique_ptr create_reshape_layer(ReshapeLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLReshapeLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend resize layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend resize layer function - */ -std::unique_ptr create_resize_layer(ResizeLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL Resize node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - const InterpolationPolicy policy = node.policy(); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, policy, BorderMode::CONSTANT); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLScale" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Interpolation: " << policy - << std::endl); - - return std::move(func); -} - -/** Create a backend softmax layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend softmax layer function - */ -std::unique_ptr create_softmax_layer(SoftmaxLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating CL SoftmaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ICLTensor *input = get_backing_tensor(node.input(0)); - ICLTensor *output = get_backing_tensor(node.output(0)); - const float beta = node.beta(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::CL)); - func->configure(input, output, beta); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated CLSoftmaxLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} -} // namespace + using Addition = CLArithmeticAddition; + using Subtraction = CLArithmeticSubtraction; + using Multiplication = CLPixelWiseMultiplication; +}; std::unique_ptr CLFunctionFactory::create(INode *node, GraphContext &ctx) { @@ -683,35 +80,35 @@ std::unique_ptr CLFunctionFactory::create(INode *node, GraphContext & switch(type) { case NodeType::ActivationLayer: - return create_activation_layer(*polymorphic_downcast(node)); + return detail::create_activation_layer(*polymorphic_downcast(node)); case NodeType::BatchNormalizationLayer: - return create_batch_normalization_layer(*polymorphic_downcast(node)); + return detail::create_batch_normalization_layer(*polymorphic_downcast(node)); case NodeType::ChannelShuffleLayer: - return create_channel_shuffle_layer(*polymorphic_downcast(node)); + return detail::create_channel_shuffle_layer(*polymorphic_downcast(node)); case NodeType::ConvolutionLayer: - return create_convolution_layer(*polymorphic_downcast(node), ctx); + return detail::create_convolution_layer(*polymorphic_downcast(node), ctx); case NodeType::DeconvolutionLayer: - return create_deconvolution_layer(*polymorphic_downcast(node), ctx); + return detail::create_deconvolution_layer(*polymorphic_downcast(node), ctx); case NodeType::DepthConcatenateLayer: - return create_depth_concatenate_layer(*polymorphic_downcast(node)); + return detail::create_depth_concatenate_layer(*polymorphic_downcast(node)); case NodeType::DepthwiseConvolutionLayer: - return create_depthwise_convolution_layer(*polymorphic_downcast(node)); + return detail::create_depthwise_convolution_layer(*polymorphic_downcast(node)); case NodeType::EltwiseLayer: - return create_eltwise_layer(*polymorphic_downcast(node)); + return detail::create_eltwise_layer(*polymorphic_downcast(node)); case NodeType::FlattenLayer: - return create_flatten_layer(*polymorphic_downcast(node)); + return detail::create_flatten_layer(*polymorphic_downcast(node)); case NodeType::FullyConnectedLayer: - return create_fully_connected_layer(*polymorphic_downcast(node), ctx); + return detail::create_fully_connected_layer(*polymorphic_downcast(node), ctx); case NodeType::NormalizationLayer: - return create_normalization_layer(*polymorphic_downcast(node)); + return detail::create_normalization_layer(*polymorphic_downcast(node), ctx); case NodeType::PoolingLayer: - return create_pooling_layer(*polymorphic_downcast(node)); + return detail::create_pooling_layer(*polymorphic_downcast(node)); case NodeType::ReshapeLayer: - return create_reshape_layer(*polymorphic_downcast(node)); + return detail::create_reshape_layer(*polymorphic_downcast(node)); case NodeType::ResizeLayer: - return create_resize_layer(*polymorphic_downcast(node)); + return detail::create_resize_layer(*polymorphic_downcast(node)); case NodeType::SoftmaxLayer: - return create_softmax_layer(*polymorphic_downcast(node), ctx); + return detail::create_softmax_layer(*polymorphic_downcast(node), ctx); default: return nullptr; } diff --git a/src/graph/backends/GLES/GCFunctionsFactory.cpp b/src/graph/backends/GLES/GCFunctionsFactory.cpp index d53daf1109..e6bd5a5f02 100644 --- a/src/graph/backends/GLES/GCFunctionsFactory.cpp +++ b/src/graph/backends/GLES/GCFunctionsFactory.cpp @@ -25,16 +25,9 @@ #include "arm_compute/core/utils/misc/Cast.h" #include "arm_compute/graph/Graph.h" -#include "arm_compute/graph/GraphContext.h" -#include "arm_compute/graph/Logger.h" -#include "arm_compute/graph/TypePrinter.h" -#include "arm_compute/graph/Types.h" -#include "arm_compute/graph/backends/Utils.h" -#include "arm_compute/graph/nodes/Nodes.h" +#include "arm_compute/graph/backends/FunctionHelpers.h" #include "arm_compute/runtime/GLES_COMPUTE/GCFunctions.h" -#include "support/ToolchainSupport.h" - using namespace arm_compute::utils::cast; namespace arm_compute @@ -43,121 +36,48 @@ namespace graph { namespace backends { -namespace -{ -/** Returns backing tensor of a given tensor - * - * @param[in] tensor Tensor to extract the backing tensor from - * - * @return Backing tensor if present else nullptr - */ -arm_compute::IGCTensor *get_backing_tensor(arm_compute::graph::Tensor *tensor) +/** Target specific information structure used to pass information to the layer templates */ +struct GCTargetInfo { - arm_compute::IGCTensor *backing_tensor = nullptr; - if(tensor != nullptr) - { - ARM_COMPUTE_ERROR_ON(tensor->desc().target != arm_compute::graph::Target::GC); - // Get backing tensor handle - ITensorHandle *tensor_handle = tensor->handle(); - // Get backing tensor - backing_tensor = (tensor_handle != nullptr) ? polymorphic_cast(&tensor_handle->tensor()) : nullptr; - } + using TensorType = arm_compute::IGCTensor; + static Target TargetType; +}; - return backing_tensor; -} +Target GCTargetInfo::TargetType = Target::GC; -/** Create a backend activation layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend activation layer function - */ -std::unique_ptr create_activation_layer(ActivationLayerNode &node) +/** Collection of GC convolution functions */ +struct GCConvolutionLayerFunctions { - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating GC ActivationLayerNode node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *output = get_backing_tensor(node.output(0)); - const ActivationLayerInfo act_info = node.activation_info(); - - // Create function - auto func = support::cpp14::make_unique(); - func->configure(input, output, act_info); - - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCActivationLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Activation function: " << act_info.activation() - << " a: " << act_info.a() - << " b: " << act_info.b() - << " InPlace : " << is_in_place_operation(input, output) - << std::endl); + using GenericConvolutionLayer = GCConvolutionLayer; + using GEMMConvolutionLayer = GCConvolutionLayer; + using DirectConvolutionLayer = GCDirectConvolutionLayer; +}; - return std::move(func); -} - -/** Create a backend batch normalization layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend batch normalization layer function - */ -std::unique_ptr create_batch_normalization_layer(BatchNormalizationLayerNode &node) +/** Collection of GC depthwise convolution functions */ +struct GCDepthwiseConvolutionLayerFunctions { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating GC BatchNormalization node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - - // TODO (geopin01) : Var and mean are compulsory, switch function to accept nullptr as beta and/or gamma - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 5); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *mean = get_backing_tensor(node.input(1)); - IGCTensor *var = get_backing_tensor(node.input(2)); - IGCTensor *beta = get_backing_tensor(node.input(3)); - IGCTensor *gamma = get_backing_tensor(node.input(4)); - IGCTensor *output = get_backing_tensor(node.output(0)); - const float epsilon = node.epsilon(); - const ActivationLayerInfo fused_act = node.fused_activation(); + using DepthwiseConvolutionLayer3x3 = GCDepthwiseConvolutionLayer3x3; +}; - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, mean, var, beta, gamma, epsilon, fused_act); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCBatchNormalizationLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Epsilon: " << epsilon << " " - << (fused_act.enabled() ? to_string(fused_act.activation()) : "") - << " InPlace : " << is_in_place_operation(input, output) - << std::endl); - - return std::move(func); -} +/** Collection of GC element-wise functions */ +struct GCEltwiseFunctions +{ + using Addition = GCArithmeticAddition; + using Multiplication = GCPixelWiseMultiplication; +}; -/** Create a backend convolution layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend convolution layer function - */ -std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx) +namespace detail { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating GC ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); +template <> +std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx) +{ + validate_node(node, 3 /* expected inputs */, 1 /* expected outputs */); // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *weights = get_backing_tensor(node.input(1)); - IGCTensor *biases = get_backing_tensor(node.input(2)); - IGCTensor *output = get_backing_tensor(node.output(0)); + GCTargetInfo::TensorType *input = get_backing_tensor(node.input(0)); + GCTargetInfo::TensorType *weights = get_backing_tensor(node.input(1)); + GCTargetInfo::TensorType *biases = get_backing_tensor(node.input(2)); + GCTargetInfo::TensorType *output = get_backing_tensor(node.output(0)); if(is_data_type_quantized_asymmetric(input->info()->data_type())) { @@ -168,19 +88,21 @@ std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, const ConvolutionMethod conv_algorithm = node.convolution_method(); // Create and configure function (we assume that functions have been validated before creation) - std::shared_ptr mm = get_memory_manager(ctx, Target::GC); + std::shared_ptr mm = get_memory_manager(ctx, GCTargetInfo::TargetType); std::unique_ptr func; std::string func_name; if(conv_algorithm == ConvolutionMethod::DIRECT) { - std::tie(func, func_name) = create_named_function( - std::string("GCDirectConvolutionLayer"), input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_function( + std::string("DirectConvolutionLayer"), + input, weights, biases, output, conv_info); } else { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("GCConvolutionLayer"), mm, - input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_memory_managed_function( + std::string("ConvolutionLayer"), mm, + input, weights, biases, output, conv_info); } // Log info @@ -195,64 +117,16 @@ std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, return func; } -/** Create a backend layer depth concatenate function - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth concatenate layer function - */ -std::unique_ptr create_depth_concatenate_layer(DepthConcatenateLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating GC DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Return nullptr if depth concatenate is switched off - if(!node.is_enabled()) - { - return nullptr; - } - - // Extract IO and info - std::vector inputs; - for(unsigned int i = 0; i < node.num_inputs(); ++i) - { - inputs.push_back(get_backing_tensor(node.input(i))); - } - IGCTensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(inputs, output); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCDepthConcatenateLayer" - << " Data Type: " << output->info()->data_type() - << " Shape: " << output->info()->tensor_shape() - << " Num Inputs: " << inputs.size() - << std::endl); - - return std::move(func); -} - -/** Create a backend layer depth-wise convolution function - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth-wise convolution layer function - */ -std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) +template <> +std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating GC DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); + validate_node(node, 3 /* expected inputs */, 1 /* expected outputs */); // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *weights = get_backing_tensor(node.input(1)); - IGCTensor *biases = get_backing_tensor(node.input(2)); - IGCTensor *output = get_backing_tensor(node.output(0)); + GCTargetInfo::TensorType *input = get_backing_tensor(node.input(0)); + GCTargetInfo::TensorType *weights = get_backing_tensor(node.input(1)); + GCTargetInfo::TensorType *biases = get_backing_tensor(node.input(2)); + GCTargetInfo::TensorType *output = get_backing_tensor(node.output(0)); if(is_data_type_quantized_asymmetric(input->info()->data_type())) { @@ -267,8 +141,9 @@ std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvoluti std::string func_name; if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) { - std::tie(func, func_name) = create_named_function( - std::string("GCDepthwiseConvolutionLayer3x3"), input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_function( + std::string("DepthwiseConvolutionLayer3x3"), + input, weights, biases, output, conv_info); } else { @@ -277,6 +152,7 @@ std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvoluti // Log info ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name + << " Target " << GCTargetInfo::TargetType << " Data Type: " << input->info()->data_type() << " Input QuantInfo: " << input->info()->quantization_info() << " Weights QuantInfo: " << weights->info()->quantization_info() @@ -287,13 +163,8 @@ std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvoluti return func; } -/** Create a backend element-wise operation layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend element-wise operation layer function - */ -std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) +template <> +std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) { ARM_COMPUTE_LOG_GRAPH_VERBOSE( "Creating GC EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); @@ -301,11 +172,11 @@ std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); // Extract IO and info - IGCTensor *input1 = get_backing_tensor(node.input(0)); - IGCTensor *input2 = get_backing_tensor(node.input(1)); - IGCTensor *output = get_backing_tensor(node.output(0)); - const EltwiseOperation eltwise_op = node.eltwise_operation(); - const ConvertPolicy convert_policy = node.convert_policy(); + GCTargetInfo::TensorType *input1 = get_backing_tensor(node.input(0)); + GCTargetInfo::TensorType *input2 = get_backing_tensor(node.input(1)); + GCTargetInfo::TensorType *output = get_backing_tensor(node.output(0)); + const EltwiseOperation eltwise_op = node.eltwise_operation(); + const ConvertPolicy convert_policy = node.convert_policy(); ARM_COMPUTE_ERROR_ON(input1 == nullptr); ARM_COMPUTE_ERROR_ON(input2 == nullptr); ARM_COMPUTE_ERROR_ON(output == nullptr); @@ -314,9 +185,9 @@ std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) std::string func_name; if(eltwise_op == EltwiseOperation::ADD) { - std::tie(func, func_name) = create_named_function(std::string("GCArithmeticAddition"), - input1, input2, output, - convert_policy); + std::tie(func, func_name) = create_named_function( + std::string("GCArithmeticAddition"), + input1, input2, output, convert_policy); } else if(eltwise_op == EltwiseOperation::SUB) { @@ -324,8 +195,9 @@ std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) } else if(eltwise_op == EltwiseOperation::MUL) { - std::tie(func, func_name) = create_named_function( - std::string("GCPixelWiseMultiplication"), input1, input2, output, 1.f); + std::tie(func, func_name) = create_named_function( + std::string("PixelWiseMultiplication"), + input1, input2, output, 1.f); } else { @@ -333,157 +205,16 @@ std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) } // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() + << " Target " << GCTargetInfo::TargetType + << " Operation " << func_name << " Data Type: " << input1->info()->data_type() << " Shape : " << input1->info()->tensor_shape() << std::endl); return func; } - -/** Create a backend fully connected layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend fully connected layer function - */ -std::unique_ptr create_fully_connected_layer(FullyConnectedLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating GC FullyConnectedLayer node with ID : " << node.id() << " and Name: " << node.name() - << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *weights = get_backing_tensor(node.input(1)); - IGCTensor *biases = get_backing_tensor(node.input(2)); - IGCTensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::GC)); - func->configure(input, weights, biases, output); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(weights == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCFullyConnectedLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Biases Shape: " << biases->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend normalization layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend normalization layer function - */ -std::unique_ptr create_normalization_layer(NormalizationLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating GC NormalizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *output = get_backing_tensor(node.output(0)); - const NormalizationLayerInfo norm_info = node.normalization_info(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, norm_info); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCNormalizationLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Normalization info: " << norm_info.type() - << std::endl); - - return std::move(func); -} - -/** Create a backend pooling layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend pooling layer function - */ -std::unique_ptr create_pooling_layer(PoolingLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating GC PoolingLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *output = get_backing_tensor(node.output(0)); - const PoolingLayerInfo pool_info = node.pooling_info(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, pool_info); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCPoolingLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Pooling info: " << pool_info.pool_type() - << std::endl); - - return std::move(func); -} - -/** Create a backend softmax layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend softmax layer function - */ -std::unique_ptr create_softmax_layer(SoftmaxLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating GC SoftmaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - IGCTensor *input = get_backing_tensor(node.input(0)); - IGCTensor *output = get_backing_tensor(node.output(0)); - const float beta = node.beta(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::CL)); - func->configure(input, output, beta); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated GCSoftmaxLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} -} // namespace +} //namespace detail std::unique_ptr GCFunctionFactory::create(INode *node, GraphContext &ctx) { @@ -496,25 +227,27 @@ std::unique_ptr GCFunctionFactory::create(INode *node, GraphContext & switch(type) { case NodeType::ActivationLayer: - return create_activation_layer(*polymorphic_downcast(node)); + return detail::create_activation_layer(*polymorphic_downcast(node)); case NodeType::BatchNormalizationLayer: - return create_batch_normalization_layer(*polymorphic_downcast(node)); + return detail::create_batch_normalization_layer(*polymorphic_downcast(node)); case NodeType::ConvolutionLayer: - return create_convolution_layer(*polymorphic_downcast(node), ctx); + return detail::create_convolution_layer(*polymorphic_downcast(node), ctx); case NodeType::DepthConcatenateLayer: - return create_depth_concatenate_layer(*polymorphic_downcast(node)); + return detail::create_depth_concatenate_layer(*polymorphic_downcast(node)); case NodeType::DepthwiseConvolutionLayer: - return create_depthwise_convolution_layer(*polymorphic_downcast(node)); + return detail::create_depthwise_convolution_layer(*polymorphic_downcast(node)); case NodeType::EltwiseLayer: - return create_eltwise_layer(*polymorphic_downcast(node)); + return detail::create_eltwise_layer(*polymorphic_downcast(node)); case NodeType::FullyConnectedLayer: - return create_fully_connected_layer(*polymorphic_downcast(node), ctx); + return detail::create_fully_connected_layer(*polymorphic_downcast(node), ctx); case NodeType::NormalizationLayer: - return create_normalization_layer(*polymorphic_downcast(node)); + return detail::create_normalization_layer(*polymorphic_downcast(node), ctx); case NodeType::PoolingLayer: - return create_pooling_layer(*polymorphic_downcast(node)); + return detail::create_pooling_layer(*polymorphic_downcast(node)); + case NodeType::ResizeLayer: + return detail::create_resize_layer(*polymorphic_downcast(node)); case NodeType::SoftmaxLayer: - return create_softmax_layer(*polymorphic_downcast(node), ctx); + return detail::create_softmax_layer(*polymorphic_downcast(node), ctx); default: return nullptr; } diff --git a/src/graph/backends/NEON/NEFunctionFactory.cpp b/src/graph/backends/NEON/NEFunctionFactory.cpp index 8376feb265..3b7417da3f 100644 --- a/src/graph/backends/NEON/NEFunctionFactory.cpp +++ b/src/graph/backends/NEON/NEFunctionFactory.cpp @@ -28,6 +28,7 @@ #include "arm_compute/graph/GraphContext.h" #include "arm_compute/graph/Logger.h" #include "arm_compute/graph/TypePrinter.h" +#include "arm_compute/graph/backends/FunctionHelpers.h" #include "arm_compute/graph/backends/Utils.h" #include "arm_compute/graph/nodes/Nodes.h" #include "arm_compute/runtime/NEON/NEFunctions.h" @@ -41,109 +42,53 @@ namespace graph { namespace backends { -namespace +/** Target specific information structure used to pass information to the layer templates */ +struct NETargetInfo { -/** Returns backing tensor of a given tensor - * - * @param[in] tensor Tensor to extract the backing tensor from - * - * @return Backing tensor if present else nullptr - */ -arm_compute::ITensor *get_backing_tensor(arm_compute::graph::Tensor *tensor) -{ - return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : &tensor->handle()->tensor(); -} + using TensorType = arm_compute::ITensor; + static Target TargetType; +}; -/** Create a backend activation layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend activation layer function - */ -std::unique_ptr create_activation_layer(ActivationLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON ActivationLayerNode node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); +Target NETargetInfo::TargetType = Target::NEON; - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); - const ActivationLayerInfo act_info = node.activation_info(); - - // Create function - auto func = support::cpp14::make_unique(); - func->configure(input, output, act_info); - - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEActivationLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Activation function: " << act_info.activation() - << " a: " << act_info.a() - << " b: " << act_info.b() - << " InPlace : " << is_in_place_operation(input, output) - << std::endl); - - return std::move(func); -} - -/** Create a backend batch normalization layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend batch normalization layer function - */ -std::unique_ptr create_batch_normalization_layer(BatchNormalizationLayerNode &node) +/** Collection of CL convolution functions */ +struct NEConvolutionLayerFunctions { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON BatchNormalization node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - - // TODO (geopin01) : Var and mean are compulsory, switch function to accept nullptr as beta and/or gamma - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 5); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *mean = get_backing_tensor(node.input(1)); - ITensor *var = get_backing_tensor(node.input(2)); - ITensor *beta = get_backing_tensor(node.input(3)); - ITensor *gamma = get_backing_tensor(node.input(4)); - ITensor *output = get_backing_tensor(node.output(0)); - const float epsilon = node.epsilon(); - const ActivationLayerInfo fused_act = node.fused_activation(); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, mean, var, beta, gamma, epsilon, fused_act); + using GenericConvolutionLayer = NEConvolutionLayer; + using GEMMConvolutionLayer = NEGEMMConvolutionLayer; + using DirectConvolutionLayer = NEDirectConvolutionLayer; + using WinogradConvolutionLayer = NEWinogradConvolutionLayer; +}; - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEBatchNormalizationLayer" - << " Data Type: " << input->info()->data_type() - << " Shape: " << input->info()->tensor_shape() - << " Epsilon: " << epsilon << " " - << (fused_act.enabled() ? to_string(fused_act.activation()) : "") - << " InPlace : " << is_in_place_operation(input, output) - << std::endl); +/** Collection of CL depthwise convolution functions */ +struct NEDepthwiseConvolutionLayerFunctions +{ + using GenericDepthwiseConvolutionLayer = NEDepthwiseConvolutionLayer; + using DepthwiseConvolutionLayer3x3 = NEDepthwiseConvolutionLayer3x3; +}; - return std::move(func); -} +/** Collection of CL element-wise functions */ +struct NEEltwiseFunctions +{ + using Addition = NEArithmeticAddition; + using Subtraction = NEArithmeticSubtraction; + using Multiplication = NEPixelWiseMultiplication; +}; -/** Create a backend convolution layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend convolution layer function - */ -std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, GraphContext &ctx) +namespace detail { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); +// Specialize functions +template <> +std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, + GraphContext &ctx) +{ + validate_node(node, 3 /* expected inputs */, 1 /* expected outputs */); // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *weights = get_backing_tensor(node.input(1)); - ITensor *biases = get_backing_tensor(node.input(2)); - ITensor *output = get_backing_tensor(node.output(0)); + NETargetInfo::TensorType *input = get_backing_tensor(node.input(0)); + NETargetInfo::TensorType *weights = get_backing_tensor(node.input(1)); + NETargetInfo::TensorType *biases = get_backing_tensor(node.input(2)); + NETargetInfo::TensorType *output = get_backing_tensor(node.output(0)); if(is_data_type_quantized_asymmetric(input->info()->data_type())) { @@ -159,27 +104,28 @@ std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, std::string func_name; if(conv_algorithm == ConvolutionMethod::DIRECT) { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("NEDirectConvolutionLayer"), mm, - input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_memory_managed_function( + std::string("DirectConvolutionLayer"), mm, input, weights, biases, output, conv_info); } else if(conv_algorithm == ConvolutionMethod::GEMM) { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("NEGEMMConvolutionLayer"), mm, - input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_memory_managed_function( + std::string("GEMMConvolutionLayer"), mm, input, weights, biases, output, conv_info); } else if(conv_algorithm == ConvolutionMethod::WINOGRAD) { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("NEWinogradConvolutionLayer"), mm, - input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_memory_managed_function( + std::string("WinogradConvolutionLayer"), mm, input, weights, biases, output, conv_info); } else { - std::tie(func, func_name) = create_named_memory_managed_function(std::string("NEConvolutionLayer"), mm, - input, weights, biases, output, conv_info); + std::tie(func, func_name) = create_named_memory_managed_function( + std::string("ConvolutionLayer"), mm, input, weights, biases, output, conv_info); } // Log info ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name + << " Target " << NETargetInfo::TargetType << " Data Type: " << input->info()->data_type() << " Input QuantInfo: " << input->info()->quantization_info() << " Weights QuantInfo: " << weights->info()->quantization_info() @@ -190,284 +136,25 @@ std::unique_ptr create_convolution_layer(ConvolutionLayerNode &node, return func; } -/** Create a backend deconvolution layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend deconvolution layer function - */ -std::unique_ptr create_deconvolution_layer(DeconvolutionLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON DeconvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *weights = get_backing_tensor(node.input(1)); - ITensor *biases = get_backing_tensor(node.input(2)); - ITensor *output = get_backing_tensor(node.output(0)); - - const PadStrideInfo deconv_info = node.deconvolution_info(); - const Size2D inner_border = node.inner_border(); - - // Create and configure function (we assume that functions have been validated before creation) - std::shared_ptr mm = get_memory_manager(ctx, Target::CL); - std::unique_ptr func; - std::string func_name; - - std::tie(func, func_name) = create_named_memory_managed_function(std::string("NEDeconvolutionLayer"), mm, - input, weights, biases, output, - deconv_info, inner_border.x(), inner_border.y()); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - return func; -} - -/** Create a backend layer depth concatenate function - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth concatenate layer function - */ -std::unique_ptr create_depth_concatenate_layer(DepthConcatenateLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Return nullptr if depth concatenate is switched off - if(!node.is_enabled()) - { - return nullptr; - } - - // Extract IO and info - std::vector inputs; - for(unsigned int i = 0; i < node.num_inputs(); ++i) - { - inputs.push_back(get_backing_tensor(node.input(i))); - } - ITensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(inputs, output); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEDepthConcatenateLayer" - << " Data Type: " << output->info()->data_type() - << " Shape: " << output->info()->tensor_shape() - << " Num Inputs: " << inputs.size() - << std::endl); - - return std::move(func); -} - -/** Create a backend layer depth-wise convolution function - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth-wise convolution layer function - */ -std::unique_ptr create_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *weights = get_backing_tensor(node.input(1)); - ITensor *biases = get_backing_tensor(node.input(2)); - ITensor *output = get_backing_tensor(node.output(0)); - - if(is_data_type_quantized_asymmetric(input->info()->data_type())) - { - biases->info()->set_data_type(DataType::S32); - } - - const PadStrideInfo conv_info = node.convolution_info(); - const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method(); - - // Create and configure function (we assume that functions have been validated before creation) - std::unique_ptr func; - std::string func_name; - if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) - { - std::tie(func, func_name) = create_named_function(std::string("NEDepthwiseConvolutionLayer3x3"), - input, weights, biases, output, conv_info); - } - else - { - std::tie(func, func_name) = create_named_function(std::string("NEDepthwiseConvolutionLayer"), - input, weights, biases, output, conv_info); - } - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input->info()->data_type() - << " Input QuantInfo: " << input->info()->quantization_info() - << " Weights QuantInfo: " << weights->info()->quantization_info() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - return func; -} - -/** Create a backend element-wise operation layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend element-wise operation layer function - */ -std::unique_ptr create_eltwise_layer(EltwiseLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 2); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input1 = get_backing_tensor(node.input(0)); - ITensor *input2 = get_backing_tensor(node.input(1)); - ITensor *output = get_backing_tensor(node.output(0)); - const EltwiseOperation eltwise_op = node.eltwise_operation(); - const ConvertPolicy convert_policy = node.convert_policy(); - ARM_COMPUTE_ERROR_ON(input1 == nullptr); - ARM_COMPUTE_ERROR_ON(input2 == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - std::unique_ptr func = nullptr; - std::string func_name; - if(eltwise_op == EltwiseOperation::ADD) - { - std::tie(func, func_name) = create_named_function(std::string("NEArithmeticAddition"), - input1, input2, output, convert_policy); - } - else if(eltwise_op == EltwiseOperation::SUB) - { - std::tie(func, func_name) = create_named_function(std::string("NEArithmeticSubtraction"), - input1, input2, output, convert_policy); - } - else if(eltwise_op == EltwiseOperation::MUL) - { - std::tie(func, func_name) = create_named_function(std::string("NEPixelWiseMultiplication"), - input1, input2, output, 1.f, - convert_policy, node.rounding_policy()); - } - else - { - ARM_COMPUTE_ERROR("Unsupported element-wise operation!"); - } - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << func_name - << " Data Type: " << input1->info()->data_type() - << " Shape : " << input1->info()->tensor_shape() - << std::endl); - - return func; -} - -/** Create a backend flatten layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend flatten layer function - */ -std::unique_ptr create_flatten_layer(FlattenLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON FlattenLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEFlattenLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend fully connected layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend fully connected layer function - */ -std::unique_ptr create_fully_connected_layer(FullyConnectedLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON FullyConnectedLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 3); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *weights = get_backing_tensor(node.input(1)); - ITensor *biases = get_backing_tensor(node.input(2)); - ITensor *output = get_backing_tensor(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::NEON)); - func->configure(input, weights, biases, output); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(weights == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEFullyConnectedLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Weights shape: " << weights->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend normalization layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend normalization layer function - */ -std::unique_ptr create_normalization_layer(NormalizationLayerNode &node, GraphContext &ctx) +template <> +std::unique_ptr create_normalization_layer(NormalizationLayerNode &node, GraphContext &ctx) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON NormalizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); + validate_node(node, 1 /* expected inputs */, 1 /* expected outputs */); // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); + NETargetInfo::TensorType *input = get_backing_tensor(node.input(0)); + NETargetInfo::TensorType *output = get_backing_tensor(node.output(0)); const NormalizationLayerInfo norm_info = node.normalization_info(); ARM_COMPUTE_ERROR_ON(input == nullptr); ARM_COMPUTE_ERROR_ON(output == nullptr); // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::NEON)); + auto func = support::cpp14::make_unique(get_memory_manager(ctx, NETargetInfo::TargetType)); func->configure(input, output, norm_info); // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NENormalizationLayer" + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() + << " Target " << NETargetInfo::TargetType << " Data Type: " << input->info()->data_type() << " Input shape: " << input->info()->tensor_shape() << " Output shape: " << output->info()->tensor_shape() @@ -476,141 +163,7 @@ std::unique_ptr create_normalization_layer(NormalizationLayerNode &no return std::move(func); } - -/** Create a backend pooling layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend pooling layer function - */ -std::unique_ptr create_pooling_layer(PoolingLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON PoolingLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); - const PoolingLayerInfo pool_info = node.pooling_info(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, pool_info); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEPoolingLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Pooling info: " << pool_info.pool_type() - << std::endl); - - return std::move(func); -} - -/** Create a backend reshape layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend reshape layer function - */ -std::unique_ptr create_reshape_layer(ReshapeLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEReshapeLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} - -/** Create a backend resize layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend resize layer function - */ -std::unique_ptr create_resize_layer(ResizeLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE( - "Creating NEON Resize node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - const InterpolationPolicy policy = node.policy(); - - // Create and configure function - auto func = support::cpp14::make_unique(); - func->configure(input, output, policy, BorderMode::CONSTANT); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NEScale" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << " Interpolation: " << policy - << std::endl); - - return std::move(func); -} - -/** Create a backend softmax layer function - * - * @param[in] node Node to create the backend function for - * - * @return Backend softmax layer function - */ -std::unique_ptr create_softmax_layer(SoftmaxLayerNode &node, GraphContext &ctx) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating NEON SoftmaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_inputs() != 1); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Extract IO and info - ITensor *input = get_backing_tensor(node.input(0)); - ITensor *output = get_backing_tensor(node.output(0)); - const float beta = node.beta(); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); - - // Create and configure function - auto func = support::cpp14::make_unique(get_memory_manager(ctx, Target::NEON)); - func->configure(input, output, beta); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated NESoftmaxLayer" - << " Data Type: " << input->info()->data_type() - << " Input shape: " << input->info()->tensor_shape() - << " Output shape: " << output->info()->tensor_shape() - << std::endl); - - return std::move(func); -} -} // namespace +} // namespace detail std::unique_ptr NEFunctionFactory::create(INode *node, GraphContext &ctx) { @@ -623,33 +176,33 @@ std::unique_ptr NEFunctionFactory::create(INode *node, GraphContext & switch(type) { case NodeType::ActivationLayer: - return create_activation_layer(*polymorphic_downcast(node)); + return detail::create_activation_layer(*polymorphic_downcast(node)); case NodeType::BatchNormalizationLayer: - return create_batch_normalization_layer(*polymorphic_downcast(node)); + return detail::create_batch_normalization_layer(*polymorphic_downcast(node)); case NodeType::ConvolutionLayer: - return create_convolution_layer(*polymorphic_downcast(node), ctx); + return detail::create_convolution_layer(*polymorphic_downcast(node), ctx); case NodeType::DeconvolutionLayer: - return create_deconvolution_layer(*polymorphic_downcast(node), ctx); + return detail::create_deconvolution_layer(*polymorphic_downcast(node), ctx); case NodeType::DepthConcatenateLayer: - return create_depth_concatenate_layer(*polymorphic_downcast(node)); + return detail::create_depth_concatenate_layer(*polymorphic_downcast(node)); case NodeType::DepthwiseConvolutionLayer: - return create_depthwise_convolution_layer(*polymorphic_downcast(node)); + return detail::create_depthwise_convolution_layer(*polymorphic_downcast(node)); case NodeType::EltwiseLayer: - return create_eltwise_layer(*polymorphic_downcast(node)); + return detail::create_eltwise_layer(*polymorphic_downcast(node)); case NodeType::FlattenLayer: - return create_flatten_layer(*polymorphic_downcast(node)); + return detail::create_flatten_layer(*polymorphic_downcast(node)); case NodeType::FullyConnectedLayer: - return create_fully_connected_layer(*polymorphic_downcast(node), ctx); + return detail::create_fully_connected_layer(*polymorphic_downcast(node), ctx); case NodeType::NormalizationLayer: - return create_normalization_layer(*polymorphic_downcast(node), ctx); + return detail::create_normalization_layer(*polymorphic_downcast(node), ctx); case NodeType::PoolingLayer: - return create_pooling_layer(*polymorphic_downcast(node)); + return detail::create_pooling_layer(*polymorphic_downcast(node)); case NodeType::ReshapeLayer: - return create_reshape_layer(*polymorphic_downcast(node)); + return detail::create_reshape_layer(*polymorphic_downcast(node)); case NodeType::ResizeLayer: - return create_resize_layer(*polymorphic_downcast(node)); + return detail::create_resize_layer(*polymorphic_downcast(node)); case NodeType::SoftmaxLayer: - return create_softmax_layer(*polymorphic_downcast(node), ctx); + return detail::create_softmax_layer(*polymorphic_downcast(node), ctx); default: return nullptr; } -- cgit v1.2.1