SHA256 哈希算法原理和 Rust 实现

介绍

作为日常生活中每个网络用户都会使用的算法,SHA256 的原理可能却鲜有人知。

SHA256 的输入是任意长度的字节串,输出是一个 256 bit 的字节串。例如,

  • 输入:字符串 abc(相当于字节数组 [0x6c, 0x65, 0x73, 0x73, 0x2d, 0x62, 0x75, 0x67, 0x2e, 0x63, 0x6f, 0x6d]

  • 输出:

    1[0x20, 0x1a, 0xb7, 0x03, 0x64, 0x78, 0xee, 0x8a, 0xba, 0x2e, 0xb3, 0xc7, 0xc0, 0xc6, 0x5b, 0x8b, 0x17, 0x49, 0x58, 0x52, 0xf4, 0xeb, 0xe1, 0x5b, 0x79, 0xe5, 0x34, 0x39, 0x6f, 0x29, 0x0c, 0x5b]
    

示例

我们以 less-bug.com 为例,来演示 SHA256 的计算过程。

我们有 8 个特殊的常量,成为初始哈希变量(initial hash variables),它们是 32 位无符号整数,用十六进制表示如下:

 1static H: [u32; 8] = [
 2    0x6a09e667,
 3    0xbb67ae85,
 4    0x3c6ef372,
 5    0xa54ff53a,
 6    0x510e527f,
 7    0x9b05688c,
 8    0x1f83d9ab,
 9    0x5be0cd19,
10];

它们之所以是这些值,其实是发明算法的人自己选的。他从自然数中前 8 个质数的平方根的小数部分中,取前 32 位作为初始值。

举个例子,0x6a09e667 通过如下过程计算:

1>>> hex(int(2**32 *(2**0.5 - int(2**0.5))))
2'0x6a09e667'

同理,所有 8 个初始值可以计算出来:

1primes = [2, 3, 5, 7, 11, 13, 17, 19]
2for p in primes:
3    print(hex(int(2**32 *(p**0.5 - int(p**0.5)))))

定义初始回合常量(initial round constant variables):

 1const K: [u32; 64] = [
 2    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
 3    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
 4    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
 5    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
 6    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
 7    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
 8    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
 9    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
10];

它们是从前 64 个质数的立方根的小数部分中,取前 32 位作为初始值。下面是计算过程:

1primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311]
2for p in primes:
3    print(hex(int(2**32 *(p**0.3333333333333333 - int(p**0.3333333333333333)))))

现在处理我们的输入 less-bug.com,它对应的 ASCII 编码是

1[0x6c, 0x65, 0x73, 0x73, 0x2d, 0x62, 0x75, 0x67, 0x2e, 0x63, 0x6f, 0x6d]

转换成二进制是

101101100, 01100101, 01110011, 01110011
200101101, 01100010, 01110101, 01100111
300101110, 01100011, 01101111, 01101101

,输入数据大小为 16 个字节.

然后需要对数据分组。要求每 512 位(64 字节)为一组。因此 16 / 64 = 0,可知一共需要 1 组。并且最后一组的填充长度需要进一步计算。

分为两种情况讨论。

  1. 如果数据长度小于 448 位(56 字节),则在数据后面添加一个 1,然后添加足够的 0,使得数据长度对 512 取模后,等于 448(即 512-64,减去的 64 位留作他用)。

  2. 如果数据长度大于等于 448 位(56 字节),则在数据后面添加一个 1,然后添加足够的 0,使得数据长度对 512 取模后,等于 0。然后再添加一个新的分组,这个分组的数据全部为 0,只有最后 64 位为数据长度。

计算公式如下:

1let padding_len = if input_len % 64 < 56 {
2    56 - input_len % 64
3} else {
4    120 - input_len % 64
5};

这里的 56 是 448 位(56 字节),120 是 56 + 64 = 120 字节,表示添加了新的分组,新组的数据填充长度。

我们的输入对应于第一种情况,可知填充长度为 44. 44 中第一个字节用于填充 1000 0000,这是 SHA-256 的填充规则。后面 43 个字节全部填充 0000 0000

101101100, 01100101, 01110011, 01110011
200101101, 01100010, 01110101, 01100111
300101110, 01100011, 01101111, 01101101
410000000, 00000000, ...,      00000000

最后的 64 位用于存放位长度。由于输入为 16 字节,位长度为 64 字节,需要写入的到最后 64 位的数据是:

10x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60

这些数据是逐个写入的,每一个的计算规则如下:

1byte = ((input_bit_len >> (7 * 8 - i * 8)) & 0xff)

