基于 BPE 的汉语 Tokenization 重制版

源码已发布到 Github:

pluveto/bpe_v3: 基于 BPE 实现的中文分词。优化:预处理,并行计算,多字词,多词表 (github.com)

目标

采用 BPE 算法对汉语进行子词切割,算法采用 Python (3.0 以上版本) 编码实现,自行编制代码完成算法,不直接用 subword-nmt 等已有模块。

BPE 算法介绍

BPE 的概念源自一种无损压缩算法中提出。

BPE 压缩算法

举例 1 :对于 aaabdaaabac 的 BPE 压缩过程如下:

  • 寻找出现频率最高的相邻两字符(BP,Byte Pair)
aaabdaaabac
aa|||||||||
 aa||||||||
  ab|||||||
   bd||||||
    da|||||
     aa||||
      aa|||
       ab||
        ba|
         ac

统计如下:

aa - 4
ab - 2
ac - 1
bd - 1
da - 1

因此替换频率最高的 aa[aa]

[aa]abd[aa]abac

再次统计如下:

[aa]a   - 2
ab      - 2
bd      - 1
d[aa]   - 1
ba      - 1
ac      - 1

因此替换 [aa] a[aaa]

[aaa]bd[aaa]bac

重复。替换 [aaa] b[aaab]

[aaab]d[aaab]ac

此时不必再替换。数据被压缩成:

XdXac

其中 X 映射为 aaab

此时,只要一个压缩后的数据,加上一个字典,就能表示原数据。

BPE 分词思路

无论在英文还是中文里,词汇的特点就是出现频率高。

假设我们对这句话压缩,大概率会得到一个这样的结果:

压缩后文本

"W1 在 W2W3W4 里,W5 的 W6W7W8W9 高。"

哈希表:

W1 = [无论]
W2 = [英文]
W3 = [还是]
W4 = [中文]
W5 = [词汇]
W6 = [特点]
W7 = [就是]
W8 = [出现]
W9 = [频率]

也就是说,词汇的可以用是高频的相邻字符来代表。(当然,不一定是两个字符,比如汉语成语常常是连续的四字高频的字符)

这给了我们一种分词的思路。

BPE 算法设计

上面基本已经成型了,但我们还需要对有关数据结构进行建模。

假设有数据集如下:

所谓调度就是决定某时刻,应该运行哪个进程。
调度分为实时调度和非实时调度。
实时调度分为硬实时和软实时。

这个数据集有三行,假设指派给三个执行器。记作 E1E2E3

E1 收到的数据是:所谓调度就是决定某时刻,应该运行哪个进程。 按照字符切分,得到:

所 谓 调 度 就 是 决 定 某 时 刻 , 应 该 运 行 哪 个 进 程 。

我们注意到空格标点符号容易造成干扰,因为最终词典不会有标点。不妨统一替换为 “#”。

所 谓 调 度 就 是 决 定 某 时 刻 # 应 该 运 行 哪 个 进 程 #

然后,就是进行高频组合的相邻连接。

统计高频的相邻字符

比如,我们发现 word [i]+word [i+1] = 所 谓 | 调 度 | 决 定 | 某 | 时 刻 | 运 行 | 进 程 的频次非常高,就可以合并为:

所谓 调度 就 是 决定 某 时刻 # 应 该 运行 哪 个 进程 #

多字词词汇生成

一般来说,我们需要进行多轮处理,进一步合并。因为很多次并不止两个字,例如 “某 时刻” 或许可以合并为 “某时刻”,“国家 主席 毛泽东” 可以合并为 “国家主席毛泽东”。

实现方法就是多轮处理。伪代码如下:

round_num = 1 # 当前处理的是第几轮
max_round_num = 4  # 最大轮数

