Split Lines Using Very Little Memory#

在大数据处理领域, 将大数据分而治之是非常常用的技巧. 例如将一个 CSV 或者 NDJSON 按行拆分成小文件然后分而治之.

下面这个例子给出了 Python 中的最优实现 (main1 内存消耗小, 速度慢. main2 内存消耗大, 速度快).

example.py
 1# -*- coding: utf-8 -*-
 2
 3"""
 4"""
 5
 6import io
 7import uuid
 8import gzip
 9import time
10from pathlib import Path
11
12import polars as pl
13from more_itertools import batched  # itertools.batched is only available in Python3.12
14from memory_profiler import profile
15
16dir_tmp = Path(__file__).absolute().parent.joinpath("tmp")
17dir_tmp.mkdir(exist_ok=True)
18path_source = dir_tmp.joinpath("source.json.gz")
19
20
21def make_source():
22    """
23    一共 1M 行, 每个 uuid 重复 100 次 (一条记录大约 3KB). 压缩前 3200MB, 压缩后 34.8MB.
24    最终的输出是 100 个 32MB 的小文件.
25    """
26    n_rows = 1_000_000
27    data = [{"id": i, "text": uuid.uuid4().hex * 100} for i in range(1, 1 + n_rows)]
28    df = pl.DataFrame(data)
29    buffer = io.BytesIO()
30    df.write_ndjson(buffer)
31    path_source.write_bytes(gzip.compress(buffer.getvalue()))
32
33
34@profile
35def main1():
36    """
37    elapsed = 3.99
38
39    Line #    Mem usage    Increment  Occurrences   Line Contents
40    =============================================================
41    34     63.3 MiB     63.3 MiB           1   @profile
42    35                                         def main1():
43    46   4531.5 MiB   4468.2 MiB           1       lines = gzip.decompress(path_source.read_bytes()).decode("utf-8").splitlines()
44    47   4562.5 MiB -11951.4 MiB         101       for ith, lst in enumerate(batched(lines, 10000), start=1):
45    48   4562.5 MiB  -8820.7 MiB         100           p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
46    49   4593.3 MiB  -8759.2 MiB         100           p.write_text("\n".join(lst))
47    """
48    lines = gzip.decompress(path_source.read_bytes()).decode("utf-8").splitlines()
49    for ith, lst in enumerate(batched(lines, 10000), start=1):
50        p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
51        p.write_text("\n".join(lst))
52
53
54@profile
55def main2():
56    """
57    elapsed = 2.2
58
59    Line #    Mem usage    Increment  Occurrences   Line Contents
60    =============================================================
61    52     62.7 MiB     62.7 MiB           1   @profile
62    53                                         def main2():
63    64   3543.1 MiB   3480.4 MiB           1       buffer = io.BytesIO(gzip.decompress(path_source.read_bytes()))
64    65   6985.7 MiB     80.7 MiB         101       for ith, lst in enumerate(batched(buffer.readlines(), 10000), start=1):
65    66   6985.7 MiB      0.0 MiB         100           p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
66    67   6985.7 MiB      1.9 MiB         100           p.write_bytes(b"".join(lst))
67    """
68    buffer = io.BytesIO(gzip.decompress(path_source.read_bytes()))
69    for ith, lst in enumerate(batched(buffer.readlines(), 10000), start=1):
70        p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
71        p.write_bytes(b"".join(lst))
72
73
74if __name__ == "__main__":
75    st = time.process_time()
76    # make_source()
77    main1()
78    # main2()
79    et = time.process_time()
80    elapse = et - st
81    print(f"{elapse = }")