这里 (7 * 8 - i * 8) 的意思是,从高位开始,每 8 位取一次,取 8 次,即取 64 位。& 0xff 的意思是取低 8 位,避免把高位取进来。

我们输入较短,只会产生一个分组(block),但为了完整说明算法,下面介绍所有 block 的计算过程。

首先,这个过程的输入是待处理的字节数据,输出是最终的哈希值。内部主要是一个循环,每次循环处理一个 block,每个 block 的长度为 512 位(64 字节)。结果哈希则在循环存放,从而能够在每次迭代中更新这个结果(其实就是加到上面)。

预留数组(schedule array)用于存放每一轮的数据,在每一轮循环中生成,长度为 64 位,名称为 w。初始值分为两部分:

  1. 前 16 字节,就是输入数据在对应 block 的字节数据

    1let mut w = [0; 64];
    2for j in 0..16 {
    3    let offset = i * 64 + j * 4;
    4    w[j] = ((input_bytes[offset + 0] as u32) << 24)
    5        | ((input_bytes[offset + 1] as u32) << 16)
    6        | ((input_bytes[offset + 2] as u32) << 8)
    7        | ((input_bytes[offset + 3] as u32) << 0);
    8}
    
  2. 后 48 字节,需要通过如下公式计算:

    1w[i] = w[i-16] + sigma0(w[i-15]) + w[i-7] + sigma1(w[i-2])
    

    其中的 sigma0 和 sigma1 是:

    1sigma0(i) = (w[i - 15] rotr 7) xor (w[i - 15] rotr 18) xor (w[i - 15] shr 3)
    2sigma1(i) = (w[i - 2] rotr 17) xor (w[i - 2] rotr 19) xor (w[i - 2] shr 10)
    

    其中 rotr 表示循环右移,shr 表示右移。

    例如:0b1001 rotr 2 表示让 0b1001 循环右移 2 位,即 0b1100。右侧移除的位在左侧补出。 0b1001 shr 2 表示让 0b1001 右移 2 位,高位补零,即 0b0010

    具体代码如下:

    1for j in 16..64 {
    2let s0 = w[j - 15].rotate_right(7) ^ w[j - 15].rotate_right(18) ^ (w[j - 15] >> 3);
    3let s1 = w[j - 2].rotate_right(17) ^ w[j - 2].rotate_right(19) ^ (w[j - 2] >> 10);
    4w[j] = w[j - 16]
    5    .wrapping_add(s0)
    6    .wrapping_add(w[j - 7])
    7    .wrapping_add(s1);
    8}
    

然后我们定义 8 个工作变量 working[8],它们的初始值就是当前哈希值。然后通过如下公式迭代:

 1s1 = (e rotr 6) xor (e rotr 11) xor (e rotr 25)
 2choose = (e and f) xor ((not e) and g)
 3temp1 = h + s1 + choose + k[i] + w[i]
 4s0 = (a rotr 2) xor (a rotr 13) xor (a rotr 22)
 5major = (a and b) xor (a and c) xor (b and c)
 6temp2 = s0 + major
 7
 8working[7] = working[6]
 9working[6] = working[5]
10working[5] = working[4]
11working[4] = working[3] + temp1
12working[3] = working[2]
13working[2] = working[1]
14working[1] = working[0]
15working[0] = temp1 + temp2

迭代次数为 64 次,和工作变量的长度相同,从而将每个哈希值参与运算。

迭代完成后,最后将工作变量和哈希值相加,得到本 block 最终的哈希值。然后继续进入下一个 block 的运算直到所有 block 都处理完毕。

最后得到的就是 SHA256 哈希值。

