本文共 9451 字,大约阅读时间需要 31 分钟。
相信在使用PyTorch的时候,大家都用过
torch.randperm
等随机数生成的接口,今天就来分析一下在PyTorch中使用的随机数生成及其背后蕴含的算法原理。
首先,需要定位随机数生成的代码,经过查找,随机数生成的代码位于pytorch/aten/src/TH/
下面的THRandom.h
和THRandom.cpp
。
THRandom.h
分析说明#ifndef TH_RANDOM_INC#define TH_RANDOM_INC#include // _MERSENNE_STATE_N 是递归长度; _MERSENNE_STATE_M 是周期参数,用作对旋转链执行旋转算法用到的偏移量。#define _MERSENNE_STATE_N 624#define _MERSENNE_STATE_M 397/* Struct definition is moved to THGenerator.hpp, because THRandom.hneeds to be C-compatible in order to be included in C FFI extensions. *//* 需要注意,这里的结构体定义是为了兼容C FFI扩展。(这里是) */typedef struct THGenerator THGenerator;typedef struct THGeneratorState THGeneratorState;#define torch_Generator "torch.Generator".../* Checks if given generator state is valid */TH_API int THGeneratorState_isValid(THGeneratorState *_gen_state);/* Initializes the random number generator from /dev/urandom (or on Windowsplatforms with the current time (granularity: seconds)) and returns the seed. */TH_API uint64_t THRandom_seed(THGenerator *_generator);/* Initializes the random number generator with the given int64_t "the_seed_". */TH_API void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_);/* Returns the starting seed used. */TH_API uint64_t THRandom_initialSeed(THGenerator *_generator);/* 生成32 bits的整型 */TH_API uint64_t THRandom_random(THGenerator *_generator);/* Generates a uniform 64 bits integer. */TH_API uint64_t THRandom_random64(THGenerator *_generator);/* Generates a uniform random double on [0,1). */TH_API double THRandom_standard_uniform(THGenerator *_generator);...#endif
在开头的预编译宏中定义了
①#define _MERSENNE_STATE_N 624
② #define _MERSENNE_STATE_M 397
那么,它们的作用是什么呢?这里先卖个关子,这块会在THRandom.cpp
中进行详细说明。
下面我们可以看到,有一些检查生成器(Generator)状态的函数,如THGeneratorState_isValid
等,因为我们的随机数生成函数接收的参数是生成器(Generator)
接下来,会看到THRandom_seed
和THRandom_manualSeed
两个随机数种子函数,根据文档上的说明(以Linux系统为例):THRandom_seed
是利用/dev/urandom
来对生成器(Generator)进行初始化。
补充知识:
/dev/urandom
记录Linux下的熵池,所谓熵池就是当前系统下的环境噪音,描述了一个系统的混乱程度,环境噪音由这几个方面组成,如内存的使用,文件的使用量,不同类型的进程数量等等,刚开机的时候系统噪音会较小,越到后面噪音会越大。关于如何使用
/dev/urandom
生成随机数,请看这篇文章——。
而THRandom_manualSeed
就很简单了——就是根据你传入的随机数种子对生成器(Generator)进行初始化(显然,如果随机数种子一样,如果传入THRandom_manualSeed
的生成器一样,那么初始化的结果也是一样的)。
THRandom_random
随机数生成函数这是本文关注的重点,此函数的签名为:TH_API uint64_t THRandom_random(THGenerator *_generator);
,其作用是生成32 bits的整型。
THRandom.cpp
分析说明#include #include "THRandom.h"#include "THGenerator.hpp".../* Code for the Mersenne Twister random generator.... */#define n _MERSENNE_STATE_N#define m _MERSENNE_STATE_M/* Creates (unseeded) new generator*/static THGenerator* THGenerator_newUnseeded(){ THGenerator *self = (THGenerator *)THAlloc(sizeof(THGenerator)); ... return self;}/* Creates new generator and makes sure it is seeded*/THGenerator* THGenerator_new(){ ...}#ifndef _WIN32static uint64_t readURandomLong(){ ...}#endif // _WIN32// 随机数生成uint64_t THRandom_seed(THGenerator *_generator){#ifdef _WIN32 uint64_t s = (uint64_t)time(0);#else uint64_t s = readURandomLong();#endif THRandom_manualSeed(_generator, s); return s;}.../* 下面是采用了日本人松本 眞和西村 拓士开发的基于梅森(Mersenne)素数的伪随机数生成器, (pseudorandom number generator)"A C-program for MT19937", 用到了一共4个函数以及一些宏定义,还有之前在THRandom.h定义的_MERSENNE_STATE_N 等宏。*//* 梅森旋转宏定义... *//* 周期参数 *//* #define n 624 *//* #define m 397 */#define MATRIX_A 0x9908b0dfUL /* constant vector a */.../*********************************************************** That's it. */void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_){ ...}uint64_t THRandom_initialSeed(THGenerator *_generator){ ...}void THRandom_nextState(THGenerator *_generator){ ...}uint64_t THRandom_random(THGenerator *_generator){ ...}}
上面是整体的结构,现在,让我们来一点点的分析这块的内容:
readURandomLong()
和THRandom_seed(THGenerator *_generator)
在第2章的的2.2节,里面包含着PyTorch中随机数种子生成的方法,这里进行详细介绍:
#ifndef _WIN32static uint64_t readURandomLong(){ int randDev = open("/dev/urandom", O_RDONLY); uint64_t randValue; if (randDev < 0) { THError("Unable to open /dev/urandom"); } ssize_t readBytes = read(randDev, &randValue, sizeof(randValue)); if (readBytes < (ssize_t) sizeof(randValue)) { THError("Unable to read from /dev/urandom"); } close(randDev); return randValue;}#endif // _WIN32uint64_t THRandom_seed(THGenerator *_generator){#ifdef _WIN32 uint64_t s = (uint64_t)time(0);#else uint64_t s = readURandomLong();#endif THRandom_manualSeed(_generator, s); return s;}
以Linux类操作系统平台为例,readURandomLong
的作用是根据/dev/urandom
的信息生成一个随机数种子——64位非负int。
THRandom_seed(THGenerator *_generator)
根据readURandomLong
生成的随机数种子,对生成器(generator)进行初始化,返回值为随机数种子。
梅森旋转算法三步走:1. 生成器初始化 2. 对旋转链执行旋转算法 3. 对旋转算法所得结果进行处理
MT19937-32的参数列表如下(在伪代码中使用):
·(w, n, m, r) = (32, 624, 397, 31)·a = 9908B0DF(16)·f = 1812433253·(u, d) = (11, FFFFFFFF16) # 经过我的实验,FFFFFFFF16跟0xFFFFFFFF是一样的·(s, b) = (7, 9D2C568016) # 2636928640·(t, c) = (15, EFC6000016) # 4022730752·l = 18
THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_)
THRandom_manualSeed
函数是梅森旋转4个函数中的第一个,调用此函数,就表明对初THGenerator 开始初始化,也就是梅森旋转算法的第一步。
算法的思路是:
首先将传入的seed(随机数种子)赋给MT[0]作为初值,然后根据递推式:
递推求出梅森旋转链。伪代码如下:
// 由一个seed初始化随机数产生器 function seed_mt(int seed) { index := n MT[0] := seed for i from 1 to (n - 1) { MT[i] := lowest w bits of (f * (MT[i-1] xor (MT[i-1] >> (w-2))) + i) } }
我们先来看看n
是什么?哈哈,其实这个n
就是在THRandom.h
定义的宏_MERSENNE_STATE_N
,同样的m
就是宏_MERSENNE_STATE_M
,在MT19937-32的梅森旋转算法:
void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_){ int j; /* This ensures reseeding resets all of the state (i.e. state for Gaussian numbers) */ .... _generator->gen_state.the_initial_seed = the_seed_; _generator->gen_state.state[0] = _generator->gen_state.the_initial_seed & 0xffffffffUL; for(j = 1; j < n; j++) { _generator->gen_state.state[j] = (1812433253UL * (_generator->gen_state.state[j-1] ^ (_generator->gen_state.state[j-1] >> 30)) + j); /* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */ /* In the previous versions, mSBs of the seed affect */ /* only mSBs of the array state[]. */ /* 2002/01/09 modified by makoto matsumoto */ _generator->gen_state.state[j] &= 0xffffffffUL; /* 对大于32bit的机器 */ } ...}
宏TWIST(u,v)和THRandom_nextState(THGenerator *_generator)
遍历旋转链,对每个MT[i],根据递推式:
其中,“||”代表连接的意思,即组合MT[i]的高 w-r 位和MT[i+1]的低 r 位,设组合后的数字为x,则xA的运算规则为(x0是最低位):
伪代码为:
lower_mask = (1 << r) - 1 // r = 31时,lower_mask = 2147483647upper_mask = !lower_mask // 旋转算法处理旋转链 function twist() { for i from 0 to (n-1) { // & 按为与 // 两者都为1为1,否则为0. // 1&1=1, 1&0=0, 0&1=0, 0&0=0 // & 举例: 5&3 = 1 解释: 101 011 相同位仅为个位1 ,故结果为 1 int x := (MT[i] & upper_mask)+ (MT[(i+1) mod n] & lower_mask) int xA := x >> 1 if (x mod 2) != 0 { // 最低位是1 xA := xA xor a } MT[i] := MT[(i + m) mod n] xor xA } index := 0}
PyTorch中的实际代码:
#define MATRIX_A 0x9908b0dfUL /* constant vector a */#define UMASK 0x80000000UL /* UpperMask w-r 比特 */#define LMASK 0x7fffffffUL /* LowerMask r 比特*/#define MIXBITS(u,v) ( ((u) & UMASK) | ((v) & LMASK) )#define TWIST(u,v) ((MIXBITS(u,v) >> 1) ^ ((v)&1UL ? MATRIX_A : 0UL))...void THRandom_nextState(THGenerator *_generator){ uint64_t *p = _generator->gen_state.state; int j; _generator->gen_state.left = n; _generator->gen_state.next = 0; for(j = n-m+1; --j; p++) *p = p[m] ^ TWIST(p[0], p[1]); for(j = m; --j; p++) *p = p[m-n] ^ TWIST(p[0], p[1]); *p = p[m-n] ^ TWIST(p[0], _generator->gen_state.state[0]);}
THRandom_random(THGenerator *_generator)
设x是当前序列的下一个值,y是一个临时中间变量,z是算法的返回值。则处理过程如下:
y := x ⊕ ((x >> u) & d) y := y ⊕ ((y << s) & b) y := y ⊕ ((y << t) & c) z := y ⊕ (y >> l)补充知识:掩码
位级运算的一个常见用法就是实现掩码运算,这里掩码是一个位模式,表示从一个字节中选出的位的集合。
看一个例子:掩码0xff(最低的8位为1)表示一个字的低位字节。位级运算 x & 0xff 生成一个由 x 的最低有效字节组成的值,而其他的字节就被置为0。 比如,对于 x = 0x89ABCDEF,其表达式将得到 0x000000EF。
伪代码如下:
// 从MT[index]中提取出一个经过处理的值// 每输出n个数字要执行一次旋转算法,以保证随机性 function extract_number() { if index >= n { if index > n { error "发生器尚未初始化" } twist() } int x := MT[index] y := x xor ((x >> u) and d) // u d包括下面的 s, b, t, c // 等参数可以看上面的MT19937-32的参数列表。 y := y xor ((y << s) and b) y := y xor ((y << t) and c) z := y xor (y >> l) index := index + 1 return lowest w bits of (z) }
PyTorch中的实际代码:
uint64_t THRandom_random(THGenerator *_generator){ uint64_t y; if (--(_generator->gen_state.left) == 0) THRandom_nextState(_generator); y = *(_generator->gen_state.state + (_generator->gen_state.next)++); /* Tempering */ y ^= (y >> 11); y ^= (y << 7) & 0x9d2c5680UL; y ^= (y << 15) & 0xefc60000UL; y ^= (y >> 18); return y;}
到这里,我们就知道了pytorch/aten/src/ATen/native/TensorFactories.cpp
这个Tensor的工厂函数封装中里面的一些诸如randperm
等用到随机数生成器的函数背后的秘密了——梅森旋转算法MT19937
但是,关于梅森算法的介绍并没有展开,这里只是把算法堆上去了,有兴趣的同学可以这篇文章,此作者对梅森算法进行了详细的解释说明。
转载地址:http://ltxaf.baihongyu.com/