diff --git a/scripts/check-snafu.py b/scripts/check-snafu.py index b91950692b..662f6758d5 100644 --- a/scripts/check-snafu.py +++ b/scripts/check-snafu.py @@ -14,6 +14,7 @@ import os import re +from multiprocessing import Pool def find_rust_files(directory): @@ -33,13 +34,11 @@ def extract_branch_names(file_content): return pattern.findall(file_content) -def check_snafu_in_files(branch_name, rust_files): +def check_snafu_in_files(branch_name, rust_files_content): branch_name_snafu = f"{branch_name}Snafu" - for rust_file in rust_files: - with open(rust_file, "r") as file: - content = file.read() - if branch_name_snafu in content: - return True + for content in rust_files_content.values(): + if branch_name_snafu in content: + return True return False @@ -49,21 +48,24 @@ def main(): for error_file in error_files: with open(error_file, "r") as file: - content = file.read() - branch_names.extend(extract_branch_names(content)) + branch_names.extend(extract_branch_names(file.read())) - unused_snafu = [ - branch_name - for branch_name in branch_names - if not check_snafu_in_files(branch_name, other_rust_files) - ] + # Read all rust files into memory once + rust_files_content = {} + for rust_file in other_rust_files: + with open(rust_file, "r") as file: + rust_files_content[rust_file] = file.read() + + with Pool() as pool: + results = pool.starmap( + check_snafu_in_files, [(bn, rust_files_content) for bn in branch_names] + ) + unused_snafu = [bn for bn, found in zip(branch_names, results) if not found] if unused_snafu: print("Unused error variants:") for name in unused_snafu: print(name) - - if unused_snafu: raise SystemExit(1) diff --git a/src/common/datasource/tests/orc/write.py b/src/common/datasource/tests/orc/write.py index f0e2792299..aa97c09a63 100644 --- a/src/common/datasource/tests/orc/write.py +++ b/src/common/datasource/tests/orc/write.py @@ -35,10 +35,23 @@ data = { "bigint_other": [5, -5, 1, 5, 5], "utf8_increase": ["a", "bb", "ccc", "dddd", "eeeee"], "utf8_decrease": ["eeeee", "dddd", "ccc", "bb", "a"], - "timestamp_simple": [datetime.datetime(2023, 4, 1, 20, 15, 30, 2000), datetime.datetime.fromtimestamp(int('1629617204525777000')/1000000000), datetime.datetime(2023, 1, 1), datetime.datetime(2023, 2, 1), datetime.datetime(2023, 3, 1)], - "date_simple": [datetime.date(2023, 4, 1), datetime.date(2023, 3, 1), datetime.date(2023, 1, 1), datetime.date(2023, 2, 1), datetime.date(2023, 3, 1)] + "timestamp_simple": [ + datetime.datetime(2023, 4, 1, 20, 15, 30, 2000), + datetime.datetime.fromtimestamp(int("1629617204525777000") / 1000000000), + datetime.datetime(2023, 1, 1), + datetime.datetime(2023, 2, 1), + datetime.datetime(2023, 3, 1), + ], + "date_simple": [ + datetime.date(2023, 4, 1), + datetime.date(2023, 3, 1), + datetime.date(2023, 1, 1), + datetime.date(2023, 2, 1), + datetime.date(2023, 3, 1), + ], } + def infer_schema(data): schema = "struct<" for key, value in data.items(): @@ -56,7 +69,7 @@ def infer_schema(data): elif key.startswith("date"): dt = "date" else: - print(key,value,dt) + print(key, value, dt) raise NotImplementedError if key.startswith("double"): dt = "double" @@ -68,7 +81,6 @@ def infer_schema(data): return schema - def _write( schema: str, data,