Merkle Tree 及其算法的设计与实现

Hi,我是 Pluveto,正在学习成为区块链工程师。今天我们的主题是 Merkle 树。Merkle 🌲是一种思想巧妙,设计简洁的数据结构。它能快速地(准确来说,在对数级别的时间复杂度内)验证数据块是否存在于一个更大的数据集合中,甚至还能找到出它的位置。

我打算用 Python 来实现,这样能看懂的人会更多一些。我还会指出一些需要注意的点,这些点很可能在别的文章会被忽略。让我们开始吧。

只是简单地把哈希值堆放起来……

在动手写代码之前,我们可以先梳理思路。Merkle 树又叫哈希树。假设有一个列表 l,有 4 个元素:

1Alice Bob Caro David

我们很容易计算出各自的哈希,记作 H(A), H(B), H(C), H(D)。把它们作为叶子结点。

然后计算 H(AB)=H(H(A)·H(B))H(CD)=H(H(C)·H(D)) ,我们就得到两个哈希值。把它们作为倒数第二层节点。

最后我们计算出 H(ABCD)=H(H(AB)·H(CD)) 就得到根节点。

1        H(ABCD)
2        /    \
3      H(AB)  H(CD)
4     / \     /  \
5  H(A) H(B) H(C) H(D)

这棵树看上去很好理解,但它的背后威力无穷!

和谎言说不——发现 Merkle Tree

现在,假设你是区块链系统,你有这棵树的全部知识。同时,你也在另一个地方存放了 A、B、C、D 的真实值(Alice、Bob……)只不过要在那里读取数据,会非常慢。

此时,老王声称:他在树里放了一个数据块 D。

我们怎么知道它有没有说谎?

方案 1:遍历整个数据库

就像子串匹配一样,我们拿着老王上传的数据 D,在整个数据库(可以理解为一个很长的字符串)中搜索,直到找到 D,然后我们告诉老王:你没说谎。

然而,由于系统太大了,当我们找到数据的时候,你看到的人可能是老王的孙子……

方案2:把所有数据的哈希值记录,然后搜索

这次我们不搜原始数据了,因为动辄有人拿出很大的数据给你搜索,导致你如同被 DDOS 攻击。于是你决定,记录下所有交易(其实就是一段数据)的哈希。这样只要遍历整个区块的所有交易的哈希,就知道某个交易是否存在(可能会恰好蒙对,老王胡编的哈希和哈希列表里的产生了碰撞,但概率可以忽略不计)。

这样快多了,直到你发现了一种更巧妙的方法。

方案3:哈希树

回看开头那棵树

1        H(ABCD)
2        /    \
3      H(AB)  H(CD)
4     / \     /  \
5  H(A) H(B) H(C) H(D)

你发现:

  • 如果老王把 DH(C)给你,你就可以计算出 H(CD)

  • 如果他再给你 H(AB),那么你就可以计算出 H(ABCD)

  • 只要计算出来的 H(ABCD) 与真实的 H(ABCD) 不一样,就说明老王骗了你。

原因在于,如果 [H(AB),H(C)] 中的任何一个无效,都会导致最终算出来的 H(ABCD) 与实际不同。我们称 [H(AB),H(C)] 为一个证明(proof)

同时,根据证明,我们还能定位到元素的位置:

  1. 从根节点出发。位置序列为:0

  2. 证明的下一个元素为 H(AB),则我们选择邻分支 H(CD)。位置序列为:0、1

  3. 再下一个元素为 H(C),则我们选择邻分支 H(D)。位置序列为:0、1、1

恭喜你,发现了 Merkle 树,以及它的性质。

实现 Merkle Tree

定义树结点

由于是二叉树,只定义左右结点。我们还定义了 is_copied 字段来指示此节点是否复制产生。这是为了确保构建出二叉树,我们会在奇数结点的情况下复制补足成对。

