背景与目标

前段时间太忙,导致有的作业是一晚上赶工出来的。最近在重写 NLP 相关的作业,由于并行优化使用比较频繁,所以想封装一下,实现简单地进行 MapReduce 多线程计算。

工业项目中,往往会进一步抽象为 DAG 分布式工作流,典型的框架有 Sparn 等。

我们希望封装一个 `` 类,然后可以用类似这样的接口调用:

result_lines = ParalledTask.create('-- test')\
    .set_nworker(args.nworker)\
    .set_worker_func(tok_worker)\
    .set_reducer_func(tok_reducer)\
    .set_progress_goal(len(test_lines))\
    .set_worker_args({'datasets': datasets, 'tokenizer': tokenizer})\
    .execute()\
    .get_results()

接口设计

  • set_nworker:设置并行计算的线程数。
  • set_worker_func:设置并行计算的工作函数。
  • set_reducer_func:设置并行计算的结果合并函数。
  • set_progress_goal:设置计算总进度的量化值。可选,用于产生进度条。
  • set_worker_args:设置并行计算的工作函数的参数。比如我们会将数据集划分为多个部分,每个部分都被一个线程计算。
  • execute:执行并行计算。
  • get_results:获取并行计算的结果。

严格来说并不是一定真的并行,因为很可能实际被分配到很少的核心上计算。但只要对提高性能有帮助,我们不深究这些词汇定义。

原理

我们主要对 ThreadPoolExecutor 进行封装。

抛开建造者模式的抽象表象,本质上是三个步骤:

  1. 通过 ThreadPoolExecutor (max_workers=n) 创建执行器
  2. 通过 executor.submit (func, param) 将任务指派到线程
  3. 通过 concurrent.futures.as_completed (futures) 等待执行完成

在用户看来,步骤是这样的:

  1. 设计执行函数 mapper 和合并函数 reducer
  2. 提交任务。任务自动配到 mapper,并在最后自动调用 reducer 合并
  3. 使用结果

总的来说比较简单。

实现

WithMutex 类

实现将对象(或者函数)与一个互斥锁绑定。我们主要是用来给进度条加锁,否则进度会出现视觉上跑不满 100%.

owned_mutex.py 11:

import threading
from typing import TypeVar

T = TypeVar('T')


class WithMutex:
    def __init__(self, obj: T):
        self.obj = obj
        self.mutex = threading.Lock()

ParalledTask 类

paralled_task.py 124:

import logging
import concurrent.futures
from asyncio.log import logger
from concurrent.futures import ThreadPoolExecutor
from contextlib import ExitStack
from alive_progress import alive_bar

from owned_mutex import WithMutex

logger = logging.getLogger("default")

class ParalledTask:
    """ 实现了对并行协程的封装
    """
    @staticmethod
    def create(task_name) -> 'ParalledTask':
        """
        创建并行任务
        """
        instance = ParalledTask()
        instance.task_name = task_name
        instance.results = {} # key: worker_id, value: worker return value
        return instance

    def set_nworker(self, nworker: int) -> 'ParalledTask':
        """
        设置并行协程数
        Args:
            nworker (int): 并行协程数
        """
        self.nworker = nworker
        return self

    def set_worker_func(self, worker_func) -> 'ParalledTask':
        """
        设置并行协程执行器
        Args:
            worker_func (callable): 协程执行器
        """
        self.worker_func = worker_func
        return self

    def set_worker_args(self, worker_args) -> 'ParalledTask':
        """
        设置协程执行器的参数
        Args:
            worker_args (list): 协程执行器的参数
        """
        self.worker_args = worker_args
        return self

    def set_worker_arg_provider_func(self, worker_arg_provider_func):
        """
        设置参数提供函数
        函数原型为:worker_arg_provider_func (worker_id=worker_id, nworker=nworker)
        Args:
            worker_arg_provider_func (callable): 参数提供函数
        """
        self.worker_arg_provider_func = worker_arg_provider_func
        return self

    def set_reducer_func(self, reducer_func) -> 'ParalledTask':
        """
        设置并行任务执行结果合并器
        Args:
            reducer_func (callable): 合并器
        """
        self.reducer_func = reducer_func
        return self

    def set_progress_goal(self, goal: int) -> 'ParalledTask':
        self.progress_goal = goal
        return self

    def execute(self) -> 'ParalledTask':
        """
        执行并行任务
        """
        logger.info(f'{self.task_name} start')
        with ExitStack() as stack, \
                ThreadPoolExecutor(max_workers=self.nworker) as executor:
            ctxs = []
            if hasattr(self, 'progress_goal'):
                goal = self.progress_goal
                # 创建进度条
                bar = stack.enter_context(alive_bar(goal))
                # 使用互斥锁封装进度条
                bar_with_mutex = WithMutex(bar)
            for worker_id in range(self.nworker):
                # if has worker_arg_provider_func attr
                if(hasattr(self, 'worker_arg_provider_func')):
                    worker_arg = self.worker_arg_provider_func(
                        worker_id, self.nworker)
                else:
                    worker_arg = self.worker_args

                worker_ctx = {
                    **worker_arg,
                    'worker_id': worker_id,
                    'task': self,
                    'bar': bar_with_mutex,
                }
                ctxs.append(worker_ctx)
            # 提交任务到执行器
            futures = [executor.submit(self.worker_func, ctxs[i]) for i in range(self.nworker)]
            # 等待完成,并收集执行结果
            for future in concurrent.futures.as_completed(futures):
                workder_id = futures.index(future)
                self.results[workder_id] = future.result()
            # 根据执行器 id 排序,避免乱序
            self.results = {k: v for k, v in sorted(self.results.items(), key=lambda item: item[0])}
        logger.info(f'{self.task_name} done')
        return self

    def get_results(self):        
        """
        获取并行任务执行结果
        """
        # 如果有合并器,则进行合并,否则直接返回结果
        if(self.reducer_func is None):
            return self.results

        return self.reducer_func(self.results)

使用实例

可以参考我写的 “使用 BPE 原理进行汉语字词切分” 一文。

不足之处

由于 Python 的限制,上述方案对 CPU 利用率的提高非常有限。

如果想要更高的性能,可以使用多进程的方式,但这样对 Python 来说复杂度太高了,尤其是跨进程共享内存,实现起来十分繁琐。

有追求的读者可以尝试使用 Golang 等语言的协程,通过 rpc 等方式传递要计算的数据,封装成一个使用简单的并行计算服务。