aboutsummaryrefslogtreecommitdiff
path: root/tests/validation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation')
-rw-r--r--tests/validation/fixtures/DirectConvolution3DFixture.h5
-rw-r--r--tests/validation/reference/Conv3D.cpp24
-rw-r--r--tests/validation/reference/Conv3D.h10
3 files changed, 20 insertions, 19 deletions
diff --git a/tests/validation/fixtures/DirectConvolution3DFixture.h b/tests/validation/fixtures/DirectConvolution3DFixture.h
index e80ad2f54f..e27a41a23b 100644
--- a/tests/validation/fixtures/DirectConvolution3DFixture.h
+++ b/tests/validation/fixtures/DirectConvolution3DFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,6 +46,7 @@ class DirectConvolution3DValidationGenericFixture : public framework::Fixture
{
public:
using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+ using TAcc = typename std::conditional < std::is_integral<T>::value, int32_t, float >::type;
void setup(const TensorShape &input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout,
@@ -150,7 +151,7 @@ protected:
fill(bias, 2);
}
- return reference::activation_layer(reference::conv3d<T, TBias>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
+ return reference::activation_layer(reference::conv3d<T, TBias, TAcc>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
}
TensorType _target{};
diff --git a/tests/validation/reference/Conv3D.cpp b/tests/validation/reference/Conv3D.cpp
index e4010a507a..38472a9aec 100644
--- a/tests/validation/reference/Conv3D.cpp
+++ b/tests/validation/reference/Conv3D.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,7 @@ inline bool is_valid_pixel(int i, int min, int max)
}
// Evaluate the weights against an element in a given tensor.
-template < typename T, typename TB, typename std::enable_if < validation::is_floating_point<T>::value &&validation::is_floating_point<TB>::value, int >::type = 0 >
+template < typename T, typename TB, typename TACC, typename std::enable_if < validation::is_floating_point<T>::value &&validation::is_floating_point<TB>::value, int >::type = 0 >
T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const Size3D &dilation, int batch,
int z_start, int y_start, int x_start, int ch_out, UniformQuantizationInfo oq_info)
{
@@ -73,7 +73,7 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
const unsigned int src_height = src.shape()[height_dim];
const unsigned int src_depth = src.shape()[depth_dim];
- T total(0);
+ TACC total(0);
for(unsigned int weight_d = 0; weight_d < weights_depth; ++weight_d)
{
const int idx_z = z_start + dilation.depth * weight_d;
@@ -112,10 +112,10 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
const TB *b_ptr = bias.data();
TB bias_value = b_ptr[ch_out];
- return total + bias_value;
+ return static_cast<T>(total) + bias_value;
}
-template < typename T, typename TB, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
+template < typename T, typename TB, typename TACC, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const Size3D &dilation, int batch,
int z_start, int y_start, int x_start, int ch_out, UniformQuantizationInfo oq_info)
{
@@ -143,7 +143,7 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
const float multiplier = input_scale * weights_scale / output_scale;
arm_compute::quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
- int32_t total(0);
+ TACC total(0);
for(unsigned int weight_d = 0; weight_d < weights_depth; ++weight_d)
{
const int idx_z = z_start + dilation.depth * weight_d;
@@ -189,7 +189,7 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
}
} // namespace
-template <typename T, typename TB>
+template <typename T, typename TB, typename TACC = T>
SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const Conv3dInfo &conv3d_info)
{
// Compute reference
@@ -237,7 +237,7 @@ SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weight
T *out_ptr = dst.data();
const int out_offset = coord2index(dst.shape(), Coordinates{ ch_out, x_out, y_out, z_out, batch });
- out_ptr[out_offset] = calculate_conv3d<T, TB>(src, weights, bias, conv3d_info.dilation, batch, z_start, y_start, x_start, ch_out, dst.quantization_info().uniform());
+ out_ptr[out_offset] = calculate_conv3d<T, TB, TACC>(src, weights, bias, conv3d_info.dilation, batch, z_start, y_start, x_start, ch_out, dst.quantization_info().uniform());
}
}
}
@@ -246,13 +246,13 @@ SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weight
return dst;
}
-template SimpleTensor<float> conv3d(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, SimpleTensor<float> &dst,
+template SimpleTensor<float> conv3d<float, float, float>(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, SimpleTensor<float> &dst,
const Conv3dInfo &conv3d_info);
-template SimpleTensor<half> conv3d(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, SimpleTensor<half> &dst,
+template SimpleTensor<half> conv3d<half, half, float>(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, SimpleTensor<half> &dst,
const Conv3dInfo &conv3d_info);
-template SimpleTensor<uint8_t> conv3d(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<uint8_t> &dst,
+template SimpleTensor<uint8_t> conv3d<uint8_t, int32_t, int32_t>(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<uint8_t> &dst,
const Conv3dInfo &conv3d_info);
-template SimpleTensor<int8_t> conv3d(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<int8_t> &dst,
+template SimpleTensor<int8_t> conv3d<int8_t, int32_t, int32_t>(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<int8_t> &dst,
const Conv3dInfo &conv3d_info);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/Conv3D.h b/tests/validation/reference/Conv3D.h
index e3674f4bfb..a440b15d55 100644
--- a/tests/validation/reference/Conv3D.h
+++ b/tests/validation/reference/Conv3D.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_CONV3D_LAYER_H
-#define ARM_COMPUTE_TEST_CONV3D_LAYER_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_CONV3D_H
+#define ACL_TESTS_VALIDATION_REFERENCE_CONV3D_H
#include "Utils.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
@@ -37,11 +37,11 @@ namespace validation
{
namespace reference
{
-template <typename T, typename TB>
+template <typename T, typename TB, typename TACC>
SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst,
const Conv3dInfo &conv3d_info);
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_CONV3D_LAYER_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_CONV3D_H