diff options
Diffstat (limited to 'utils/GraphUtils.cpp')
-rw-r--r-- | utils/GraphUtils.cpp | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index d763606867..f0b0dded18 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -31,8 +31,10 @@ #endif /* ARM_COMPUTE_CL */ #include "arm_compute/core/Error.h" +#include "arm_compute/core/PixelValue.h" #include "libnpy/npy.hpp" +#include <random> #include <sstream> using namespace arm_compute::graph_utils; @@ -85,6 +87,116 @@ bool DummyAccessor::access_tensor(ITensor &tensor) return ret; } +RandomAccessor::RandomAccessor(PixelValue lower, PixelValue upper, std::random_device::result_type seed) + : _lower(lower), _upper(upper), _seed(seed) +{ +} + +template <typename T, typename D> +void RandomAccessor::fill(ITensor &tensor, D &&distribution) +{ + std::mt19937 gen(_seed); + + if(tensor.info()->padding().empty()) + { + for(size_t offset = 0; offset < tensor.info()->total_size(); offset += tensor.info()->element_size()) + { + const T value = distribution(gen); + *reinterpret_cast<T *>(tensor.buffer() + offset) = value; + } + } + else + { + // If tensor has padding accessing tensor elements through execution window. + Window window; + window.use_tensor_dimensions(tensor.info()->tensor_shape()); + + execute_window_loop(window, [&](const Coordinates & id) + { + const T value = distribution(gen); + *reinterpret_cast<T *>(tensor.ptr_to_element(id)) = value; + }); + } +} + +bool RandomAccessor::access_tensor(ITensor &tensor) +{ + switch(tensor.info()->data_type()) + { + case DataType::U8: + { + std::uniform_int_distribution<uint8_t> distribution_u8(_lower.get<uint8_t>(), _upper.get<uint8_t>()); + fill<uint8_t>(tensor, distribution_u8); + break; + } + case DataType::S8: + case DataType::QS8: + { + std::uniform_int_distribution<int8_t> distribution_s8(_lower.get<int8_t>(), _upper.get<int8_t>()); + fill<int8_t>(tensor, distribution_s8); + break; + } + case DataType::U16: + { + std::uniform_int_distribution<uint16_t> distribution_u16(_lower.get<uint16_t>(), _upper.get<uint16_t>()); + fill<uint16_t>(tensor, distribution_u16); + break; + } + case DataType::S16: + case DataType::QS16: + { + std::uniform_int_distribution<int16_t> distribution_s16(_lower.get<int16_t>(), _upper.get<int16_t>()); + fill<int16_t>(tensor, distribution_s16); + break; + } + case DataType::U32: + { + std::uniform_int_distribution<uint32_t> distribution_u32(_lower.get<uint32_t>(), _upper.get<uint32_t>()); + fill<uint32_t>(tensor, distribution_u32); + break; + } + case DataType::S32: + { + std::uniform_int_distribution<int32_t> distribution_s32(_lower.get<int32_t>(), _upper.get<int32_t>()); + fill<int32_t>(tensor, distribution_s32); + break; + } + case DataType::U64: + { + std::uniform_int_distribution<uint64_t> distribution_u64(_lower.get<uint64_t>(), _upper.get<uint64_t>()); + fill<uint64_t>(tensor, distribution_u64); + break; + } + case DataType::S64: + { + std::uniform_int_distribution<int64_t> distribution_s64(_lower.get<int64_t>(), _upper.get<int64_t>()); + fill<int64_t>(tensor, distribution_s64); + break; + } + case DataType::F16: + { + std::uniform_real_distribution<float> distribution_f16(_lower.get<float>(), _upper.get<float>()); + fill<float>(tensor, distribution_f16); + break; + } + case DataType::F32: + { + std::uniform_real_distribution<float> distribution_f32(_lower.get<float>(), _upper.get<float>()); + fill<float>(tensor, distribution_f32); + break; + } + case DataType::F64: + { + std::uniform_real_distribution<double> distribution_f64(_lower.get<double>(), _upper.get<double>()); + fill<double>(tensor, distribution_f64); + break; + } + default: + ARM_COMPUTE_ERROR("NOT SUPPORTED!"); + } + return true; +} + NumPyBinLoader::NumPyBinLoader(std::string filename) : _filename(std::move(filename)) { |