fix: better fmt check from 40s to 4s (#5279)

Signed-off-by: yihong0618 <zouzou0208@gmail.com>
This commit is contained in:
yihong
2025-01-03 16:12:49 +08:00
committed by GitHub
parent 89399131dd
commit 856bba5d95
2 changed files with 33 additions and 19 deletions

View File

@@ -14,6 +14,7 @@
import os import os
import re import re
from multiprocessing import Pool
def find_rust_files(directory): def find_rust_files(directory):
@@ -33,11 +34,9 @@ def extract_branch_names(file_content):
return pattern.findall(file_content) return pattern.findall(file_content)
def check_snafu_in_files(branch_name, rust_files): def check_snafu_in_files(branch_name, rust_files_content):
branch_name_snafu = f"{branch_name}Snafu" branch_name_snafu = f"{branch_name}Snafu"
for rust_file in rust_files: for content in rust_files_content.values():
with open(rust_file, "r") as file:
content = file.read()
if branch_name_snafu in content: if branch_name_snafu in content:
return True return True
return False return False
@@ -49,21 +48,24 @@ def main():
for error_file in error_files: for error_file in error_files:
with open(error_file, "r") as file: with open(error_file, "r") as file:
content = file.read() branch_names.extend(extract_branch_names(file.read()))
branch_names.extend(extract_branch_names(content))
unused_snafu = [ # Read all rust files into memory once
branch_name rust_files_content = {}
for branch_name in branch_names for rust_file in other_rust_files:
if not check_snafu_in_files(branch_name, other_rust_files) with open(rust_file, "r") as file:
] rust_files_content[rust_file] = file.read()
with Pool() as pool:
results = pool.starmap(
check_snafu_in_files, [(bn, rust_files_content) for bn in branch_names]
)
unused_snafu = [bn for bn, found in zip(branch_names, results) if not found]
if unused_snafu: if unused_snafu:
print("Unused error variants:") print("Unused error variants:")
for name in unused_snafu: for name in unused_snafu:
print(name) print(name)
if unused_snafu:
raise SystemExit(1) raise SystemExit(1)

View File

@@ -35,10 +35,23 @@ data = {
"bigint_other": [5, -5, 1, 5, 5], "bigint_other": [5, -5, 1, 5, 5],
"utf8_increase": ["a", "bb", "ccc", "dddd", "eeeee"], "utf8_increase": ["a", "bb", "ccc", "dddd", "eeeee"],
"utf8_decrease": ["eeeee", "dddd", "ccc", "bb", "a"], "utf8_decrease": ["eeeee", "dddd", "ccc", "bb", "a"],
"timestamp_simple": [datetime.datetime(2023, 4, 1, 20, 15, 30, 2000), datetime.datetime.fromtimestamp(int('1629617204525777000')/1000000000), datetime.datetime(2023, 1, 1), datetime.datetime(2023, 2, 1), datetime.datetime(2023, 3, 1)], "timestamp_simple": [
"date_simple": [datetime.date(2023, 4, 1), datetime.date(2023, 3, 1), datetime.date(2023, 1, 1), datetime.date(2023, 2, 1), datetime.date(2023, 3, 1)] datetime.datetime(2023, 4, 1, 20, 15, 30, 2000),
datetime.datetime.fromtimestamp(int("1629617204525777000") / 1000000000),
datetime.datetime(2023, 1, 1),
datetime.datetime(2023, 2, 1),
datetime.datetime(2023, 3, 1),
],
"date_simple": [
datetime.date(2023, 4, 1),
datetime.date(2023, 3, 1),
datetime.date(2023, 1, 1),
datetime.date(2023, 2, 1),
datetime.date(2023, 3, 1),
],
} }
def infer_schema(data): def infer_schema(data):
schema = "struct<" schema = "struct<"
for key, value in data.items(): for key, value in data.items():
@@ -56,7 +69,7 @@ def infer_schema(data):
elif key.startswith("date"): elif key.startswith("date"):
dt = "date" dt = "date"
else: else:
print(key,value,dt) print(key, value, dt)
raise NotImplementedError raise NotImplementedError
if key.startswith("double"): if key.startswith("double"):
dt = "double" dt = "double"
@@ -68,7 +81,6 @@ def infer_schema(data):
return schema return schema
def _write( def _write(
schema: str, schema: str,
data, data,