注意:

  • content 字段用于学习、调试目的,生产环境应该去除。

  • 对于高性能要求的环境,最好使用其他语言实现,并进行算法优化

 1@dataclasses.dataclass
 2class Node:
 3    """
 4    Represents a binary tree node, in our case a Merkle Tree node.
 5
 6    Attributes:
 7        left: left child node
 8        right: right child node
 9        value: hash value of the node
10        content: content of the node
11        is_copied: whether the node is a copy, this is because we duplicate
12            the last element if the number of elements is odd when building
13    """
14
15    left: Optional["Node"]
16    right: Optional["Node"]
17    value: bytes
18    content: bytes  # just for debugging, remove in production
19    is_copied: bool = False
20
21    def __str__(self) -> str:
22        return self.value.decode("utf-8")
23
24    def copy(self) -> "Node":
25        """
26        Get the duplication of the node.
27        We mark the node with copied=True when duplicating the last element.
28        """
29        return Node(self.left, self.right, self.value, self.content, True)

树的构建算法

先定义类的基本形状。

 1class MerkleTree:
 2    """
 3    Represents a Merkle Tree, for which every leaf node is labelled with the hash
 4    of a data block, and every non-leaf node is labelled with the crypto hash of
 5    the labels of its child nodes. It is used to verify the integrity of blocks.
 6    """
 7
 8    def __init__(self, values: List[bytes], hash_fn: HashFn) -> None:
 9        self._hash_fn = hash_fn
10        self._root: Node = self._buildTree(values)
11
12    def __str__(self) -> str:
13        return self._root.value.hex()
14
15    @property
16    def root(self) -> Node:
17        """
18        get the duplication of root node of the Merkle Tree
19        """
20        return self._root.copy()

_buildTree 函数用于构建树。对于奇数个结点,会复制补齐。

 1    def _buildTree(self, values: List[bytes]) -> Node:
 2        leaves: List[Node] = [Node(None, None, self._hash_fn(e), e) for e in values]
 3        return self._buildTreeRec(leaves)
 4
 5    def _buildTreeRec(self, nodes: List[Node]) -> Node:
 6        # duplicate last elem if odd number of elements
 7        if len(nodes) % 2 == 1:
 8            nodes.append(nodes[-1].copy())
 9
10        half: int = len(nodes) // 2
11        if len(nodes) == 2:
12            value = self._hash_fn(bytes_xor(nodes[0].value, nodes[1].value))
13            return Node(nodes[0], nodes[1], value, nodes[0].content + nodes[1].content)
14
15        left: Node = self._buildTreeRec(nodes[:half])
16        right: Node = self._buildTreeRec(nodes[half:])
17        value: bytes = self._hash_fn(bytes_xor(left.value, right.value))
18        return Node(left, right, value, left.content + right.content)

获取叶子位置的算法

遍历获取,复杂度为 O(n)。**实际场景下,位置应该通过从元素在原列表的索引得出。**例如索引是 6,那么直接转换为二进制 110,则位置序列是 [1, 1, 0],而非调用这个函数。

 1    def get_location(self, block_hash: bytes) -> Optional[List[int]]:
 2        """
 3        Get the location of a block hash in a Merkle Tree
 4        """
 5        return self._get_location_rec(self._root, block_hash, [])
 6
 7    def _get_location_rec(
 8        self, node: Optional[Node], block_hash: bytes, path: List[int]
 9    ) -> Optional[List[int]]:
10        if node is None:
11            return None
12
13        if node.value == block_hash and not node.is_copied:
14            return path
15
16        left_path = self._get_location_rec(node.left, block_hash, path + [0])
17        if left_path is not None:
18            return left_path
19
20        right_path = self._get_location_rec(node.right, block_hash, path + [1])
21        if right_path is not None:
22            return right_path
23
24        return None

获取证明序列的算法实现

