aboutsummaryrefslogtreecommitdiff
path: root/utils/Utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'utils/Utils.h')
-rw-r--r--utils/Utils.h22
1 files changed, 8 insertions, 14 deletions
diff --git a/utils/Utils.h b/utils/Utils.h
index b10d18aca2..7eeeae5419 100644
--- a/utils/Utils.h
+++ b/utils/Utils.h
@@ -296,20 +296,15 @@ inline void unmap(GCTensor &tensor)
*/
class uniform_real_distribution_fp16
{
- half min{ 0.0f }, max{ 0.0f };
- std::uniform_real_distribution<float> neg{ min, -0.3f };
- std::uniform_real_distribution<float> pos{ 0.3f, max };
- std::uniform_int_distribution<uint8_t> sign_picker{ 0, 1 };
-
public:
using result_type = half;
/** Constructor
*
- * @param[in] a Minimum value of the distribution
- * @param[in] b Maximum value of the distribution
+ * @param[in] min Minimum value of the distribution
+ * @param[in] max Maximum value of the distribution
*/
- explicit uniform_real_distribution_fp16(half a = half(0.0), half b = half(1.0))
- : min(a), max(b)
+ explicit uniform_real_distribution_fp16(half min = half(0.0), half max = half(1.0))
+ : dist(min, max)
{
}
@@ -319,12 +314,11 @@ public:
*/
half operator()(std::mt19937 &gen)
{
- if(sign_picker(gen))
- {
- return (half)neg(gen);
- }
- return (half)pos(gen);
+ return half(dist(gen));
}
+
+private:
+ std::uniform_real_distribution<float> dist;
};
/** Numpy data loader */