diff options
Diffstat (limited to 'arm_compute/graph/backends/ValidateHelpers.h')
-rw-r--r-- | arm_compute/graph/backends/ValidateHelpers.h | 299 |
1 files changed, 242 insertions, 57 deletions
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index 673caf9eac..0e102942a7 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -21,17 +21,16 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H -#define ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H - -#include "arm_compute/graph/Logger.h" -#include "arm_compute/graph/Tensor.h" -#include "arm_compute/graph/Types.h" -#include "arm_compute/graph/nodes/Nodes.h" +#ifndef ACL_ARM_COMPUTE_GRAPH_BACKENDS_VALIDATEHELPERS_H +#define ACL_ARM_COMPUTE_GRAPH_BACKENDS_VALIDATEHELPERS_H #include "arm_compute/core/Error.h" #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensorInfo.h" +#include "arm_compute/graph/Logger.h" +#include "arm_compute/graph/nodes/Nodes.h" +#include "arm_compute/graph/Tensor.h" +#include "arm_compute/graph/Types.h" namespace arm_compute { @@ -52,6 +51,30 @@ inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Ten return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : tensor->handle()->tensor().info(); } +/** Validates a ArgMinMax layer node + * + * @tparam ArgMinMax layer function type + * + * @param[in] node Node to validate + * + * @return Status + */ +template <typename ArgMinMaxLayer> +Status validate_arg_min_max_layer(ArgMinMaxLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating ArgMinMaxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + + // Extract IO and info + arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + + // Validate function + return ArgMinMaxLayer::validate(input, node.axis(), output, node.reduction_operation()); +} + /** Validates a Bounding Box Transform layer node * * @tparam BoundingBoxTransformLayer Bounding Box Transform layer function type @@ -63,7 +86,8 @@ inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Ten template <typename BoundingBoxTransformLayer> Status validate_bounding_box_transform_layer(BoundingBoxTransformLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating BoundingBoxTransformLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating BoundingBoxTransformLayer node with ID : " << node.id() << " and Name: " + << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -87,7 +111,8 @@ Status validate_bounding_box_transform_layer(BoundingBoxTransformLayerNode &node template <typename ChannelShuffleLayer> Status validate_channel_shuffle_layer(ChannelShuffleLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ChannelShuffle node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating ChannelShuffle node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -110,10 +135,14 @@ Status validate_channel_shuffle_layer(ChannelShuffleLayerNode &node) * * @return Status */ -template <typename ConvolutionLayer, typename DirectConvolutionLayer, typename GEMMConvolutionLayer, typename WinogradConvolutionLayer> +template <typename ConvolutionLayer, + typename DirectConvolutionLayer, + typename GEMMConvolutionLayer, + typename WinogradConvolutionLayer> Status validate_convolution_layer(ConvolutionLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -123,7 +152,7 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2)); arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); - if(is_data_type_quantized_asymmetric(input->data_type())) + if (is_data_type_quantized_asymmetric(input->data_type())) { biases->set_data_type(DataType::S32); } @@ -135,23 +164,24 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) // Validate function Status status{}; - switch(conv_algorithm) + switch (conv_algorithm) { case ConvolutionMethod::Direct: ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "DirectConvolutionLayer does not support grouping!"); status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info); break; case ConvolutionMethod::GEMM: - status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, - WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups); + status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, WeightsInfo(), + Size2D(1, 1), ActivationLayerInfo(), num_groups); break; case ConvolutionMethod::Winograd: ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "WinogradConvolutionLayer does not support grouping!"); - status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math); + status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, + ActivationLayerInfo(), fast_math); break; case ConvolutionMethod::Default: - status = ConvolutionLayer::validate(input, weights, biases, output, conv_info, - WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), fast_math, num_groups); + status = ConvolutionLayer::validate(input, weights, biases, output, conv_info, WeightsInfo(), Size2D(1, 1), + ActivationLayerInfo(), fast_math, num_groups); break; default: ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported convolution method"); @@ -171,7 +201,8 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) template <typename DepthwiseConvolutionLayer> Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " + << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -187,7 +218,7 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) // Validate function Status status{}; - switch(dwc_algorithm) + switch (dwc_algorithm) { case DepthwiseConvolutionMethod::Default: case DepthwiseConvolutionMethod::Optimized3x3: @@ -199,6 +230,28 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) return status; } +/** Validates a depth to space layer node + * + * @tparam DequantizationLayer Dequantize layer type + * + * @param[in] node Node to validate + * + * @return Status + */ +template <typename DepthToSpaceLayer> +Status validate_depth_to_space_layer(DepthToSpaceLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + + // Extract IO and info + arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + + return DepthToSpaceLayer::validate(input, output, node.block_shape()); +} /** Validates a dequantize layer node * * @tparam DequantizationLayer Dequantize layer type @@ -210,7 +263,8 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) template <typename DequantizationLayer> Status validate_dequantization_layer(DequantizationLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -231,7 +285,8 @@ Status validate_dequantization_layer(DequantizationLayerNode &node) template <typename DetectionOutputLayer> Status validate_detection_output_layer(DetectionOutputLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -255,7 +310,8 @@ Status validate_detection_output_layer(DetectionOutputLayerNode &node) template <typename DetectionPostProcessLayer> Status validate_detection_post_process_layer(DetectionPostProcessLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionPostProcessLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionPostProcessLayer node with ID : " << node.id() << " and Name: " + << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 4); @@ -283,7 +339,8 @@ Status validate_detection_post_process_layer(DetectionPostProcessLayerNode &node template <typename GenerateProposalsLayer> Status validate_generate_proposals_layer(GenerateProposalsLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating GenerateProposalsLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating GenerateProposalsLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 3); @@ -299,6 +356,32 @@ Status validate_generate_proposals_layer(GenerateProposalsLayerNode &node) return GenerateProposalsLayer::validate(scores, deltas, anchors, proposals, scores_out, num_valid_proposals, info); } +/** Validates a L2Normalization layer node + * + * @tparam L2Normalization layer type + * + * @param[in] node Node to validate + * + * @return Status + */ +template <typename L2NormalizeLayer> +Status validate_l2_normalize_layer(L2NormalizeLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating L2NormalizeLayerNode node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + + // Extract IO and info + arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + int axis = node.axis(); + float epsilon = node.epsilon(); + + // Validate function + return L2NormalizeLayer::validate(input, output, axis, epsilon); +} + /** Validates a NormalizePlanarYUV layer node * * @tparam NormalizePlanarYUVLayer layer type @@ -310,7 +393,8 @@ Status validate_generate_proposals_layer(GenerateProposalsLayerNode &node) template <typename NormalizePlanarYUVLayer> Status validate_normalize_planar_yuv_layer(NormalizePlanarYUVLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating NormalizePlanarYUVLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating NormalizePlanarYUVLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -335,7 +419,8 @@ Status validate_normalize_planar_yuv_layer(NormalizePlanarYUVLayerNode &node) template <typename PadLayer> Status validate_pad_layer(PadLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PadLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PadLayer node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -358,14 +443,15 @@ Status validate_pad_layer(PadLayerNode &node) template <typename PermuteLayer> Status validate_permute_layer(PermuteLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PermuteLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PermuteLayer node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); // Extract IO and info arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0)); arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); - const PermutationVector &perm = node.permutation_vector(); + const PermutationVector &perm = node.permutation_vector(); return PermuteLayer::validate(input, output, perm); } @@ -381,7 +467,8 @@ Status validate_permute_layer(PermuteLayerNode &node) template <typename PReluLayer> Status validate_prelu_layer(PReluLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PRelu node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PRelu node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -404,7 +491,8 @@ Status validate_prelu_layer(PReluLayerNode &node) template <typename PriorBoxLayer> Status validate_priorbox_layer(PriorBoxLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PriorBoxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating PriorBoxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -428,7 +516,8 @@ Status validate_priorbox_layer(PriorBoxLayerNode &node) template <typename QuantizationLayer> Status validate_quantization_layer(QuantizationLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating QuantizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating QuantizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -440,6 +529,31 @@ Status validate_quantization_layer(QuantizationLayerNode &node) return QuantizationLayer::validate(input, output); } +/** Validates a Reduction operation layer node + * + * @tparam ReductionLayer Reduction layer type + * + * @param[in] node Node to validate + * + * @return Status + */ +template <typename ReductionLayer> +Status validate_reduction_operation_layer(ReductionLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating ReductionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + + // Extract input and output + arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + + // Validate function + return ReductionLayer::validate(input, output, node.axis(), node.op(), node.keep_dims()); +} + /** Validates a Reorg layer node * * @tparam ReorgLayer Reorg layer type @@ -451,7 +565,8 @@ Status validate_quantization_layer(QuantizationLayerNode &node) template <typename ReorgLayer> Status validate_reorg_layer(ReorgLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReorgLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReorgLayer node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -474,7 +589,8 @@ Status validate_reorg_layer(ReorgLayerNode &node) template <typename ReshapeLayer> Status validate_reshape_layer(ReshapeLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -497,14 +613,15 @@ Status validate_reshape_layer(ReshapeLayerNode &node) template <typename ROIAlignLayer> Status validate_roi_align_layer(ROIAlignLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ROIAlignLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE( + "Validating ROIAlignLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); // Extract input and output - arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); - arm_compute::ITensorInfo *rois = detail::get_backing_tensor_info(node.input(1)); - arm_compute::ITensorInfo *output = detail::get_backing_tensor_info(node.output(0)); + arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *rois = detail::get_backing_tensor_info(node.input(1)); + arm_compute::ITensorInfo *output = detail::get_backing_tensor_info(node.output(0)); const ROIPoolingLayerInfo &pool_info = node.pooling_info(); // Validate function @@ -522,7 +639,8 @@ Status validate_roi_align_layer(ROIAlignLayerNode &node) template <typename SliceLayer> Status validate_slice_layer(SliceLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating Slice node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating Slice node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); @@ -535,53 +653,120 @@ Status validate_slice_layer(SliceLayerNode &node) return SliceLayer::validate(input, output, starts, ends); } -/** Validates a Upsample layer node +/** Validates a Strided Slice layer node * - * @tparam UpsampleLayer Upsample layer type + * @tparam StridedSliceLayer Strided Slice layer function type * * @param[in] node Node to validate * * @return Status */ -template <typename UpsampleLayer> -Status validate_upsample_layer(UpsampleLayerNode &node) +template <typename StridedSliceLayer> +Status validate_strided_slice_layer(StridedSliceLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating UpsampleLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating StridedSlice node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + // Extract IO and info + arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + const Coordinates starts = node.starts(); + const Coordinates ends = node.ends(); + const BiStrides strides = node.strides(); + const StridedSliceLayerInfo info = node.strided_slice_info(); + + return StridedSliceLayer::validate(input, output, starts, ends, strides, info.begin_mask(), info.end_mask(), + info.shrink_axis_mask()); +} + +/** Validates a element-wise layer node + * + * @param[in] node Node to validate + * + * @return Status + */ +template <typename EltwiseLayerFunctions> +Status validate_eltwise_Layer(EltwiseLayerNode &node) +{ + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2); + ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); + // Extract input and output - arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); - arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + const arm_compute::ITensorInfo *input1 = detail::get_backing_tensor_info(node.input(0)); + const arm_compute::ITensorInfo *input2 = detail::get_backing_tensor_info(node.input(1)); + const arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + const EltwiseOperation eltwise_op = node.eltwise_operation(); + const ConvertPolicy convert_policy = node.convert_policy(); + const RoundingPolicy round_policy = node.rounding_policy(); + const ActivationLayerInfo act_info = node.fused_activation(); + const QuantizationInfo quant_info = node.output_quant_info(); // Validate function - return UpsampleLayer::validate(input, output, node.info(), node.upsampling_policy()); + if (eltwise_op == EltwiseOperation::Add) + { + return EltwiseLayerFunctions::ArithmeticAddition::validate(input1, input2, output, convert_policy, act_info); + } + else if (eltwise_op == EltwiseOperation::Sub) + { + return EltwiseLayerFunctions::ArithmeticSubtraction::validate(input1, input2, output, convert_policy, act_info); + } + else if (eltwise_op == EltwiseOperation::Mul) + { + return EltwiseLayerFunctions::PixelWiseMultiplication::validate(input1, input2, output, 1.0f, convert_policy, + round_policy, act_info); + } + else if (eltwise_op == EltwiseOperation::Max) + { + return EltwiseLayerFunctions::ElementwiseMax::validate(input1, input2, output, act_info); + } + else if (eltwise_op == EltwiseOperation::Div) + { + return EltwiseLayerFunctions::ArithmeticDivision::validate(input1, input2, output, act_info); + } + else + { + ARM_COMPUTE_ERROR("Unsupported element-wise operation!"); + } + return Status{}; } -/** Validates a YOLO layer node - * - * @tparam YOLOLayer YOLO layer type +/** Validates a unary element-wise layer node * * @param[in] node Node to validate * * @return Status */ -template <typename YOLOLayer> -Status validate_yolo_layer(YOLOLayerNode &node) +template <typename UnaryEltwiseLayerFunctions> +Status validate_unary_eltwise_layer(UnaryEltwiseLayerNode &node) { - ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating YOLOLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl); + ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() + << std::endl); ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1); ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1); // Extract input and output - arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); - arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0)); + arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0)); + const UnaryEltwiseOperation eltwise_op = node.eltwise_descriptor().op; // Validate function - return YOLOLayer::validate(input, output, node.activation_info(), node.num_classes()); + if (eltwise_op == UnaryEltwiseOperation::Exp) + { + return UnaryEltwiseLayerFunctions::ExpLayer::validate(input, output); + } + else + { + ARM_COMPUTE_ERROR("Unsupported unary element-wise operation!"); + } + + return Status{}; } } // namespace detail } // namespace backends } // namespace graph } // namespace arm_compute -#endif /* ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H */ +#endif // ACL_ARM_COMPUTE_GRAPH_BACKENDS_VALIDATEHELPERS_H |