这是关键部分。主要原理是,自上而下顺着位置路径行走,记录一路的邻元素哈希。

 1    def get_proof(self, block_hash: bytes) -> Optional[List[bytes]]:
 2        """
 3        Get the proof of a block hash in a Merkle Tree.
 4        The proof is a list of sibling hashes of the block hash.
 5
 6        Note:
 7            This function returns hashes in the top-down order. So, don't forget
 8            to reverse the list when you want to verify the proof.
 9        """
10        location = self.get_location(block_hash)
11        if location is None:
12            return None
13
14        return self._get_proof_rec(self._root, location, 0, [])
15
16    def _get_proof_rec(
17        self, node: Optional[Node], location: List[int], index: int, proof: List[bytes]
18    ) -> Optional[List[bytes]]:
19        if node is None or index >= len(location):
20            return proof
21
22        if node.right and location[index] == 0:
23            proof.append(node.right.value)
24            return self._get_proof_rec(node.left, location, index + 1, proof)
25
26        elif node.left and location[index] == 1:
27            proof.append(node.left.value)
28            return self._get_proof_rec(node.right, location, index + 1, proof)
29
30        return None

验证算法的实现

验证十分简单,只需倒序遍历证明并迭代计算,最后与根哈希比对。

1def verify_proof(
2    root_hash: bytes, block_hash: bytes, proof: List[bytes], hash_fn: HashFn
3) -> bool:
4    """verify if a block hash is in a Merkle Tree with a given root hash and proof"""
5    current_hash = block_hash
6    for sibling_hash in reversed(proof):
7        current_hash = hash_fn(bytes_xor(current_hash, sibling_hash))
8
9    return current_hash == root_hash

测试 Merkle Tree

完整的实现和测试代码如下:

  1import dataclasses
  2
  3from typing import Callable, Optional, List
  4
  5HashFn = Callable[[bytes], bytes]
  6
  7
  8def bytes_xor(a: bytes, b: bytes) -> bytes:
  9    assert len(a) == len(b), "length of a and b should be equal"
 10    return bytes([_a ^ _b for _a, _b in zip(a, b)])
 11
 12
 13def verify_proof(
 14    root_hash: bytes, block_hash: bytes, proof: List[bytes], hash_fn: HashFn
 15) -> bool:
 16    """verify if a block hash is in a Merkle Tree with a given root hash and proof"""
 17    current_hash = block_hash
 18    for sibling_hash in reversed(proof):
 19        current_hash = hash_fn(bytes_xor(current_hash, sibling_hash))
 20
 21    return current_hash == root_hash
 22
 23
 24@dataclasses.dataclass
 25class Node:
 26    """
 27    Represents a binary tree node, in our case a Merkle Tree node.
 28
 29    Attributes:
 30        left: left child node
 31        right: right child node
 32        value: hash value of the node
 33        content: content of the node
 34        is_copied: whether the node is a copy, this is because we duplicate
 35            the last element if the number of elements is odd when building
 36    """
 37
 38    left: Optional["Node"]
 39    right: Optional["Node"]
 40    value: bytes
 41    content: bytes  # just for debugging, remove in production
 42    is_copied: bool = False
 43
 44    def __str__(self) -> str:
 45        return self.value.decode("utf-8")
 46
 47    def copy(self) -> "Node":
 48        """
 49        Get the duplication of the node.
 50        We mark the node with copied=True when duplicating the last element.
 51        """
 52        return Node(self.left, self.right, self.value, self.content, True)
 53
 54
 55class MerkleTree:
 56    """
 57    Represents a Merkle Tree, for which every leaf node is labelled with the hash
 58    of a data block, and every non-leaf node is labelled with the crypto hash of
 59    the labels of its child nodes. It is used to verify the integrity of blocks.
 60    """
 61
 62    def __init__(self, values: List[bytes], hash_fn: HashFn) -> None:
 63        self._hash_fn = hash_fn
 64        self._root: Node = self._buildTree(values)
 65
 66    def __str__(self) -> str:
 67        return self._root.value.hex()
 68
 69    @property
 70    def root(self) -> Node:
 71        """
 72        get the duplication of root node of the Merkle Tree
 73        """
 74        return self._root.copy()
 75
 76    def _buildTree(self, values: List[bytes]) -> Node:
 77        leaves: List[Node] = [Node(None, None, self._hash_fn(e), e) for e in values]
 78        return self._buildTreeRec(leaves)
 79
 80    def _buildTreeRec(self, nodes: List[Node]) -> Node:
 81        # duplicate last elem if odd number of elements
 82        if len(nodes) % 2 == 1:
 83            nodes.append(nodes[-1].copy())
 84
 85        half: int = len(nodes) // 2
 86        if len(nodes) == 2:
 87            value = self._hash_fn(bytes_xor(nodes[0].value, nodes[1].value))
 88            return Node(nodes[0], nodes[1], value, nodes[0].content + nodes[1].content)
 89
 90        left: Node = self._buildTreeRec(nodes[:half])
 91        right: Node = self._buildTreeRec(nodes[half:])
 92        value: bytes = self._hash_fn(bytes_xor(left.value, right.value))
 93        return Node(left, right, value, left.content + right.content)
 94
 95    def compare_trees(self, other: "MerkleTree") -> bool:
 96        """
 97        Compare the root hashes of two Merkle Trees
 98        """
 99        return self._root.value == other._root.value
