aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorSangwon Ha <sangwon.ha@arm.com>2024-06-17 13:28:13 +0100
committerSang Won Ha <sangwon.ha@arm.com>2024-06-19 15:32:36 +0000
commit4d5838ced3edbff2175a9b3e6cae33023c6249e8 (patch)
tree4ad25432445c8a54723282ab8b006291d81a9be5 /tests/validation/fixtures
parentea6c6d4e9b1397c216d65ad7342151e832f25e53 (diff)
downloadComputeLibrary-4d5838ced3edbff2175a9b3e6cae33023c6249e8.tar.gz
Separate data type for accumulator in DConv3D test
Resolves: COMPMID-6947 Signed-off-by: Sangwon Ha <sangwon.ha@arm.com> Change-Id: I7fcf4f41d2961edf1fdf05e8f0b538a94f75295a Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11710 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Ramy Elgammal <ramy.elgammal@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/DirectConvolution3DFixture.h5
1 files changed, 3 insertions, 2 deletions
diff --git a/tests/validation/fixtures/DirectConvolution3DFixture.h b/tests/validation/fixtures/DirectConvolution3DFixture.h
index e80ad2f54f..e27a41a23b 100644
--- a/tests/validation/fixtures/DirectConvolution3DFixture.h
+++ b/tests/validation/fixtures/DirectConvolution3DFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,6 +46,7 @@ class DirectConvolution3DValidationGenericFixture : public framework::Fixture
{
public:
using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+ using TAcc = typename std::conditional < std::is_integral<T>::value, int32_t, float >::type;
void setup(const TensorShape &input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout,
@@ -150,7 +151,7 @@ protected:
fill(bias, 2);
}
- return reference::activation_layer(reference::conv3d<T, TBias>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
+ return reference::activation_layer(reference::conv3d<T, TBias, TAcc>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
}
TensorType _target{};