From 54e98d98fbe082b265b2c4a384eabe0144866bcc Mon Sep 17 00:00:00 2001 From: Pablo Tello Date: Tue, 5 Feb 2019 16:16:19 +0000 Subject: COMPMID-1918: Different qinfos support in NEConcatLayer. Added support in NEDepthConcatenateLayerKernel and NEWidthConcatenateLayer for different quantization arguments both for the input and output. If input's quantization infos are not homogeneous the input values are requantized using the output's quantization info. Change-Id: I2daa638361947eb3ec848d5425d0a5bbfea1936d Reviewed-on: https://review.mlplatform.org/627 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Isabella Gottardi --- arm_compute/core/NEON/NEAsymm.h | 62 +++++++++++++++++++++- .../NEON/kernels/NEDepthConcatenateLayerKernel.cpp | 33 +++++++++--- .../NEON/kernels/NEWidthConcatenateLayerKernel.cpp | 36 +++++++++---- .../NEON/functions/NEWidthConcatenateLayer.cpp | 8 +-- .../fixtures/DepthConcatenateLayerFixture.h | 45 ++++++++++------ .../fixtures/WidthConcatenateLayerFixture.h | 44 +++++++++------ .../validation/reference/DepthConcatenateLayer.cpp | 30 +++++++---- tests/validation/reference/DepthConcatenateLayer.h | 4 +- .../validation/reference/WidthConcatenateLayer.cpp | 31 ++++++----- tests/validation/reference/WidthConcatenateLayer.h | 4 +- 10 files changed, 216 insertions(+), 81 deletions(-) diff --git a/arm_compute/core/NEON/NEAsymm.h b/arm_compute/core/NEON/NEAsymm.h index faff59563b..c7f59e9eba 100644 --- a/arm_compute/core/NEON/NEAsymm.h +++ b/arm_compute/core/NEON/NEAsymm.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -124,6 +124,66 @@ uint8x16_t finalize_quantization(int32x4x4_t &in_s32, return out_u8; } + +/** Dequantize a neon vector holding 16 quantized values. + * + * @param qv Input values to be dequantized. + * @param qi Quantization information to be used in the computation. + * + * @return Dequantized values in a neon vector + */ +inline float32x4x4_t vdequantize(const uint8x16_t &qv, const QuantizationInfo &qi) +{ + const float scale = qi.scale; + const int offset = qi.offset; + const int32x4_t voffset = vdupq_n_s32(offset); + const float32x4_t vscale = vdupq_n_f32(scale); + const float32x4x4_t vdequantized_input = + { + { + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(qv))))), voffset)), vscale), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(qv))))), voffset)), vscale), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(qv))))), voffset)), vscale), + vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(qv))))), voffset)), vscale), + } + }; + return vdequantized_input; +} + +/** Quantize a neon vector holding 16 floating point values. + * + * @param qv Input values to be quantized. + * @param qi Quantization information to be used in the computation. + * + * @return A neon vector holding the quantized values + */ +inline uint8x16_t vquantize(const float32x4x4_t &qv, const QuantizationInfo &qi) +{ + const float scale = qi.scale; + const int offset = qi.offset; + const float32x4_t voffset = vdupq_n_f32(offset); + const float32x4_t vinvscale = vdupq_n_f32(1.f / scale); + const int32x4x4_t rf = + { + { +#ifdef __aarch64__ + vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[0], vinvscale)), + vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[1], vinvscale)), + vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[2], vinvscale)), + vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[3], vinvscale)), +#else //__aarch64__ + vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[0], vinvscale)), + vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[1], vinvscale)), + vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[2], vinvscale)), + vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[3], vinvscale)), +#endif //__aarch64__ + } + }; + const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1]))); + const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3]))); + return vcombine_u8(pa, pb); +} + } // namespace arm_compute #include "arm_compute/core/NEON/NEAsymm.inl" #endif // __ARM_COMPUTE_NEASYMM_H__ diff --git a/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp index 8c875cdb2d..8352c94586 100644 --- a/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp +++ b/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/NEON/NEAsymm.h" #include "arm_compute/core/NEON/NEFixedPoint.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" @@ -57,14 +58,30 @@ void depth_concat(const ITensor *in, ITensor *out, std::pair start_xy, Iterator input(in, window); Iterator output(out, window); - execute_window_loop(window, [&](const Coordinates & id) + const DataType dt = in->info()->data_type(); + const QuantizationInfo &input_qinfo = in->info()->quantization_info(); + const QuantizationInfo &output_qinfo = out->info()->quantization_info(); + if(dt == DataType::QASYMM8 && input_qinfo != output_qinfo) { - const auto in_ptr = reinterpret_cast(input_ptr + input.offset()); - const auto out_ptr = reinterpret_cast(output_ptr + output.offset()); - - wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr)); - }, - input, output); + execute_window_loop(window, [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast(input_ptr + input.offset()); + const auto out_ptr = reinterpret_cast(output_ptr + output.offset()); + vst1q_u8(out_ptr, vquantize(vdequantize(vld1q_u8(in_ptr), input_qinfo), output_qinfo)); + }, + input, output); + } + else + { + execute_window_loop(window, [&](const Coordinates &) + { + const auto in_ptr = reinterpret_cast(input_ptr + input.offset()); + const auto out_ptr = reinterpret_cast(output_ptr + output.offset()); + + wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr)); + }, + input, output); + } } std::pair validate_and_configure_window(ITensorInfo *input, unsigned int depth_offset, ITensorInfo *output) diff --git a/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp b/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp index a84a6d9028..ca27a26493 100644 --- a/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp +++ b/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -27,6 +27,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/IAccessWindow.h" #include "arm_compute/core/ITensor.h" +#include "arm_compute/core/NEON/NEAsymm.h" #include "arm_compute/core/NEON/wrapper/wrapper.h" #include "arm_compute/core/TensorInfo.h" #include "arm_compute/core/Utils.h" @@ -110,15 +111,28 @@ void NEWidthConcatenateLayerKernel::run(const Window &window, const ThreadInfo & uint8_t *output_ptr = _output->buffer() + _output->info()->offset_first_element_in_bytes() + _width_offset * _output->info()->strides_in_bytes()[0]; // Create iterators - Iterator input(_input, window); - Iterator output(_output, window); - - execute_window_loop(window, [&](const Coordinates & id) + Iterator input(_input, window); + Iterator output(_output, window); + const DataType dt = _input->info()->data_type(); + const QuantizationInfo &input_qinfo = _input->info()->quantization_info(); + const QuantizationInfo &output_qinfo = _output->info()->quantization_info(); + if(dt == DataType::QASYMM8 && input_qinfo != output_qinfo) { - const auto in_ptr = input.ptr(); - const auto out_ptr = output_ptr + output.offset(); - - wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr)); - }, - input, output); + execute_window_loop(window, [&](const Coordinates &) + { + vst1q_u8(output_ptr + output.offset(), vquantize(vdequantize(vld1q_u8(input.ptr()), input_qinfo), output_qinfo)); + }, + input, output); + } + else + { + execute_window_loop(window, [&](const Coordinates &) + { + const auto in_ptr = input.ptr(); + const auto out_ptr = output_ptr + output.offset(); + + wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr)); + }, + input, output); + } } diff --git a/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp b/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp index 097605c062..7e435c34b1 100644 --- a/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp +++ b/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -66,7 +66,7 @@ void NEWidthConcatenateLayer::configure(std::vector inputs_vector, IT _num_inputs = inputs_vector.size(); std::vector inputs_vector_info; - for(unsigned int i = 0; i < _num_inputs; i++) + for(unsigned int i = 0; i < _num_inputs; ++i) { inputs_vector_info.emplace_back(inputs_vector.at(i)->info()); } @@ -80,7 +80,7 @@ void NEWidthConcatenateLayer::configure(std::vector inputs_vector, IT _concat_kernels_vector = arm_compute::support::cpp14::make_unique(_num_inputs); - for(unsigned int i = 0; i < _num_inputs; i++) + for(unsigned int i = 0; i < _num_inputs; ++i) { _concat_kernels_vector[i].configure(inputs_vector.at(i), width_offset, output); width_offset += inputs_vector.at(i)->info()->dimension(0); @@ -89,7 +89,7 @@ void NEWidthConcatenateLayer::configure(std::vector inputs_vector, IT void NEWidthConcatenateLayer::run() { - for(unsigned i = 0; i < _num_inputs; i++) + for(unsigned i = 0; i < _num_inputs; ++i) { NEScheduler::get().schedule(_concat_kernels_vector.get() + i, Window::DimY); } diff --git a/tests/validation/fixtures/DepthConcatenateLayerFixture.h b/tests/validation/fixtures/DepthConcatenateLayerFixture.h index 5fdfacbb76..edeefa228a 100644 --- a/tests/validation/fixtures/DepthConcatenateLayerFixture.h +++ b/tests/validation/fixtures/DepthConcatenateLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,9 +53,22 @@ public: // Create input shapes std::mt19937 gen(library->seed()); std::uniform_int_distribution<> num_dis(2, 4); - const int num_tensors = num_dis(gen); + std::uniform_int_distribution<> offset_dis(0, 20); + + const int num_tensors = num_dis(gen); + + std::vector shapes(num_tensors, shape); + + // vector holding the quantization info: + // the last element is the output quantization info + // all other elements are the quantization info for the input tensors + std::vector qinfo(num_tensors + 1, QuantizationInfo()); + + for(auto &qi : qinfo) + { + qi = QuantizationInfo(1.f / 255.f, offset_dis(gen)); + } - std::vector shapes(num_tensors, shape); std::uniform_int_distribution<> depth_dis(1, 3); std::bernoulli_distribution mutate_dis(0.5f); std::uniform_real_distribution<> change_dis(-0.25f, 0.f); @@ -82,8 +95,8 @@ public: } } - _target = compute_target(shapes, data_type); - _reference = compute_reference(shapes, data_type); + _target = compute_target(shapes, qinfo, data_type); + _reference = compute_reference(shapes, qinfo, data_type); } protected: @@ -93,7 +106,7 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(std::vector shapes, DataType data_type) + TensorType compute_target(std::vector shapes, const std::vector &qinfo, DataType data_type) { std::vector srcs; std::vector src_ptrs; @@ -101,14 +114,14 @@ protected: // Create tensors srcs.reserve(shapes.size()); - for(const auto &shape : shapes) + for(size_t j = 0; j < shapes.size(); ++j) { - srcs.emplace_back(create_tensor(shape, data_type, 1)); + srcs.emplace_back(create_tensor(shapes[j], data_type, 1, qinfo[j])); src_ptrs.emplace_back(&srcs.back()); } TensorShape dst_shape = misc::shape_calculator::calculate_depth_concatenate_shape(src_ptrs); - TensorType dst = create_tensor(dst_shape, data_type, 1); + TensorType dst = create_tensor(dst_shape, data_type, 1, qinfo[shapes.size()]); // Create and configure function FunctionType depth_concat; @@ -144,19 +157,21 @@ protected: return dst; } - SimpleTensor compute_reference(std::vector shapes, DataType data_type) + SimpleTensor compute_reference(std::vector shapes, const std::vector &qinfo, DataType data_type) { std::vector> srcs; // Create and fill tensors - int i = 0; - for(const auto &shape : shapes) + for(size_t j = 0; j < shapes.size(); ++j) { - srcs.emplace_back(shape, data_type, 1); - fill(srcs.back(), i++); + srcs.emplace_back(shapes[j], data_type, 1, qinfo[j]); + fill(srcs.back(), j); } - return reference::depthconcatenate_layer(srcs); + const TensorShape dst_shape = calculate_depth_concatenate_shape(shapes); + SimpleTensor dst{ dst_shape, data_type, 1, qinfo[shapes.size()] }; + + return reference::depthconcatenate_layer(srcs, dst); } TensorType _target{}; diff --git a/tests/validation/fixtures/WidthConcatenateLayerFixture.h b/tests/validation/fixtures/WidthConcatenateLayerFixture.h index 1f79210350..47a03ed865 100644 --- a/tests/validation/fixtures/WidthConcatenateLayerFixture.h +++ b/tests/validation/fixtures/WidthConcatenateLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,9 +53,20 @@ public: // Create input shapes std::mt19937 gen(library->seed()); std::uniform_int_distribution<> num_dis(2, 8); - const int num_tensors = num_dis(gen); + std::uniform_int_distribution<> offset_dis(0, 20); - std::vector shapes(num_tensors, shape); + const int num_tensors = num_dis(gen); + + std::vector shapes(num_tensors, shape); + + // vector holding the quantization info: + // the last element is the output quantization info + // all other elements are the quantization info for the input tensors + std::vector qinfo(num_tensors + 1, QuantizationInfo()); + for(auto &qi : qinfo) + { + qi = QuantizationInfo(1.f / 255.f, offset_dis(gen)); + } std::bernoulli_distribution mutate_dis(0.5f); std::uniform_real_distribution<> change_dis(-0.25f, 0.f); @@ -71,8 +82,8 @@ public: } } - _target = compute_target(shapes, data_type); - _reference = compute_reference(shapes, data_type); + _target = compute_target(shapes, qinfo, data_type); + _reference = compute_reference(shapes, qinfo, data_type); } protected: @@ -82,7 +93,7 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(std::vector shapes, DataType data_type) + TensorType compute_target(std::vector shapes, const std::vector &qinfo, DataType data_type) { std::vector srcs; std::vector src_ptrs; @@ -90,14 +101,15 @@ protected: // Create tensors srcs.reserve(shapes.size()); - for(const auto &shape : shapes) + for(size_t j = 0; j < shapes.size(); ++j) { - srcs.emplace_back(create_tensor(shape, data_type, 1)); + srcs.emplace_back(create_tensor(shapes[j], data_type, 1, qinfo[j])); src_ptrs.emplace_back(&srcs.back()); } TensorShape dst_shape = misc::shape_calculator::calculate_width_concatenate_shape(src_ptrs); - TensorType dst = create_tensor(dst_shape, data_type, 1); + + TensorType dst = create_tensor(dst_shape, data_type, 1, qinfo[shapes.size()]); // Create and configure function FunctionType width_concat; @@ -133,19 +145,21 @@ protected: return dst; } - SimpleTensor compute_reference(std::vector shapes, DataType data_type) + SimpleTensor compute_reference(std::vector shapes, const std::vector &qinfo, DataType data_type) { std::vector> srcs; // Create and fill tensors - int i = 0; - for(const auto &shape : shapes) + for(size_t j = 0; j < shapes.size(); ++j) { - srcs.emplace_back(shape, data_type, 1); - fill(srcs.back(), i++); + srcs.emplace_back(shapes[j], data_type, 1, qinfo[j]); + fill(srcs.back(), j); } - return reference::widthconcatenate_layer(srcs); + const TensorShape dst_shape = calculate_width_concatenate_shape(shapes); + SimpleTensor dst{ dst_shape, data_type, 1, qinfo[shapes.size()] }; + + return reference::widthconcatenate_layer(srcs, dst); } TensorType _target{}; diff --git a/tests/validation/reference/DepthConcatenateLayer.cpp b/tests/validation/reference/DepthConcatenateLayer.cpp index 90fbd915b1..6551f0c79e 100644 --- a/tests/validation/reference/DepthConcatenateLayer.cpp +++ b/tests/validation/reference/DepthConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -34,7 +34,7 @@ namespace validation namespace reference { template -SimpleTensor depthconcatenate_layer(const std::vector> &srcs) +SimpleTensor depthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst) { // Create reference std::vector shapes; @@ -44,10 +44,6 @@ SimpleTensor depthconcatenate_layer(const std::vector> &srcs) shapes.emplace_back(src.shape()); } - DataType dst_type = srcs.empty() ? DataType::UNKNOWN : srcs[0].data_type(); - TensorShape dst_shape = calculate_depth_concatenate_shape(shapes); - SimpleTensor dst(dst_shape, dst_type); - // Compute reference int depth_offset = 0; const int width_out = dst.shape().x(); @@ -80,8 +76,20 @@ SimpleTensor depthconcatenate_layer(const std::vector> &srcs) { for(int r = 0; r < height; ++r) { - std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out); - src_ptr += width; + if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info()) + { + std::transform(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out, [src, dst](T t) + { + const float dequantized_input = src.quantization_info().dequantize(t); + return dst.quantization_info().quantize(dequantized_input, RoundingPolicy::TO_NEAREST_UP); + }); + src_ptr += width; + } + else + { + std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out); + src_ptr += width; + } } } } @@ -92,9 +100,9 @@ SimpleTensor depthconcatenate_layer(const std::vector> &srcs) return dst; } -template SimpleTensor depthconcatenate_layer(const std::vector> &srcs); -template SimpleTensor depthconcatenate_layer(const std::vector> &srcs); -template SimpleTensor depthconcatenate_layer(const std::vector> &srcs); +template SimpleTensor depthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); +template SimpleTensor depthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); +template SimpleTensor depthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/DepthConcatenateLayer.h b/tests/validation/reference/DepthConcatenateLayer.h index 3c486a8015..8a78441651 100644 --- a/tests/validation/reference/DepthConcatenateLayer.h +++ b/tests/validation/reference/DepthConcatenateLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -37,7 +37,7 @@ namespace validation namespace reference { template -SimpleTensor depthconcatenate_layer(const std::vector> &srcs); +SimpleTensor depthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/WidthConcatenateLayer.cpp b/tests/validation/reference/WidthConcatenateLayer.cpp index 6be171b64d..38543393ce 100644 --- a/tests/validation/reference/WidthConcatenateLayer.cpp +++ b/tests/validation/reference/WidthConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -34,7 +34,7 @@ namespace validation namespace reference { template -SimpleTensor widthconcatenate_layer(const std::vector> &srcs) +SimpleTensor widthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst) { // Create reference std::vector shapes; @@ -44,10 +44,6 @@ SimpleTensor widthconcatenate_layer(const std::vector> &srcs) shapes.emplace_back(src.shape()); } - DataType dst_type = srcs.empty() ? DataType::UNKNOWN : srcs[0].data_type(); - TensorShape dst_shape = calculate_width_concatenate_shape(shapes); - SimpleTensor dst(dst_shape, dst_type); - // Compute reference int width_offset = 0; const int width_out = dst.shape().x(); @@ -74,21 +70,32 @@ SimpleTensor widthconcatenate_layer(const std::vector> &srcs) for(int r = 0; r < height; ++r) { const int offset = u * height * depth + d * height + r; - std::copy(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out); - src_ptr += width; + if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info()) + { + std::transform(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out, [src, dst](T t) + { + const float dequantized_input = src.quantization_info().dequantize(t); + return dst.quantization_info().quantize(dequantized_input, RoundingPolicy::TO_NEAREST_UP); + }); + src_ptr += width; + } + else + { + std::copy(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out); + src_ptr += width; + } } } } - width_offset += width; } return dst; } -template SimpleTensor widthconcatenate_layer(const std::vector> &srcs); -template SimpleTensor widthconcatenate_layer(const std::vector> &srcs); -template SimpleTensor widthconcatenate_layer(const std::vector> &srcs); +template SimpleTensor widthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); +template SimpleTensor widthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); +template SimpleTensor widthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/WidthConcatenateLayer.h b/tests/validation/reference/WidthConcatenateLayer.h index 237e72b947..0f1f428f10 100644 --- a/tests/validation/reference/WidthConcatenateLayer.h +++ b/tests/validation/reference/WidthConcatenateLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -37,7 +37,7 @@ namespace validation namespace reference { template -SimpleTensor widthconcatenate_layer(const std::vector> &srcs); +SimpleTensor widthconcatenate_layer(const std::vector> &srcs, SimpleTensor &dst); } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1