博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch学习笔记(8)——PyTorch之随机数生成
阅读量:2032 次
发布时间:2019-04-28

本文共 9451 字,大约阅读时间需要 31 分钟。

0.前言(基于Torch0.4.1)

相信在使用PyTorch的时候,大家都用过torch.randperm等随机数生成的接口,今天就来分析一下在PyTorch中使用的随机数生成及其背后蕴含的算法原理。


1. 定位源码

首先,需要定位随机数生成的代码,经过查找,随机数生成的代码位于pytorch/aten/src/TH/下面的THRandom.hTHRandom.cpp


2. THRandom.h分析说明

#ifndef TH_RANDOM_INC#define TH_RANDOM_INC#include // _MERSENNE_STATE_N 是递归长度; _MERSENNE_STATE_M 是周期参数,用作对旋转链执行旋转算法用到的偏移量。#define _MERSENNE_STATE_N 624#define _MERSENNE_STATE_M 397/* Struct definition is moved to THGenerator.hpp, because THRandom.hneeds to be C-compatible in order to be included in C FFI extensions. *//* 需要注意,这里的结构体定义是为了兼容C FFI扩展。(这里是) */typedef struct THGenerator THGenerator;typedef struct THGeneratorState THGeneratorState;#define torch_Generator "torch.Generator".../* Checks if given generator state is valid */TH_API int THGeneratorState_isValid(THGeneratorState *_gen_state);/* Initializes the random number generator from /dev/urandom (or on Windowsplatforms with the current time (granularity: seconds)) and returns the seed. */TH_API uint64_t THRandom_seed(THGenerator *_generator);/* Initializes the random number generator with the given int64_t "the_seed_". */TH_API void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_);/* Returns the starting seed used. */TH_API uint64_t THRandom_initialSeed(THGenerator *_generator);/* 生成32 bits的整型 */TH_API uint64_t THRandom_random(THGenerator *_generator);/* Generates a uniform 64 bits integer. */TH_API uint64_t THRandom_random64(THGenerator *_generator);/* Generates a uniform random double on [0,1). */TH_API double THRandom_standard_uniform(THGenerator *_generator);...#endif

2.1 宏定义

在开头的预编译宏中定义了

#define _MERSENNE_STATE_N 624
#define _MERSENNE_STATE_M 397

那么,它们的作用是什么呢?这里先卖个关子,这块会在THRandom.cpp中进行详细说明。


下面我们可以看到,有一些检查生成器(Generator)状态的函数,如THGeneratorState_isValid等,因为我们的随机数生成函数接收的参数是生成器(Generator)

2.2 随机数种子函数

接下来,会看到THRandom_seedTHRandom_manualSeed两个随机数种子函数,根据文档上的说明(以Linux系统为例):THRandom_seed是利用/dev/urandom来对生成器(Generator)进行初始化。

补充知识:/dev/urandom记录Linux下的熵池,所谓熵池就是当前系统下的环境噪音,描述了一个系统的混乱程度,环境噪音由这几个方面组成,如内存的使用,文件的使用量,不同类型的进程数量等等,刚开机的时候系统噪音会较小,越到后面噪音会越大。

关于如何使用/dev/urandom生成随机数,请看这篇文章——。

THRandom_manualSeed就很简单了——就是根据你传入的随机数种子对生成器(Generator)进行初始化(显然,如果随机数种子一样,如果传入THRandom_manualSeed的生成器一样,那么初始化的结果也是一样的)。

2.3 THRandom_random随机数生成函数

这是本文关注的重点,此函数的签名为:TH_API uint64_t THRandom_random(THGenerator *_generator);,其作用是生成32 bits的整型


3. THRandom.cpp分析说明

#include #include "THRandom.h"#include "THGenerator.hpp".../* Code for the Mersenne Twister random generator.... */#define n _MERSENNE_STATE_N#define m _MERSENNE_STATE_M/* Creates (unseeded) new generator*/static THGenerator* THGenerator_newUnseeded(){  THGenerator *self = (THGenerator *)THAlloc(sizeof(THGenerator));  ...  return self;}/* Creates new generator and makes sure it is seeded*/THGenerator* THGenerator_new(){  ...}#ifndef _WIN32static uint64_t readURandomLong(){  ...}#endif // _WIN32// 随机数生成uint64_t THRandom_seed(THGenerator *_generator){#ifdef _WIN32  uint64_t s = (uint64_t)time(0);#else  uint64_t s = readURandomLong();#endif  THRandom_manualSeed(_generator, s);  return s;}.../*  下面是采用了日本人松本 眞和西村 拓士开发的基于梅森(Mersenne)素数的伪随机数生成器, (pseudorandom number generator)"A C-program for MT19937", 用到了一共4个函数以及一些宏定义,还有之前在THRandom.h定义的_MERSENNE_STATE_N 等宏。*//* 梅森旋转宏定义... *//* 周期参数 *//* #define n 624 *//* #define m 397 */#define MATRIX_A 0x9908b0dfUL   /* constant vector a */.../*********************************************************** That's it. */void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_){  ...}uint64_t THRandom_initialSeed(THGenerator *_generator){  ...}void THRandom_nextState(THGenerator *_generator){ ...}uint64_t THRandom_random(THGenerator *_generator){  ...}}

