From 53b405f1e08ad41cb9a527abfe0308ec1edf18ff Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Wed, 27 Sep 2017 15:55:31 +0100 Subject: COMPMID-417 - Add RandomAccessor support in Graph API. Change-Id: I54dd435258a2d0ff486ded64b23654bab6b80f3f Reviewed-on: http://mpd-gerrit.cambridge.arm.com/89373 Tested-by: Kaizen Reviewed-by: Georgios Pinitas --- utils/GraphUtils.cpp | 112 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) (limited to 'utils/GraphUtils.cpp') diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index d763606867..f0b0dded18 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -31,8 +31,10 @@ #endif /* ARM_COMPUTE_CL */ #include "arm_compute/core/Error.h" +#include "arm_compute/core/PixelValue.h" #include "libnpy/npy.hpp" +#include #include using namespace arm_compute::graph_utils; @@ -85,6 +87,116 @@ bool DummyAccessor::access_tensor(ITensor &tensor) return ret; } +RandomAccessor::RandomAccessor(PixelValue lower, PixelValue upper, std::random_device::result_type seed) + : _lower(lower), _upper(upper), _seed(seed) +{ +} + +template +void RandomAccessor::fill(ITensor &tensor, D &&distribution) +{ + std::mt19937 gen(_seed); + + if(tensor.info()->padding().empty()) + { + for(size_t offset = 0; offset < tensor.info()->total_size(); offset += tensor.info()->element_size()) + { + const T value = distribution(gen); + *reinterpret_cast(tensor.buffer() + offset) = value; + } + } + else + { + // If tensor has padding accessing tensor elements through execution window. + Window window; + window.use_tensor_dimensions(tensor.info()->tensor_shape()); + + execute_window_loop(window, [&](const Coordinates & id) + { + const T value = distribution(gen); + *reinterpret_cast(tensor.ptr_to_element(id)) = value; + }); + } +} + +bool RandomAccessor::access_tensor(ITensor &tensor) +{ + switch(tensor.info()->data_type()) + { + case DataType::U8: + { + std::uniform_int_distribution distribution_u8(_lower.get(), _upper.get()); + fill(tensor, distribution_u8); + break; + } + case DataType::S8: + case DataType::QS8: + { + std::uniform_int_distribution distribution_s8(_lower.get(), _upper.get()); + fill(tensor, distribution_s8); + break; + } + case DataType::U16: + { + std::uniform_int_distribution distribution_u16(_lower.get(), _upper.get()); + fill(tensor, distribution_u16); + break; + } + case DataType::S16: + case DataType::QS16: + { + std::uniform_int_distribution distribution_s16(_lower.get(), _upper.get()); + fill(tensor, distribution_s16); + break; + } + case DataType::U32: + { + std::uniform_int_distribution distribution_u32(_lower.get(), _upper.get()); + fill(tensor, distribution_u32); + break; + } + case DataType::S32: + { + std::uniform_int_distribution distribution_s32(_lower.get(), _upper.get()); + fill(tensor, distribution_s32); + break; + } + case DataType::U64: + { + std::uniform_int_distribution distribution_u64(_lower.get(), _upper.get()); + fill(tensor, distribution_u64); + break; + } + case DataType::S64: + { + std::uniform_int_distribution distribution_s64(_lower.get(), _upper.get()); + fill(tensor, distribution_s64); + break; + } + case DataType::F16: + { + std::uniform_real_distribution distribution_f16(_lower.get(), _upper.get()); + fill(tensor, distribution_f16); + break; + } + case DataType::F32: + { + std::uniform_real_distribution distribution_f32(_lower.get(), _upper.get()); + fill(tensor, distribution_f32); + break; + } + case DataType::F64: + { + std::uniform_real_distribution distribution_f64(_lower.get(), _upper.get()); + fill(tensor, distribution_f64); + break; + } + default: + ARM_COMPUTE_ERROR("NOT SUPPORTED!"); + } + return true; +} + NumPyBinLoader::NumPyBinLoader(std::string filename) : _filename(std::move(filename)) { -- cgit v1.2.1