diff options
Diffstat (limited to 'tests/AssetsLibrary.h')
-rw-r--r-- | tests/AssetsLibrary.h | 68 |
1 files changed, 61 insertions, 7 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h index 84653ed089..16ebff7ec0 100644 --- a/tests/AssetsLibrary.h +++ b/tests/AssetsLibrary.h @@ -392,6 +392,19 @@ public: template <typename T, typename D> void fill_tensor_value(T &&tensor, D value) const; + /** Fill a tensor with a given vector with static values. + * + * @param[in, out] tensor To be filled tensor. + * @param[in] values A vector containing values + * + * To cope with various size tensors, the vector size doens't have to be + * the same as tensor's size. If the size of the tensor is larger than the vector, + * the iterator the vector will keep iterating and wrap around. If the vector is + * larger, values located after the required size won't be used. + */ + template <typename T, typename DataType> + void fill_static_values(T &&tensor, const std::vector<DataType> &values) const; + private: // Function prototype to convert between image formats. using Converter = void (*)(const RawTensor &src, RawTensor &dst); @@ -399,6 +412,9 @@ private: using Extractor = void (*)(const RawTensor &src, RawTensor &dst); // Function prototype to load an image file. using Loader = RawTensor (*)(const std::string &path); + // Function type to generate a number to fill tensors. + template <typename ResultType> + using GeneratorFunctionType = std::function<ResultType(void)>; const Converter &get_converter(Format src, Format dst) const; const Converter &get_converter(DataType src, Format dst) const; @@ -443,6 +459,14 @@ private: */ const RawTensor &find_or_create_raw_tensor(const std::string &name, Format format, Channel channel) const; + /** Fill a tensor with a value generator function. + * + * @param[in, out] tensor To be filled tensor. + * @param[in] generate_value A function that generates values. + */ + template <typename T, typename ResultType> + void fill_with_generator(T &&tensor, const GeneratorFunctionType<ResultType> &generate_value) const; + mutable TensorCache _cache{}; mutable arm_compute::Mutex _format_lock{}; mutable arm_compute::Mutex _channel_lock{}; @@ -556,13 +580,9 @@ void AssetsLibrary::fill(std::vector<T> &vec, D &&distribution, std::random_devi } } -template <typename T, typename D> -void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::result_type seed_offset) const +template <typename T, typename ResultType> +void AssetsLibrary::fill_with_generator(T &&tensor, const GeneratorFunctionType<ResultType> &generate_value) const { - using ResultType = typename std::remove_reference<D>::type::result_type; - - std::mt19937 gen(_seed + seed_offset); - const bool is_nhwc = tensor.data_layout() == DataLayout::NHWC; TensorShape shape(tensor.shape()); @@ -587,16 +607,50 @@ void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::resul // Iterate over all channels for(int channel = 0; channel < tensor.num_channels(); ++channel) { - const ResultType value = distribution(gen); + const ResultType value = generate_value(); ResultType &target_value = reinterpret_cast<ResultType *>(tensor(id))[channel]; store_value_with_data_type(&target_value, value, tensor.data_type()); } } +} +template <typename T, typename D> +void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::result_type seed_offset) const +{ + using ResultType = typename std::remove_reference<D>::type::result_type; + std::mt19937 gen(_seed + seed_offset); + + GeneratorFunctionType<ResultType> number_generator = [&]() + { + const ResultType value = distribution(gen); + return value; + }; + + fill_with_generator(tensor, number_generator); fill_borders_with_garbage(tensor, distribution, seed_offset); } +template <typename T, typename DataType> +void AssetsLibrary::fill_static_values(T &&tensor, const std::vector<DataType> &values) const +{ + auto it = values.begin(); + GeneratorFunctionType<DataType> get_next_value = [&]() + { + const DataType value = *it; + ++it; + + if(it == values.end()) + { + it = values.begin(); + } + + return value; + }; + + fill_with_generator(tensor, get_next_value); +} + template <typename D> void AssetsLibrary::fill(RawTensor &raw, D &&distribution, std::random_device::result_type seed_offset) const { |