上面是整体的结构,现在,让我们来一点点的分析这块的内容:

3.1 随机数种子生成&初始化生成器——readURandomLong()THRandom_seed(THGenerator *_generator)

在第2章的的2.2节,里面包含着PyTorch中随机数种子生成的方法,这里进行详细介绍:

#ifndef _WIN32static uint64_t readURandomLong(){  int randDev = open("/dev/urandom", O_RDONLY);  uint64_t randValue;  if (randDev < 0) {    THError("Unable to open /dev/urandom");  }  ssize_t readBytes = read(randDev, &randValue, sizeof(randValue));  if (readBytes < (ssize_t) sizeof(randValue)) {    THError("Unable to read from /dev/urandom");  }  close(randDev);  return randValue;}#endif // _WIN32uint64_t THRandom_seed(THGenerator *_generator){#ifdef _WIN32  uint64_t s = (uint64_t)time(0);#else  uint64_t s = readURandomLong();#endif  THRandom_manualSeed(_generator, s);  return s;}

以Linux类操作系统平台为例,readURandomLong的作用是根据/dev/urandom的信息生成一个随机数种子——64位非负int

THRandom_seed(THGenerator *_generator)根据readURandomLong生成的随机数种子,对生成器(generator)进行初始化,返回值为随机数种子。

梅森旋转算法三步走:1. 生成器初始化 2. 对旋转链执行旋转算法 3. 对旋转算法所得结果进行处理

MT19937-32的参数列表如下(在伪代码中使用):

·(w, n, m, r) = (32, 624, 397, 31)·a = 9908B0DF(16)·f = 1812433253·(u, d) = (11, FFFFFFFF16) # 经过我的实验,FFFFFFFF16跟0xFFFFFFFF是一样的·(s, b) = (7, 9D2C568016) # 2636928640·(t, c) = (15, EFC6000016)  # 4022730752·l = 18

3.2 生成器初始化——THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_)

THRandom_manualSeed函数是梅森旋转4个函数中的第一个,调用此函数,就表明对初THGenerator 开始初始化,也就是梅森旋转算法的第一步。

算法的思路是:

首先将传入的seed(随机数种子)赋给MT[0]作为初值,然后根据递推式:

MT[i]=f×(MT[i1](MT[i1]>>(w2)))+i M T [ i ] = f × ( M T [ i − 1 ] ⊕ ( M T [ i − 1 ] >> ( w − 2 ) ) ) + i

递推求出梅森旋转链。伪代码如下:

// 由一个seed初始化随机数产生器 function seed_mt(int seed) {     index := n     MT[0] := seed     for i from 1 to (n - 1) {         MT[i] := lowest w bits of (f * (MT[i-1] xor (MT[i-1] >> (w-2))) + i)     } }

我们先来看看n是什么?哈哈,其实这个n就是在THRandom.h定义的宏_MERSENNE_STATE_N,同样的m就是宏_MERSENNE_STATE_M,在MT19937-32的梅森旋转算法:

n=624;m=397 n = 624 ; m = 397
PyTorch中的实际代码:

void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_){  int j;  /* This ensures reseeding resets all of the state (i.e. state for Gaussian numbers) */  ....  _generator->gen_state.the_initial_seed = the_seed_;  _generator->gen_state.state[0] = _generator->gen_state.the_initial_seed & 0xffffffffUL;  for(j = 1; j < n; j++)  {    _generator->gen_state.state[j] = (1812433253UL * (_generator->gen_state.state[j-1] ^ (_generator->gen_state.state[j-1] >> 30)) + j);    /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */    /* In the previous versions, mSBs of the seed affect   */    /* only mSBs of the array state[].                        */    /* 2002/01/09 modified by makoto matsumoto             */    _generator->gen_state.state[j] &= 0xffffffffUL;  /* 对大于32bit的机器 */  }  ...}

3.3 对旋转链执行旋转算法——宏TWIST(u,v)和THRandom_nextState(THGenerator *_generator)

遍历旋转链,对每个MT[i],根据递推式:

MT[i]=MT[i+m]((upper_mask(MT[i])||lower_mask(MT[i+1]))A M T [ i ] = M T [ i + m ] ⊕ ( ( u p p e r _ m a s k ( M T [ i ] ) | | l o w e r _ m a s k ( M T [ i + 1 ] ) ) A )
进行旋转链处理。

其中,“||”代表连接的意思,即组合MT[i]的高 w-r 位和MT[i+1]的低 r 位,设组合后的数字为x,则xA的运算规则为(x0是最低位):

xA={
x>>1,(x>>1)a,if x0= 0if x0 = 1
x A = { x >> 1 , if  x 0 = 0 ( x >> 1 ) ⊕ a , if  x 0  = 1
其中,
x0 x 0
是最低位,
a a <script type="math/tex" id="MathJax-Element-22">a</script> = 0x9908B0DF,也就是源代码里面的宏
MATRIX_A

伪代码为:

lower_mask = (1 << r) - 1 // r = 31时,lower_mask = 2147483647upper_mask = !lower_mask // 旋转算法处理旋转链  function twist() {
for i from 0 to (n-1) { // & 按为与 // 两者都为1为1,否则为0. // 1&1=1, 1&0=0, 0&1=0, 0&0=0 // & 举例: 5&3 = 1 解释: 101 011 相同位仅为个位1 ,故结果为 1 int x := (MT[i] & upper_mask)+ (MT[(i+1) mod n] & lower_mask) int xA := x >> 1 if (x mod 2) != 0 { // 最低位是1 xA := xA xor a } MT[i] := MT[(i + m) mod n] xor xA } index := 0}

PyTorch中的实际代码:

#define MATRIX_A 0x9908b0dfUL   /* constant vector a */#define UMASK 0x80000000UL /* UpperMask w-r 比特 */#define LMASK 0x7fffffffUL /* LowerMask r 比特*/#define MIXBITS(u,v) ( ((u) & UMASK) | ((v) & LMASK) )#define TWIST(u,v) ((MIXBITS(u,v) >> 1) ^ ((v)&1UL ? MATRIX_A : 0UL))...void THRandom_nextState(THGenerator *_generator){  uint64_t *p = _generator->gen_state.state;  int j;  _generator->gen_state.left = n;  _generator->gen_state.next = 0;  for(j = n-m+1; --j; p++)    *p = p[m] ^ TWIST(p[0], p[1]);  for(j = m; --j; p++)    *p = p[m-n] ^ TWIST(p[0], p[1]);  *p = p[m-n] ^ TWIST(p[0], _generator->gen_state.state[0]);}

3.4 对旋转算法所得结果进行处理——THRandom_random(THGenerator *_generator)

设x是当前序列的下一个值,y是一个临时中间变量,z是算法的返回值。则处理过程如下:

y := x ⊕ ((x >> u) & d)
y := y ⊕ ((y << s) & b)
y := y ⊕ ((y << t) & c)
z := y ⊕ (y >> l)

补充知识:掩码

位级运算的一个常见用法就是实现掩码运算,这里掩码是一个位模式,表示从一个字节中选出的位的集合。

看一个例子:掩码0xff(最低的8位为1)表示一个字的低位字节。位级运算 x & 0xff 生成一个由 x 的最低有效字节组成的值,而其他的字节就被置为0。 比如,对于 x = 0x89ABCDEF,其表达式将得到 0x000000EF。

伪代码如下:

// 从MT[index]中提取出一个经过处理的值// 每输出n个数字要执行一次旋转算法,以保证随机性 function extract_number() {     if index >= n {         if index > n {           error "发生器尚未初始化"         }         twist()     }     int x := MT[index]     y := x xor ((x >> u) and d) // u d包括下面的 s, b, t, c                                 // 等参数可以看上面的MT19937-32的参数列表。     y := y xor ((y << s) and b)     y := y xor ((y << t) and c)     z := y xor (y >> l)     index := index + 1     return lowest w bits of (z) }

PyTorch中的实际代码:

uint64_t THRandom_random(THGenerator *_generator){  uint64_t y;  if (--(_generator->gen_state.left) == 0)    THRandom_nextState(_generator);  y = *(_generator->gen_state.state + (_generator->gen_state.next)++);  /* Tempering */  y ^= (y >> 11);  y ^= (y << 7) & 0x9d2c5680UL;  y ^= (y << 15) & 0xefc60000UL;  y ^= (y >> 18);  return y;}

4. 总结

到这里,我们就知道了pytorch/aten/src/ATen/native/TensorFactories.cpp这个Tensor的工厂函数封装中里面的一些诸如randperm等用到随机数生成器的函数背后的秘密了——梅森旋转算法MT19937

但是,关于梅森算法的介绍并没有展开,这里只是把算法堆上去了,有兴趣的同学可以这篇文章,此作者对梅森算法进行了详细的解释说明。

转载地址:http://ltxaf.baihongyu.com/

你可能感兴趣的文章
python学习手册笔记——16.函数基础
查看>>
python学习手册笔记——20.迭代和解析
查看>>
python学习手册笔记——30.类的设计
查看>>
Spring Boot 使用 Log4j2 & Logback 输出日志到 EKL
查看>>
使用 febootstrap 制作自定义基础镜像
查看>>
Big Analytice with Cassandra
查看>>
spring多个AOP执行先后顺序(面试问题:怎么控制多个aop的执行循序)
查看>>
leetcode 之 Single Number II
查看>>
关于AOP无法切入同类调用方法的问题
查看>>
两线程交替打印数字
查看>>
Post with HttpClient4
查看>>
打靶问题 一个射击运动员打靶,靶一共有10环,连开10枪打中90环的可能行有多少种?...
查看>>
zk的watcher机制的实现
查看>>
缓存兼容性
查看>>
Hessian序列化
查看>>
Thread的中断机制(interrupt)
查看>>
[LeetCode] 268. Missing Number ☆(丢失的数字)
查看>>
http1.0 1.1 2.0区别
查看>>
spring bean生命周期
查看>>
学习成长之路
查看>>