feat(nodejs): add compatibility across arrow versions (#1337)

while adding some more docs & examples for the new js sdk, i ran across
a few compatibility issues when using different arrow versions. This
should fix those issues.
This commit is contained in:
Cory Grinstead
2024-05-29 17:36:34 -05:00
committed by GitHub
parent dbea3a7544
commit bc139000bd
12 changed files with 211 additions and 77 deletions

View File

@@ -11,18 +11,21 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import { Float, Float32, Int32, Utf8, Vector } from "apache-arrow";
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import * as tmp from "tmp";
import { connect } from "../lancedb";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
describe("LanceSchema", () => {
describe.each([arrow, arrowOld])("LanceSchema", (arrow) => {
test("should preserve input order", async () => {
const schema = LanceSchema({
id: new Int32(),
text: new Utf8(),
vector: new Float32(),
id: new arrow.Int32(),
text: new arrow.Utf8(),
vector: new arrow.Float32(),
});
expect(schema.fields.map((x) => x.name)).toEqual(["id", "text", "vector"]);
});
@@ -53,8 +56,8 @@ describe("Registry", () => {
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
@@ -65,8 +68,8 @@ describe("Registry", () => {
.create();
const schema = LanceSchema({
id: new Int32(),
text: func.sourceField(new Utf8()),
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
@@ -88,7 +91,7 @@ describe("Registry", () => {
.getChild("vector")
?.toArray()
.map((x: unknown) => {
if (x instanceof Vector) {
if (x instanceof arrow.Vector) {
return [...x];
} else {
return x;
@@ -109,8 +112,8 @@ describe("Registry", () => {
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
@@ -134,8 +137,8 @@ describe("Registry", () => {
ndims() {
return 3;
}
embeddingDataType(): Float {
return new Float32();
embeddingDataType(): arrow.Float {
return new arrow.Float32();
}
async computeSourceEmbeddings(data: string[]) {
return data.map(() => [1, 2, 3]);
@@ -144,8 +147,8 @@ describe("Registry", () => {
const func = new MockEmbeddingFunction();
const schema = LanceSchema({
id: new Int32(),
text: func.sourceField(new Utf8()),
id: new arrow.Int32(),
text: func.sourceField(new arrow.Utf8()),
vector: func.vectorField(),
});
const expectedMetadata = new Map<string, string>([

View File

@@ -16,6 +16,10 @@ import * as fs from "fs";
import * as path from "path";
import * as tmp from "tmp";
import * as arrow from "apache-arrow";
import * as arrowOld from "apache-arrow-old";
import { Table, connect } from "../lancedb";
import {
Field,
FixedSizeList,
@@ -26,17 +30,20 @@ import {
Int64,
Schema,
Utf8,
} from "apache-arrow";
import { Table, connect } from "../lancedb";
import { makeArrowTable } from "../lancedb/arrow";
makeArrowTable,
} from "../lancedb/arrow";
import { EmbeddingFunction, LanceSchema } from "../lancedb/embedding";
import { getRegistry, register } from "../lancedb/embedding/registry";
import { Index } from "../lancedb/indices";
describe("Given a table", () => {
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
let tmpDir: tmp.DirResult;
let table: Table;
const schema = new Schema([new Field("id", new Float64(), true)]);
const schema = new arrow.Schema([
new arrow.Field("id", new arrow.Float64(), true),
]);
beforeEach(async () => {
tmpDir = tmp.dirSync({ unsafeCleanup: true });
const conn = await connect(tmpDir.name);
@@ -551,7 +558,7 @@ describe("embedding functions", () => {
const func = getRegistry().get<MockEmbeddingFunction>("mock")!.create();
const schema = LanceSchema({
id: new Float64(),
id: new arrow.Float64(),
text: func.sourceField(new Utf8()),
vector: func.vectorField(),
});

View File

@@ -17,10 +17,14 @@ import {
Binary,
DataType,
Field,
FixedSizeBinary,
FixedSizeList,
type Float,
Float,
Float32,
Int,
LargeBinary,
List,
Null,
RecordBatch,
RecordBatchFileWriter,
RecordBatchStreamWriter,
@@ -35,7 +39,98 @@ import {
} from "apache-arrow";
import { type EmbeddingFunction } from "./embedding/embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { sanitizeSchema } from "./sanitize";
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
export * from "apache-arrow";
export function isArrowTable(value: object): value is ArrowTable {
if (value instanceof ArrowTable) return true;
return "schema" in value && "batches" in value;
}
export function isDataType(value: unknown): value is DataType {
return (
value instanceof DataType ||
DataType.isNull(value) ||
DataType.isInt(value) ||
DataType.isFloat(value) ||
DataType.isBinary(value) ||
DataType.isLargeBinary(value) ||
DataType.isUtf8(value) ||
DataType.isLargeUtf8(value) ||
DataType.isBool(value) ||
DataType.isDecimal(value) ||
DataType.isDate(value) ||
DataType.isTime(value) ||
DataType.isTimestamp(value) ||
DataType.isInterval(value) ||
DataType.isDuration(value) ||
DataType.isList(value) ||
DataType.isStruct(value) ||
DataType.isUnion(value) ||
DataType.isFixedSizeBinary(value) ||
DataType.isFixedSizeList(value) ||
DataType.isMap(value) ||
DataType.isDictionary(value)
);
}
export function isNull(value: unknown): value is Null {
return value instanceof Null || DataType.isNull(value);
}
export function isInt(value: unknown): value is Int {
return value instanceof Int || DataType.isInt(value);
}
export function isFloat(value: unknown): value is Float {
return value instanceof Float || DataType.isFloat(value);
}
export function isBinary(value: unknown): value is Binary {
return value instanceof Binary || DataType.isBinary(value);
}
export function isLargeBinary(value: unknown): value is LargeBinary {
return value instanceof LargeBinary || DataType.isLargeBinary(value);
}
export function isUtf8(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isUtf8(value);
}
export function isLargeUtf8(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isLargeUtf8(value);
}
export function isBool(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isBool(value);
}
export function isDecimal(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDecimal(value);
}
export function isDate(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDate(value);
}
export function isTime(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isTime(value);
}
export function isTimestamp(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isTimestamp(value);
}
export function isInterval(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isInterval(value);
}
export function isDuration(value: unknown): value is Utf8 {
return value instanceof Utf8 || DataType.isDuration(value);
}
export function isList(value: unknown): value is List {
return value instanceof List || DataType.isList(value);
}
export function isStruct(value: unknown): value is Struct {
return value instanceof Struct || DataType.isStruct(value);
}
export function isUnion(value: unknown): value is Struct {
return value instanceof Struct || DataType.isUnion(value);
}
export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary {
return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value);
}
export function isFixedSizeList(value: unknown): value is FixedSizeList {
return value instanceof FixedSizeList || DataType.isFixedSizeList(value);
}
/** Data type accepted by NodeJS SDK */
export type Data = Record<string, unknown>[] | ArrowTable;
@@ -442,8 +537,8 @@ async function applyEmbeddingsFromMetadata(
}
let destType: DataType;
const dtype = schema.fields.find((f) => f.name === destColumn)!.type;
if (dtype instanceof FixedSizeList) {
destType = dtype;
if (isFixedSizeList(dtype)) {
destType = sanitizeType(dtype);
} else {
throw new Error(
"Expected FixedSizeList as datatype for vector field, instead got: " +
@@ -588,7 +683,7 @@ export function newVectorType<T extends Float>(
): FixedSizeList<T> {
// in Lance we always default to have the elements nullable, so we need to set it to true
// otherwise we often get schema mismatches because the stored data always has schema with nullable elements
const children = new Field<T>("item", innerType, true);
const children = new Field("item", <T>sanitizeType(innerType), true);
return new FixedSizeList(dim, children);
}
@@ -669,7 +764,7 @@ export async function fromDataToBuffer(
if (schema !== undefined && schema !== null) {
schema = sanitizeSchema(schema);
}
if (data instanceof ArrowTable) {
if (isArrowTable(data)) {
return fromTableToBuffer(data, embeddings, schema);
} else {
const table = await convertToTable(data, embeddings, { schema });
@@ -750,8 +845,10 @@ function validateSchemaEmbeddings(
// if it does not, we add it to the list of missing embedding fields
// Finally, we check if those missing embedding fields are `this._embeddings`
// if they are not, we throw an error
for (const field of schema.fields) {
if (field.type instanceof FixedSizeList) {
for (let field of schema.fields) {
if (isFixedSizeList(field.type)) {
field = sanitizeField(field);
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
if (schema.metadata.has("embedding_functions")) {
const embeddings = JSON.parse(

View File

@@ -12,8 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable, Schema } from "apache-arrow";
import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow";
import { Table as ArrowTable, Schema } from "./arrow";
import {
fromTableToBuffer,
isArrowTable,
makeArrowTable,
makeEmptyTable,
} from "./arrow";
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
import { Table } from "./table";
@@ -200,7 +205,7 @@ export class Connection {
}
let table: ArrowTable;
if (data instanceof ArrowTable) {
if (isArrowTable(data)) {
table = data;
} else {
table = makeArrowTable(data, options);

View File

@@ -12,9 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { DataType, Field, FixedSizeList, Float, Float32 } from "apache-arrow";
import "reflect-metadata";
import { newVectorType } from "../arrow";
import {
DataType,
Field,
FixedSizeList,
Float,
Float32,
isDataType,
isFixedSizeList,
isFloat,
newVectorType,
} from "../arrow";
import { sanitizeType } from "../sanitize";
/**
* Options for a given embedding function
@@ -69,13 +79,13 @@ export abstract class EmbeddingFunction<
sourceField(
optionsOrDatatype: Partial<FieldOptions> | DataType,
): [DataType, Map<string, EmbeddingFunction>] {
const datatype =
optionsOrDatatype instanceof DataType
? optionsOrDatatype
: optionsOrDatatype?.datatype;
let datatype = isDataType(optionsOrDatatype)
? optionsOrDatatype
: optionsOrDatatype?.datatype;
if (!datatype) {
throw new Error("Datatype is required");
}
datatype = sanitizeType(datatype);
const metadata = new Map<string, EmbeddingFunction>();
metadata.set("source_column_for", this);
@@ -100,9 +110,9 @@ export abstract class EmbeddingFunction<
}
dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
} else {
if (options.datatype instanceof FixedSizeList) {
if (isFixedSizeList(options.datatype)) {
dtype = options.datatype;
} else if (options.datatype instanceof Float) {
} else if (isFloat(options.datatype)) {
if (dims === undefined) {
throw new Error("ndims is required for vector field");
}

View File

@@ -12,12 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { DataType, Field, Schema } from "apache-arrow";
import { DataType, Field, Schema } from "../arrow";
import { isDataType } from "../arrow";
import { sanitizeType } from "../sanitize";
import { EmbeddingFunction } from "./embedding_function";
import { EmbeddingFunctionConfig, getRegistry } from "./registry";
export { EmbeddingFunction } from "./embedding_function";
// We need to explicitly export '*' so that the `register` decorator actually registers the class.
export * from "./openai";
export * from "./registry";
/**
* Create a schema with embedding functions.
@@ -42,7 +47,7 @@ export * from "./openai";
* ```
*/
export function LanceSchema(
fields: Record<string, [DataType, Map<string, EmbeddingFunction>] | DataType>,
fields: Record<string, [object, Map<string, EmbeddingFunction>] | object>,
): Schema {
const arrowFields: Field[] = [];
@@ -51,11 +56,14 @@ export function LanceSchema(
Partial<EmbeddingFunctionConfig>
>();
Object.entries(fields).forEach(([key, value]) => {
if (value instanceof DataType) {
arrowFields.push(new Field(key, value, true));
if (isDataType(value)) {
arrowFields.push(new Field(key, sanitizeType(value), true));
} else {
const [dtype, metadata] = value;
arrowFields.push(new Field(key, dtype, true));
const [dtype, metadata] = value as [
object,
Map<string, EmbeddingFunction>,
];
arrowFields.push(new Field(key, sanitizeType(dtype), true));
parseEmbeddingFunctions(embeddingFunctions, key, metadata);
}
});

View File

@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Float, Float32 } from "apache-arrow";
import type OpenAI from "openai";
import { Float, Float32 } from "../arrow";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "apache-arrow";
import { Table as ArrowTable, RecordBatch, tableFromIPC } from "./arrow";
import { type IvfPqOptions } from "./indices";
import {
RecordBatchIterator as NativeBatchIterator,

View File

@@ -20,6 +20,7 @@
// comes from the exact same library instance. This is not always the case
// and so we must sanitize the input to ensure that it is compatible.
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
import {
Binary,
Bool,
@@ -75,10 +76,9 @@ import {
Uint64,
Union,
Utf8,
} from "apache-arrow";
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
} from "./arrow";
function sanitizeMetadata(
export function sanitizeMetadata(
metadataLike?: unknown,
): Map<string, string> | undefined {
if (metadataLike === undefined || metadataLike === null) {
@@ -97,7 +97,7 @@ function sanitizeMetadata(
return metadataLike as Map<string, string>;
}
function sanitizeInt(typeLike: object) {
export function sanitizeInt(typeLike: object) {
if (
!("bitWidth" in typeLike) ||
typeof typeLike.bitWidth !== "number" ||
@@ -111,14 +111,14 @@ function sanitizeInt(typeLike: object) {
return new Int(typeLike.isSigned, typeLike.bitWidth as IntBitWidth);
}
function sanitizeFloat(typeLike: object) {
export function sanitizeFloat(typeLike: object) {
if (!("precision" in typeLike) || typeof typeLike.precision !== "number") {
throw Error("Expected a Float Type to have a `precision` property");
}
return new Float(typeLike.precision as Precision);
}
function sanitizeDecimal(typeLike: object) {
export function sanitizeDecimal(typeLike: object) {
if (
!("scale" in typeLike) ||
typeof typeLike.scale !== "number" ||
@@ -134,14 +134,14 @@ function sanitizeDecimal(typeLike: object) {
return new Decimal(typeLike.scale, typeLike.precision, typeLike.bitWidth);
}
function sanitizeDate(typeLike: object) {
export function sanitizeDate(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Date type to have a `unit` property");
}
return new Date_(typeLike.unit as DateUnit);
}
function sanitizeTime(typeLike: object) {
export function sanitizeTime(typeLike: object) {
if (
!("unit" in typeLike) ||
typeof typeLike.unit !== "number" ||
@@ -155,7 +155,7 @@ function sanitizeTime(typeLike: object) {
return new Time(typeLike.unit, typeLike.bitWidth as TimeBitWidth);
}
function sanitizeTimestamp(typeLike: object) {
export function sanitizeTimestamp(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Timestamp type to have a `unit` property");
}
@@ -166,7 +166,7 @@ function sanitizeTimestamp(typeLike: object) {
return new Timestamp(typeLike.unit, timezone);
}
function sanitizeTypedTimestamp(
export function sanitizeTypedTimestamp(
typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention
Datatype:
@@ -182,14 +182,14 @@ function sanitizeTypedTimestamp(
return new Datatype(timezone);
}
function sanitizeInterval(typeLike: object) {
export function sanitizeInterval(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected an Interval type to have a `unit` property");
}
return new Interval(typeLike.unit);
}
function sanitizeList(typeLike: object) {
export function sanitizeList(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a List type to have an array-like `children` property",
@@ -201,7 +201,7 @@ function sanitizeList(typeLike: object) {
return new List(sanitizeField(typeLike.children[0]));
}
function sanitizeStruct(typeLike: object) {
export function sanitizeStruct(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Struct type to have an array-like `children` property",
@@ -210,7 +210,7 @@ function sanitizeStruct(typeLike: object) {
return new Struct(typeLike.children.map((child) => sanitizeField(child)));
}
function sanitizeUnion(typeLike: object) {
export function sanitizeUnion(typeLike: object) {
if (
!("typeIds" in typeLike) ||
!("mode" in typeLike) ||
@@ -234,7 +234,7 @@ function sanitizeUnion(typeLike: object) {
);
}
function sanitizeTypedUnion(
export function sanitizeTypedUnion(
typeLike: object,
// eslint-disable-next-line @typescript-eslint/naming-convention
UnionType: typeof DenseUnion | typeof SparseUnion,
@@ -256,7 +256,7 @@ function sanitizeTypedUnion(
);
}
function sanitizeFixedSizeBinary(typeLike: object) {
export function sanitizeFixedSizeBinary(typeLike: object) {
if (!("byteWidth" in typeLike) || typeof typeLike.byteWidth !== "number") {
throw Error(
"Expected a FixedSizeBinary type to have a `byteWidth` property",
@@ -265,7 +265,7 @@ function sanitizeFixedSizeBinary(typeLike: object) {
return new FixedSizeBinary(typeLike.byteWidth);
}
function sanitizeFixedSizeList(typeLike: object) {
export function sanitizeFixedSizeList(typeLike: object) {
if (!("listSize" in typeLike) || typeof typeLike.listSize !== "number") {
throw Error("Expected a FixedSizeList type to have a `listSize` property");
}
@@ -283,7 +283,7 @@ function sanitizeFixedSizeList(typeLike: object) {
);
}
function sanitizeMap(typeLike: object) {
export function sanitizeMap(typeLike: object) {
if (!("children" in typeLike) || !Array.isArray(typeLike.children)) {
throw Error(
"Expected a Map type to have an array-like `children` property",
@@ -300,14 +300,14 @@ function sanitizeMap(typeLike: object) {
);
}
function sanitizeDuration(typeLike: object) {
export function sanitizeDuration(typeLike: object) {
if (!("unit" in typeLike) || typeof typeLike.unit !== "number") {
throw Error("Expected a Duration type to have a `unit` property");
}
return new Duration(typeLike.unit);
}
function sanitizeDictionary(typeLike: object) {
export function sanitizeDictionary(typeLike: object) {
if (!("id" in typeLike) || typeof typeLike.id !== "number") {
throw Error("Expected a Dictionary type to have an `id` property");
}
@@ -329,7 +329,7 @@ function sanitizeDictionary(typeLike: object) {
}
// biome-ignore lint/suspicious/noExplicitAny: skip
function sanitizeType(typeLike: unknown): DataType<any> {
export function sanitizeType(typeLike: unknown): DataType<any> {
if (typeof typeLike !== "object" || typeLike === null) {
throw Error("Expected a Type but object was null/undefined");
}
@@ -449,7 +449,7 @@ function sanitizeType(typeLike: unknown): DataType<any> {
}
}
function sanitizeField(fieldLike: unknown): Field {
export function sanitizeField(fieldLike: unknown): Field {
if (fieldLike instanceof Field) {
return fieldLike;
}

View File

@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import { Schema, tableFromIPC } from "apache-arrow";
import { Data, fromDataToBuffer } from "./arrow";
import { Data, Schema, fromDataToBuffer, tableFromIPC } from "./arrow";
import { getRegistry } from "./embedding/registry";
import { IndexOptions } from "./indices";
import {

View File

@@ -1,12 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.4.20",
"version": "0.5.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "@lancedb/lancedb",
"version": "0.4.20",
"version": "0.5.0",
"cpu": [
"x64",
"arm64"

View File

@@ -1,8 +1,12 @@
{
"name": "@lancedb/lancedb",
"version": "0.5.0",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"main": "dist/index.js",
"exports": {
".": "./dist/index.js",
"./embedding": "./dist/embedding/index.js"
},
"types": "dist/index.d.ts",
"napi": {
"name": "lancedb",
"triples": {