feat(nodejs): add compatibility across arrow versions (#1337)

while adding some more docs & examples for the new js sdk, i ran across
a few compatibility issues when using different arrow versions. This
should fix those issues.
This commit is contained in:
Cory Grinstead
2024-05-29 17:36:34 -05:00
committed by GitHub
parent dbea3a7544
commit bc139000bd
12 changed files with 211 additions and 77 deletions

View File

@@ -11,18 +11,21 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Float, Float32, Int32, Utf8, Vector } from "apache-arrow";
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import * as tmp from "tmp";
import { connect } from "../lancedb";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("LanceSchema", () => {
describe.each([arrow, arrowOld])("LanceSchema", (arrow) => {
test("should preserve input order", async () => {
const schema = LanceSchema({
id: new Int32(),
text: new Utf8(),
vector: new Float32(),
id: new arrow.Int32(),
text: new arrow.Utf8(),
vector: new arrow.Float32(),
});
expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]);
});
@@ -53,8 +56,8 @@ describe("Registry", () => {
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
@@ -65,8 +68,8 @@ describe("Registry", () => {
.create();
const schema = LanceSchema({
id: new Int32(),
text: func.sourceField(new Utf8()),
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
@@ -88,7 +91,7 @@ describe("Registry", () => {
.getChild("vector")
?.toArray()
.map((x: unknown) => {
if (x instanceof Vector) {
if (x instanceof arrow.Vector) {
return [...x];
} else {
return x;
@@ -109,8 +112,8 @@ describe("Registry", () => {
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
@@ -134,8 +137,8 @@ describe("Registry", () => {
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
@@ -144,8 +147,8 @@ describe("Registry", () => {
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
id: new Int32(),
text: func.sourceField(new Utf8()),
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const expectedMetadata = new Map<string, string>([

View File

@@ -16,6 +16,10 @@ import * as fs from "fs";
import * as path from "path";
import * as tmp from "tmp";
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import { Table, connect } from "../lancedb";
import {
Field,
FixedSizeList,
@@ -26,17 +30,20 @@ import {
Int64,
Schema,
Utf8,
} from "apache-arrow";
import { Table, connect } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
makeArrowTable,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { Index } from "../lancedb/indices";
describe("Given a table", () => {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
let tmpDir: tmp.DirResult;
let table: Table;
const schema = new Schema([new Field("id", new Float64(), true)]);
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
@@ -551,7 +558,7 @@ describe("embedding functions", () => {
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
id: new arrow.Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});