feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)

This commit is contained in:
Weston Pace
2025-11-06 16:15:33 -08:00
committed by GitHub
parent 6ddd271627
commit aeac9c7644
24 changed files with 2071 additions and 126 deletions

View File

@@ -138,7 +138,9 @@ describe("PermutationBuilder", () => {
});
test("should create permutation with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated("id % 2");
const builder = permutationBuilder(table).splitCalculated({
calculation: "id % 2",
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
@@ -224,4 +226,146 @@ describe("PermutationBuilder", () => {
// Should throw error on second execution
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
});
test("should accept custom split names with random splits", async () => {
const builder = permutationBuilder(table).splitRandom({
ratios: [0.3, 0.7],
seed: 42,
splitNames: ["train", "test"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric (0, 1, etc.)
// The names are metadata that can be used by higher-level APIs
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should accept custom split names with hash splits", async () => {
const builder = permutationBuilder(table).splitHash({
columns: ["id"],
splitWeights: [50, 50],
discardWeight: 0,
splitNames: ["set_a", "set_b"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should accept custom split names with sequential splits", async () => {
const builder = permutationBuilder(table).splitSequential({
ratios: [0.5, 0.5],
splitNames: ["first", "second"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBe(5);
expect(split1Count).toBe(5);
});
test("should accept custom split names with calculated splits", async () => {
const builder = permutationBuilder(table).splitCalculated({
calculation: "id % 2",
splitNames: ["even", "odd"],
});
const permutationTable = await builder.execute();
const rowCount = await permutationTable.countRows();
expect(rowCount).toBe(10);
// Split names are provided but split_id is still numeric
const split0Count = await permutationTable.countRows("split_id = 0");
const split1Count = await permutationTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
});
test("should persist permutation to a new table", async () => {
const db = await connect(tmpDir.name);
const builder = permutationBuilder(table)
.splitRandom({
ratios: [0.7, 0.3],
seed: 42,
splitNames: ["train", "validation"],
})
.persist(db, "my_permutation");
// Execute the builder which will persist the table
const permutationTable = await builder.execute();
// Verify the persisted table exists and can be opened
const persistedTable = await db.openTable("my_permutation");
expect(persistedTable).toBeDefined();
// Verify the persisted table has the correct number of rows
const rowCount = await persistedTable.countRows();
expect(rowCount).toBe(10);
// Verify splits exist (numeric split_id values)
const split0Count = await persistedTable.countRows("split_id = 0");
const split1Count = await persistedTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(10);
// Verify the table returned by execute is the same as the persisted one
const executedRowCount = await permutationTable.countRows();
expect(executedRowCount).toBe(10);
});
test("should persist permutation with multiple operations", async () => {
const db = await connect(tmpDir.name);
const builder = permutationBuilder(table)
.filter("value > 30")
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
.shuffle({ seed: 456 })
.persist(db, "filtered_permutation");
// Execute the builder
const permutationTable = await builder.execute();
// Verify the persisted table
const persistedTable = await db.openTable("filtered_permutation");
const rowCount = await persistedTable.countRows();
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
// Verify splits exist (numeric split_id values)
const split0Count = await persistedTable.countRows("split_id = 0");
const split1Count = await persistedTable.countRows("split_id = 1");
expect(split0Count).toBeGreaterThan(0);
expect(split1Count).toBeGreaterThan(0);
expect(split0Count + split1Count).toBe(7);
// Verify the executed table matches
const executedRowCount = await permutationTable.countRows();
expect(executedRowCount).toBe(7);
});
});