From e2220551b7a64b929650ba9a60529c31e70c13c5 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 20 Jul 2018 13:23:44 +0100 Subject: COMPMID-1367: Enable NHWC in graph examples Change-Id: Iabc54a3a1bdcd46a9a921cda39c7c85fef672b72 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/141449 Reviewed-by: Giorgio Arena Reviewed-by: Anthony Barbier Tested-by: Jenkins --- arm_compute/graph/backends/ValidateHelpers.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'arm_compute/graph/backends/ValidateHelpers.h') diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h index 189fbdc9c7..ae52593b03 100644 --- a/arm_compute/graph/backends/ValidateHelpers.h +++ b/arm_compute/graph/backends/ValidateHelpers.h @@ -106,22 +106,22 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) const PadStrideInfo conv_info = node.convolution_info(); const ConvolutionMethod conv_algorithm = node.convolution_method(); - //const bool fast_math = node.fast_math_hint() == FastMathHint::ENABLED; // FIXME (COMPMID-1138): uncomment once NEON and GLES support fast_math + const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled; // Validate function Status status{}; switch(conv_algorithm) { - case ConvolutionMethod::DIRECT: + case ConvolutionMethod::Direct: status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info); break; case ConvolutionMethod::GEMM: status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info); break; - case ConvolutionMethod::WINOGRAD: - status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info /*, fast_math*/); + case ConvolutionMethod::Winograd: + status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math); break; - case ConvolutionMethod::DEFAULT: + case ConvolutionMethod::Default: status = ConvolutionLayer::validate(input, weights, biases, output, conv_info); break; default: @@ -136,7 +136,7 @@ Status validate_convolution_layer(ConvolutionLayerNode &node) { ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - node.set_convolution_method(ConvolutionMethod::DEFAULT); + node.set_convolution_method(ConvolutionMethod::Default); } } @@ -166,11 +166,11 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node) // TODO (geopin01) : Switch when validation is implemented // Validate function - if((dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3)) + if((dwc_algorithm == DepthwiseConvolutionMethod::Optimized3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3)) { ARM_COMPUTE_LOG_GRAPH_INFO("Switched DepthwiseConvolutionLayer method of node with ID : " << node.id() << " and Name: " << node.name() << std::endl); - node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::DEFAULT); + node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default); } return Status{}; -- cgit v1.2.1