aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/NEON/NEAsymm.h62
-rw-r--r--src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp33
-rw-r--r--src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp36
-rw-r--r--src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp8
-rw-r--r--tests/validation/fixtures/DepthConcatenateLayerFixture.h45
-rw-r--r--tests/validation/fixtures/WidthConcatenateLayerFixture.h44
-rw-r--r--tests/validation/reference/DepthConcatenateLayer.cpp30
-rw-r--r--tests/validation/reference/DepthConcatenateLayer.h4
-rw-r--r--tests/validation/reference/WidthConcatenateLayer.cpp31
-rw-r--r--tests/validation/reference/WidthConcatenateLayer.h4
10 files changed, 216 insertions, 81 deletions
diff --git a/arm_compute/core/NEON/NEAsymm.h b/arm_compute/core/NEON/NEAsymm.h
index faff59563b..c7f59e9eba 100644
--- a/arm_compute/core/NEON/NEAsymm.h
+++ b/arm_compute/core/NEON/NEAsymm.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -124,6 +124,66 @@ uint8x16_t finalize_quantization(int32x4x4_t &in_s32,
return out_u8;
}
+
+/** Dequantize a neon vector holding 16 quantized values.
+ *
+ * @param qv Input values to be dequantized.
+ * @param qi Quantization information to be used in the computation.
+ *
+ * @return Dequantized values in a neon vector
+ */
+inline float32x4x4_t vdequantize(const uint8x16_t &qv, const QuantizationInfo &qi)
+{
+ const float scale = qi.scale;
+ const int offset = qi.offset;
+ const int32x4_t voffset = vdupq_n_s32(offset);
+ const float32x4_t vscale = vdupq_n_f32(scale);
+ const float32x4x4_t vdequantized_input =
+ {
+ {
+ vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_low_u8(qv))))), voffset)), vscale),
+ vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_low_u8(qv))))), voffset)), vscale),
+ vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(vmovl_u8(vget_high_u8(qv))))), voffset)), vscale),
+ vmulq_f32(vcvtq_f32_s32(vsubq_s32(vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(vmovl_u8(vget_high_u8(qv))))), voffset)), vscale),
+ }
+ };
+ return vdequantized_input;
+}
+
+/** Quantize a neon vector holding 16 floating point values.
+ *
+ * @param qv Input values to be quantized.
+ * @param qi Quantization information to be used in the computation.
+ *
+ * @return A neon vector holding the quantized values
+ */
+inline uint8x16_t vquantize(const float32x4x4_t &qv, const QuantizationInfo &qi)
+{
+ const float scale = qi.scale;
+ const int offset = qi.offset;
+ const float32x4_t voffset = vdupq_n_f32(offset);
+ const float32x4_t vinvscale = vdupq_n_f32(1.f / scale);
+ const int32x4x4_t rf =
+ {
+ {
+#ifdef __aarch64__
+ vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[0], vinvscale)),
+ vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[1], vinvscale)),
+ vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[2], vinvscale)),
+ vcvtnq_s32_f32(vmlaq_f32(voffset, qv.val[3], vinvscale)),
+#else //__aarch64__
+ vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[0], vinvscale)),
+ vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[1], vinvscale)),
+ vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[2], vinvscale)),
+ vcvtq_s32_f32(vmlaq_f32(voffset, qv.val[3], vinvscale)),
+#endif //__aarch64__
+ }
+ };
+ const uint8x8_t pa = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[0]), vqmovn_s32(rf.val[1])));
+ const uint8x8_t pb = vqmovun_s16(vcombine_s16(vqmovn_s32(rf.val[2]), vqmovn_s32(rf.val[3])));
+ return vcombine_u8(pa, pb);
+}
+
} // namespace arm_compute
#include "arm_compute/core/NEON/NEAsymm.inl"
#endif // __ARM_COMPUTE_NEASYMM_H__
diff --git a/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp b/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp
index 8c875cdb2d..8352c94586 100644
--- a/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDepthConcatenateLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/NEAsymm.h"
#include "arm_compute/core/NEON/NEFixedPoint.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
@@ -57,14 +58,30 @@ void depth_concat(const ITensor *in, ITensor *out, std::pair<int, int> start_xy,
Iterator input(in, window);
Iterator output(out, window);
- execute_window_loop(window, [&](const Coordinates & id)
+ const DataType dt = in->info()->data_type();
+ const QuantizationInfo &input_qinfo = in->info()->quantization_info();
+ const QuantizationInfo &output_qinfo = out->info()->quantization_info();
+ if(dt == DataType::QASYMM8 && input_qinfo != output_qinfo)
{
- const auto in_ptr = reinterpret_cast<const T *>(input_ptr + input.offset());
- const auto out_ptr = reinterpret_cast<T *>(output_ptr + output.offset());
-
- wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr));
- },
- input, output);
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const uint8_t *>(input_ptr + input.offset());
+ const auto out_ptr = reinterpret_cast<uint8_t *>(output_ptr + output.offset());
+ vst1q_u8(out_ptr, vquantize(vdequantize(vld1q_u8(in_ptr), input_qinfo), output_qinfo));
+ },
+ input, output);
+ }
+ else
+ {
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ const auto in_ptr = reinterpret_cast<const T *>(input_ptr + input.offset());
+ const auto out_ptr = reinterpret_cast<T *>(output_ptr + output.offset());
+
+ wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr));
+ },
+ input, output);
+ }
}
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, unsigned int depth_offset, ITensorInfo *output)
diff --git a/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp b/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp
index a84a6d9028..ca27a26493 100644
--- a/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEWidthConcatenateLayerKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/IAccessWindow.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/NEAsymm.h"
#include "arm_compute/core/NEON/wrapper/wrapper.h"
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Utils.h"
@@ -110,15 +111,28 @@ void NEWidthConcatenateLayerKernel::run(const Window &window, const ThreadInfo &
uint8_t *output_ptr = _output->buffer() + _output->info()->offset_first_element_in_bytes() + _width_offset * _output->info()->strides_in_bytes()[0];
// Create iterators
- Iterator input(_input, window);
- Iterator output(_output, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
+ Iterator input(_input, window);
+ Iterator output(_output, window);
+ const DataType dt = _input->info()->data_type();
+ const QuantizationInfo &input_qinfo = _input->info()->quantization_info();
+ const QuantizationInfo &output_qinfo = _output->info()->quantization_info();
+ if(dt == DataType::QASYMM8 && input_qinfo != output_qinfo)
{
- const auto in_ptr = input.ptr();
- const auto out_ptr = output_ptr + output.offset();
-
- wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr));
- },
- input, output);
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ vst1q_u8(output_ptr + output.offset(), vquantize(vdequantize(vld1q_u8(input.ptr()), input_qinfo), output_qinfo));
+ },
+ input, output);
+ }
+ else
+ {
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ const auto in_ptr = input.ptr();
+ const auto out_ptr = output_ptr + output.offset();
+
+ wrapper::vstore(out_ptr, wrapper::vloadq(in_ptr));
+ },
+ input, output);
+ }
}
diff --git a/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp b/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp
index 097605c062..7e435c34b1 100644
--- a/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp
+++ b/src/runtime/NEON/functions/NEWidthConcatenateLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -66,7 +66,7 @@ void NEWidthConcatenateLayer::configure(std::vector<ITensor *> inputs_vector, IT
_num_inputs = inputs_vector.size();
std::vector<ITensorInfo *> inputs_vector_info;
- for(unsigned int i = 0; i < _num_inputs; i++)
+ for(unsigned int i = 0; i < _num_inputs; ++i)
{
inputs_vector_info.emplace_back(inputs_vector.at(i)->info());
}
@@ -80,7 +80,7 @@ void NEWidthConcatenateLayer::configure(std::vector<ITensor *> inputs_vector, IT
_concat_kernels_vector = arm_compute::support::cpp14::make_unique<NEWidthConcatenateLayerKernel[]>(_num_inputs);
- for(unsigned int i = 0; i < _num_inputs; i++)
+ for(unsigned int i = 0; i < _num_inputs; ++i)
{
_concat_kernels_vector[i].configure(inputs_vector.at(i), width_offset, output);
width_offset += inputs_vector.at(i)->info()->dimension(0);
@@ -89,7 +89,7 @@ void NEWidthConcatenateLayer::configure(std::vector<ITensor *> inputs_vector, IT
void NEWidthConcatenateLayer::run()
{
- for(unsigned i = 0; i < _num_inputs; i++)
+ for(unsigned i = 0; i < _num_inputs; ++i)
{
NEScheduler::get().schedule(_concat_kernels_vector.get() + i, Window::DimY);
}
diff --git a/tests/validation/fixtures/DepthConcatenateLayerFixture.h b/tests/validation/fixtures/DepthConcatenateLayerFixture.h
index 5fdfacbb76..edeefa228a 100644
--- a/tests/validation/fixtures/DepthConcatenateLayerFixture.h
+++ b/tests/validation/fixtures/DepthConcatenateLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,9 +53,22 @@ public:
// Create input shapes
std::mt19937 gen(library->seed());
std::uniform_int_distribution<> num_dis(2, 4);
- const int num_tensors = num_dis(gen);
+ std::uniform_int_distribution<> offset_dis(0, 20);
+
+ const int num_tensors = num_dis(gen);
+
+ std::vector<TensorShape> shapes(num_tensors, shape);
+
+ // vector holding the quantization info:
+ // the last element is the output quantization info
+ // all other elements are the quantization info for the input tensors
+ std::vector<QuantizationInfo> qinfo(num_tensors + 1, QuantizationInfo());
+
+ for(auto &qi : qinfo)
+ {
+ qi = QuantizationInfo(1.f / 255.f, offset_dis(gen));
+ }
- std::vector<TensorShape> shapes(num_tensors, shape);
std::uniform_int_distribution<> depth_dis(1, 3);
std::bernoulli_distribution mutate_dis(0.5f);
std::uniform_real_distribution<> change_dis(-0.25f, 0.f);
@@ -82,8 +95,8 @@ public:
}
}
- _target = compute_target(shapes, data_type);
- _reference = compute_reference(shapes, data_type);
+ _target = compute_target(shapes, qinfo, data_type);
+ _reference = compute_reference(shapes, qinfo, data_type);
}
protected:
@@ -93,7 +106,7 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(std::vector<TensorShape> shapes, DataType data_type)
+ TensorType compute_target(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
{
std::vector<TensorType> srcs;
std::vector<ITensorType *> src_ptrs;
@@ -101,14 +114,14 @@ protected:
// Create tensors
srcs.reserve(shapes.size());
- for(const auto &shape : shapes)
+ for(size_t j = 0; j < shapes.size(); ++j)
{
- srcs.emplace_back(create_tensor<TensorType>(shape, data_type, 1));
+ srcs.emplace_back(create_tensor<TensorType>(shapes[j], data_type, 1, qinfo[j]));
src_ptrs.emplace_back(&srcs.back());
}
TensorShape dst_shape = misc::shape_calculator::calculate_depth_concatenate_shape(src_ptrs);
- TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]);
// Create and configure function
FunctionType depth_concat;
@@ -144,19 +157,21 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, DataType data_type)
+ SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
{
std::vector<SimpleTensor<T>> srcs;
// Create and fill tensors
- int i = 0;
- for(const auto &shape : shapes)
+ for(size_t j = 0; j < shapes.size(); ++j)
{
- srcs.emplace_back(shape, data_type, 1);
- fill(srcs.back(), i++);
+ srcs.emplace_back(shapes[j], data_type, 1, qinfo[j]);
+ fill(srcs.back(), j);
}
- return reference::depthconcatenate_layer<T>(srcs);
+ const TensorShape dst_shape = calculate_depth_concatenate_shape(shapes);
+ SimpleTensor<T> dst{ dst_shape, data_type, 1, qinfo[shapes.size()] };
+
+ return reference::depthconcatenate_layer<T>(srcs, dst);
}
TensorType _target{};
diff --git a/tests/validation/fixtures/WidthConcatenateLayerFixture.h b/tests/validation/fixtures/WidthConcatenateLayerFixture.h
index 1f79210350..47a03ed865 100644
--- a/tests/validation/fixtures/WidthConcatenateLayerFixture.h
+++ b/tests/validation/fixtures/WidthConcatenateLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,9 +53,20 @@ public:
// Create input shapes
std::mt19937 gen(library->seed());
std::uniform_int_distribution<> num_dis(2, 8);
- const int num_tensors = num_dis(gen);
+ std::uniform_int_distribution<> offset_dis(0, 20);
- std::vector<TensorShape> shapes(num_tensors, shape);
+ const int num_tensors = num_dis(gen);
+
+ std::vector<TensorShape> shapes(num_tensors, shape);
+
+ // vector holding the quantization info:
+ // the last element is the output quantization info
+ // all other elements are the quantization info for the input tensors
+ std::vector<QuantizationInfo> qinfo(num_tensors + 1, QuantizationInfo());
+ for(auto &qi : qinfo)
+ {
+ qi = QuantizationInfo(1.f / 255.f, offset_dis(gen));
+ }
std::bernoulli_distribution mutate_dis(0.5f);
std::uniform_real_distribution<> change_dis(-0.25f, 0.f);
@@ -71,8 +82,8 @@ public:
}
}
- _target = compute_target(shapes, data_type);
- _reference = compute_reference(shapes, data_type);
+ _target = compute_target(shapes, qinfo, data_type);
+ _reference = compute_reference(shapes, qinfo, data_type);
}
protected:
@@ -82,7 +93,7 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(std::vector<TensorShape> shapes, DataType data_type)
+ TensorType compute_target(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
{
std::vector<TensorType> srcs;
std::vector<ITensorType *> src_ptrs;
@@ -90,14 +101,15 @@ protected:
// Create tensors
srcs.reserve(shapes.size());
- for(const auto &shape : shapes)
+ for(size_t j = 0; j < shapes.size(); ++j)
{
- srcs.emplace_back(create_tensor<TensorType>(shape, data_type, 1));
+ srcs.emplace_back(create_tensor<TensorType>(shapes[j], data_type, 1, qinfo[j]));
src_ptrs.emplace_back(&srcs.back());
}
TensorShape dst_shape = misc::shape_calculator::calculate_width_concatenate_shape(src_ptrs);
- TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1);
+
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]);
// Create and configure function
FunctionType width_concat;
@@ -133,19 +145,21 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, DataType data_type)
+ SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
{
std::vector<SimpleTensor<T>> srcs;
// Create and fill tensors
- int i = 0;
- for(const auto &shape : shapes)
+ for(size_t j = 0; j < shapes.size(); ++j)
{
- srcs.emplace_back(shape, data_type, 1);
- fill(srcs.back(), i++);
+ srcs.emplace_back(shapes[j], data_type, 1, qinfo[j]);
+ fill(srcs.back(), j);
}
- return reference::widthconcatenate_layer<T>(srcs);
+ const TensorShape dst_shape = calculate_width_concatenate_shape(shapes);
+ SimpleTensor<T> dst{ dst_shape, data_type, 1, qinfo[shapes.size()] };
+
+ return reference::widthconcatenate_layer<T>(srcs, dst);
}
TensorType _target{};
diff --git a/tests/validation/reference/DepthConcatenateLayer.cpp b/tests/validation/reference/DepthConcatenateLayer.cpp
index 90fbd915b1..6551f0c79e 100644
--- a/tests/validation/reference/DepthConcatenateLayer.cpp
+++ b/tests/validation/reference/DepthConcatenateLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,7 +34,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
+SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst)
{
// Create reference
std::vector<TensorShape> shapes;
@@ -44,10 +44,6 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
shapes.emplace_back(src.shape());
}
- DataType dst_type = srcs.empty() ? DataType::UNKNOWN : srcs[0].data_type();
- TensorShape dst_shape = calculate_depth_concatenate_shape(shapes);
- SimpleTensor<T> dst(dst_shape, dst_type);
-
// Compute reference
int depth_offset = 0;
const int width_out = dst.shape().x();
@@ -80,8 +76,20 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
{
for(int r = 0; r < height; ++r)
{
- std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out);
- src_ptr += width;
+ if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info())
+ {
+ std::transform(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out, [src, dst](T t)
+ {
+ const float dequantized_input = src.quantization_info().dequantize(t);
+ return dst.quantization_info().quantize(dequantized_input, RoundingPolicy::TO_NEAREST_UP);
+ });
+ src_ptr += width;
+ }
+ else
+ {
+ std::copy(src_ptr, src_ptr + width, dst.data() + offset_to_first_element + d * out_stride_z + r * width_out);
+ src_ptr += width;
+ }
}
}
}
@@ -92,9 +100,9 @@ SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
return dst;
}
-template SimpleTensor<uint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs);
-template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs);
-template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs);
+template SimpleTensor<uint8_t> depthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst);
+template SimpleTensor<float> depthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst);
+template SimpleTensor<half> depthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/DepthConcatenateLayer.h b/tests/validation/reference/DepthConcatenateLayer.h
index 3c486a8015..8a78441651 100644
--- a/tests/validation/reference/DepthConcatenateLayer.h
+++ b/tests/validation/reference/DepthConcatenateLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,7 +37,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs);
+SimpleTensor<T> depthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/WidthConcatenateLayer.cpp b/tests/validation/reference/WidthConcatenateLayer.cpp
index 6be171b64d..38543393ce 100644
--- a/tests/validation/reference/WidthConcatenateLayer.cpp
+++ b/tests/validation/reference/WidthConcatenateLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -34,7 +34,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
+SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst)
{
// Create reference
std::vector<TensorShape> shapes;
@@ -44,10 +44,6 @@ SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
shapes.emplace_back(src.shape());
}
- DataType dst_type = srcs.empty() ? DataType::UNKNOWN : srcs[0].data_type();
- TensorShape dst_shape = calculate_width_concatenate_shape(shapes);
- SimpleTensor<T> dst(dst_shape, dst_type);
-
// Compute reference
int width_offset = 0;
const int width_out = dst.shape().x();
@@ -74,21 +70,32 @@ SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs)
for(int r = 0; r < height; ++r)
{
const int offset = u * height * depth + d * height + r;
- std::copy(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out);
- src_ptr += width;
+ if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info())
+ {
+ std::transform(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out, [src, dst](T t)
+ {
+ const float dequantized_input = src.quantization_info().dequantize(t);
+ return dst.quantization_info().quantize(dequantized_input, RoundingPolicy::TO_NEAREST_UP);
+ });
+ src_ptr += width;
+ }
+ else
+ {
+ std::copy(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out);
+ src_ptr += width;
+ }
}
}
}
-
width_offset += width;
}
return dst;
}
-template SimpleTensor<float> widthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs);
-template SimpleTensor<half> widthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs);
-template SimpleTensor<uint8_t> widthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs);
+template SimpleTensor<float> widthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst);
+template SimpleTensor<half> widthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst);
+template SimpleTensor<uint8_t> widthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/WidthConcatenateLayer.h b/tests/validation/reference/WidthConcatenateLayer.h
index 237e72b947..0f1f428f10 100644
--- a/tests/validation/reference/WidthConcatenateLayer.h
+++ b/tests/validation/reference/WidthConcatenateLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,7 +37,7 @@ namespace validation
namespace reference
{
template <typename T>
-SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs);
+SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst);
} // namespace reference
} // namespace validation
} // namespace test