diff options
Diffstat (limited to 'arm_compute/graph/backends/FunctionHelpers.h')
-rw-r--r-- | arm_compute/graph/backends/FunctionHelpers.h | 116 |
1 files changed, 60 insertions, 56 deletions
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h index 978d3bc1a8..172f00277e 100644 --- a/arm_compute/graph/backends/FunctionHelpers.h +++ b/arm_compute/graph/backends/FunctionHelpers.h @@ -192,6 +192,52 @@ std::unique_ptr<IFunction> create_channel_shuffle_layer(ChannelShuffleLayerNode return std::move(func); } +/** Create a backend layer concatenate function + * + * @tparam ConcatenateLayerFunction Backend concatenate function + * @tparam TargetInfo Target-specific information + * + * @param[in] node Node to create the backend function for + * + * @return Backend concatenate layer function + */ +template <typename ConcatenateLayerFunction, typename TargetInfo> +std::unique_ptr<arm_compute::IFunction> create_concatenate_layer(ConcatenateLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating Concatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); + + // Return nullptr if depth concatenate is switched off + if(!node.is_enabled()) + { + return nullptr; + } + + // Extract IO and info + std::vector<typename TargetInfo::TensorType *> inputs; + for(unsigned int i = 0; i < node.num_inputs(); ++i) + { + inputs.push_back(get_backing_tensor<TargetInfo>(node.input(i))); + } + typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0)); + const DataLayoutDimension concat_axis = node.concatenation_axis(); + + // Create and configure function + auto func = support::cpp14::make_unique<ConcatenateLayerFunction>(); + func->configure(inputs, output, concat_axis); + + // Log info + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() + << " Target " << TargetInfo::TargetType + << " Data Type: " << output->info()->data_type() + << " Shape: " << output->info()->tensor_shape() + << " Num Inputs: " << inputs.size() + << " Axis: " << concat_axis + << std::endl); + + return std::move(func); +} + /** Create a backend convolution layer function * * @tparam ConvolutionLayerFunctions Backend convolution functions @@ -220,20 +266,20 @@ std::unique_ptr<IFunction> create_convolution_layer(ConvolutionLayerNode &node, const PadStrideInfo conv_info = node.convolution_info(); const ConvolutionMethod conv_algorithm = node.convolution_method(); - const bool fast_math = node.fast_math_hint() == FastMathHint::ENABLED; + const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled; // Create and configure function (we assume that functions have been validated before creation) std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, TargetInfo::TargetType); std::unique_ptr<IFunction> func; std::string func_name; - if(conv_algorithm == ConvolutionMethod::WINOGRAD) + if(conv_algorithm == ConvolutionMethod::Winograd) { std::tie(func, func_name) = create_named_memory_managed_function<typename ConvolutionLayerFunctions::WinogradConvolutionLayer>( std::string("WinogradConvolutionLayer"), mm, input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math); } - else if(conv_algorithm == ConvolutionMethod::DIRECT) + else if(conv_algorithm == ConvolutionMethod::Direct) { std::tie(func, func_name) = create_named_function<typename ConvolutionLayerFunctions::DirectConvolutionLayer>( std::string("DirectConvolutionLayer"), @@ -308,50 +354,6 @@ std::unique_ptr<IFunction> create_deconvolution_layer(DeconvolutionLayerNode &no return func; } -/** Create a backend layer depth concatenate function - * - * @tparam DepthConcatenateLayerFunction Backend depth concatenate function - * @tparam TargetInfo Target-specific information - * - * @param[in] node Node to create the backend function for - * - * @return Backend depth concatenate layer function - */ -template <typename DepthConcatenateLayerFunction, typename TargetInfo> -std::unique_ptr<arm_compute::IFunction> create_depth_concatenate_layer(DepthConcatenateLayerNode &node) -{ - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Creating DepthConcatenate node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - ARM_COMPUTE_ERROR_ON(node.num_outputs() != 1); - - // Return nullptr if depth concatenate is switched off - if(!node.is_enabled()) - { - return nullptr; - } - - // Extract IO and info - std::vector<typename TargetInfo::TensorType *> inputs; - for(unsigned int i = 0; i < node.num_inputs(); ++i) - { - inputs.push_back(get_backing_tensor<TargetInfo>(node.input(i))); - } - typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0)); - - // Create and configure function - auto func = support::cpp14::make_unique<DepthConcatenateLayerFunction>(); - func->configure(inputs, output); - - // Log info - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() - << " Target " << TargetInfo::TargetType - << " Data Type: " << output->info()->data_type() - << " Shape: " << output->info()->tensor_shape() - << " Num Inputs: " << inputs.size() - << std::endl); - - return std::move(func); -} - /** Create a backend layer depth-wise convolution function * * @tparam DepthwiseConvolutionLayerFunctions Backend depthwise convolution function @@ -383,7 +385,7 @@ std::unique_ptr<IFunction> create_depthwise_convolution_layer(DepthwiseConvoluti // Create and configure function (we assume that functions have been validated before creation) std::unique_ptr<IFunction> func; std::string func_name; - if(dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) + if(dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3) { std::tie(func, func_name) = create_named_function<typename DepthwiseConvolutionLayerFunctions::DepthwiseConvolutionLayer3x3>( std::string("DepthwiseConvolutionLayer3x3"), @@ -435,19 +437,19 @@ std::unique_ptr<IFunction> create_eltwise_layer(EltwiseLayerNode &node) std::unique_ptr<IFunction> func = nullptr; std::string func_name; - if(eltwise_op == EltwiseOperation::ADD) + if(eltwise_op == EltwiseOperation::Add) { std::tie(func, func_name) = create_named_function<typename EltwiseFunctions::Addition>( std::string("ArithmeticAddition"), input1, input2, output, convert_policy); } - else if(eltwise_op == EltwiseOperation::SUB) + else if(eltwise_op == EltwiseOperation::Sub) { std::tie(func, func_name) = create_named_function<typename EltwiseFunctions::Subtraction>( std::string("ArithmeticSubtraction"), input1, input2, output, convert_policy); } - else if(eltwise_op == EltwiseOperation::MUL) + else if(eltwise_op == EltwiseOperation::Mul) { std::tie(func, func_name) = create_named_function<typename EltwiseFunctions::Multiplication>( std::string("PixelWiseMultiplication"), @@ -487,11 +489,12 @@ std::unique_ptr<IFunction> create_flatten_layer(FlattenLayerNode &node) typename TargetInfo::TensorType *input = get_backing_tensor<TargetInfo>(node.input(0)); typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0)); + ARM_COMPUTE_ERROR_ON(input == nullptr); + ARM_COMPUTE_ERROR_ON(output == nullptr); + // Create and configure function auto func = support::cpp14::make_unique<FlattenLayerFunction>(); func->configure(input, output); - ARM_COMPUTE_ERROR_ON(input == nullptr); - ARM_COMPUTE_ERROR_ON(output == nullptr); // Log info ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() @@ -526,13 +529,14 @@ std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode typename TargetInfo::TensorType *output = get_backing_tensor<TargetInfo>(node.output(0)); const FullyConnectedLayerInfo fc_info = node.info(); - // Create and configure function - auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(get_memory_manager(ctx, TargetInfo::TargetType)); - func->configure(input, weights, biases, output, fc_info); ARM_COMPUTE_ERROR_ON(input == nullptr); ARM_COMPUTE_ERROR_ON(weights == nullptr); ARM_COMPUTE_ERROR_ON(output == nullptr); + // Create and configure function + auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(get_memory_manager(ctx, TargetInfo::TargetType)); + func->configure(input, weights, biases, output, fc_info); + // Log info ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type() << " Target " << TargetInfo::TargetType |