From d1d7722cfc5ee130115d8d195068a98b16102a21 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Wed, 8 Apr 2020 14:10:15 +0100 Subject: COMPMID-3314: Enable OpenMP in the reference tests Change-Id: I05b5fedb998317144e0dd13a6377a97207b27f46 Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3024 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- tests/SConscript | 7 ++++- tests/validation/Helpers.cpp | 31 +++++++++++++++++++++- tests/validation/reference/AbsoluteDifference.cpp | 6 +++-- tests/validation/reference/Accumulate.cpp | 14 +++++++--- tests/validation/reference/ActivationLayer.cpp | 5 ++-- tests/validation/reference/ArithmeticDivision.cpp | 9 ++++--- .../validation/reference/ArithmeticOperations.cpp | 9 ++++--- .../reference/BatchNormalizationLayer.cpp | 6 +++-- tests/validation/reference/BitwiseAnd.cpp | 6 +++-- tests/validation/reference/BitwiseNot.cpp | 6 +++-- tests/validation/reference/BitwiseOr.cpp | 6 +++-- tests/validation/reference/BitwiseXor.cpp | 6 +++-- .../validation/reference/BoundingBoxTransform.cpp | 6 +++-- tests/validation/reference/Box3x3.cpp | 5 +++- tests/validation/reference/ChannelCombine.cpp | 5 +++- tests/validation/reference/ChannelExtract.cpp | 5 +++- tests/validation/reference/ChannelShuffle.cpp | 6 +++-- tests/validation/reference/Col2Im.cpp | 16 +++++++---- tests/validation/reference/Comparisons.cpp | 7 +++-- tests/validation/reference/ComputeAllAnchors.cpp | 5 +++- .../reference/ConvertFullyConnectedWeights.cpp | 5 +++- tests/validation/reference/Convolution.cpp | 3 +++ tests/validation/reference/ConvolutionLayer.cpp | 3 +++ tests/validation/reference/DFT.cpp | 28 ++++++++++++++++--- tests/validation/reference/DeconvolutionLayer.cpp | 7 ++++- .../validation/reference/DepthConcatenateLayer.cpp | 5 +++- tests/validation/reference/DepthConvertLayer.cpp | 6 +++++ tests/validation/reference/DepthToSpaceLayer.cpp | 9 ++++--- tests/validation/reference/DequantizationLayer.cpp | 10 ++++--- tests/validation/reference/Derivative.cpp | 5 +++- tests/validation/reference/Dilate.cpp | 5 +++- tests/validation/reference/ElementWiseUnary.cpp | 2 +- tests/validation/reference/EqualizeHistogram.cpp | 5 +++- tests/validation/reference/Erode.cpp | 5 +++- tests/validation/reference/Floor.cpp | 5 +++- tests/validation/reference/FullyConnectedLayer.cpp | 15 ++++++----- .../reference/FuseBatchNormalization.cpp | 6 ++++- tests/validation/reference/GEMM.cpp | 9 ++++++- tests/validation/reference/GEMMLowp.cpp | 9 +++++++ .../validation/reference/GEMMReshapeLHSMatrix.cpp | 2 +- .../validation/reference/GEMMReshapeRHSMatrix.cpp | 6 +++-- tests/validation/reference/Gaussian3x3.cpp | 5 +++- tests/validation/reference/Gaussian5x5.cpp | 5 +++- tests/validation/reference/Im2Col.cpp | 6 +++-- .../reference/InstanceNormalizationLayer.cpp | 13 +++++---- tests/validation/reference/QuantizationLayer.cpp | 9 +++++++ tests/validation/reference/ReorgLayer.cpp | 6 ++++- tests/validation/reference/Reverse.cpp | 6 ++++- tests/validation/reference/SoftmaxLayer.cpp | 7 +++-- tests/validation/reference/Winograd.cpp | 2 +- tests/validation/reference/YOLOLayer.cpp | 5 +++- 51 files changed, 293 insertions(+), 87 deletions(-) diff --git a/tests/SConscript b/tests/SConscript index 5c95c551f4..26d422c10b 100644 --- a/tests/SConscript +++ b/tests/SConscript @@ -30,6 +30,7 @@ Import('install_bin') variables = [ BoolVariable("benchmark_examples", "Build benchmark examples programs", True), BoolVariable("validate_examples", "Build validate examples programs", True), + BoolVariable("reference_openmp", "Build reference validation with openmp", True), #FIXME Switch the following two options to False before releasing BoolVariable("validation_tests", "Build validation test programs", True), BoolVariable("benchmark_tests", "Build benchmark test programs", True), @@ -170,8 +171,12 @@ bm_link_flags = ['-fstack-protector-strong'] if test_env['linker_script']: bm_link_flags += ['-Wl,--build-id=none', '-T', env['linker_script']] +if test_env['reference_openmp'] and env['os'] != 'bare_metal': + test_env['CXXFLAGS'].append('-fopenmp') + test_env['LINKFLAGS'].append('-fopenmp') + if test_env['validation_tests']: - arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LIBS= [ arm_compute_test_framework, arm_compute_core_a]) + arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LINKFLAGS=test_env['LINKFLAGS'], CXXFLAGS=test_env['CXXFLAGS'], LIBS= [ arm_compute_test_framework, arm_compute_core_a]) Depends(arm_compute_validation_framework , arm_compute_test_framework) Depends(arm_compute_validation_framework , arm_compute_core_a) diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp index 093271244e..6e93cd0638 100644 --- a/tests/validation/Helpers.cpp +++ b/tests/validation/Helpers.cpp @@ -113,7 +113,9 @@ SimpleTensor convert_from_asymmetric(const SimpleTensor &src) { const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); SimpleTensor dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() }; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = dequantize_qasymm8(src[i], quantization_info); @@ -127,6 +129,9 @@ SimpleTensor convert_from_asymmetric(const SimpleTensor &src) const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); SimpleTensor dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() }; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = dequantize_qasymm8_signed(src[i], quantization_info); @@ -140,6 +145,9 @@ SimpleTensor convert_from_asymmetric(const SimpleTensor &src) const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); SimpleTensor dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() }; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = dequantize_qasymm16(src[i], quantization_info); @@ -153,6 +161,9 @@ SimpleTensor convert_to_asymmetric(const SimpleTensor &src, cons SimpleTensor dst{ src.shape(), DataType::QASYMM8, 1, quantization_info }; const UniformQuantizationInfo &qinfo = quantization_info.uniform(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = quantize_qasymm8(src[i], qinfo); @@ -166,6 +177,9 @@ SimpleTensor convert_to_asymmetric(const SimpleTensor &src, const SimpleTensor dst{ src.shape(), DataType::QASYMM8_SIGNED, 1, quantization_info }; const UniformQuantizationInfo &qinfo = quantization_info.uniform(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = quantize_qasymm8_signed(src[i], qinfo); @@ -179,6 +193,9 @@ SimpleTensor convert_to_asymmetric(const SimpleTensor &src, con SimpleTensor dst{ src.shape(), DataType::QASYMM16, 1, quantization_info }; const UniformQuantizationInfo &qinfo = quantization_info.uniform(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = quantize_qasymm16(src[i], qinfo); @@ -192,6 +209,9 @@ SimpleTensor convert_to_symmetric(const SimpleTensor &src, const SimpleTensor dst{ src.shape(), DataType::QSYMM16, 1, quantization_info }; const UniformQuantizationInfo &qinfo = quantization_info.uniform(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = quantize_qsymm16(src[i], qinfo); @@ -205,6 +225,9 @@ SimpleTensor convert_from_symmetric(const SimpleTensor &src) const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); SimpleTensor dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() }; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = dequantize_qsymm16(src[i], quantization_info); @@ -223,6 +246,9 @@ void matrix_multiply(const SimpleTensor &a, const SimpleTensor &b, SimpleT const int N = b.shape()[0]; // Cols const int K = b.shape()[1]; +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int y = 0; y < M; ++y) { for(int x = 0; x < N; ++x) @@ -246,6 +272,9 @@ void transpose_matrix(const SimpleTensor &in, SimpleTensor &out) const int width = in.shape()[0]; const int height = in.shape()[1]; +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int y = 0; y < height; ++y) { for(int x = 0; x < width; ++x) diff --git a/tests/validation/reference/AbsoluteDifference.cpp b/tests/validation/reference/AbsoluteDifference.cpp index f9fce5b42a..ea7685bc9e 100644 --- a/tests/validation/reference/AbsoluteDifference.cpp +++ b/tests/validation/reference/AbsoluteDifference.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,7 +40,9 @@ SimpleTensor absolute_difference(const SimpleTensor &src1, const SimpleTen SimpleTensor result(src1.shape(), dst_data_type); using intermediate_type = typename common_promoted_signed_type::intermediate_type; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src1.num_elements(); ++i) { intermediate_type val = std::abs(static_cast(src1[i]) - static_cast(src2[i])); diff --git a/tests/validation/reference/Accumulate.cpp b/tests/validation/reference/Accumulate.cpp index 7f34be9663..2758577ef9 100644 --- a/tests/validation/reference/Accumulate.cpp +++ b/tests/validation/reference/Accumulate.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,7 +42,9 @@ SimpleTensor accumulate(const SimpleTensor &src, DataType output_data_ty library->fill_tensor_uniform(dst, 1, static_cast(0), static_cast(std::numeric_limits::max())); using intermediate_type = typename common_promoted_signed_type::intermediate_type; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { intermediate_type val = static_cast(src[i]) + static_cast(dst[i]); @@ -62,7 +64,9 @@ SimpleTensor accumulate_weighted(const SimpleTensor &src, float alpha, D library->fill_tensor_uniform(dst, 1, static_cast(0), static_cast(std::numeric_limits::max())); using intermediate_type = typename common_promoted_signed_type::intermediate_type; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { double val = (1. - static_cast(alpha)) * static_cast(dst[i]) + static_cast(alpha) * static_cast(src[i]); @@ -83,7 +87,9 @@ SimpleTensor accumulate_squared(const SimpleTensor &src, uint32_t shift, using intermediate_type = typename common_promoted_signed_type::intermediate_type; intermediate_type denom = 1 << shift; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { intermediate_type val = static_cast(dst[i]) + (static_cast(src[i]) * static_cast(src[i]) / denom); diff --git a/tests/validation/reference/ActivationLayer.cpp b/tests/validation/reference/ActivationLayer.cpp index 4aa0f880da..4aeefaaa79 100644 --- a/tests/validation/reference/ActivationLayer.cpp +++ b/tests/validation/reference/ActivationLayer.cpp @@ -45,7 +45,9 @@ SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo // Compute reference const T a(info.a()); const T b(info.b()); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = activate_float(src[i], a, b, info.activation()); @@ -86,7 +88,6 @@ SimpleTensor activation_layer(const SimpleTensor &src SimpleTensor dst = convert_to_symmetric(dst_tmp, dst_qinfo); return dst; } - template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info); template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info); template SimpleTensor activation_layer(const SimpleTensor &src, ActivationLayerInfo info, const QuantizationInfo &oq_info); diff --git a/tests/validation/reference/ArithmeticDivision.cpp b/tests/validation/reference/ArithmeticDivision.cpp index 0ced439404..f86ee5e599 100644 --- a/tests/validation/reference/ArithmeticDivision.cpp +++ b/tests/validation/reference/ArithmeticDivision.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,13 +49,16 @@ struct BroadcastUnroll id_src1.set(dim - 1, 0); id_src2.set(dim - 1, 0); id_dst.set(dim - 1, 0); - - for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1]) +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ + for(size_t i = 0; i < dst.shape()[dim - 1]; ++i) { BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, id_src1, id_src2, id_dst); id_src1[dim - 1] += !src1_is_broadcast; id_src2[dim - 1] += !src2_is_broadcast; + ++id_dst[dim - 1]; } } }; diff --git a/tests/validation/reference/ArithmeticOperations.cpp b/tests/validation/reference/ArithmeticOperations.cpp index d86833fefa..fd32f45cfa 100644 --- a/tests/validation/reference/ArithmeticOperations.cpp +++ b/tests/validation/reference/ArithmeticOperations.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -62,13 +62,16 @@ struct BroadcastUnroll id_src1.set(dim - 1, 0); id_src2.set(dim - 1, 0); id_dst.set(dim - 1, 0); - - for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1]) +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ + for(size_t i = 0; i < dst.shape()[dim - 1]; ++i) { BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst); id_src1[dim - 1] += !src1_is_broadcast; id_src2[dim - 1] += !src2_is_broadcast; + ++id_dst[dim - 1]; } } }; diff --git a/tests/validation/reference/BatchNormalizationLayer.cpp b/tests/validation/reference/BatchNormalizationLayer.cpp index 37713c841d..6623b22e78 100644 --- a/tests/validation/reference/BatchNormalizationLayer.cpp +++ b/tests/validation/reference/BatchNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -46,7 +46,9 @@ SimpleTensor batch_normalization_layer(const SimpleTensor &src, const Simp const auto rows = static_cast(src.shape()[1]); const auto depth = static_cast(src.shape()[2]); const int upper_dims = src.shape().total_size() / (cols * rows * depth); - +#if defined(_OPENMP) + #pragma omp parallel for schedule(dynamic, 1) collapse(4) +#endif /* _OPENMP */ for(int r = 0; r < upper_dims; ++r) { for(int i = 0; i < depth; ++i) diff --git a/tests/validation/reference/BitwiseAnd.cpp b/tests/validation/reference/BitwiseAnd.cpp index 6fc46b402b..356c27ee13 100644 --- a/tests/validation/reference/BitwiseAnd.cpp +++ b/tests/validation/reference/BitwiseAnd.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,9 @@ template SimpleTensor bitwise_and(const SimpleTensor &src1, const SimpleTensor &src2) { SimpleTensor dst(src1.shape(), src1.data_type()); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src1.num_elements(); ++i) { dst[i] = src1[i] & src2[i]; diff --git a/tests/validation/reference/BitwiseNot.cpp b/tests/validation/reference/BitwiseNot.cpp index 5a6a13b56c..03578a3beb 100644 --- a/tests/validation/reference/BitwiseNot.cpp +++ b/tests/validation/reference/BitwiseNot.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,9 @@ template SimpleTensor bitwise_not(const SimpleTensor &src) { SimpleTensor dst(src.shape(), src.data_type()); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = ~src[i]; diff --git a/tests/validation/reference/BitwiseOr.cpp b/tests/validation/reference/BitwiseOr.cpp index fc258d54f1..11c0a932fe 100644 --- a/tests/validation/reference/BitwiseOr.cpp +++ b/tests/validation/reference/BitwiseOr.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,9 @@ template SimpleTensor bitwise_or(const SimpleTensor &src1, const SimpleTensor &src2) { SimpleTensor dst(src1.shape(), src1.data_type()); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src1.num_elements(); ++i) { dst[i] = src1[i] | src2[i]; diff --git a/tests/validation/reference/BitwiseXor.cpp b/tests/validation/reference/BitwiseXor.cpp index b8d275d8b5..afae032b1b 100644 --- a/tests/validation/reference/BitwiseXor.cpp +++ b/tests/validation/reference/BitwiseXor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -35,7 +35,9 @@ template SimpleTensor bitwise_xor(const SimpleTensor &src1, const SimpleTensor &src2) { SimpleTensor dst(src1.shape(), src1.data_type()); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src1.num_elements(); ++i) { dst[i] = src1[i] ^ src2[i]; diff --git a/tests/validation/reference/BoundingBoxTransform.cpp b/tests/validation/reference/BoundingBoxTransform.cpp index e09bcff1c6..89182f1a9c 100644 --- a/tests/validation/reference/BoundingBoxTransform.cpp +++ b/tests/validation/reference/BoundingBoxTransform.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -57,7 +57,9 @@ SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const Simpl const size_t box_fields = 4; const size_t class_fields = 4; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(size_t i = 0; i < num_boxes; ++i) { // Extract ROI information diff --git a/tests/validation/reference/Box3x3.cpp b/tests/validation/reference/Box3x3.cpp index 153f26a5c6..7ea3f1fe99 100644 --- a/tests/validation/reference/Box3x3.cpp +++ b/tests/validation/reference/Box3x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,6 +41,9 @@ SimpleTensor box3x3(const SimpleTensor &src, BorderMode border_mode, T con const std::array filter{ { 1, 1, 1, 1, 1, 1, 1, 1, 1 } }; const float scale = 1.f / static_cast(filter.size()); const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx) { const Coordinates id = index2coord(src.shape(), element_idx); diff --git a/tests/validation/reference/ChannelCombine.cpp b/tests/validation/reference/ChannelCombine.cpp index a6c0557b79..2380b58583 100644 --- a/tests/validation/reference/ChannelCombine.cpp +++ b/tests/validation/reference/ChannelCombine.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -92,6 +92,9 @@ std::vector> channel_combine(const TensorShape &shape, const std { std::vector> dst = create_image_planes(shape, format); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int plane_idx = 0; plane_idx < dst.size(); ++plane_idx) { SimpleTensor &dst_tensor = dst[plane_idx]; diff --git a/tests/validation/reference/ChannelExtract.cpp b/tests/validation/reference/ChannelExtract.cpp index fc7ae7d6cb..75d0a00604 100644 --- a/tests/validation/reference/ChannelExtract.cpp +++ b/tests/validation/reference/ChannelExtract.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -51,6 +51,9 @@ SimpleTensor channel_extract(const TensorShape &shape, const std::vecto const int height = dst.shape().y(); // Loop over each pixel and extract channel +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int y = 0; y < height; ++y) { for(int x = 0; x < width; ++x) diff --git a/tests/validation/reference/ChannelShuffle.cpp b/tests/validation/reference/ChannelShuffle.cpp index b8aa9203ab..39d89e9f02 100644 --- a/tests/validation/reference/ChannelShuffle.cpp +++ b/tests/validation/reference/ChannelShuffle.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -50,7 +50,9 @@ SimpleTensor channel_shuffle(const SimpleTensor &src, int num_groups) const T *src_ref = src.data(); T *dst_ref = dst.data(); - +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int n = 0; n < batches; ++n) { for(int g = 0; g < num_groups; ++g) diff --git a/tests/validation/reference/Col2Im.cpp b/tests/validation/reference/Col2Im.cpp index 53969d4725..f42582bbe8 100644 --- a/tests/validation/reference/Col2Im.cpp +++ b/tests/validation/reference/Col2Im.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -47,21 +47,26 @@ SimpleTensor col2im(const SimpleTensor &src, const TensorShape &dst_shape, if(num_groups == 1) { // Batches are on the 3rd dimension of the input tensor - int dst_idx = 0; +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif /* _OPENMP */ for(size_t b = 0; b < batches; ++b) { for(size_t x = 0; x < src_width; ++x) { for(size_t y = 0; y < src_height; ++y) { - dst[dst_idx++] = src[coord2index(src.shape(), Coordinates(x, y, b))]; + const int dst_idx = y + x * src_height + b * src_height * src_width; + dst[dst_idx] = src[coord2index(src.shape(), Coordinates(x, y, b))]; } } } } else { - int dst_idx = 0; +#if defined(_OPENMP) + #pragma omp parallel for collapse(4) +#endif /* _OPENMP */ for(size_t b = 0; b < batches; ++b) { for(size_t g = 0; g < num_groups; ++g) @@ -70,7 +75,8 @@ SimpleTensor col2im(const SimpleTensor &src, const TensorShape &dst_shape, { for(size_t y = 0; y < src_height; ++y) { - dst[dst_idx++] = src[coord2index(src.shape(), Coordinates(x, y, g, b))]; + const int dst_idx = y + x * src_height + g * src_height * src_width + b * src_height * src_width * num_groups; + dst[dst_idx] = src[coord2index(src.shape(), Coordinates(x, y, g, b))]; } } } diff --git a/tests/validation/reference/Comparisons.cpp b/tests/validation/reference/Comparisons.cpp index c0c86b1933..2313d9b022 100644 --- a/tests/validation/reference/Comparisons.cpp +++ b/tests/validation/reference/Comparisons.cpp @@ -80,13 +80,16 @@ struct BroadcastUnroll id_src1.set(dim - 1, 0); id_src2.set(dim - 1, 0); id_dst.set(dim - 1, 0); - - for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1]) +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ + for(size_t i = 0; i < dst.shape()[dim - 1]; ++i) { BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, id_src1, id_src2, id_dst); id_src1[dim - 1] += !src1_is_broadcast; id_src2[dim - 1] += !src2_is_broadcast; + ++id_dst[dim - 1]; } } }; diff --git a/tests/validation/reference/ComputeAllAnchors.cpp b/tests/validation/reference/ComputeAllAnchors.cpp index 60be7ef8a8..9654da2ea3 100644 --- a/tests/validation/reference/ComputeAllAnchors.cpp +++ b/tests/validation/reference/ComputeAllAnchors.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,6 +49,9 @@ SimpleTensor compute_all_anchors(const SimpleTensor &anchors, const Comput T *all_anchors_ptr = all_anchors.data(); // Iterate over the input grid and anchors +#if defined(_OPENMP) + #pragma omp parallel for schedule(dynamic, 1) collapse(3) +#endif /* _OPENMP */ for(int y = 0; y < height; y++) { for(int x = 0; x < width; x++) diff --git a/tests/validation/reference/ConvertFullyConnectedWeights.cpp b/tests/validation/reference/ConvertFullyConnectedWeights.cpp index 5925496f45..710644a4e5 100644 --- a/tests/validation/reference/ConvertFullyConnectedWeights.cpp +++ b/tests/validation/reference/ConvertFullyConnectedWeights.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,6 +49,9 @@ SimpleTensor convert_fully_connected_weights(const SimpleTensor &src, cons const unsigned int factor_2 = is_nchw_to_nhwc ? num_channels : num_elems_per_input_plane; const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t i = 0; i < num_elements; ++i) { const Coordinates coords_in = index2coords(src.shape(), i); diff --git a/tests/validation/reference/Convolution.cpp b/tests/validation/reference/Convolution.cpp index 5e0f29d421..ad93b3077a 100644 --- a/tests/validation/reference/Convolution.cpp +++ b/tests/validation/reference/Convolution.cpp @@ -45,6 +45,9 @@ SimpleTensor convolution(const SimpleTensor &src, DataType output_da SimpleTensor dst(src.shape(), output_data_type); SimpleTensor sum(src.shape(), output_data_type); const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx) { const Coordinates id = index2coord(src.shape(), element_idx); diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp index 84fb3491bd..aed976e5ee 100644 --- a/tests/validation/reference/ConvolutionLayer.cpp +++ b/tests/validation/reference/ConvolutionLayer.cpp @@ -69,6 +69,9 @@ SimpleTensor convolution_layer_nchw(const SimpleTensor &src, const SimpleT const int end_xi = output_wh.first * stride_xi; const int end_yi = output_wh.second * stride_yi; const int num_batches = src.shape().total_size() / (width_in * height_in * depth_in); +#if defined(_OPENMP) + #pragma omp parallel for collapse(5) +#endif /* _OPENMP */ for(int r = 0; r < num_batches; ++r) { for(int yi = start_yi; yi < start_yi + end_yi; yi += stride_yi) diff --git a/tests/validation/reference/DFT.cpp b/tests/validation/reference/DFT.cpp index b3c2c6b0b9..ae030c7104 100644 --- a/tests/validation/reference/DFT.cpp +++ b/tests/validation/reference/DFT.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -50,6 +50,9 @@ namespace template void rdft_1d_step(const T *src_ptr, size_t N, T *dst_ptr, size_t K) { +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int k = 0; k < K; ++k) { float Xr = 0; @@ -77,6 +80,9 @@ void rdft_1d_step(const T *src_ptr, size_t N, T *dst_ptr, size_t K) template void dft_1d_step(const T *src_ptr, T *dst_ptr, size_t N) { +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int k = 0; k < N; ++k) { float Xr = 0; @@ -111,7 +117,9 @@ void irdft_1d_step(const T *src_ptr, size_t K, T *dst_ptr, size_t N) const bool is_odd = N % 2; const unsigned int Nleft = N - K; const int tail_start = is_odd ? K - 1 : K - 2; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int n = 0; n < N; ++n) { float xr = 0; @@ -142,6 +150,9 @@ void irdft_1d_step(const T *src_ptr, size_t K, T *dst_ptr, size_t N) template void idft_1d_step(const T *src_ptr, T *dst_ptr, size_t N) { +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int n = 0; n < N; ++n) { float xr = 0; @@ -181,6 +192,9 @@ SimpleTensor rdft_1d_core(const SimpleTensor &src, FFTDirection direction, SimpleTensor dst(dst_shape, src.data_type(), num_channels); const unsigned int upper_dims = src.shape().total_size_upper(1); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int du = 0; du < upper_dims; ++du) { const T *src_row_ptr = src.data() + du * N * src.num_channels(); @@ -201,6 +215,9 @@ SimpleTensor dft_1d_core(const SimpleTensor &src, FFTDirection direction) SimpleTensor dst(src.shape(), src.data_type(), src.num_channels()); const unsigned int upper_dims = src.shape().total_size_upper(1); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int du = 0; du < upper_dims; ++du) { const T *src_row_ptr = src.data() + du * N * src.num_channels(); @@ -221,6 +238,9 @@ void scale(SimpleTensor &tensor, T scaling_factor) { const int total_elements = tensor.num_elements() * tensor.num_channels(); T *data_ptr = tensor.data(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < total_elements; ++i) { data_ptr[i] /= scaling_factor; @@ -249,7 +269,9 @@ SimpleTensor complex_mul_and_reduce(const SimpleTensor &input, const Simpl // MemSet dst memory to zero std::memset(dst.data(), 0, dst.size()); - +#if defined(_OPENMP) + #pragma omp parallel for collapse(5) +#endif /* _OPENMP */ for(uint32_t b = 0; b < N; ++b) { for(uint32_t co = 0; co < Co; ++co) diff --git a/tests/validation/reference/DeconvolutionLayer.cpp b/tests/validation/reference/DeconvolutionLayer.cpp index 5750f51e3f..01b9c1c403 100644 --- a/tests/validation/reference/DeconvolutionLayer.cpp +++ b/tests/validation/reference/DeconvolutionLayer.cpp @@ -100,6 +100,9 @@ SimpleTensor deconvolution_layer(const SimpleTensor &src, const SimpleTens // Flip weights by 180 degrees SimpleTensor weights_flipped{ weights.shape(), weights.data_type(), 1, weights.quantization_info() }; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int ud = 0; ud < weights_upper_dims; ++ud) { const int offset = ud * weights_width * weights_height; @@ -111,7 +114,9 @@ SimpleTensor deconvolution_layer(const SimpleTensor &src, const SimpleTens } } } - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int slice = 0; slice < num_2d_slices; ++slice) { const int offset_slice_in = slice * width_in * height_in; diff --git a/tests/validation/reference/DepthConcatenateLayer.cpp b/tests/validation/reference/DepthConcatenateLayer.cpp index d6e6e78187..2c93e7060a 100644 --- a/tests/validation/reference/DepthConcatenateLayer.cpp +++ b/tests/validation/reference/DepthConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -58,6 +58,9 @@ SimpleTensor depthconcatenate_layer(const std::vector> &srcs, if(srcs[0].data_type() == DataType::QASYMM8 && std::any_of(srcs.cbegin(), srcs.cend(), have_different_quantization_info)) { +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int b = 0; b < batches; ++b) { // input tensors can have smaller width and height than the output, so for each output's slice we need to requantize 0 (as this is the value diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp index 57eeb7f6f3..9c6a9aa76c 100644 --- a/tests/validation/reference/DepthConvertLayer.cpp +++ b/tests/validation/reference/DepthConvertLayer.cpp @@ -46,6 +46,9 @@ SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, Con // Up-casting if(element_size_from_data_type(src.data_type()) < element_size_from_data_type(dt_out)) { +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { result[i] = src[i] << shift; @@ -54,6 +57,9 @@ SimpleTensor depth_convert(const SimpleTensor &src, DataType dt_out, Con // Down-casting else { +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { T1 val = src[i] >> shift; diff --git a/tests/validation/reference/DepthToSpaceLayer.cpp b/tests/validation/reference/DepthToSpaceLayer.cpp index 4135ce5471..e2329ed60b 100644 --- a/tests/validation/reference/DepthToSpaceLayer.cpp +++ b/tests/validation/reference/DepthToSpaceLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,13 +40,14 @@ SimpleTensor depth_to_space(const SimpleTensor &src, const TensorShape &ds ARM_COMPUTE_ERROR_ON(block_shape <= 0); SimpleTensor result(dst_shape, src.data_type()); - int in_pos = 0; const auto width_in = static_cast(src.shape()[0]); const auto height_in = static_cast(src.shape()[1]); const auto channel_in = static_cast(src.shape()[2]); const auto batch_in = static_cast(src.shape()[3]); const int r = channel_in / (block_shape * block_shape); - +#if defined(_OPENMP) + #pragma omp parallel for collapse(4) +#endif /* _OPENMP */ for(int b = 0; b < batch_in; ++b) { for(int z = 0; z < channel_in; ++z) @@ -58,8 +59,8 @@ SimpleTensor depth_to_space(const SimpleTensor &src, const TensorShape &ds const int out_x = (block_shape * x + (z / r) % block_shape); const int out_y = (block_shape * y + (z / r) / block_shape); const int out_pos = out_x + dst_shape[0] * out_y + (z % r) * dst_shape[0] * dst_shape[1] + b * dst_shape[0] * dst_shape[1] * dst_shape[2]; + const int in_pos = x + width_in * y + z * width_in * height_in + b * width_in * height_in * channel_in; result[out_pos] = src[in_pos]; - ++in_pos; } } } diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp index 7dd36402b3..7dec988b18 100644 --- a/tests/validation/reference/DequantizationLayer.cpp +++ b/tests/validation/reference/DequantizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -75,7 +75,9 @@ SimpleTensor dequantization_layer(const SimpleTensor &src) const int N = src.shape().total_size() / (WH * C); const std::vector qscales = src.quantization_info().scale(); - +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int n = 0; n < N; ++n) { for(int c = 0; c < C; ++c) @@ -95,7 +97,9 @@ SimpleTensor dequantization_layer(const SimpleTensor &src) { const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform(); ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = static_cast(dequantize(static_cast(src[i]), quantization_info, src_data_type)); diff --git a/tests/validation/reference/Derivative.cpp b/tests/validation/reference/Derivative.cpp index 3c6f3259b2..f4c2934728 100644 --- a/tests/validation/reference/Derivative.cpp +++ b/tests/validation/reference/Derivative.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -62,6 +62,9 @@ std::pair, SimpleTensor> derivative(const SimpleTensor &sr ValidRegion valid_region = shape_to_valid_region(src.shape(), border_mode == BorderMode::UNDEFINED, BorderSize(filter_size / 2)); const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t i = 0; i < num_elements; ++i) { Coordinates coord = index2coord(src.shape(), i); diff --git a/tests/validation/reference/Dilate.cpp b/tests/validation/reference/Dilate.cpp index 8e244e9b7b..cba9af127f 100644 --- a/tests/validation/reference/Dilate.cpp +++ b/tests/validation/reference/Dilate.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -52,6 +52,9 @@ SimpleTensor dilate(const SimpleTensor &src, BorderMode border_mode, T con SimpleTensor dst(src.shape(), src.data_type()); const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t i = 0; i < num_elements; ++i) { Coordinates coord = index2coord(src.shape(), i); diff --git a/tests/validation/reference/ElementWiseUnary.cpp b/tests/validation/reference/ElementWiseUnary.cpp index eaaaa4ec1e..f1bb7c783c 100644 --- a/tests/validation/reference/ElementWiseUnary.cpp +++ b/tests/validation/reference/ElementWiseUnary.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * diff --git a/tests/validation/reference/EqualizeHistogram.cpp b/tests/validation/reference/EqualizeHistogram.cpp index 1a10c2c30a..34e7c397bf 100644 --- a/tests/validation/reference/EqualizeHistogram.cpp +++ b/tests/validation/reference/EqualizeHistogram.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -75,6 +75,9 @@ SimpleTensor equalize_histogram(const SimpleTensor &src) } // Fill output tensor with equalized values +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = lut[src[i]]; diff --git a/tests/validation/reference/Erode.cpp b/tests/validation/reference/Erode.cpp index e7598c3900..0964c3d4b2 100644 --- a/tests/validation/reference/Erode.cpp +++ b/tests/validation/reference/Erode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -52,6 +52,9 @@ SimpleTensor erode(const SimpleTensor &src, BorderMode border_mode, T cons SimpleTensor dst(src.shape(), src.data_type()); const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t i = 0; i < num_elements; ++i) { Coordinates coord = index2coord(src.shape(), i); diff --git a/tests/validation/reference/Floor.cpp b/tests/validation/reference/Floor.cpp index b011a16974..21fa1c9932 100644 --- a/tests/validation/reference/Floor.cpp +++ b/tests/validation/reference/Floor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,6 +42,9 @@ SimpleTensor floor_layer(const SimpleTensor &src) SimpleTensor dst{ src.shape(), src.data_type() }; // Compute reference +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = std::floor(src[i]); diff --git a/tests/validation/reference/FullyConnectedLayer.cpp b/tests/validation/reference/FullyConnectedLayer.cpp index 9aecd6cf14..908c583161 100644 --- a/tests/validation/reference/FullyConnectedLayer.cpp +++ b/tests/validation/reference/FullyConnectedLayer.cpp @@ -49,11 +49,12 @@ void vector_matrix_multiply(const SimpleTensor &src, const SimpleTensor &w const T *weights_ptr = weights.data(); const TB *bias_ptr = bias.data(); T *dst_ptr = dst.data() + offset_dst; - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int y = 0; y < rows_weights; ++y) { - dst_ptr[y] = std::inner_product(src_ptr, src_ptr + cols_weights, weights_ptr, static_cast(0)) + bias_ptr[y]; - weights_ptr += cols_weights; + dst_ptr[y] = std::inner_product(src_ptr, src_ptr + cols_weights, &weights_ptr[cols_weights * y], static_cast(0)) + bias_ptr[y]; } } @@ -85,7 +86,9 @@ void vector_matrix_multiply(const SimpleTensor &src, const SimpleTensor &w const int min = std::numeric_limits::lowest(); const int max = std::numeric_limits::max(); - +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int y = 0; y < rows_weights; ++y) { // Reset accumulator @@ -93,7 +96,7 @@ void vector_matrix_multiply(const SimpleTensor &src, const SimpleTensor &w for(int x = 0; x < cols_weights; ++x) { - acc += (src_ptr[x] + input_offset) * (weights_ptr[x] + weights_offset); + acc += (src_ptr[x] + input_offset) * (weights_ptr[x + y * cols_weights] + weights_offset); } // Accumulate the bias @@ -104,8 +107,6 @@ void vector_matrix_multiply(const SimpleTensor &src, const SimpleTensor &w // Store the result dst_ptr[y] = static_cast(acc); - - weights_ptr += cols_weights; } } } // namespace diff --git a/tests/validation/reference/FuseBatchNormalization.cpp b/tests/validation/reference/FuseBatchNormalization.cpp index df12b25912..cb5003874b 100644 --- a/tests/validation/reference/FuseBatchNormalization.cpp +++ b/tests/validation/reference/FuseBatchNormalization.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -22,6 +22,7 @@ * SOFTWARE. */ #include "FuseBatchNormalization.h" +#include "tests/validation/Helpers.h" namespace arm_compute { @@ -45,6 +46,9 @@ void fuse_batch_normalization_dwc_layer(const SimpleTensor &w, const SimpleTe const unsigned int height = w.shape()[1]; const unsigned int dim2 = w.shape()[2]; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(unsigned int b = 0; b < dim2; ++b) { const auto mean_val = mean.data()[b]; diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp index 3c72b94143..20def87a64 100644 --- a/tests/validation/reference/GEMM.cpp +++ b/tests/validation/reference/GEMM.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,7 @@ */ #include "GEMM.h" +#include "arm_compute/core/Helpers.h" #include "arm_compute/core/Types.h" namespace arm_compute @@ -55,6 +56,9 @@ SimpleTensor gemm(const SimpleTensor &a, const SimpleTensor &b, const S const int c_stride_z = N * M; const int c_stride_w = N * M * D; +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int w = 0; w < W; ++w) { for(int depth = 0; depth < D; ++depth) @@ -107,6 +111,9 @@ SimpleTensor gemm_mixed_precision(const SimpleTensor &a, const SimpleTenso const int c_stride_z = N * M; const int c_stride_w = N * M * D; +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(int w = 0; w < W; ++w) { for(int depth = 0; depth < D; ++depth) diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp index 36d86d1532..85a98e4a76 100644 --- a/tests/validation/reference/GEMMLowp.cpp +++ b/tests/validation/reference/GEMMLowp.cpp @@ -69,6 +69,9 @@ void quantize_down_scale(const SimpleTensor *in, const SimpleTensor *b const int cols_in = in->shape().x(); const bool is_per_channel = result_mult_int.size() > 1; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < in->num_elements(); ++i) { int32_t result = ((*in)[i] + result_offset); @@ -100,6 +103,9 @@ void quantize_down_scale_by_fixedpoint(const SimpleTensor *in, const Simple const int cols_in = in->shape().x(); const bool is_per_channel = result_fixedpoint_multiplier.size() > 1; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < in->num_elements(); ++i) { TIn result = (*in)[i]; @@ -141,6 +147,9 @@ void quantize_down_scale_by_float(const SimpleTensor *in, const SimpleTenso const int cols_in = in->shape().x(); const bool is_per_channel = result_real_multiplier.size() > 1; +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < in->num_elements(); ++i) { TIn result = (*in)[i]; diff --git a/tests/validation/reference/GEMMReshapeLHSMatrix.cpp b/tests/validation/reference/GEMMReshapeLHSMatrix.cpp index 431d65696e..f21fe50e58 100644 --- a/tests/validation/reference/GEMMReshapeLHSMatrix.cpp +++ b/tests/validation/reference/GEMMReshapeLHSMatrix.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * diff --git a/tests/validation/reference/GEMMReshapeRHSMatrix.cpp b/tests/validation/reference/GEMMReshapeRHSMatrix.cpp index 0224c5c67c..ebb6f856d2 100644 --- a/tests/validation/reference/GEMMReshapeRHSMatrix.cpp +++ b/tests/validation/reference/GEMMReshapeRHSMatrix.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -70,7 +70,9 @@ SimpleTensor gemm_reshape_rhs_matrix(const SimpleTensor &in, const TensorS const unsigned int offset_output_x = rhs_info.interleave ? tile_to_use->shape()[0] : tile_to_use->shape()[0] * tile_to_use->shape()[1]; const unsigned int step_output_x = rhs_info.interleave ? tile_to_use->shape()[0] * rhs_info.h0 : tile_to_use->shape()[0]; - +#ifdef ARM_COMPUTE_OPENMP + #pragma omp parallel for schedule(dynamic, 1) collapse(3) +#endif /* _OPENMP */ for(unsigned int z = 0; z < B; ++z) { for(unsigned int y = 0; y < num_tiles_y; ++y) diff --git a/tests/validation/reference/Gaussian3x3.cpp b/tests/validation/reference/Gaussian3x3.cpp index 5ca24a7961..f2ac134f17 100644 --- a/tests/validation/reference/Gaussian3x3.cpp +++ b/tests/validation/reference/Gaussian3x3.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,6 +42,9 @@ SimpleTensor gaussian3x3(const SimpleTensor &src, BorderMode border_mode, const float scale = 1.f / 16.f; const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx) { const Coordinates id = index2coord(src.shape(), element_idx); diff --git a/tests/validation/reference/Gaussian5x5.cpp b/tests/validation/reference/Gaussian5x5.cpp index ac84f6d097..426e66647c 100644 --- a/tests/validation/reference/Gaussian5x5.cpp +++ b/tests/validation/reference/Gaussian5x5.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -48,6 +48,9 @@ SimpleTensor gaussian5x5(const SimpleTensor &src, BorderMode border_mode, const float scale = 1.f / 256.f; const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx) { const Coordinates id = index2coord(src.shape(), element_idx); diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp index 4b41cdb70b..a3dcf07273 100644 --- a/tests/validation/reference/Im2Col.cpp +++ b/tests/validation/reference/Im2Col.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -109,7 +109,9 @@ void im2col_nhwc(const SimpleTensor &src, SimpleTensor &dst, const Size2D // Compute width and height of the convolved tensors std::pair convolved_dims = scaled_dimensions(src_width, src_height, kernel_dims.width, kernel_dims.height, conv_info); - +#if defined(_OPENMP) + #pragma omp parallel for schedule(dynamic, 1) collapse(2) +#endif /* _OPENMP */ for(int b = 0; b < batches; ++b) { for(int yo = 0; yo < dst_height; ++yo) diff --git a/tests/validation/reference/InstanceNormalizationLayer.cpp b/tests/validation/reference/InstanceNormalizationLayer.cpp index ad0ac1be68..339549723d 100644 --- a/tests/validation/reference/InstanceNormalizationLayer.cpp +++ b/tests/validation/reference/InstanceNormalizationLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -46,12 +46,14 @@ SimpleTensor instance_normalization(const SimpleTensor &src, float gamma, const size_t h_size = src.shape()[1]; const size_t c_size = src.shape()[2]; const size_t n_size = src.shape()[3]; - +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif /* _OPENMP */ for(size_t n_i = 0; n_i < n_size; ++n_i) { for(size_t c_i = 0; c_i < c_size; ++c_i) { - float sum_h_w = 0; + float sum_h_w = 0; float sum_sq_h_w = 0; for(size_t h_i = 0; h_i < h_size; ++h_i) @@ -60,13 +62,14 @@ SimpleTensor instance_normalization(const SimpleTensor &src, float gamma, { float val = src[coord2index(src.shape(), Coordinates(w_i, h_i, c_i, n_i))]; sum_h_w += val; - sum_sq_h_w += val*val; + sum_sq_h_w += val * val; } } //Compute mean const float mean_h_w = sum_h_w / (h_size * w_size); //Compute variance - const float var_h_w = sum_sq_h_w / (h_size * w_size) - mean_h_w * mean_h_w;; + const float var_h_w = sum_sq_h_w / (h_size * w_size) - mean_h_w * mean_h_w; + ; //Apply mean for(size_t h_i = 0; h_i < h_size; ++h_i) diff --git a/tests/validation/reference/QuantizationLayer.cpp b/tests/validation/reference/QuantizationLayer.cpp index cfc508529e..a70523d7da 100644 --- a/tests/validation/reference/QuantizationLayer.cpp +++ b/tests/validation/reference/QuantizationLayer.cpp @@ -50,12 +50,18 @@ SimpleTensor quantization_layer(const SimpleTensor &src, DataType out switch(output_data_type) { case DataType::QASYMM8: +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = quantize_qasymm8((src[i]), qinfo, rounding_policy); } break; case DataType::QASYMM8_SIGNED: +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { #ifdef __aarch64__ @@ -66,6 +72,9 @@ SimpleTensor quantization_layer(const SimpleTensor &src, DataType out } break; case DataType::QASYMM16: +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int i = 0; i < src.num_elements(); ++i) { dst[i] = quantize_qasymm16((src[i]), qinfo, rounding_policy); diff --git a/tests/validation/reference/ReorgLayer.cpp b/tests/validation/reference/ReorgLayer.cpp index 2eb5d01926..9f087d06cb 100644 --- a/tests/validation/reference/ReorgLayer.cpp +++ b/tests/validation/reference/ReorgLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -54,6 +54,10 @@ SimpleTensor reorg_layer(const SimpleTensor &src, int32_t stride) // Calculate layer reorg in NCHW Coordinates map_coords; + +#if defined(_OPENMP) + #pragma omp parallel for private(map_coords) +#endif /* _OPENMP */ for(unsigned int b = 0; b < outer_dims; ++b) { map_coords.set(3, b); diff --git a/tests/validation/reference/Reverse.cpp b/tests/validation/reference/Reverse.cpp index 4bd8efc6a8..f5630b9a40 100644 --- a/tests/validation/reference/Reverse.cpp +++ b/tests/validation/reference/Reverse.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -55,6 +55,10 @@ SimpleTensor reverse(const SimpleTensor &src, const SimpleTensor } const uint32_t num_elements = src.num_elements(); + +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t i = 0; i < num_elements; ++i) { const Coordinates src_coord = index2coord(src.shape(), i); diff --git a/tests/validation/reference/SoftmaxLayer.cpp b/tests/validation/reference/SoftmaxLayer.cpp index 0e470260a9..ee7a5f175a 100644 --- a/tests/validation/reference/SoftmaxLayer.cpp +++ b/tests/validation/reference/SoftmaxLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -55,6 +55,9 @@ SimpleTensor softmax_layer_generic(const SimpleTensor &src, float beta, si upper_dims *= src.shape()[i]; } +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(int r = 0; r < upper_dims; ++r) { const T *src_row_ptr = src.data() + r * lower_dims; @@ -107,7 +110,7 @@ SimpleTensor softmax_layer(const SimpleTensor &src, float beta, size_t axi return softmax_layer_generic(src, beta, axis, false); } -template ::value || std::is_same::value, int>::type> +template < typename T, typename std::enable_if < std::is_same::value || std::is_same::value, int >::type > SimpleTensor softmax_layer(const SimpleTensor &src, float beta, size_t axis) { const QuantizationInfo output_quantization_info = arm_compute::get_softmax_output_quantization_info(src.data_type(), false); diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp index 47f5ac7a7d..61ba510fc6 100644 --- a/tests/validation/reference/Winograd.cpp +++ b/tests/validation/reference/Winograd.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * diff --git a/tests/validation/reference/YOLOLayer.cpp b/tests/validation/reference/YOLOLayer.cpp index 0011b85c1e..92bbf5445f 100644 --- a/tests/validation/reference/YOLOLayer.cpp +++ b/tests/validation/reference/YOLOLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -47,6 +47,9 @@ SimpleTensor yolo_layer(const SimpleTensor &src, const ActivationLayerInfo const T b(info.b()); const uint32_t num_elements = src.num_elements(); +#if defined(_OPENMP) + #pragma omp parallel for +#endif /* _OPENMP */ for(uint32_t i = 0; i < num_elements; ++i) { const size_t z = index2coord(dst.shape(), i).z() % (num_classes + 5); -- cgit v1.2.1