100
101    def verify_block(self, root_hash: bytes, block_hash: bytes) -> bool:
102        """
103        Verify if a block hash is in a Merkle Tree with a given root hash
104        """
105        return self._root.value == root_hash and self._verify_block_rec(
106            self._root, block_hash
107        )
108
109    def _verify_block_rec(self, node: Optional[Node], block_hash: bytes) -> bool:
110        if node is None:
111            return False
112
113        return (
114            node.value == block_hash
115            or self._verify_block_rec(node.left, block_hash)
116            or self._verify_block_rec(node.right, block_hash)
117        )
118
119    def get_location(self, block_hash: bytes) -> Optional[List[int]]:
120        """
121        Get the location of a block hash in a Merkle Tree
122        """
123        return self._get_location_rec(self._root, block_hash, [])
124
125    def _get_location_rec(
126        self, node: Optional[Node], block_hash: bytes, path: List[int]
127    ) -> Optional[List[int]]:
128        if node is None:
129            return None
130
131        if node.value == block_hash and not node.is_copied:
132            return path
133
134        left_path = self._get_location_rec(node.left, block_hash, path + [0])
135        if left_path is not None:
136            return left_path
137
138        right_path = self._get_location_rec(node.right, block_hash, path + [1])
139        if right_path is not None:
140            return right_path
141
142        return None
143
144    def get_proof(self, block_hash: bytes) -> Optional[List[bytes]]:
145        """
146        Get the proof of a block hash in a Merkle Tree.
147        The proof is a list of sibling hashes of the block hash.
148
149        Note:
150            This function returns hashes in the top-down order. So, don't forget
151            to reverse the list when you want to verify the proof.
152        """
153        location = self.get_location(block_hash)
154        if location is None:
155            return None
156
157        return self._get_proof_rec(self._root, location, 0, [])
158
159    def _get_proof_rec(
160        self, node: Optional[Node], location: List[int], index: int, proof: List[bytes]
161    ) -> Optional[List[bytes]]:
162        if node is None or index >= len(location):
163            return proof
164
165        if node.right and location[index] == 0:
166            proof.append(node.right.value)
167            return self._get_proof_rec(node.left, location, index + 1, proof)
168
169        elif node.left and location[index] == 1:
170            proof.append(node.left.value)
171            return self._get_proof_rec(node.right, location, index + 1, proof)
172
173        return None
174
175    def print_tree(self, brief: bool = True) -> None:
176        """
177        Print the Merkle Tree in a tree structure.
178        """
179        self._print_tree_rec(self._root, 0, brief)
180
181    def _print_tree_rec(self, node: Optional[Node], level: int, brief: bool) -> None:
182        """helper function for print_tree"""
183        if node is None:
184            return
185
186        value = (node.value[:4] if brief else node.value).hex()
187        content = node.content
188        print(f'{"    " * level}{value=}, {content=}')
189        self._print_tree_rec(node.left, level + 1, brief)
190        self._print_tree_rec(node.right, level + 1, brief)
 1import hashlib
 2import unittest
 3from merkle_tree import MerkleTree, verify_proof
 4
 5
 6def sha256(val: bytes) -> bytes:
 7    return hashlib.sha256(val).digest()
 8
 9
