aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ReductionOperation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ReductionOperation.cpp')
-rw-r--r--tests/validation/reference/ReductionOperation.cpp10
1 files changed, 5 insertions, 5 deletions
diff --git a/tests/validation/reference/ReductionOperation.cpp b/tests/validation/reference/ReductionOperation.cpp
index 11947bd293..499263f11e 100644
--- a/tests/validation/reference/ReductionOperation.cpp
+++ b/tests/validation/reference/ReductionOperation.cpp
@@ -76,7 +76,7 @@ template <typename T>
SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShape &dst_shape, unsigned int axis, ReductionOperation op)
{
// Create reference
- SimpleTensor<T> dst{ dst_shape, src.data_type() };
+ SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, src.quantization_info() };
const unsigned int src_width = src.shape().x();
const unsigned int src_height = src.shape().y();
const unsigned int src_depth = src.shape().z();
@@ -102,7 +102,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
{
res /= src_width;
}
- dst[du] = static_cast<uint8_t>(res);
+ dst[du] = saturate_cast<uint8_t>(res);
}
else
{
@@ -136,7 +136,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
{
res /= src_height;
}
- dst[du * src_width + x] = static_cast<uint8_t>(res);
+ dst[du * src_width + x] = saturate_cast<uint8_t>(res);
}
else
{
@@ -175,7 +175,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
{
res /= src_depth;
}
- dst[du * src_width * src_height + y * src_width + x] = static_cast<uint8_t>(res);
+ dst[du * src_width * src_height + y * src_width + x] = saturate_cast<uint8_t>(res);
}
else
{
@@ -218,7 +218,7 @@ SimpleTensor<T> reduction_operation(const SimpleTensor<T> &src, const TensorShap
res /= src_batch;
}
- dst[du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x] = static_cast<uint8_t>(res);
+ dst[du * src_depth * src_height * src_width + z * src_width * src_height + y * src_width + x] = saturate_cast<uint8_t>(res);
}
else
{