diff options
Diffstat (limited to 'utils/GraphUtils.h')
-rw-r--r-- | utils/GraphUtils.h | 32 |
1 files changed, 30 insertions, 2 deletions
diff --git a/utils/GraphUtils.h b/utils/GraphUtils.h index a19f7e510d..c8cbb00237 100644 --- a/utils/GraphUtils.h +++ b/utils/GraphUtils.h @@ -24,9 +24,12 @@ #ifndef __ARM_COMPUTE_GRAPH_UTILS_H__ #define __ARM_COMPUTE_GRAPH_UTILS_H__ +#include "arm_compute/core/PixelValue.h" #include "arm_compute/graph/ITensorAccessor.h" #include "arm_compute/graph/Types.h" +#include <random> + namespace arm_compute { namespace graph_utils @@ -54,7 +57,7 @@ private: }; /** Dummy accessor class */ -class DummyAccessor : public graph::ITensorAccessor +class DummyAccessor final : public graph::ITensorAccessor { public: /** Constructor @@ -73,8 +76,33 @@ private: unsigned int _maximum; }; +/** Random accessor class */ +class RandomAccessor final : public graph::ITensorAccessor +{ +public: + /** Constructor + * + * @param[in] lower Lower bound value. + * @param[in] upper Upper bound value. + * @param[in] seed (Optional) Seed used to initialise the random number generator. + */ + RandomAccessor(PixelValue lower, PixelValue upper, const std::random_device::result_type seed = 0); + /** Allows instances to move constructed */ + RandomAccessor(RandomAccessor &&) = default; + + // Inherited methods overriden: + bool access_tensor(ITensor &tensor) override; + +private: + template <typename T, typename D> + void fill(ITensor &tensor, D &&distribution); + PixelValue _lower; + PixelValue _upper; + std::random_device::result_type _seed; +}; + /** Numpy Binary loader class*/ -class NumPyBinLoader : public graph::ITensorAccessor +class NumPyBinLoader final : public graph::ITensorAccessor { public: /** Default Constructor |