diff options
Diffstat (limited to 'utils')
-rw-r--r-- | utils/GraphUtils.cpp | 112 | ||||
-rw-r--r-- | utils/GraphUtils.h | 32 |
2 files changed, 142 insertions, 2 deletions
diff --git a/utils/GraphUtils.cpp b/utils/GraphUtils.cpp index d763606867..f0b0dded18 100644 --- a/utils/GraphUtils.cpp +++ b/utils/GraphUtils.cpp @@ -31,8 +31,10 @@ #endif /* ARM_COMPUTE_CL */ #include "arm_compute/core/Error.h" +#include "arm_compute/core/PixelValue.h" #include "libnpy/npy.hpp" +#include <random> #include <sstream> using namespace arm_compute::graph_utils; @@ -85,6 +87,116 @@ bool DummyAccessor::access_tensor(ITensor &tensor) return ret; } +RandomAccessor::RandomAccessor(PixelValue lower, PixelValue upper, std::random_device::result_type seed) + : _lower(lower), _upper(upper), _seed(seed) +{ +} + +template <typename T, typename D> +void RandomAccessor::fill(ITensor &tensor, D &&distribution) +{ + std::mt19937 gen(_seed); + + if(tensor.info()->padding().empty()) + { + for(size_t offset = 0; offset < tensor.info()->total_size(); offset += tensor.info()->element_size()) + { + const T value = distribution(gen); + *reinterpret_cast<T *>(tensor.buffer() + offset) = value; + } + } + else + { + // If tensor has padding accessing tensor elements through execution window. + Window window; + window.use_tensor_dimensions(tensor.info()->tensor_shape()); + + execute_window_loop(window, [&](const Coordinates & id) + { + const T value = distribution(gen); + *reinterpret_cast<T *>(tensor.ptr_to_element(id)) = value; + }); + } +} + +bool RandomAccessor::access_tensor(ITensor &tensor) +{ + switch(tensor.info()->data_type()) + { + case DataType::U8: + { + std::uniform_int_distribution<uint8_t> distribution_u8(_lower.get<uint8_t>(), _upper.get<uint8_t>()); + fill<uint8_t>(tensor, distribution_u8); + break; + } + case DataType::S8: + case DataType::QS8: + { + std::uniform_int_distribution<int8_t> distribution_s8(_lower.get<int8_t>(), _upper.get<int8_t>()); + fill<int8_t>(tensor, distribution_s8); + break; + } + case DataType::U16: + { + std::uniform_int_distribution<uint16_t> distribution_u16(_lower.get<uint16_t>(), _upper.get<uint16_t>()); + fill<uint16_t>(tensor, distribution_u16); + break; + } + case DataType::S16: + case DataType::QS16: + { + std::uniform_int_distribution<int16_t> distribution_s16(_lower.get<int16_t>(), _upper.get<int16_t>()); + fill<int16_t>(tensor, distribution_s16); + break; + } + case DataType::U32: + { + std::uniform_int_distribution<uint32_t> distribution_u32(_lower.get<uint32_t>(), _upper.get<uint32_t>()); + fill<uint32_t>(tensor, distribution_u32); + break; + } + case DataType::S32: + { + std::uniform_int_distribution<int32_t> distribution_s32(_lower.get<int32_t>(), _upper.get<int32_t>()); + fill<int32_t>(tensor, distribution_s32); + break; + } + case DataType::U64: + { + std::uniform_int_distribution<uint64_t> distribution_u64(_lower.get<uint64_t>(), _upper.get<uint64_t>()); + fill<uint64_t>(tensor, distribution_u64); + break; + } + case DataType::S64: + { + std::uniform_int_distribution<int64_t> distribution_s64(_lower.get<int64_t>(), _upper.get<int64_t>()); + fill<int64_t>(tensor, distribution_s64); + break; + } + case DataType::F16: + { + std::uniform_real_distribution<float> distribution_f16(_lower.get<float>(), _upper.get<float>()); + fill<float>(tensor, distribution_f16); + break; + } + case DataType::F32: + { + std::uniform_real_distribution<float> distribution_f32(_lower.get<float>(), _upper.get<float>()); + fill<float>(tensor, distribution_f32); + break; + } + case DataType::F64: + { + std::uniform_real_distribution<double> distribution_f64(_lower.get<double>(), _upper.get<double>()); + fill<double>(tensor, distribution_f64); + break; + } + default: + ARM_COMPUTE_ERROR("NOT SUPPORTED!"); + } + return true; +} + NumPyBinLoader::NumPyBinLoader(std::string filename) : _filename(std::move(filename)) { diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h index a19f7e510d..c8cbb00237 100644 --- a/utils/GraphUtils.h +++ b/utils/GraphUtils.h @@ -24,9 +24,12 @@ #ifndef __ARM_COMPUTE_GRAPH_UTILS_H__ #define __ARM_COMPUTE_GRAPH_UTILS_H__ +#include "arm_compute/core/PixelValue.h" #include "arm_compute/graph/ITensorAccessor.h" #include "arm_compute/graph/Types.h" +#include <random> + namespace arm_compute { namespace graph_utils @@ -54,7 +57,7 @@ private: }; /** Dummy accessor class */ -class DummyAccessor : public graph::ITensorAccessor +class DummyAccessor final : public graph::ITensorAccessor { public: /** Constructor @@ -73,8 +76,33 @@ private: unsigned int _maximum; }; +/** Random accessor class */ +class RandomAccessor final : public graph::ITensorAccessor +{ +public: + /** Constructor + * + * @param[in] lower Lower bound value. + * @param[in] upper Upper bound value. + * @param[in] seed (Optional) Seed used to initialise the random number generator. + */ + RandomAccessor(PixelValue lower, PixelValue upper, const std::random_device::result_type seed = 0); + /** Allows instances to move constructed */ + RandomAccessor(RandomAccessor &&) = default; + + // Inherited methods overriden: + bool access_tensor(ITensor &tensor) override; + +private: + template <typename T, typename D> + void fill(ITensor &tensor, D &&distribution); + PixelValue _lower; + PixelValue _upper; + std::random_device::result_type _seed; +}; + /** Numpy Binary loader class*/ -class NumPyBinLoader : public graph::ITensorAccessor +class NumPyBinLoader final : public graph::ITensorAccessor { public: /** Default Constructor |