Split Lines Using Very Little Memory#
在大数据处理领域, 将大数据分而治之是非常常用的技巧. 例如将一个 CSV 或者 NDJSON 按行拆分成小文件然后分而治之.
下面这个例子给出了 Python 中的最优实现 (main1 内存消耗小, 速度慢. main2 内存消耗大, 速度快).
example.py
1# -*- coding: utf-8 -*-
2
3"""
4"""
5
6import io
7import uuid
8import gzip
9import time
10from pathlib import Path
11
12import polars as pl
13from more_itertools import batched # itertools.batched is only available in Python3.12
14from memory_profiler import profile
15
16dir_tmp = Path(__file__).absolute().parent.joinpath("tmp")
17dir_tmp.mkdir(exist_ok=True)
18path_source = dir_tmp.joinpath("source.json.gz")
19
20
21def make_source():
22 """
23 一共 1M 行, 每个 uuid 重复 100 次 (一条记录大约 3KB). 压缩前 3200MB, 压缩后 34.8MB.
24 最终的输出是 100 个 32MB 的小文件.
25 """
26 n_rows = 1_000_000
27 data = [{"id": i, "text": uuid.uuid4().hex * 100} for i in range(1, 1 + n_rows)]
28 df = pl.DataFrame(data)
29 buffer = io.BytesIO()
30 df.write_ndjson(buffer)
31 path_source.write_bytes(gzip.compress(buffer.getvalue()))
32
33
34@profile
35def main1():
36 """
37 elapsed = 3.99
38
39 Line # Mem usage Increment Occurrences Line Contents
40 =============================================================
41 34 63.3 MiB 63.3 MiB 1 @profile
42 35 def main1():
43 46 4531.5 MiB 4468.2 MiB 1 lines = gzip.decompress(path_source.read_bytes()).decode("utf-8").splitlines()
44 47 4562.5 MiB -11951.4 MiB 101 for ith, lst in enumerate(batched(lines, 10000), start=1):
45 48 4562.5 MiB -8820.7 MiB 100 p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
46 49 4593.3 MiB -8759.2 MiB 100 p.write_text("\n".join(lst))
47 """
48 lines = gzip.decompress(path_source.read_bytes()).decode("utf-8").splitlines()
49 for ith, lst in enumerate(batched(lines, 10000), start=1):
50 p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
51 p.write_text("\n".join(lst))
52
53
54@profile
55def main2():
56 """
57 elapsed = 2.2
58
59 Line # Mem usage Increment Occurrences Line Contents
60 =============================================================
61 52 62.7 MiB 62.7 MiB 1 @profile
62 53 def main2():
63 64 3543.1 MiB 3480.4 MiB 1 buffer = io.BytesIO(gzip.decompress(path_source.read_bytes()))
64 65 6985.7 MiB 80.7 MiB 101 for ith, lst in enumerate(batched(buffer.readlines(), 10000), start=1):
65 66 6985.7 MiB 0.0 MiB 100 p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
66 67 6985.7 MiB 1.9 MiB 100 p.write_bytes(b"".join(lst))
67 """
68 buffer = io.BytesIO(gzip.decompress(path_source.read_bytes()))
69 for ith, lst in enumerate(batched(buffer.readlines(), 10000), start=1):
70 p = dir_tmp.joinpath(f"{str(ith).zfill(4)}.json")
71 p.write_bytes(b"".join(lst))
72
73
74if __name__ == "__main__":
75 st = time.process_time()
76 # make_source()
77 main1()
78 # main2()
79 et = time.process_time()
80 elapse = et - st
81 print(f"{elapse = }")