完整的代码如下:

  1fn compress(input_bytes: &Vec<u8>) -> [u32; 8] {
  2    static H: [u32; 8] = [
  3        0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab,
  4        0x5be0cd19,
  5    ];
  6    static K: [u32; 64] = [
  7        0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
  8        0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
  9        0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
 10        0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
 11        0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
 12        0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
 13        0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
 14        0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
 15        0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
 16        0xc67178f2,
 17    ];
 18
 19    let mut hash = H;
 20    let nblocks = input_bytes.len() / 64;
 21    for i in 0..nblocks {
 22        let mut w = [0; 64];
 23        for j in 0..16 {
 24            let offset = i * 64 + j * 4;
 25            w[j] = ((input_bytes[offset + 0] as u32) << 24)
 26                | ((input_bytes[offset + 1] as u32) << 16)
 27                | ((input_bytes[offset + 2] as u32) << 8)
 28                | ((input_bytes[offset + 3] as u32) << 0);
 29        }
 30        for j in 16..64 {
 31            let s0 = w[j - 15].rotate_right(7) ^ w[j - 15].rotate_right(18) ^ (w[j - 15] >> 3);
 32            let s1 = w[j - 2].rotate_right(17) ^ w[j - 2].rotate_right(19) ^ (w[j - 2] >> 10);
 33            w[j] = w[j - 16]
 34                .wrapping_add(s0)
 35                .wrapping_add(w[j - 7])
 36                .wrapping_add(s1);
 37        }
 38        let mut working: [u32; 8] = hash; // working variables
 39        for j in 0..64 {
 40            let s1 = working[4].rotate_right(6)
 41                ^ working[4].rotate_right(11)
 42                ^ working[4].rotate_right(25);
 43            let choose = (working[4] & working[5]) ^ ((!working[4]) & working[6]);
 44            let temp1 = working[7]
 45                .wrapping_add(s1)
 46                .wrapping_add(choose)
 47                .wrapping_add(K[j])
 48                .wrapping_add(w[j]);
 49            let s0 = working[0].rotate_right(2)
 50                ^ working[0].rotate_right(13)
 51                ^ working[0].rotate_right(22);
 52            let major =
 53                (working[0] & working[1]) ^ (working[0] & working[2]) ^ (working[1] & working[2]);
 54            let temp2 = s0.wrapping_add(major);
 55            working[7] = working[6];
 56            working[6] = working[5];
 57            working[5] = working[4];
 58            working[4] = working[3].wrapping_add(temp1);
 59            working[3] = working[2];
 60            working[2] = working[1];
 61            working[1] = working[0];
 62            working[0] = temp1.wrapping_add(temp2);
 63        }
 64
 65        for j in 0..8 {
 66            hash[j] = hash[j].wrapping_add(working[j]);
 67        }
 68    }
 69    return hash;
 70}
 71
 72pub fn sha256(input: &[u8]) -> [u8; 32] {
 73    // padding
 74    let mut input_bytes = input.to_vec();
 75    let input_len = input_bytes.len();
 76    let padding_len = if input_len % 64 < 56 {
 77        56 - input_len % 64
 78    } else {
 79        120 - input_len % 64
 80    };
 81
 82    input_bytes.push(0x80); // 1000 0000
 83    for _ in 0..padding_len - 1 {
 84        input_bytes.push(0x00);
 85    }
 86    assert!(input_bytes.len() % 64 == 56);
 87    let input_bit_len = input_len * 8;
 88    for i in 0..8 {
 89        let byte = ((input_bit_len >> (56 - i * 8)) & 0xff) as u8;
 90
 91        input_bytes.push(byte);
 92    }
 93
 94    assert!(input_bytes.len() % 64 == 0);
 95
 96    let hash = compress(&input_bytes);
 97
 98    let mut ret = [0 as u8; 32];
 99    for i in 0..8 {
100        ret[i * 4 + 0] = ((hash[i] >> 24) & 0xff) as u8;
101        ret[i * 4 + 1] = ((hash[i] >> 16) & 0xff) as u8;
102        ret[i * 4 + 2] = ((hash[i] >> 8) & 0xff) as u8;
103        ret[i * 4 + 3] = ((hash[i] >> 0) & 0xff) as u8;
104    }
105
106    ret
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    fn u8array_to_string(arr: &[u8]) -> String {
113        let mut ret = String::new();
114        for i in arr {
115            ret.push_str(&format!("{:02x}", i));
116        }
117        ret
118    }
119    #[test]
120    fn short_string_test() {
121        let tests = [
122            (
123                "",
124                "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
125            ),
126            (
127                "less-bug.com",
128                "201ab7036478ee8aba2eb3c7c0c65b8b17495852f4ebe15b79e534396f290c5b",
129            ),
130            (
131                "abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq",
132                "248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1",
133            ),
134        ];
135        for (input, expected) in tests.iter() {
136            let input_bytes = input.as_bytes();
137            let output = sha256(input_bytes);
138
139            let output_string = u8array_to_string(&output);
140            assert_eq!(output_string, *expected);
141        }
142    }
143
144    #[test]
145    fn long_string_test() {
146        let input = {
147            let mut ret = String::new();
148            for _ in 0..(512 * 16 + 500) {
149                ret.push_str("a");
150            }
151            ret
152        };
153        let expected =
154            "31ef976b92b5879f6068892a737803b40dac69e6a9c5563e05dd6197b2b39a27".to_string();
155        let input_bytes = input.as_bytes();
156        let output = sha256(input_bytes);
157        let output_string = u8array_to_string(&output);
158        assert_eq!(output_string, expected);
159    }
160}