aboutsummaryrefslogtreecommitdiff
path: root/tests/AssetsLibrary.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/AssetsLibrary.h')
-rw-r--r--tests/AssetsLibrary.h68
1 files changed, 61 insertions, 7 deletions
diff --git a/tests/AssetsLibrary.h b/tests/AssetsLibrary.h
index 84653ed089..16ebff7ec0 100644
--- a/tests/AssetsLibrary.h
+++ b/tests/AssetsLibrary.h
@@ -392,6 +392,19 @@ public:
template <typename T, typename D>
void fill_tensor_value(T &&tensor, D value) const;
+ /** Fill a tensor with a given vector with static values.
+ *
+ * @param[in, out] tensor To be filled tensor.
+ * @param[in] values A vector containing values
+ *
+ * To cope with various size tensors, the vector size doens't have to be
+ * the same as tensor's size. If the size of the tensor is larger than the vector,
+ * the iterator the vector will keep iterating and wrap around. If the vector is
+ * larger, values located after the required size won't be used.
+ */
+ template <typename T, typename DataType>
+ void fill_static_values(T &&tensor, const std::vector<DataType> &values) const;
+
private:
// Function prototype to convert between image formats.
using Converter = void (*)(const RawTensor &src, RawTensor &dst);
@@ -399,6 +412,9 @@ private:
using Extractor = void (*)(const RawTensor &src, RawTensor &dst);
// Function prototype to load an image file.
using Loader = RawTensor (*)(const std::string &path);
+ // Function type to generate a number to fill tensors.
+ template <typename ResultType>
+ using GeneratorFunctionType = std::function<ResultType(void)>;
const Converter &get_converter(Format src, Format dst) const;
const Converter &get_converter(DataType src, Format dst) const;
@@ -443,6 +459,14 @@ private:
*/
const RawTensor &find_or_create_raw_tensor(const std::string &name, Format format, Channel channel) const;
+ /** Fill a tensor with a value generator function.
+ *
+ * @param[in, out] tensor To be filled tensor.
+ * @param[in] generate_value A function that generates values.
+ */
+ template <typename T, typename ResultType>
+ void fill_with_generator(T &&tensor, const GeneratorFunctionType<ResultType> &generate_value) const;
+
mutable TensorCache _cache{};
mutable arm_compute::Mutex _format_lock{};
mutable arm_compute::Mutex _channel_lock{};
@@ -556,13 +580,9 @@ void AssetsLibrary::fill(std::vector<T> &vec, D &&distribution, std::random_devi
}
}
-template <typename T, typename D>
-void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::result_type seed_offset) const
+template <typename T, typename ResultType>
+void AssetsLibrary::fill_with_generator(T &&tensor, const GeneratorFunctionType<ResultType> &generate_value) const
{
- using ResultType = typename std::remove_reference<D>::type::result_type;
-
- std::mt19937 gen(_seed + seed_offset);
-
const bool is_nhwc = tensor.data_layout() == DataLayout::NHWC;
TensorShape shape(tensor.shape());
@@ -587,16 +607,50 @@ void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::resul
// Iterate over all channels
for(int channel = 0; channel < tensor.num_channels(); ++channel)
{
- const ResultType value = distribution(gen);
+ const ResultType value = generate_value();
ResultType &target_value = reinterpret_cast<ResultType *>(tensor(id))[channel];
store_value_with_data_type(&target_value, value, tensor.data_type());
}
}
+}
+template <typename T, typename D>
+void AssetsLibrary::fill(T &&tensor, D &&distribution, std::random_device::result_type seed_offset) const
+{
+ using ResultType = typename std::remove_reference<D>::type::result_type;
+ std::mt19937 gen(_seed + seed_offset);
+
+ GeneratorFunctionType<ResultType> number_generator = [&]()
+ {
+ const ResultType value = distribution(gen);
+ return value;
+ };
+
+ fill_with_generator(tensor, number_generator);
fill_borders_with_garbage(tensor, distribution, seed_offset);
}
+template <typename T, typename DataType>
+void AssetsLibrary::fill_static_values(T &&tensor, const std::vector<DataType> &values) const
+{
+ auto it = values.begin();
+ GeneratorFunctionType<DataType> get_next_value = [&]()
+ {
+ const DataType value = *it;
+ ++it;
+
+ if(it == values.end())
+ {
+ it = values.begin();
+ }
+
+ return value;
+ };
+
+ fill_with_generator(tensor, get_next_value);
+}
+
template <typename D>
void AssetsLibrary::fill(RawTensor &raw, D &&distribution, std::random_device::result_type seed_offset) const
{