while round_num <= max_round_num:
    logger.info("Round {} start...".format(round_num))
    # 初始化
    shared_freq_stat_map = {}
    # 进行词典的建立
    ParalledTask.create("--freq stat round={}".format(round_num))\
        .set_nworker(self.nworker)\
        .set_worker_args({'datasets': ds})\
        .set_worker_func(_train_worker)\
        .set_progress_goal(len(self._train_lines_np))\
        .execute()
    # 从高频到低频排序
    sorted_freq_stat_ls = sorted(
        shared_freq_stat_map.obj.items(), key=lambda x: x[1], reverse=True)
    # 选出一些词汇纳入词汇表
    sorted_freq_stat_ls = dict_filter(sorted_freq_stat_ls)
    # 利用词汇表合并字节对
    ParalledTask.create("--connect round={}".format(round_num))\
        .set_nworker(self.nworker)\
        .set_worker_args({'datasets': ds, 'thold': thold})\
        .set_worker_func(_connect_worker)\
        .set_progress_goal(len(self._train_lines_np))\
        .execute()
    # 进入下一轮
    round_num += 1

其中,涉及到两个工作函数,一个负责 单行数据处理,生成词表,一个负责 连接词表命中词

单行数据的处理思路

for line in line_strs:
    # 对每个 Byte Pair 进行处理
    for i in range(len(line) - 1):
        # 如果是 `#`,则跳过
        if line[i] == '#' or line[i+1] == '#':
            continue
        # 获取当前词和下一个词,如 ' 天安 ', ' 门'
        cur_word: str = line[i]
        next_word: str = line[i + 1]
        # 当前词和下一个词拼接,如 ' 天安门'
        cur_word_next_word: str = cur_word + next_word
        # 当前词和下一个词拼接的词频
        freq = inner_freq_stat_map[cur_word_next_word]
        # 当前词和下一个词拼接的词频加 1
        inner_freq_stat_map[cur_word_next_word] = freq + 1
    with bar.mutex:
        bar.obj()
# 将当前行的词频统计表加入总表
with shared_freq_stat_map.mutex:
    for key, value in inner_freq_stat_map.items():
        shared_freq_stat_map.obj[key] += value

词汇连接的实现原理

bpe_cn.py 145:

bar: WithMutex = ctx.get('bar')
worker_id = ctx.get('worker_id')
line_strs: List[List[str]] = ctx.get('datasets')[worker_id]
thold = ctx.get('thold')
for line in line_strs:
    # 对每个 Byte Pair 进行处
    i = 0
    while(i < len(line) - 1):
        # 如果是 `#`,则跳过
        if line[i] == '#' or line[i+1] == '#':
            i += 1
            continue
        # 获取当前词和下一个词,如 ' 天安 ', ' 门'
        cur_word: str = line[i]
        next_word: str = line[i + 1]
        # 当前词和下一个词拼接,如 ' 天安门'
        cur_word_next_word: str = cur_word + next_word
        # 比较是否大于阈值
        if shared_freq_stat_map.obj[cur_word_next_word] > thold:
            # 如果大于阈值,则连接
            line[i] = cur_word_next_word
            # 删除下一个词
            line.pop(i + 1)
            i += 1
        i += 1
    with bar.mutex:
        bar.obj()

最大匹配分词算法设计

这部分相对简单。查询词典中最长的词,与当前串比较即可。

算法预览

    def tokenize(self, spaced_sentence: list) -> list:
        """对句子分词"""
        si = 0
        result = []
        while si < len(spaced_sentence):
            matched = False
            prevlen = 0
            to_match = ''
            for w, c in self.user_dict_items:
                if prevlen != len(w):
                    to_match = ''.join(spaced_sentence[si:si + len(w)])
                    prevlen = len(w)
                if w == to_match:
                    result.append(w)
                    si += len(w)
                    matched = True
                    break
            if not matched:
                result.append(spaced_sentence[si])
                si += 1

        return result

效果预览

测试数据:

他是一位声誉很高的学者,凭借丰富的知识储备,在这部重要的作品中,就弥源太是否皈依进行过讨论。

结果:

