aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2020-04-08 14:10:15 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2020-04-17 14:20:01 +0000
commitd1d7722cfc5ee130115d8d195068a98b16102a21 (patch)
treef68f3ecca02ab4edde90189266fa186ec1a69474
parent4c6bd514a8d424a29b776754f1b3426fa3a8c339 (diff)
downloadComputeLibrary-d1d7722cfc5ee130115d8d195068a98b16102a21.tar.gz
COMPMID-3314: Enable OpenMP in the reference tests
Change-Id: I05b5fedb998317144e0dd13a6377a97207b27f46 Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3024 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--tests/SConscript7
-rw-r--r--tests/validation/Helpers.cpp31
-rw-r--r--tests/validation/reference/AbsoluteDifference.cpp6
-rw-r--r--tests/validation/reference/Accumulate.cpp14
-rw-r--r--tests/validation/reference/ActivationLayer.cpp5
-rw-r--r--tests/validation/reference/ArithmeticDivision.cpp9
-rw-r--r--tests/validation/reference/ArithmeticOperations.cpp9
-rw-r--r--tests/validation/reference/BatchNormalizationLayer.cpp6
-rw-r--r--tests/validation/reference/BitwiseAnd.cpp6
-rw-r--r--tests/validation/reference/BitwiseNot.cpp6
-rw-r--r--tests/validation/reference/BitwiseOr.cpp6
-rw-r--r--tests/validation/reference/BitwiseXor.cpp6
-rw-r--r--tests/validation/reference/BoundingBoxTransform.cpp6
-rw-r--r--tests/validation/reference/Box3x3.cpp5
-rw-r--r--tests/validation/reference/ChannelCombine.cpp5
-rw-r--r--tests/validation/reference/ChannelExtract.cpp5
-rw-r--r--tests/validation/reference/ChannelShuffle.cpp6
-rw-r--r--tests/validation/reference/Col2Im.cpp16
-rw-r--r--tests/validation/reference/Comparisons.cpp7
-rw-r--r--tests/validation/reference/ComputeAllAnchors.cpp5
-rw-r--r--tests/validation/reference/ConvertFullyConnectedWeights.cpp5
-rw-r--r--tests/validation/reference/Convolution.cpp3
-rw-r--r--tests/validation/reference/ConvolutionLayer.cpp3
-rw-r--r--tests/validation/reference/DFT.cpp28
-rw-r--r--tests/validation/reference/DeconvolutionLayer.cpp7
-rw-r--r--tests/validation/reference/DepthConcatenateLayer.cpp5
-rw-r--r--tests/validation/reference/DepthConvertLayer.cpp6
-rw-r--r--tests/validation/reference/DepthToSpaceLayer.cpp9
-rw-r--r--tests/validation/reference/DequantizationLayer.cpp10
-rw-r--r--tests/validation/reference/Derivative.cpp5
-rw-r--r--tests/validation/reference/Dilate.cpp5
-rw-r--r--tests/validation/reference/ElementWiseUnary.cpp2
-rw-r--r--tests/validation/reference/EqualizeHistogram.cpp5
-rw-r--r--tests/validation/reference/Erode.cpp5
-rw-r--r--tests/validation/reference/Floor.cpp5
-rw-r--r--tests/validation/reference/FullyConnectedLayer.cpp15
-rw-r--r--tests/validation/reference/FuseBatchNormalization.cpp6
-rw-r--r--tests/validation/reference/GEMM.cpp9
-rw-r--r--tests/validation/reference/GEMMLowp.cpp9
-rw-r--r--tests/validation/reference/GEMMReshapeLHSMatrix.cpp2
-rw-r--r--tests/validation/reference/GEMMReshapeRHSMatrix.cpp6
-rw-r--r--tests/validation/reference/Gaussian3x3.cpp5
-rw-r--r--tests/validation/reference/Gaussian5x5.cpp5
-rw-r--r--tests/validation/reference/Im2Col.cpp6
-rw-r--r--tests/validation/reference/InstanceNormalizationLayer.cpp13
-rw-r--r--tests/validation/reference/QuantizationLayer.cpp9
-rw-r--r--tests/validation/reference/ReorgLayer.cpp6
-rw-r--r--tests/validation/reference/Reverse.cpp6
-rw-r--r--tests/validation/reference/SoftmaxLayer.cpp7
-rw-r--r--tests/validation/reference/Winograd.cpp2
-rw-r--r--tests/validation/reference/YOLOLayer.cpp5
51 files changed, 293 insertions, 87 deletions
diff --git a/tests/SConscript b/tests/SConscript
index 5c95c551f4..26d422c10b 100644
--- a/tests/SConscript
+++ b/tests/SConscript
@@ -30,6 +30,7 @@ Import('install_bin')
variables = [
BoolVariable("benchmark_examples", "Build benchmark examples programs", True),
BoolVariable("validate_examples", "Build validate examples programs", True),
+ BoolVariable("reference_openmp", "Build reference validation with openmp", True),
#FIXME Switch the following two options to False before releasing
BoolVariable("validation_tests", "Build validation test programs", True),
BoolVariable("benchmark_tests", "Build benchmark test programs", True),
@@ -170,8 +171,12 @@ bm_link_flags = ['-fstack-protector-strong']
if test_env['linker_script']:
bm_link_flags += ['-Wl,--build-id=none', '-T', env['linker_script']]
+if test_env['reference_openmp'] and env['os'] != 'bare_metal':
+ test_env['CXXFLAGS'].append('-fopenmp')
+ test_env['LINKFLAGS'].append('-fopenmp')
+
if test_env['validation_tests']:
- arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LIBS= [ arm_compute_test_framework, arm_compute_core_a])
+ arm_compute_validation_framework = env.StaticLibrary('arm_compute_validation_framework', Glob('validation/reference/*.cpp') + Glob('validation/*.cpp'), LINKFLAGS=test_env['LINKFLAGS'], CXXFLAGS=test_env['CXXFLAGS'], LIBS= [ arm_compute_test_framework, arm_compute_core_a])
Depends(arm_compute_validation_framework , arm_compute_test_framework)
Depends(arm_compute_validation_framework , arm_compute_core_a)
diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp
index 093271244e..6e93cd0638 100644
--- a/tests/validation/Helpers.cpp
+++ b/tests/validation/Helpers.cpp
@@ -113,7 +113,9 @@ SimpleTensor<float> convert_from_asymmetric(const SimpleTensor<uint8_t> &src)
{
const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
SimpleTensor<float> dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() };
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = dequantize_qasymm8(src[i], quantization_info);
@@ -127,6 +129,9 @@ SimpleTensor<float> convert_from_asymmetric(const SimpleTensor<int8_t> &src)
const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
SimpleTensor<float> dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() };
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = dequantize_qasymm8_signed(src[i], quantization_info);
@@ -140,6 +145,9 @@ SimpleTensor<float> convert_from_asymmetric(const SimpleTensor<uint16_t> &src)
const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
SimpleTensor<float> dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() };
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = dequantize_qasymm16(src[i], quantization_info);
@@ -153,6 +161,9 @@ SimpleTensor<uint8_t> convert_to_asymmetric(const SimpleTensor<float> &src, cons
SimpleTensor<uint8_t> dst{ src.shape(), DataType::QASYMM8, 1, quantization_info };
const UniformQuantizationInfo &qinfo = quantization_info.uniform();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = quantize_qasymm8(src[i], qinfo);
@@ -166,6 +177,9 @@ SimpleTensor<int8_t> convert_to_asymmetric(const SimpleTensor<float> &src, const
SimpleTensor<int8_t> dst{ src.shape(), DataType::QASYMM8_SIGNED, 1, quantization_info };
const UniformQuantizationInfo &qinfo = quantization_info.uniform();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = quantize_qasymm8_signed(src[i], qinfo);
@@ -179,6 +193,9 @@ SimpleTensor<uint16_t> convert_to_asymmetric(const SimpleTensor<float> &src, con
SimpleTensor<uint16_t> dst{ src.shape(), DataType::QASYMM16, 1, quantization_info };
const UniformQuantizationInfo &qinfo = quantization_info.uniform();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = quantize_qasymm16(src[i], qinfo);
@@ -192,6 +209,9 @@ SimpleTensor<int16_t> convert_to_symmetric(const SimpleTensor<float> &src, const
SimpleTensor<int16_t> dst{ src.shape(), DataType::QSYMM16, 1, quantization_info };
const UniformQuantizationInfo &qinfo = quantization_info.uniform();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = quantize_qsymm16(src[i], qinfo);
@@ -205,6 +225,9 @@ SimpleTensor<float> convert_from_symmetric(const SimpleTensor<int16_t> &src)
const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
SimpleTensor<float> dst{ src.shape(), DataType::F32, 1, QuantizationInfo(), src.data_layout() };
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = dequantize_qsymm16(src[i], quantization_info);
@@ -223,6 +246,9 @@ void matrix_multiply(const SimpleTensor<T> &a, const SimpleTensor<T> &b, SimpleT
const int N = b.shape()[0]; // Cols
const int K = b.shape()[1];
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int y = 0; y < M; ++y)
{
for(int x = 0; x < N; ++x)
@@ -246,6 +272,9 @@ void transpose_matrix(const SimpleTensor<T> &in, SimpleTensor<T> &out)
const int width = in.shape()[0];
const int height = in.shape()[1];
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int y = 0; y < height; ++y)
{
for(int x = 0; x < width; ++x)
diff --git a/tests/validation/reference/AbsoluteDifference.cpp b/tests/validation/reference/AbsoluteDifference.cpp
index f9fce5b42a..ea7685bc9e 100644
--- a/tests/validation/reference/AbsoluteDifference.cpp
+++ b/tests/validation/reference/AbsoluteDifference.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,7 +40,9 @@ SimpleTensor<T> absolute_difference(const SimpleTensor<T> &src1, const SimpleTen
SimpleTensor<T> result(src1.shape(), dst_data_type);
using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src1.num_elements(); ++i)
{
intermediate_type val = std::abs(static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]));
diff --git a/tests/validation/reference/Accumulate.cpp b/tests/validation/reference/Accumulate.cpp
index 7f34be9663..2758577ef9 100644
--- a/tests/validation/reference/Accumulate.cpp
+++ b/tests/validation/reference/Accumulate.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,7 +42,9 @@ SimpleTensor<T2> accumulate(const SimpleTensor<T1> &src, DataType output_data_ty
library->fill_tensor_uniform(dst, 1, static_cast<T2>(0), static_cast<T2>(std::numeric_limits<T1>::max()));
using intermediate_type = typename common_promoted_signed_type<T1, T2>::intermediate_type;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
intermediate_type val = static_cast<intermediate_type>(src[i]) + static_cast<intermediate_type>(dst[i]);
@@ -62,7 +64,9 @@ SimpleTensor<T2> accumulate_weighted(const SimpleTensor<T1> &src, float alpha, D
library->fill_tensor_uniform(dst, 1, static_cast<T2>(0), static_cast<T2>(std::numeric_limits<T1>::max()));
using intermediate_type = typename common_promoted_signed_type<T1, T2>::intermediate_type;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
double val = (1. - static_cast<double>(alpha)) * static_cast<intermediate_type>(dst[i]) + static_cast<double>(alpha) * static_cast<intermediate_type>(src[i]);
@@ -83,7 +87,9 @@ SimpleTensor<T2> accumulate_squared(const SimpleTensor<T1> &src, uint32_t shift,
using intermediate_type = typename common_promoted_signed_type<T1, T2>::intermediate_type;
intermediate_type denom = 1 << shift;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
intermediate_type val = static_cast<intermediate_type>(dst[i]) + (static_cast<intermediate_type>(src[i]) * static_cast<intermediate_type>(src[i]) / denom);
diff --git a/tests/validation/reference/ActivationLayer.cpp b/tests/validation/reference/ActivationLayer.cpp
index 4aa0f880da..4aeefaaa79 100644
--- a/tests/validation/reference/ActivationLayer.cpp
+++ b/tests/validation/reference/ActivationLayer.cpp
@@ -45,7 +45,9 @@ SimpleTensor<T> activation_layer(const SimpleTensor<T> &src, ActivationLayerInfo
// Compute reference
const T a(info.a());
const T b(info.b());
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = activate_float<T>(src[i], a, b, info.activation());
@@ -86,7 +88,6 @@ SimpleTensor<int16_t> activation_layer<int16_t>(const SimpleTensor<int16_t> &src
SimpleTensor<int16_t> dst = convert_to_symmetric<int16_t>(dst_tmp, dst_qinfo);
return dst;
}
-
template SimpleTensor<int32_t> activation_layer(const SimpleTensor<int32_t> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
template SimpleTensor<float> activation_layer(const SimpleTensor<float> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
template SimpleTensor<half> activation_layer(const SimpleTensor<half> &src, ActivationLayerInfo info, const QuantizationInfo &oq_info);
diff --git a/tests/validation/reference/ArithmeticDivision.cpp b/tests/validation/reference/ArithmeticDivision.cpp
index 0ced439404..f86ee5e599 100644
--- a/tests/validation/reference/ArithmeticDivision.cpp
+++ b/tests/validation/reference/ArithmeticDivision.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,13 +49,16 @@ struct BroadcastUnroll
id_src1.set(dim - 1, 0);
id_src2.set(dim - 1, 0);
id_dst.set(dim - 1, 0);
-
- for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
+ for(size_t i = 0; i < dst.shape()[dim - 1]; ++i)
{
BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, id_src1, id_src2, id_dst);
id_src1[dim - 1] += !src1_is_broadcast;
id_src2[dim - 1] += !src2_is_broadcast;
+ ++id_dst[dim - 1];
}
}
};
diff --git a/tests/validation/reference/ArithmeticOperations.cpp b/tests/validation/reference/ArithmeticOperations.cpp
index d86833fefa..fd32f45cfa 100644
--- a/tests/validation/reference/ArithmeticOperations.cpp
+++ b/tests/validation/reference/ArithmeticOperations.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,13 +62,16 @@ struct BroadcastUnroll
id_src1.set(dim - 1, 0);
id_src2.set(dim - 1, 0);
id_dst.set(dim - 1, 0);
-
- for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
+ for(size_t i = 0; i < dst.shape()[dim - 1]; ++i)
{
BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst);
id_src1[dim - 1] += !src1_is_broadcast;
id_src2[dim - 1] += !src2_is_broadcast;
+ ++id_dst[dim - 1];
}
}
};
diff --git a/tests/validation/reference/BatchNormalizationLayer.cpp b/tests/validation/reference/BatchNormalizationLayer.cpp
index 37713c841d..6623b22e78 100644
--- a/tests/validation/reference/BatchNormalizationLayer.cpp
+++ b/tests/validation/reference/BatchNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,7 +46,9 @@ SimpleTensor<T> batch_normalization_layer(const SimpleTensor<T> &src, const Simp
const auto rows = static_cast<int>(src.shape()[1]);
const auto depth = static_cast<int>(src.shape()[2]);
const int upper_dims = src.shape().total_size() / (cols * rows * depth);
-
+#if defined(_OPENMP)
+ #pragma omp parallel for schedule(dynamic, 1) collapse(4)
+#endif /* _OPENMP */
for(int r = 0; r < upper_dims; ++r)
{
for(int i = 0; i < depth; ++i)
diff --git a/tests/validation/reference/BitwiseAnd.cpp b/tests/validation/reference/BitwiseAnd.cpp
index 6fc46b402b..356c27ee13 100644
--- a/tests/validation/reference/BitwiseAnd.cpp
+++ b/tests/validation/reference/BitwiseAnd.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,9 @@ template <typename T>
SimpleTensor<T> bitwise_and(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2)
{
SimpleTensor<T> dst(src1.shape(), src1.data_type());
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src1.num_elements(); ++i)
{
dst[i] = src1[i] & src2[i];
diff --git a/tests/validation/reference/BitwiseNot.cpp b/tests/validation/reference/BitwiseNot.cpp
index 5a6a13b56c..03578a3beb 100644
--- a/tests/validation/reference/BitwiseNot.cpp
+++ b/tests/validation/reference/BitwiseNot.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,9 @@ template <typename T>
SimpleTensor<T> bitwise_not(const SimpleTensor<T> &src)
{
SimpleTensor<T> dst(src.shape(), src.data_type());
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = ~src[i];
diff --git a/tests/validation/reference/BitwiseOr.cpp b/tests/validation/reference/BitwiseOr.cpp
index fc258d54f1..11c0a932fe 100644
--- a/tests/validation/reference/BitwiseOr.cpp
+++ b/tests/validation/reference/BitwiseOr.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,9 @@ template <typename T>
SimpleTensor<T> bitwise_or(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2)
{
SimpleTensor<T> dst(src1.shape(), src1.data_type());
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src1.num_elements(); ++i)
{
dst[i] = src1[i] | src2[i];
diff --git a/tests/validation/reference/BitwiseXor.cpp b/tests/validation/reference/BitwiseXor.cpp
index b8d275d8b5..afae032b1b 100644
--- a/tests/validation/reference/BitwiseXor.cpp
+++ b/tests/validation/reference/BitwiseXor.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -35,7 +35,9 @@ template <typename T>
SimpleTensor<T> bitwise_xor(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2)
{
SimpleTensor<T> dst(src1.shape(), src1.data_type());
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src1.num_elements(); ++i)
{
dst[i] = src1[i] ^ src2[i];
diff --git a/tests/validation/reference/BoundingBoxTransform.cpp b/tests/validation/reference/BoundingBoxTransform.cpp
index e09bcff1c6..89182f1a9c 100644
--- a/tests/validation/reference/BoundingBoxTransform.cpp
+++ b/tests/validation/reference/BoundingBoxTransform.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,7 +57,9 @@ SimpleTensor<T> bounding_box_transform(const SimpleTensor<T> &boxes, const Simpl
const size_t box_fields = 4;
const size_t class_fields = 4;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(size_t i = 0; i < num_boxes; ++i)
{
// Extract ROI information
diff --git a/tests/validation/reference/Box3x3.cpp b/tests/validation/reference/Box3x3.cpp
index 153f26a5c6..7ea3f1fe99 100644
--- a/tests/validation/reference/Box3x3.cpp
+++ b/tests/validation/reference/Box3x3.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,6 +41,9 @@ SimpleTensor<T> box3x3(const SimpleTensor<T> &src, BorderMode border_mode, T con
const std::array<T, 9> filter{ { 1, 1, 1, 1, 1, 1, 1, 1, 1 } };
const float scale = 1.f / static_cast<float>(filter.size());
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx)
{
const Coordinates id = index2coord(src.shape(), element_idx);
diff --git a/tests/validation/reference/ChannelCombine.cpp b/tests/validation/reference/ChannelCombine.cpp
index a6c0557b79..2380b58583 100644
--- a/tests/validation/reference/ChannelCombine.cpp
+++ b/tests/validation/reference/ChannelCombine.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -92,6 +92,9 @@ std::vector<SimpleTensor<T>> channel_combine(const TensorShape &shape, const std
{
std::vector<SimpleTensor<T>> dst = create_image_planes<T>(shape, format);
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int plane_idx = 0; plane_idx < dst.size(); ++plane_idx)
{
SimpleTensor<T> &dst_tensor = dst[plane_idx];
diff --git a/tests/validation/reference/ChannelExtract.cpp b/tests/validation/reference/ChannelExtract.cpp
index fc7ae7d6cb..75d0a00604 100644
--- a/tests/validation/reference/ChannelExtract.cpp
+++ b/tests/validation/reference/ChannelExtract.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,9 @@ SimpleTensor<uint8_t> channel_extract(const TensorShape &shape, const std::vecto
const int height = dst.shape().y();
// Loop over each pixel and extract channel
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int y = 0; y < height; ++y)
{
for(int x = 0; x < width; ++x)
diff --git a/tests/validation/reference/ChannelShuffle.cpp b/tests/validation/reference/ChannelShuffle.cpp
index b8aa9203ab..39d89e9f02 100644
--- a/tests/validation/reference/ChannelShuffle.cpp
+++ b/tests/validation/reference/ChannelShuffle.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,7 +50,9 @@ SimpleTensor<T> channel_shuffle(const SimpleTensor<T> &src, int num_groups)
const T *src_ref = src.data();
T *dst_ref = dst.data();
-
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int n = 0; n < batches; ++n)
{
for(int g = 0; g < num_groups; ++g)
diff --git a/tests/validation/reference/Col2Im.cpp b/tests/validation/reference/Col2Im.cpp
index 53969d4725..f42582bbe8 100644
--- a/tests/validation/reference/Col2Im.cpp
+++ b/tests/validation/reference/Col2Im.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,21 +47,26 @@ SimpleTensor<T> col2im(const SimpleTensor<T> &src, const TensorShape &dst_shape,
if(num_groups == 1)
{
// Batches are on the 3rd dimension of the input tensor
- int dst_idx = 0;
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(3)
+#endif /* _OPENMP */
for(size_t b = 0; b < batches; ++b)
{
for(size_t x = 0; x < src_width; ++x)
{
for(size_t y = 0; y < src_height; ++y)
{
- dst[dst_idx++] = src[coord2index(src.shape(), Coordinates(x, y, b))];
+ const int dst_idx = y + x * src_height + b * src_height * src_width;
+ dst[dst_idx] = src[coord2index(src.shape(), Coordinates(x, y, b))];
}
}
}
}
else
{
- int dst_idx = 0;
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(4)
+#endif /* _OPENMP */
for(size_t b = 0; b < batches; ++b)
{
for(size_t g = 0; g < num_groups; ++g)
@@ -70,7 +75,8 @@ SimpleTensor<T> col2im(const SimpleTensor<T> &src, const TensorShape &dst_shape,
{
for(size_t y = 0; y < src_height; ++y)
{
- dst[dst_idx++] = src[coord2index(src.shape(), Coordinates(x, y, g, b))];
+ const int dst_idx = y + x * src_height + g * src_height * src_width + b * src_height * src_width * num_groups;
+ dst[dst_idx] = src[coord2index(src.shape(), Coordinates(x, y, g, b))];
}
}
}
diff --git a/tests/validation/reference/Comparisons.cpp b/tests/validation/reference/Comparisons.cpp
index c0c86b1933..2313d9b022 100644
--- a/tests/validation/reference/Comparisons.cpp
+++ b/tests/validation/reference/Comparisons.cpp
@@ -80,13 +80,16 @@ struct BroadcastUnroll
id_src1.set(dim - 1, 0);
id_src2.set(dim - 1, 0);
id_dst.set(dim - 1, 0);
-
- for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
+ for(size_t i = 0; i < dst.shape()[dim - 1]; ++i)
{
BroadcastUnroll < dim - 1 >::unroll(op, src1, src2, dst, id_src1, id_src2, id_dst);
id_src1[dim - 1] += !src1_is_broadcast;
id_src2[dim - 1] += !src2_is_broadcast;
+ ++id_dst[dim - 1];
}
}
};
diff --git a/tests/validation/reference/ComputeAllAnchors.cpp b/tests/validation/reference/ComputeAllAnchors.cpp
index 60be7ef8a8..9654da2ea3 100644
--- a/tests/validation/reference/ComputeAllAnchors.cpp
+++ b/tests/validation/reference/ComputeAllAnchors.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,6 +49,9 @@ SimpleTensor<T> compute_all_anchors(const SimpleTensor<T> &anchors, const Comput
T *all_anchors_ptr = all_anchors.data();
// Iterate over the input grid and anchors
+#if defined(_OPENMP)
+ #pragma omp parallel for schedule(dynamic, 1) collapse(3)
+#endif /* _OPENMP */
for(int y = 0; y < height; y++)
{
for(int x = 0; x < width; x++)
diff --git a/tests/validation/reference/ConvertFullyConnectedWeights.cpp b/tests/validation/reference/ConvertFullyConnectedWeights.cpp
index 5925496f45..710644a4e5 100644
--- a/tests/validation/reference/ConvertFullyConnectedWeights.cpp
+++ b/tests/validation/reference/ConvertFullyConnectedWeights.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -49,6 +49,9 @@ SimpleTensor<T> convert_fully_connected_weights(const SimpleTensor<T> &src, cons
const unsigned int factor_2 = is_nchw_to_nhwc ? num_channels : num_elems_per_input_plane;
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t i = 0; i < num_elements; ++i)
{
const Coordinates coords_in = index2coords(src.shape(), i);
diff --git a/tests/validation/reference/Convolution.cpp b/tests/validation/reference/Convolution.cpp
index 5e0f29d421..ad93b3077a 100644
--- a/tests/validation/reference/Convolution.cpp
+++ b/tests/validation/reference/Convolution.cpp
@@ -45,6 +45,9 @@ SimpleTensor<T> convolution(const SimpleTensor<uint8_t> &src, DataType output_da
SimpleTensor<T> dst(src.shape(), output_data_type);
SimpleTensor<int32_t> sum(src.shape(), output_data_type);
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx)
{
const Coordinates id = index2coord(src.shape(), element_idx);
diff --git a/tests/validation/reference/ConvolutionLayer.cpp b/tests/validation/reference/ConvolutionLayer.cpp
index 84fb3491bd..aed976e5ee 100644
--- a/tests/validation/reference/ConvolutionLayer.cpp
+++ b/tests/validation/reference/ConvolutionLayer.cpp
@@ -69,6 +69,9 @@ SimpleTensor<T> convolution_layer_nchw(const SimpleTensor<T> &src, const SimpleT
const int end_xi = output_wh.first * stride_xi;
const int end_yi = output_wh.second * stride_yi;
const int num_batches = src.shape().total_size() / (width_in * height_in * depth_in);
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(5)
+#endif /* _OPENMP */
for(int r = 0; r < num_batches; ++r)
{
for(int yi = start_yi; yi < start_yi + end_yi; yi += stride_yi)
diff --git a/tests/validation/reference/DFT.cpp b/tests/validation/reference/DFT.cpp
index b3c2c6b0b9..ae030c7104 100644
--- a/tests/validation/reference/DFT.cpp
+++ b/tests/validation/reference/DFT.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -50,6 +50,9 @@ namespace
template <typename T>
void rdft_1d_step(const T *src_ptr, size_t N, T *dst_ptr, size_t K)
{
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int k = 0; k < K; ++k)
{
float Xr = 0;
@@ -77,6 +80,9 @@ void rdft_1d_step(const T *src_ptr, size_t N, T *dst_ptr, size_t K)
template <typename T>
void dft_1d_step(const T *src_ptr, T *dst_ptr, size_t N)
{
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int k = 0; k < N; ++k)
{
float Xr = 0;
@@ -111,7 +117,9 @@ void irdft_1d_step(const T *src_ptr, size_t K, T *dst_ptr, size_t N)
const bool is_odd = N % 2;
const unsigned int Nleft = N - K;
const int tail_start = is_odd ? K - 1 : K - 2;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int n = 0; n < N; ++n)
{
float xr = 0;
@@ -142,6 +150,9 @@ void irdft_1d_step(const T *src_ptr, size_t K, T *dst_ptr, size_t N)
template <typename T>
void idft_1d_step(const T *src_ptr, T *dst_ptr, size_t N)
{
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int n = 0; n < N; ++n)
{
float xr = 0;
@@ -181,6 +192,9 @@ SimpleTensor<T> rdft_1d_core(const SimpleTensor<T> &src, FFTDirection direction,
SimpleTensor<T> dst(dst_shape, src.data_type(), num_channels);
const unsigned int upper_dims = src.shape().total_size_upper(1);
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int du = 0; du < upper_dims; ++du)
{
const T *src_row_ptr = src.data() + du * N * src.num_channels();
@@ -201,6 +215,9 @@ SimpleTensor<T> dft_1d_core(const SimpleTensor<T> &src, FFTDirection direction)
SimpleTensor<T> dst(src.shape(), src.data_type(), src.num_channels());
const unsigned int upper_dims = src.shape().total_size_upper(1);
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int du = 0; du < upper_dims; ++du)
{
const T *src_row_ptr = src.data() + du * N * src.num_channels();
@@ -221,6 +238,9 @@ void scale(SimpleTensor<T> &tensor, T scaling_factor)
{
const int total_elements = tensor.num_elements() * tensor.num_channels();
T *data_ptr = tensor.data();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < total_elements; ++i)
{
data_ptr[i] /= scaling_factor;
@@ -249,7 +269,9 @@ SimpleTensor<T> complex_mul_and_reduce(const SimpleTensor<T> &input, const Simpl
// MemSet dst memory to zero
std::memset(dst.data(), 0, dst.size());
-
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(5)
+#endif /* _OPENMP */
for(uint32_t b = 0; b < N; ++b)
{
for(uint32_t co = 0; co < Co; ++co)
diff --git a/tests/validation/reference/DeconvolutionLayer.cpp b/tests/validation/reference/DeconvolutionLayer.cpp
index 5750f51e3f..01b9c1c403 100644
--- a/tests/validation/reference/DeconvolutionLayer.cpp
+++ b/tests/validation/reference/DeconvolutionLayer.cpp
@@ -100,6 +100,9 @@ SimpleTensor<T> deconvolution_layer(const SimpleTensor<T> &src, const SimpleTens
// Flip weights by 180 degrees
SimpleTensor<T> weights_flipped{ weights.shape(), weights.data_type(), 1, weights.quantization_info() };
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int ud = 0; ud < weights_upper_dims; ++ud)
{
const int offset = ud * weights_width * weights_height;
@@ -111,7 +114,9 @@ SimpleTensor<T> deconvolution_layer(const SimpleTensor<T> &src, const SimpleTens
}
}
}
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int slice = 0; slice < num_2d_slices; ++slice)
{
const int offset_slice_in = slice * width_in * height_in;
diff --git a/tests/validation/reference/DepthConcatenateLayer.cpp b/tests/validation/reference/DepthConcatenateLayer.cpp
index d6e6e78187..2c93e7060a 100644
--- a/tests/validation/reference/DepthConcatenateLayer.cpp
+++ b/tests/validation/reference/DepthConcatenateLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,6 +58,9 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs,
if(srcs[0].data_type() == DataType::QASYMM8 && std::any_of(srcs.cbegin(), srcs.cend(), have_different_quantization_info))
{
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int b = 0; b < batches; ++b)
{
// input tensors can have smaller width and height than the output, so for each output's slice we need to requantize 0 (as this is the value
diff --git a/tests/validation/reference/DepthConvertLayer.cpp b/tests/validation/reference/DepthConvertLayer.cpp
index 57eeb7f6f3..9c6a9aa76c 100644
--- a/tests/validation/reference/DepthConvertLayer.cpp
+++ b/tests/validation/reference/DepthConvertLayer.cpp
@@ -46,6 +46,9 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con
// Up-casting
if(element_size_from_data_type(src.data_type()) < element_size_from_data_type(dt_out))
{
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
result[i] = src[i] << shift;
@@ -54,6 +57,9 @@ SimpleTensor<T2> depth_convert(const SimpleTensor<T1> &src, DataType dt_out, Con
// Down-casting
else
{
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
T1 val = src[i] >> shift;
diff --git a/tests/validation/reference/DepthToSpaceLayer.cpp b/tests/validation/reference/DepthToSpaceLayer.cpp
index 4135ce5471..e2329ed60b 100644
--- a/tests/validation/reference/DepthToSpaceLayer.cpp
+++ b/tests/validation/reference/DepthToSpaceLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,13 +40,14 @@ SimpleTensor<T> depth_to_space(const SimpleTensor<T> &src, const TensorShape &ds
ARM_COMPUTE_ERROR_ON(block_shape <= 0);
SimpleTensor<T> result(dst_shape, src.data_type());
- int in_pos = 0;
const auto width_in = static_cast<int>(src.shape()[0]);
const auto height_in = static_cast<int>(src.shape()[1]);
const auto channel_in = static_cast<int>(src.shape()[2]);
const auto batch_in = static_cast<int>(src.shape()[3]);
const int r = channel_in / (block_shape * block_shape);
-
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(4)
+#endif /* _OPENMP */
for(int b = 0; b < batch_in; ++b)
{
for(int z = 0; z < channel_in; ++z)
@@ -58,8 +59,8 @@ SimpleTensor<T> depth_to_space(const SimpleTensor<T> &src, const TensorShape &ds
const int out_x = (block_shape * x + (z / r) % block_shape);
const int out_y = (block_shape * y + (z / r) / block_shape);
const int out_pos = out_x + dst_shape[0] * out_y + (z % r) * dst_shape[0] * dst_shape[1] + b * dst_shape[0] * dst_shape[1] * dst_shape[2];
+ const int in_pos = x + width_in * y + z * width_in * height_in + b * width_in * height_in * channel_in;
result[out_pos] = src[in_pos];
- ++in_pos;
}
}
}
diff --git a/tests/validation/reference/DequantizationLayer.cpp b/tests/validation/reference/DequantizationLayer.cpp
index 7dd36402b3..7dec988b18 100644
--- a/tests/validation/reference/DequantizationLayer.cpp
+++ b/tests/validation/reference/DequantizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -75,7 +75,9 @@ SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
const int N = src.shape().total_size() / (WH * C);
const std::vector<float> qscales = src.quantization_info().scale();
-
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int n = 0; n < N; ++n)
{
for(int c = 0; c < C; ++c)
@@ -95,7 +97,9 @@ SimpleTensor<TOut> dequantization_layer(const SimpleTensor<TIn> &src)
{
const UniformQuantizationInfo &quantization_info = src.quantization_info().uniform();
ARM_COMPUTE_ERROR_ON(quantization_info.offset != 0 && src_data_type == DataType::QSYMM8);
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = static_cast<TOut>(dequantize<TOut>(static_cast<TIn>(src[i]), quantization_info, src_data_type));
diff --git a/tests/validation/reference/Derivative.cpp b/tests/validation/reference/Derivative.cpp
index 3c6f3259b2..f4c2934728 100644
--- a/tests/validation/reference/Derivative.cpp
+++ b/tests/validation/reference/Derivative.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -62,6 +62,9 @@ std::pair<SimpleTensor<T>, SimpleTensor<T>> derivative(const SimpleTensor<U> &sr
ValidRegion valid_region = shape_to_valid_region(src.shape(), border_mode == BorderMode::UNDEFINED, BorderSize(filter_size / 2));
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t i = 0; i < num_elements; ++i)
{
Coordinates coord = index2coord(src.shape(), i);
diff --git a/tests/validation/reference/Dilate.cpp b/tests/validation/reference/Dilate.cpp
index 8e244e9b7b..cba9af127f 100644
--- a/tests/validation/reference/Dilate.cpp
+++ b/tests/validation/reference/Dilate.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,6 +52,9 @@ SimpleTensor<T> dilate(const SimpleTensor<T> &src, BorderMode border_mode, T con
SimpleTensor<T> dst(src.shape(), src.data_type());
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t i = 0; i < num_elements; ++i)
{
Coordinates coord = index2coord(src.shape(), i);
diff --git a/tests/validation/reference/ElementWiseUnary.cpp b/tests/validation/reference/ElementWiseUnary.cpp
index eaaaa4ec1e..f1bb7c783c 100644
--- a/tests/validation/reference/ElementWiseUnary.cpp
+++ b/tests/validation/reference/ElementWiseUnary.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/tests/validation/reference/EqualizeHistogram.cpp b/tests/validation/reference/EqualizeHistogram.cpp
index 1a10c2c30a..34e7c397bf 100644
--- a/tests/validation/reference/EqualizeHistogram.cpp
+++ b/tests/validation/reference/EqualizeHistogram.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -75,6 +75,9 @@ SimpleTensor<T> equalize_histogram(const SimpleTensor<T> &src)
}
// Fill output tensor with equalized values
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = lut[src[i]];
diff --git a/tests/validation/reference/Erode.cpp b/tests/validation/reference/Erode.cpp
index e7598c3900..0964c3d4b2 100644
--- a/tests/validation/reference/Erode.cpp
+++ b/tests/validation/reference/Erode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,6 +52,9 @@ SimpleTensor<T> erode(const SimpleTensor<T> &src, BorderMode border_mode, T cons
SimpleTensor<T> dst(src.shape(), src.data_type());
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t i = 0; i < num_elements; ++i)
{
Coordinates coord = index2coord(src.shape(), i);
diff --git a/tests/validation/reference/Floor.cpp b/tests/validation/reference/Floor.cpp
index b011a16974..21fa1c9932 100644
--- a/tests/validation/reference/Floor.cpp
+++ b/tests/validation/reference/Floor.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,6 +42,9 @@ SimpleTensor<T> floor_layer(const SimpleTensor<T> &src)
SimpleTensor<T> dst{ src.shape(), src.data_type() };
// Compute reference
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = std::floor(src[i]);
diff --git a/tests/validation/reference/FullyConnectedLayer.cpp b/tests/validation/reference/FullyConnectedLayer.cpp
index 9aecd6cf14..908c583161 100644
--- a/tests/validation/reference/FullyConnectedLayer.cpp
+++ b/tests/validation/reference/FullyConnectedLayer.cpp
@@ -49,11 +49,12 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
const T *weights_ptr = weights.data();
const TB *bias_ptr = bias.data();
T *dst_ptr = dst.data() + offset_dst;
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int y = 0; y < rows_weights; ++y)
{
- dst_ptr[y] = std::inner_product(src_ptr, src_ptr + cols_weights, weights_ptr, static_cast<T>(0)) + bias_ptr[y];
- weights_ptr += cols_weights;
+ dst_ptr[y] = std::inner_product(src_ptr, src_ptr + cols_weights, &weights_ptr[cols_weights * y], static_cast<T>(0)) + bias_ptr[y];
}
}
@@ -85,7 +86,9 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
const int min = std::numeric_limits<T>::lowest();
const int max = std::numeric_limits<T>::max();
-
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int y = 0; y < rows_weights; ++y)
{
// Reset accumulator
@@ -93,7 +96,7 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
for(int x = 0; x < cols_weights; ++x)
{
- acc += (src_ptr[x] + input_offset) * (weights_ptr[x] + weights_offset);
+ acc += (src_ptr[x] + input_offset) * (weights_ptr[x + y * cols_weights] + weights_offset);
}
// Accumulate the bias
@@ -104,8 +107,6 @@ void vector_matrix_multiply(const SimpleTensor<T> &src, const SimpleTensor<T> &w
// Store the result
dst_ptr[y] = static_cast<T>(acc);
-
- weights_ptr += cols_weights;
}
}
} // namespace
diff --git a/tests/validation/reference/FuseBatchNormalization.cpp b/tests/validation/reference/FuseBatchNormalization.cpp
index df12b25912..cb5003874b 100644
--- a/tests/validation/reference/FuseBatchNormalization.cpp
+++ b/tests/validation/reference/FuseBatchNormalization.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -22,6 +22,7 @@
* SOFTWARE.
*/
#include "FuseBatchNormalization.h"
+#include "tests/validation/Helpers.h"
namespace arm_compute
{
@@ -45,6 +46,9 @@ void fuse_batch_normalization_dwc_layer(const SimpleTensor<T> &w, const SimpleTe
const unsigned int height = w.shape()[1];
const unsigned int dim2 = w.shape()[2];
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(unsigned int b = 0; b < dim2; ++b)
{
const auto mean_val = mean.data()[b];
diff --git a/tests/validation/reference/GEMM.cpp b/tests/validation/reference/GEMM.cpp
index 3c72b94143..20def87a64 100644
--- a/tests/validation/reference/GEMM.cpp
+++ b/tests/validation/reference/GEMM.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "GEMM.h"
+#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
namespace arm_compute
@@ -55,6 +56,9 @@ SimpleTensor<T> gemm(const SimpleTensor<T> &a, const SimpleTensor<T> &b, const S
const int c_stride_z = N * M;
const int c_stride_w = N * M * D;
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int w = 0; w < W; ++w)
{
for(int depth = 0; depth < D; ++depth)
@@ -107,6 +111,9 @@ SimpleTensor<T> gemm_mixed_precision(const SimpleTensor<T> &a, const SimpleTenso
const int c_stride_z = N * M;
const int c_stride_w = N * M * D;
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(int w = 0; w < W; ++w)
{
for(int depth = 0; depth < D; ++depth)
diff --git a/tests/validation/reference/GEMMLowp.cpp b/tests/validation/reference/GEMMLowp.cpp
index 36d86d1532..85a98e4a76 100644
--- a/tests/validation/reference/GEMMLowp.cpp
+++ b/tests/validation/reference/GEMMLowp.cpp
@@ -69,6 +69,9 @@ void quantize_down_scale(const SimpleTensor<TIn> *in, const SimpleTensor<TIn> *b
const int cols_in = in->shape().x();
const bool is_per_channel = result_mult_int.size() > 1;
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < in->num_elements(); ++i)
{
int32_t result = ((*in)[i] + result_offset);
@@ -100,6 +103,9 @@ void quantize_down_scale_by_fixedpoint(const SimpleTensor<TIn> *in, const Simple
const int cols_in = in->shape().x();
const bool is_per_channel = result_fixedpoint_multiplier.size() > 1;
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < in->num_elements(); ++i)
{
TIn result = (*in)[i];
@@ -141,6 +147,9 @@ void quantize_down_scale_by_float(const SimpleTensor<TIn> *in, const SimpleTenso
const int cols_in = in->shape().x();
const bool is_per_channel = result_real_multiplier.size() > 1;
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < in->num_elements(); ++i)
{
TIn result = (*in)[i];
diff --git a/tests/validation/reference/GEMMReshapeLHSMatrix.cpp b/tests/validation/reference/GEMMReshapeLHSMatrix.cpp
index 431d65696e..f21fe50e58 100644
--- a/tests/validation/reference/GEMMReshapeLHSMatrix.cpp
+++ b/tests/validation/reference/GEMMReshapeLHSMatrix.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/tests/validation/reference/GEMMReshapeRHSMatrix.cpp b/tests/validation/reference/GEMMReshapeRHSMatrix.cpp
index 0224c5c67c..ebb6f856d2 100644
--- a/tests/validation/reference/GEMMReshapeRHSMatrix.cpp
+++ b/tests/validation/reference/GEMMReshapeRHSMatrix.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -70,7 +70,9 @@ SimpleTensor<T> gemm_reshape_rhs_matrix(const SimpleTensor<T> &in, const TensorS
const unsigned int offset_output_x = rhs_info.interleave ? tile_to_use->shape()[0] : tile_to_use->shape()[0] * tile_to_use->shape()[1];
const unsigned int step_output_x = rhs_info.interleave ? tile_to_use->shape()[0] * rhs_info.h0 : tile_to_use->shape()[0];
-
+#ifdef ARM_COMPUTE_OPENMP
+ #pragma omp parallel for schedule(dynamic, 1) collapse(3)
+#endif /* _OPENMP */
for(unsigned int z = 0; z < B; ++z)
{
for(unsigned int y = 0; y < num_tiles_y; ++y)
diff --git a/tests/validation/reference/Gaussian3x3.cpp b/tests/validation/reference/Gaussian3x3.cpp
index 5ca24a7961..f2ac134f17 100644
--- a/tests/validation/reference/Gaussian3x3.cpp
+++ b/tests/validation/reference/Gaussian3x3.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,6 +42,9 @@ SimpleTensor<T> gaussian3x3(const SimpleTensor<T> &src, BorderMode border_mode,
const float scale = 1.f / 16.f;
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx)
{
const Coordinates id = index2coord(src.shape(), element_idx);
diff --git a/tests/validation/reference/Gaussian5x5.cpp b/tests/validation/reference/Gaussian5x5.cpp
index ac84f6d097..426e66647c 100644
--- a/tests/validation/reference/Gaussian5x5.cpp
+++ b/tests/validation/reference/Gaussian5x5.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -48,6 +48,9 @@ SimpleTensor<T> gaussian5x5(const SimpleTensor<T> &src, BorderMode border_mode,
const float scale = 1.f / 256.f;
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t element_idx = 0; element_idx < num_elements; ++element_idx)
{
const Coordinates id = index2coord(src.shape(), element_idx);
diff --git a/tests/validation/reference/Im2Col.cpp b/tests/validation/reference/Im2Col.cpp
index 4b41cdb70b..a3dcf07273 100644
--- a/tests/validation/reference/Im2Col.cpp
+++ b/tests/validation/reference/Im2Col.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -109,7 +109,9 @@ void im2col_nhwc(const SimpleTensor<T> &src, SimpleTensor<T> &dst, const Size2D
// Compute width and height of the convolved tensors
std::pair<unsigned int, unsigned int> convolved_dims = scaled_dimensions(src_width, src_height, kernel_dims.width, kernel_dims.height, conv_info);
-
+#if defined(_OPENMP)
+ #pragma omp parallel for schedule(dynamic, 1) collapse(2)
+#endif /* _OPENMP */
for(int b = 0; b < batches; ++b)
{
for(int yo = 0; yo < dst_height; ++yo)
diff --git a/tests/validation/reference/InstanceNormalizationLayer.cpp b/tests/validation/reference/InstanceNormalizationLayer.cpp
index ad0ac1be68..339549723d 100644
--- a/tests/validation/reference/InstanceNormalizationLayer.cpp
+++ b/tests/validation/reference/InstanceNormalizationLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,12 +46,14 @@ SimpleTensor<T> instance_normalization(const SimpleTensor<T> &src, float gamma,
const size_t h_size = src.shape()[1];
const size_t c_size = src.shape()[2];
const size_t n_size = src.shape()[3];
-
+#if defined(_OPENMP)
+ #pragma omp parallel for collapse(2)
+#endif /* _OPENMP */
for(size_t n_i = 0; n_i < n_size; ++n_i)
{
for(size_t c_i = 0; c_i < c_size; ++c_i)
{
- float sum_h_w = 0;
+ float sum_h_w = 0;
float sum_sq_h_w = 0;
for(size_t h_i = 0; h_i < h_size; ++h_i)
@@ -60,13 +62,14 @@ SimpleTensor<T> instance_normalization(const SimpleTensor<T> &src, float gamma,
{
float val = src[coord2index(src.shape(), Coordinates(w_i, h_i, c_i, n_i))];
sum_h_w += val;
- sum_sq_h_w += val*val;
+ sum_sq_h_w += val * val;
}
}
//Compute mean
const float mean_h_w = sum_h_w / (h_size * w_size);
//Compute variance
- const float var_h_w = sum_sq_h_w / (h_size * w_size) - mean_h_w * mean_h_w;;
+ const float var_h_w = sum_sq_h_w / (h_size * w_size) - mean_h_w * mean_h_w;
+ ;
//Apply mean
for(size_t h_i = 0; h_i < h_size; ++h_i)
diff --git a/tests/validation/reference/QuantizationLayer.cpp b/tests/validation/reference/QuantizationLayer.cpp
index cfc508529e..a70523d7da 100644
--- a/tests/validation/reference/QuantizationLayer.cpp
+++ b/tests/validation/reference/QuantizationLayer.cpp
@@ -50,12 +50,18 @@ SimpleTensor<Tout> quantization_layer(const SimpleTensor<Tin> &src, DataType out
switch(output_data_type)
{
case DataType::QASYMM8:
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = quantize_qasymm8((src[i]), qinfo, rounding_policy);
}
break;
case DataType::QASYMM8_SIGNED:
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
#ifdef __aarch64__
@@ -66,6 +72,9 @@ SimpleTensor<Tout> quantization_layer(const SimpleTensor<Tin> &src, DataType out
}
break;
case DataType::QASYMM16:
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int i = 0; i < src.num_elements(); ++i)
{
dst[i] = quantize_qasymm16((src[i]), qinfo, rounding_policy);
diff --git a/tests/validation/reference/ReorgLayer.cpp b/tests/validation/reference/ReorgLayer.cpp
index 2eb5d01926..9f087d06cb 100644
--- a/tests/validation/reference/ReorgLayer.cpp
+++ b/tests/validation/reference/ReorgLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,6 +54,10 @@ SimpleTensor<T> reorg_layer(const SimpleTensor<T> &src, int32_t stride)
// Calculate layer reorg in NCHW
Coordinates map_coords;
+
+#if defined(_OPENMP)
+ #pragma omp parallel for private(map_coords)
+#endif /* _OPENMP */
for(unsigned int b = 0; b < outer_dims; ++b)
{
map_coords.set(3, b);
diff --git a/tests/validation/reference/Reverse.cpp b/tests/validation/reference/Reverse.cpp
index 4bd8efc6a8..f5630b9a40 100644
--- a/tests/validation/reference/Reverse.cpp
+++ b/tests/validation/reference/Reverse.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,6 +55,10 @@ SimpleTensor<T> reverse(const SimpleTensor<T> &src, const SimpleTensor<uint32_t>
}
const uint32_t num_elements = src.num_elements();
+
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t i = 0; i < num_elements; ++i)
{
const Coordinates src_coord = index2coord(src.shape(), i);
diff --git a/tests/validation/reference/SoftmaxLayer.cpp b/tests/validation/reference/SoftmaxLayer.cpp
index 0e470260a9..ee7a5f175a 100644
--- a/tests/validation/reference/SoftmaxLayer.cpp
+++ b/tests/validation/reference/SoftmaxLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,6 +55,9 @@ SimpleTensor<T> softmax_layer_generic(const SimpleTensor<T> &src, float beta, si
upper_dims *= src.shape()[i];
}
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(int r = 0; r < upper_dims; ++r)
{
const T *src_row_ptr = src.data() + r * lower_dims;
@@ -107,7 +110,7 @@ SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axi
return softmax_layer_generic<T>(src, beta, axis, false);
}
-template <typename T, typename std::enable_if<std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int>::type>
+template < typename T, typename std::enable_if < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int >::type >
SimpleTensor<T> softmax_layer(const SimpleTensor<T> &src, float beta, size_t axis)
{
const QuantizationInfo output_quantization_info = arm_compute::get_softmax_output_quantization_info(src.data_type(), false);
diff --git a/tests/validation/reference/Winograd.cpp b/tests/validation/reference/Winograd.cpp
index 47f5ac7a7d..61ba510fc6 100644
--- a/tests/validation/reference/Winograd.cpp
+++ b/tests/validation/reference/Winograd.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
diff --git a/tests/validation/reference/YOLOLayer.cpp b/tests/validation/reference/YOLOLayer.cpp
index 0011b85c1e..92bbf5445f 100644
--- a/tests/validation/reference/YOLOLayer.cpp
+++ b/tests/validation/reference/YOLOLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 ARM Limited.
+ * Copyright (c) 2018-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -47,6 +47,9 @@ SimpleTensor<T> yolo_layer(const SimpleTensor<T> &src, const ActivationLayerInfo
const T b(info.b());
const uint32_t num_elements = src.num_elements();
+#if defined(_OPENMP)
+ #pragma omp parallel for
+#endif /* _OPENMP */
for(uint32_t i = 0; i < num_elements; ++i)
{
const size_t z = index2coord(dst.shape(), i).z() % (num_classes + 5);