10class TestMerkleTree(unittest.TestCase):
11    def setUp(self):
12        """set up a MerkleTree with some testing data"""
13        self._data = list(
14            map(
15                lambda x: x.encode("utf-8"),
16                [
17                    # https://en.wikipedia.org/wiki/Classical_Chinese_poetry
18                    "Li Bai",
19                    "Du Fu",
20                    "Wang Wei",
21                    "Bai Juyi",
22                    "Su Shi",
23                    "Li Shangyin",
24                    "Li Qingzhao",
25                    "Wang Anshi",
26                ],
27            )
28        )
29        self._hash_fn = sha256
30        self._tree = MerkleTree(self._data, self._hash_fn)
31        self._tree.print_tree()
32
33    def test_verify_block(self):
34        """test the verify_block method"""
35        root_hash = self._tree.root.value
36        for i in range(len(self._data)):
37            block_hash = self._hash_fn(self._data[i])
38            # verify if the block hash is in the tree
39            self.assertTrue(self._tree.verify_block(root_hash, block_hash))
40
41        # verify a invalid block hash
42        self.assertFalse(
43            self._tree.verify_block(root_hash, self._hash_fn("invalid".encode("utf-8")))
44        )
45
46    def test_get_location(self):
47        """test the get_location method"""
48        for i in range(len(self._data)):
49            block_hash = self._hash_fn(self._data[i])
50            location = self._tree.get_location(block_hash)
51            # compare the location with the expected value location(block_hash)
52            assert location is not None, "location should not be None"
53            # location should conform to the binary counting sequence 000...111
54            self.assertEqual(location, [int(x) for x in format(i, "b").zfill(3)])
55
56        self.assertIsNone(
57            self._tree.get_location(self._hash_fn("invalid".encode("utf-8")))
58        )
59
60    def test_get_proof(self):
61        """test the get_proof method"""
62        for i in range(len(self._data)):
63            block_hash = self._hash_fn(self._data[i])
64            # get the proof of the block hash in the tree
65            proof = self._tree.get_proof(block_hash)
66            assert proof is not None, "proof should not be None"
67            self.assertEqual(len(proof), 3)
68            self.assertTrue(
69                verify_proof(self._tree.root.value, block_hash, proof, self._hash_fn),
70                f"proof {proof} is invalid for block {block_hash}",
71            )
72
73        self.assertIsNone(
74            self._tree.get_proof(self._hash_fn("invalid".encode("utf-8")))
75        )
76
77    def test_verify_proof(self):
78        """test the verify_proof method"""
79        root_hash = self._tree.root.value
80        for i in range(len(self._data)):
81            block_hash = self._hash_fn(self._data[i])
82            # get the proof of the block hash in the tree
83            proof = self._tree.get_proof(block_hash)
84            assert proof is not None, "proof should not be None"
85            self.assertTrue(verify_proof(root_hash, block_hash, proof, self._hash_fn))
86
87        # verify a invalid proof
88        self.assertFalse(
89            verify_proof(
90                root_hash, self._hash_fn("invalid".encode("utf-8")), [], self._hash_fn
91            )
92        )
93
94
95if __name__ == "__main__":
96    unittest.main()

结语

感谢你的阅读。可以在 pluveto/merkle-tree 获得本文源代码。