aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/ConcatenateLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/ConcatenateLayer.cpp')
-rw-r--r--tests/validation/reference/ConcatenateLayer.cpp23
1 files changed, 18 insertions, 5 deletions
diff --git a/tests/validation/reference/ConcatenateLayer.cpp b/tests/validation/reference/ConcatenateLayer.cpp
index aa74ca2474..266dae1c27 100644
--- a/tests/validation/reference/ConcatenateLayer.cpp
+++ b/tests/validation/reference/ConcatenateLayer.cpp
@@ -70,16 +70,27 @@ SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs,
for(int r = 0; r < height; ++r)
{
const int offset = u * height * depth + d * height + r;
- if(src.data_type() == DataType::QASYMM8 && src.quantization_info() != dst.quantization_info())
+ if(is_data_type_quantized(src.data_type()) && src.quantization_info() != dst.quantization_info())
{
const UniformQuantizationInfo iq_info = src.quantization_info().uniform();
const UniformQuantizationInfo oq_info = dst.quantization_info().uniform();
- std::transform(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out, [&](T t)
+ if(src.data_type() == DataType::QASYMM8)
{
- const float dequantized_input = dequantize_qasymm8(t, iq_info);
- return quantize_qasymm8(dequantized_input, oq_info);
- });
+ std::transform(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out, [&](T t)
+ {
+ const float dequantized_input = dequantize_qasymm8(t, iq_info);
+ return quantize_qasymm8(dequantized_input, oq_info);
+ });
+ }
+ else
+ {
+ std::transform(src_ptr, src_ptr + width, dst_ptr + width_offset + offset * width_out, [&](T t)
+ {
+ const float dequantized_input = dequantize_qasymm8_signed(t, iq_info);
+ return quantize_qasymm8_signed(dequantized_input, oq_info);
+ });
+ }
src_ptr += width;
}
else
@@ -98,6 +109,7 @@ SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs,
template SimpleTensor<float> widthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst);
template SimpleTensor<half> widthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst);
template SimpleTensor<uint8_t> widthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst);
+template SimpleTensor<int8_t> widthconcatenate_layer(const std::vector<SimpleTensor<int8_t>> &srcs, SimpleTensor<int8_t> &dst);
} // namespace
template <typename T>
@@ -148,6 +160,7 @@ SimpleTensor<T> concatenate_layer(std::vector<SimpleTensor<T>> &srcs, SimpleTens
template SimpleTensor<float> concatenate_layer(std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst, unsigned int axis);
template SimpleTensor<half> concatenate_layer(std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst, unsigned int axis);
template SimpleTensor<uint8_t> concatenate_layer(std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst, unsigned int axis);
+template SimpleTensor<int8_t> concatenate_layer(std::vector<SimpleTensor<int8_t>> &srcs, SimpleTensor<int8_t> &dst, unsigned int axis);
} // namespace reference
} // namespace validation
} // namespace test