mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-06 20:02:58 +00:00
feat: add python Permutation class to mimic hugging face dataset and provide pytorch dataloader (#2725)
This commit is contained in:
@@ -138,7 +138,9 @@ describe("PermutationBuilder", () => {
|
||||
});
|
||||
|
||||
test("should create permutation with calculated splits", async () => {
|
||||
const builder = permutationBuilder(table).splitCalculated("id % 2");
|
||||
const builder = permutationBuilder(table).splitCalculated({
|
||||
calculation: "id % 2",
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
@@ -224,4 +226,146 @@ describe("PermutationBuilder", () => {
|
||||
// Should throw error on second execution
|
||||
await expect(builder.execute()).rejects.toThrow("Builder already consumed");
|
||||
});
|
||||
|
||||
test("should accept custom split names with random splits", async () => {
|
||||
const builder = permutationBuilder(table).splitRandom({
|
||||
ratios: [0.3, 0.7],
|
||||
seed: 42,
|
||||
splitNames: ["train", "test"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric (0, 1, etc.)
|
||||
// The names are metadata that can be used by higher-level APIs
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should accept custom split names with hash splits", async () => {
|
||||
const builder = permutationBuilder(table).splitHash({
|
||||
columns: ["id"],
|
||||
splitWeights: [50, 50],
|
||||
discardWeight: 0,
|
||||
splitNames: ["set_a", "set_b"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should accept custom split names with sequential splits", async () => {
|
||||
const builder = permutationBuilder(table).splitSequential({
|
||||
ratios: [0.5, 0.5],
|
||||
splitNames: ["first", "second"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBe(5);
|
||||
expect(split1Count).toBe(5);
|
||||
});
|
||||
|
||||
test("should accept custom split names with calculated splits", async () => {
|
||||
const builder = permutationBuilder(table).splitCalculated({
|
||||
calculation: "id % 2",
|
||||
splitNames: ["even", "odd"],
|
||||
});
|
||||
|
||||
const permutationTable = await builder.execute();
|
||||
const rowCount = await permutationTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Split names are provided but split_id is still numeric
|
||||
const split0Count = await permutationTable.countRows("split_id = 0");
|
||||
const split1Count = await permutationTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
});
|
||||
|
||||
test("should persist permutation to a new table", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const builder = permutationBuilder(table)
|
||||
.splitRandom({
|
||||
ratios: [0.7, 0.3],
|
||||
seed: 42,
|
||||
splitNames: ["train", "validation"],
|
||||
})
|
||||
.persist(db, "my_permutation");
|
||||
|
||||
// Execute the builder which will persist the table
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
// Verify the persisted table exists and can be opened
|
||||
const persistedTable = await db.openTable("my_permutation");
|
||||
expect(persistedTable).toBeDefined();
|
||||
|
||||
// Verify the persisted table has the correct number of rows
|
||||
const rowCount = await persistedTable.countRows();
|
||||
expect(rowCount).toBe(10);
|
||||
|
||||
// Verify splits exist (numeric split_id values)
|
||||
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(10);
|
||||
|
||||
// Verify the table returned by execute is the same as the persisted one
|
||||
const executedRowCount = await permutationTable.countRows();
|
||||
expect(executedRowCount).toBe(10);
|
||||
});
|
||||
|
||||
test("should persist permutation with multiple operations", async () => {
|
||||
const db = await connect(tmpDir.name);
|
||||
const builder = permutationBuilder(table)
|
||||
.filter("value > 30")
|
||||
.splitRandom({ ratios: [0.5, 0.5], seed: 123, splitNames: ["a", "b"] })
|
||||
.shuffle({ seed: 456 })
|
||||
.persist(db, "filtered_permutation");
|
||||
|
||||
// Execute the builder
|
||||
const permutationTable = await builder.execute();
|
||||
|
||||
// Verify the persisted table
|
||||
const persistedTable = await db.openTable("filtered_permutation");
|
||||
const rowCount = await persistedTable.countRows();
|
||||
expect(rowCount).toBe(7); // Values 40, 50, 60, 70, 80, 90, 100
|
||||
|
||||
// Verify splits exist (numeric split_id values)
|
||||
const split0Count = await persistedTable.countRows("split_id = 0");
|
||||
const split1Count = await persistedTable.countRows("split_id = 1");
|
||||
|
||||
expect(split0Count).toBeGreaterThan(0);
|
||||
expect(split1Count).toBeGreaterThan(0);
|
||||
expect(split0Count + split1Count).toBe(7);
|
||||
|
||||
// Verify the executed table matches
|
||||
const executedRowCount = await permutationTable.countRows();
|
||||
expect(executedRowCount).toBe(7);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user