aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/DepthConcatenateLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/DepthConcatenateLayer.cpp')
-rw-r--r--tests/validation/reference/DepthConcatenateLayer.cpp30
1 files changed, 19 insertions, 11 deletions
diff --git a/tests/validation/reference/DepthConcatenateLayer.cpp b/tests/validation/reference/DepthConcatenateLayer.cpp
index 90fbd915b1..6551f0c79e 100644
--- a/tests/validation/reference/DepthConcatenateLayer.cpp
+++ b/tests/validation/reference/DepthConcatenateLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,7 +34,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
+SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst)
{
// Create reference
std::vector<TensorShape> shapes;
@@ -44,10 +44,6 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
shapes.emplace_back(src.shape());
}
- DataType dst_type = srcs.empty() ? DataType::UNKNOWN : srcs[0].data_type();
- TensorShape dst_shape = calculate_depth_concatenate_shape(shapes);
- SimpleTensor<T> dst(dst_shape, dst_type);
-
// Compute reference
int depth_offset = 0;
const int width_out = dst.shape().x();
@@ -80,8 +76,20 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
{
for(int r = 0; r < height; ++r)
{
- std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out);
- src_ptr += width;
+ if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info())
+ {
+ std::transform(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out, [src, dst](T t)
+ {
+ const float dequantized_input = src.quantization_info().dequantize(t);
+ return dst.quantization_info().quantize(dequantized_input, RoundingPolicy::TO_NEAREST_UP);
+ });
+ src_ptr += width;
+ }
+ else
+ {
+ std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out);
+ src_ptr += width;
+ }
}
}
}
@@ -92,9 +100,9 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
return dst;
}
-template SimpleTensor<uint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs);
-template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs);
-template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs);
+template SimpleTensor<uint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst);
+template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst);
+template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst);
} // namespace reference
} // namespace validation
} // namespace test