diff options
Diffstat (limited to 'tests/validation/fixtures/WidthConcatenateLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/WidthConcatenateLayerFixture.h | 44 |
1 files changed, 29 insertions, 15 deletions
diff --git a/tests/validation/fixtures/WidthConcatenateLayerFixture.h b/tests/validation/fixtures/WidthConcatenateLayerFixture.h index 1f79210350..47a03ed865 100644 --- a/tests/validation/fixtures/WidthConcatenateLayerFixture.h +++ b/tests/validation/fixtures/WidthConcatenateLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -53,9 +53,20 @@ public: // Create input shapes std::mt19937 gen(library->seed()); std::uniform_int_distribution<> num_dis(2, 8); - const int num_tensors = num_dis(gen); + std::uniform_int_distribution<> offset_dis(0, 20); - std::vector<TensorShape> shapes(num_tensors, shape); + const int num_tensors = num_dis(gen); + + std::vector<TensorShape> shapes(num_tensors, shape); + + // vector holding the quantization info: + // the last element is the output quantization info + // all other elements are the quantization info for the input tensors + std::vector<QuantizationInfo> qinfo(num_tensors + 1, QuantizationInfo()); + for(auto &qi : qinfo) + { + qi = QuantizationInfo(1.f / 255.f, offset_dis(gen)); + } std::bernoulli_distribution mutate_dis(0.5f); std::uniform_real_distribution<> change_dis(-0.25f, 0.f); @@ -71,8 +82,8 @@ public: } } - _target = compute_target(shapes, data_type); - _reference = compute_reference(shapes, data_type); + _target = compute_target(shapes, qinfo, data_type); + _reference = compute_reference(shapes, qinfo, data_type); } protected: @@ -82,7 +93,7 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(std::vector<TensorShape> shapes, DataType data_type) + TensorType compute_target(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type) { std::vector<TensorType> srcs; std::vector<ITensorType *> src_ptrs; @@ -90,14 +101,15 @@ protected: // Create tensors srcs.reserve(shapes.size()); - for(const auto &shape : shapes) + for(size_t j = 0; j < shapes.size(); ++j) { - srcs.emplace_back(create_tensor<TensorType>(shape, data_type, 1)); + srcs.emplace_back(create_tensor<TensorType>(shapes[j], data_type, 1, qinfo[j])); src_ptrs.emplace_back(&srcs.back()); } TensorShape dst_shape = misc::shape_calculator::calculate_width_concatenate_shape(src_ptrs); - TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1); + + TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]); // Create and configure function FunctionType width_concat; @@ -133,19 +145,21 @@ protected: return dst; } - SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, DataType data_type) + SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type) { std::vector<SimpleTensor<T>> srcs; // Create and fill tensors - int i = 0; - for(const auto &shape : shapes) + for(size_t j = 0; j < shapes.size(); ++j) { - srcs.emplace_back(shape, data_type, 1); - fill(srcs.back(), i++); + srcs.emplace_back(shapes[j], data_type, 1, qinfo[j]); + fill(srcs.back(), j); } - return reference::widthconcatenate_layer<T>(srcs); + const TensorShape dst_shape = calculate_width_concatenate_shape(shapes); + SimpleTensor<T> dst{ dst_shape, data_type, 1, qinfo[shapes.size()] }; + + return reference::widthconcatenate_layer<T>(srcs, dst); } TensorType _target{}; |