[' 他是 ', ' 一位 ', ' 声誉 ', ' 很高的 ', ' 学者 ', ',', ' 凭借 ', ' 丰富的 ', ' 知识 ', ' 储备 ', ',', ' 在这 ', ' 部 ', ' 重要的 ', ' 作品 ', ' 中 ', ',', ' 就 ', ' 弥 ', ' 源 ', ' 太 ', ' 是
否 ', ' 皈 ', ' 依 ', ' 进行过 ', ' 讨论 ', '。']

进一步优化

标点符号预处理

中英文标点一般来说不是词汇的组成部分。(除了 - 等特殊情况),因此可以将其替换为 # ,而在训练时,对 # 当作硬切分,即不成词。

        puncs_zh = ['。', ',', '?', '!', ';', ':', '、', '(', ')', '「',
                    '」', '“', '”', '‘', '’', '《', '》', '【', '】', '…', '—', '~', ' ']
        puncs_en = ['.', ',', '?', '!', ';', ':', 
                    '(', ')', '"', '"', '\'', '\'', '<', '>', '[', ']', '...','~']
        puncs = [*puncs_zh, *puncs_en]
        # 替换标点符号为 `#`

        def _replace_worker(ctx: dict):
            task = ctx.get('task')
            bar = ctx.get('bar')
            worker_id = ctx.get('worker_id')
            line_strs: List[List[str]] = ctx.get('datasets')[worker_id]
            for line in line_strs:
                for i in range(len(line)):
                    if line[i] in puncs:
                        line[i] = '#'
                with bar.mutex:
                    bar.obj()

数据集去重

我们发现训练集有很多是重复的,会造成过拟合,即把一些生僻的字节对当成高频词处理。解决方法非常简单:

self._train_lines_np = np.unique(self._train_lines_np)

并行计算优化

详见另一篇文章 “Python 实现简单的多线程 MapReduce 计算框架”

image-20220514181951422

二字词与三字词的取舍

我们试验发现,有时候虽然产生了三字词,但实际上其子二字词的词频远远更高。举个例子,“写代码” 与 “代码”,后者的频率更高,因此将后者纳入词表是更恰当的。

因此对于前面的词,就可以舍掉:

# 比较是否大于阈值
if shared_freq_stat_map.obj[cur_word_next_word] > thold:
    if shared_freq_stat_map.obj[cur_word] / shared_freq_stat_map.obj[cur_word_next_word]\
            > ratio:
        i += 1
        continue
    # 如果大于阈值,则连接
    line[i] = cur_word_next_word
    # 删除下一个词
    line.pop(i + 1)
    i += 1

建词 Baseline 与阈值动态调整

我们不断迭代生成更长的词的过程中,如果最重要生成 10000 词,则前几轮迭代实际上必须采用更多的词。

同时,我们希望有一个最低阈值,如果频率还低于这个阈值,就应该放弃。

baseline = int(nline / 12)
ntok_tholds = [int(baseline*1.5), int(baseline*1.3),
                       int(baseline*1.1), baseline]
while in_round:
    ntok_thold = ntok_tholds[round_num - 1][-1][1]
    thold = max(min_thold, thold)
    logger.info("Thold: {}".format(thold))

建立多级词表

image-20220514181410229

我们实际上会发现,有时候有必要按多种标准进行划分。例如 “全面深化改革” 实际上是一个专有名词,应该划为一词。同时,当上下文为:全面 xxx 时,“全面” 一词也应该被划出。所以我们生成多词表,然后用前面的 MMT 算法进行分词:

max_match_token.py

def main():
    print(
        Tokenizer()
        .add_dict(load_word_freq_map("output/vocab_BPE.txt_1"))
        .add_dict(load_word_freq_map("output/vocab_BPE.txt_2"))
        .add_dict(load_word_freq_map("output/vocab_BPE.txt_3"))
        .add_dict(load_word_freq_map("output/vocab_BPE.txt_4"))
        .tokenize (list ("他是一位声誉很高的学者,凭借丰富的知识储备,在这部重要的作品中,就弥源太是否皈依进行过讨论。"))
    )

  1. Byte Pair Encoding — The Dark Horse of Modern NLP | by Akashdeep Singh Jaswal | Towards Data Science ↩︎