diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/Utils.h | 93 |
1 files changed, 62 insertions, 31 deletions
diff --git a/utils/Utils.h b/utils/Utils.h index 7eeeae5419..ca4509778b 100644 --- a/utils/Utils.h +++ b/utils/Utils.h @@ -722,54 +722,85 @@ void load_trained_data(T &tensor, const std::string &filename) } } -template <typename T> -void fill_random_tensor(T &tensor, float lower_bound, float upper_bound) +template <typename T, typename TensorType> +void fill_tensor_value(TensorType &tensor, T value) { - std::random_device rd; - std::mt19937 gen(rd()); + map(tensor, true); Window window; window.use_tensor_dimensions(tensor.info()->tensor_shape()); + Iterator it_tensor(&tensor, window); + execute_window_loop(window, [&](const Coordinates &) + { + *reinterpret_cast<T *>(it_tensor.ptr()) = value; + }, + it_tensor); + + unmap(tensor); +} + +template <typename T, typename TensorType> +void fill_tensor_zero(TensorType &tensor) +{ + fill_tensor_value(tensor, T(0)); +} + +template <typename T, typename TensorType> +void fill_tensor_vector(TensorType &tensor, std::vector<T> vec) +{ + ARM_COMPUTE_ERROR_ON(tensor.info()->tensor_shape().total_size() != vec.size()); + map(tensor, true); - Iterator it(&tensor, window); + Window window; + window.use_tensor_dimensions(tensor.info()->tensor_shape()); - switch(tensor.info()->data_type()) + int i = 0; + Iterator it_tensor(&tensor, window); + execute_window_loop(window, [&](const Coordinates &) { - case arm_compute::DataType::F16: - { - std::uniform_real_distribution<float> dist(lower_bound, upper_bound); + *reinterpret_cast<T *>(it_tensor.ptr()) = vec.at(i++); + }, + it_tensor); - execute_window_loop(window, [&](const Coordinates &) - { - *reinterpret_cast<half *>(it.ptr()) = (half)dist(gen); - }, - it); + unmap(tensor); +} - break; - } - case arm_compute::DataType::F32: - { - std::uniform_real_distribution<float> dist(lower_bound, upper_bound); +template <typename T, typename TensorType> +void fill_random_tensor(TensorType &tensor, std::random_device::result_type seed, T lower_bound = std::numeric_limits<T>::lowest(), T upper_bound = std::numeric_limits<T>::max()) +{ + constexpr bool is_half = std::is_same<T, half>::value; + constexpr bool is_integral = std::is_integral<T>::value && !is_half; - execute_window_loop(window, [&](const Coordinates &) - { - *reinterpret_cast<float *>(it.ptr()) = dist(gen); - }, - it); + using fp_dist_type = typename std::conditional<is_half, arm_compute::utils::uniform_real_distribution_fp16, std::uniform_real_distribution<T>>::type; + using dist_type = typename std::conditional<is_integral, std::uniform_int_distribution<T>, fp_dist_type>::type; - break; - } - default: - { - ARM_COMPUTE_ERROR("Unsupported format"); - } - } + std::mt19937 gen(seed); + dist_type dist(lower_bound, upper_bound); + + map(tensor, true); + + Window window; + window.use_tensor_dimensions(tensor.info()->tensor_shape()); + + Iterator it(&tensor, window); + execute_window_loop(window, [&](const Coordinates &) + { + *reinterpret_cast<T *>(it.ptr()) = dist(gen); + }, + it); unmap(tensor); } +template <typename T, typename TensorType> +void fill_random_tensor(TensorType &tensor, T lower_bound = std::numeric_limits<T>::lowest(), T upper_bound = std::numeric_limits<T>::max()) +{ + std::random_device rd; + fill_random_tensor(tensor, rd(), lower_bound, upper_bound); +} + template <typename T> void init_sgemm_output(T &dst, T &src0, T &src1, arm_compute::DataType dt) { |