aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WidthConcatenateLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/WidthConcatenateLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WidthConcatenateLayerFixture.h44
1 files changed, 29 insertions, 15 deletions
diff --git a/tests/validation/fixtures/WidthConcatenateLayerFixture.h b/tests/validation/fixtures/WidthConcatenateLayerFixture.h
index 1f79210350..47a03ed865 100644
--- a/tests/validation/fixtures/WidthConcatenateLayerFixture.h
+++ b/tests/validation/fixtures/WidthConcatenateLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,9 +53,20 @@ public:
// Create input shapes
std::mt19937 gen(library->seed());
std::uniform_int_distribution<> num_dis(2, 8);
- const int num_tensors = num_dis(gen);
+ std::uniform_int_distribution<> offset_dis(0, 20);
- std::vector<TensorShape> shapes(num_tensors, shape);
+ const int num_tensors = num_dis(gen);
+
+ std::vector<TensorShape> shapes(num_tensors, shape);
+
+ // vector holding the quantization info:
+ // the last element is the output quantization info
+ // all other elements are the quantization info for the input tensors
+ std::vector<QuantizationInfo> qinfo(num_tensors + 1, QuantizationInfo());
+ for(auto &qi : qinfo)
+ {
+ qi = QuantizationInfo(1.f / 255.f, offset_dis(gen));
+ }
std::bernoulli_distribution mutate_dis(0.5f);
std::uniform_real_distribution<> change_dis(-0.25f, 0.f);
@@ -71,8 +82,8 @@ public:
}
}
- _target = compute_target(shapes, data_type);
- _reference = compute_reference(shapes, data_type);
+ _target = compute_target(shapes, qinfo, data_type);
+ _reference = compute_reference(shapes, qinfo, data_type);
}
protected:
@@ -82,7 +93,7 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(std::vector<TensorShape> shapes, DataType data_type)
+ TensorType compute_target(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
{
std::vector<TensorType> srcs;
std::vector<ITensorType *> src_ptrs;
@@ -90,14 +101,15 @@ protected:
// Create tensors
srcs.reserve(shapes.size());
- for(const auto &shape : shapes)
+ for(size_t j = 0; j < shapes.size(); ++j)
{
- srcs.emplace_back(create_tensor<TensorType>(shape, data_type, 1));
+ srcs.emplace_back(create_tensor<TensorType>(shapes[j], data_type, 1, qinfo[j]));
src_ptrs.emplace_back(&srcs.back());
}
TensorShape dst_shape = misc::shape_calculator::calculate_width_concatenate_shape(src_ptrs);
- TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1);
+
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]);
// Create and configure function
FunctionType width_concat;
@@ -133,19 +145,21 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, DataType data_type)
+ SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
{
std::vector<SimpleTensor<T>> srcs;
// Create and fill tensors
- int i = 0;
- for(const auto &shape : shapes)
+ for(size_t j = 0; j < shapes.size(); ++j)
{
- srcs.emplace_back(shape, data_type, 1);
- fill(srcs.back(), i++);
+ srcs.emplace_back(shapes[j], data_type, 1, qinfo[j]);
+ fill(srcs.back(), j);
}
- return reference::widthconcatenate_layer<T>(srcs);
+ const TensorShape dst_shape = calculate_width_concatenate_shape(shapes);
+ SimpleTensor<T> dst{ dst_shape, data_type, 1, qinfo[shapes.size()] };
+
+ return reference::widthconcatenate_layer<T>(srcs, dst);
}
TensorType _target{};