Python 实现简单的多线程 MapReduce 计算框架

背景与目标

前段时间太忙,导致有的作业是一晚上赶工出来的。最近在重写 NLP 相关的作业,由于并行优化使用比较频繁,所以想封装一下,实现简单地进行 MapReduce 多线程计算。

工业项目中,往往会进一步抽象为 DAG 分布式工作流,典型的框架有 Sparn 等。

我们希望封装一个 `` 类,然后可以用类似这样的接口调用:

1result_lines = ParalledTask.create('-- test')\
2    .set_nworker(args.nworker)\
3    .set_worker_func(tok_worker)\
4    .set_reducer_func(tok_reducer)\
5    .set_progress_goal(len(test_lines))\
6    .set_worker_args({'datasets': datasets, 'tokenizer': tokenizer})\
7    .execute()\
8    .get_results()

接口设计

  • set_nworker:设置并行计算的线程数。

  • set_worker_func:设置并行计算的工作函数。

  • set_reducer_func:设置并行计算的结果合并函数。

  • set_progress_goal:设置计算总进度的量化值。可选,用于产生进度条。

  • set_worker_args:设置并行计算的工作函数的参数。比如我们会将数据集划分为多个部分,每个部分都被一个线程计算。

  • execute:执行并行计算。

  • get_results:获取并行计算的结果。

严格来说并不是一定真的并行,因为很可能实际被分配到很少的核心上计算。但只要对提高性能有帮助,我们不深究这些词汇定义。

原理

我们主要对 ThreadPoolExecutor 进行封装。

抛开建造者模式的抽象表象,本质上是三个步骤:

  1. 通过 ThreadPoolExecutor(max_workers=n) 创建执行器

  2. 通过 executor.submit(func, param) 将任务指派到线程

  3. 通过 concurrent.futures.as_completed(futures) 等待执行完成

在用户看来,步骤是这样的:

  1. 设计执行函数 mapper 和合并函数 reducer

  2. 提交任务。任务自动配到 mapper,并在最后自动调用 reducer 合并

  3. 使用结果

总的来说比较简单。

实现

WithMutex 类

实现将对象(或者函数)与一个互斥锁绑定。我们主要是用来给进度条加锁,否则进度会出现视觉上跑不满 100%.

owned_mutex.py 11:

 1import threading
 2from typing import TypeVar
 3
 4T = TypeVar('T')
 5
 6
 7class WithMutex:
 8    def __init__(self, obj: T):
 9        self.obj = obj
10        self.mutex = threading.Lock()

ParalledTask 类

paralled_task.py 124:

  1import logging
  2import concurrent.futures
  3from asyncio.log import logger
  4from concurrent.futures import ThreadPoolExecutor
  5from contextlib import ExitStack
  6from alive_progress import alive_bar
  7
  8from owned_mutex import WithMutex
  9
 10logger = logging.getLogger("default")
 11
 12class ParalledTask:
 13    """实现了对并行协程的封装
 14    """
 15    @staticmethod
 16    def create(task_name) -> 'ParalledTask':
 17        """
 18        创建并行任务
 19        """
 20        instance = ParalledTask()
 21        instance.task_name = task_name
 22        instance.results = {} # key: worker_id, value: worker return value
 23        return instance
 24
 25    def set_nworker(self, nworker: int) -> 'ParalledTask':
 26        """
 27        设置并行协程数
 28        Args:
 29            nworker (int): 并行协程数
 30        """
 31        self.nworker = nworker
 32        return self
 33
 34    def set_worker_func(self, worker_func) -> 'ParalledTask':
 35        """
 36        设置并行协程执行器
 37        Args:
 38            worker_func (callable): 协程执行器
 39        """
 40        self.worker_func = worker_func
 41        return self
 42
 43    def set_worker_args(self, worker_args) -> 'ParalledTask':
 44        """
 45        设置协程执行器的参数
 46        Args:
 47            worker_args (list): 协程执行器的参数
 48        """
 49        self.worker_args = worker_args
 50        return self
 51
 52    def set_worker_arg_provider_func(self, worker_arg_provider_func):
 53        """
 54        设置参数提供函数
 55        函数原型为:worker_arg_provider_func(worker_id=worker_id, nworker=nworker)
 56        Args:
 57            worker_arg_provider_func (callable): 参数提供函数
 58        """
 59        self.worker_arg_provider_func = worker_arg_provider_func
 60        return self
 61
 62    def set_reducer_func(self, reducer_func) -> 'ParalledTask':
 63        """
 64        设置并行任务执行结果合并器
 65        Args:
 66            reducer_func (callable): 合并器
 67        """
 68        self.reducer_func = reducer_func
 69        return self
 70
 71    def set_progress_goal(self, goal: int) -> 'ParalledTask':
 72        self.progress_goal = goal
 73        return self
 74
 75    def execute(self) -> 'ParalledTask':
 76        """
 77        执行并行任务
 78        """
 79        logger.info(f'{self.task_name} start')
 80        with ExitStack() as stack, \
 81                ThreadPoolExecutor(max_workers=self.nworker) as executor:
 82            ctxs = []
 83            if hasattr(self, 'progress_goal'):
 84                goal = self.progress_goal
 85                # 创建进度条
 86                bar = stack.enter_context(alive_bar(goal))
 87                # 使用互斥锁封装进度条
 88                bar_with_mutex = WithMutex(bar)
 89            for worker_id in range(self.nworker):
 90                # if has worker_arg_provider_func attr
 91                if(hasattr(self, 'worker_arg_provider_func')):
 92                    worker_arg = self.worker_arg_provider_func(
 93                        worker_id, self.nworker)
 94                else:
 95                    worker_arg = self.worker_args
 96
 97                worker_ctx = {
 98                    **worker_arg,
 99                    'worker_id': worker_id,
100                    'task': self,
101                    'bar': bar_with_mutex,
102                }
103                ctxs.append(worker_ctx)
104            # 提交任务到执行器
105            futures = [executor.submit(self.worker_func, ctxs[i]) for i in range(self.nworker)]
106            # 等待完成,并收集执行结果
107            for future in concurrent.futures.as_completed(futures):
108                workder_id = futures.index(future)
109                self.results[workder_id] = future.result()
110            # 根据执行器 id 排序,避免乱序
111            self.results = {k: v for k, v in sorted(self.results.items(), key=lambda item: item[0])}
112        logger.info(f'{self.task_name} done')
113        return self
114
115    def get_results(self):        
116        """
117        获取并行任务执行结果
118        """
119        # 如果有合并器,则进行合并,否则直接返回结果
120        if(self.reducer_func is None):
121            return self.results
122
123        return self.reducer_func(self.results)

使用实例

可以参考我写的“使用 BPE 原理进行汉语字词切分”一文。

不足之处

由于 Python 的限制,上述方案对 CPU 利用率的提高非常有限。

如果想要更高的性能,可以使用多进程的方式,但这样对 Python 来说复杂度太高了,尤其是跨进程共享内存,实现起来十分繁琐。

有追求的读者可以尝试使用 Golang 等语言的协程,通过 rpc 等方式传递要计算的数据,封装成一个使用简单的并行计算服务。