diff options
Diffstat (limited to 'tests/validation/reference/DepthConcatenateLayer.cpp')
-rw-r--r-- | tests/validation/reference/DepthConcatenateLayer.cpp | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/tests/validation/reference/DepthConcatenateLayer.cpp b/tests/validation/reference/DepthConcatenateLayer.cpp index 90fbd915b1..6551f0c79e 100644 --- a/tests/validation/reference/DepthConcatenateLayer.cpp +++ b/tests/validation/reference/DepthConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -34,7 +34,7 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs) +SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst) { // Create reference std::vector<TensorShape> shapes; @@ -44,10 +44,6 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs) shapes.emplace_back(src.shape()); } - DataType dst_type = srcs.empty() ? DataType::UNKNOWN : srcs[0].data_type(); - TensorShape dst_shape = calculate_depth_concatenate_shape(shapes); - SimpleTensor<T> dst(dst_shape, dst_type); - // Compute reference int depth_offset = 0; const int width_out = dst.shape().x(); @@ -80,8 +76,20 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs) { for(int r = 0; r < height; ++r) { - std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out); - src_ptr += width; + if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info()) + { + std::transform(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out, [src, dst](T t) + { + const float dequantized_input = src.quantization_info().dequantize(t); + return dst.quantization_info().quantize(dequantized_input, RoundingPolicy::TO_NEAREST_UP); + }); + src_ptr += width; + } + else + { + std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out); + src_ptr += width; + } } } } @@ -92,9 +100,9 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs) return dst; } -template SimpleTensor<uint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs); -template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs); -template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs); +template SimpleTensor<uint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst); +template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst); +template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst); } // namespace reference } // namespace validation } // namespace test |