diff options
Diffstat (limited to 'arm_compute/graph/backends/ValidateHelpers.h')
-rw-r--r-- | arm_compute/graph/backends/ValidateHelpers.h | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index 189fbdc9c7..ae52593b03 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -106,22 +106,22 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) const PadStrideInfo conv_info = node.convolution_info(); const ConvolutionMethod conv_algorithm = node.convolution_method(); - //const bool fast_math = node.fast_math_hint() == FastMathHint::ENABLED; // FIXME (COMPMID-1138): uncomment once NEON and GLES support fast_math + const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled; // Validate function Status status{}; switch(conv_algorithm) { - case ConvolutionMethod::DIRECT: + case ConvolutionMethod::Direct: status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info); break; case ConvolutionMethod::GEMM: status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info); break; - case ConvolutionMethod::WINOGRAD: - status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info /*, fast_math*/); + case ConvolutionMethod::Winograd: + status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math); break; - case ConvolutionMethod::DEFAULT: + case ConvolutionMethod::Default: status = ConvolutionLayer::validate(input, weights, biases, output, conv_info); break; default: @@ -136,7 +136,7 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) { ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - node.set_convolution_method(ConvolutionMethod::DEFAULT); + node.set_convolution_method(ConvolutionMethod::Default); } } @@ -166,11 +166,11 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) // TODO (geopin01) : Switch when validation is implemented // Validate function - if((dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3)) + if((dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3)) { ARM_COMPUTE_LOG_GRAPH_INFO("Switched DepthwiseConvolutionLayer method of node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::DEFAULT); + node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default); } return Status{}; |