Compare commits

..

1 Commits

Author SHA1 Message Date
albertlockett
dcfa17c9fc temporarily use local dependencies 2024-06-26 15:28:30 -03:00
35 changed files with 339 additions and 4039 deletions

View File

@@ -24,7 +24,7 @@ env:
jobs:
test-python:
name: Test doc python code
runs-on: "warp-ubuntu-latest-x64-4x"
runs-on: "buildjet-8vcpu-ubuntu-2204"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -56,7 +56,7 @@ jobs:
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
test-node:
name: Test doc nodejs code
runs-on: "warp-ubuntu-latest-x64-4x"
runs-on: "buildjet-8vcpu-ubuntu-2204"
timeout-minutes: 60
strategy:
fail-fast: false

View File

@@ -14,7 +14,7 @@ repos:
hooks:
- id: local-biome-check
name: biome check
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
entry: npx @biomejs/biome@1.7.3 check --config-path nodejs/biome.json nodejs/
language: system
types: [text]
files: "nodejs/.*"

View File

@@ -20,11 +20,18 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.13.0", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.13.0" }
lance-linalg = { "version" = "=0.13.0" }
lance-testing = { "version" = "=0.13.0" }
lance-datafusion = { "version" = "=0.13.0" }
# lance = { "version" = "=0.13.0", "features" = ["dynamodb"] }
# lance-index = { "version" = "=0.13.0" }
# lance-linalg = { "version" = "=0.13.0" }
# lance-testing = { "version" = "=0.13.0" }
# lance-datafusion = { "version" = "=0.13.0" }
lance = { path = "../lance/rust/lance" }
lance-index = { path = "../lance/rust/lance-index" }
lance-linalg= { path = "../lance/rust/lance-linalg" }
lance-testing = { path = "../lance/rust/lance-testing" }
lance-datafusion = { path = "../lance/rust/lance-datafusion" }
# Note that this one does not include pyarrow
arrow = { version = "51.0", optional = false }
arrow-array = "51.0"

View File

@@ -57,8 +57,6 @@ plugins:
- https://arrow.apache.org/docs/objects.inv
- https://pandas.pydata.org/docs/objects.inv
- mkdocs-jupyter
- render_swagger:
allow_arbitrary_locations : true
markdown_extensions:
- admonition
@@ -160,7 +158,6 @@ nav:
- API reference:
- 🐍 Python: python/saas-python.md
- 👾 JavaScript: javascript/modules.md
- REST API: cloud/rest.md
- Quick start: basic.md
- Concepts:
@@ -231,7 +228,6 @@ nav:
- API reference:
- 🐍 Python: python/saas-python.md
- 👾 JavaScript: javascript/modules.md
- REST API: cloud/rest.md
extra_css:
- styles/global.css

View File

@@ -1,479 +0,0 @@
openapi: 3.1.0
info:
version: 1.0.0
title: LanceDB Cloud API
description: |
LanceDB Cloud API is a RESTful API that allows users to access and modify data stored in LanceDB Cloud.
Table actions are considered temporary resource creations and all use POST method.
contact:
name: LanceDB support
url: https://lancedb.com
email: contact@lancedb.com
servers:
- url: https://{db}.{region}.api.lancedb.com
description: LanceDB Cloud REST endpoint.
variables:
db:
default: ""
description: the name of DB
region:
default: "us-east-1"
description: the service region of the DB
security:
- key_auth: []
components:
securitySchemes:
key_auth:
name: x-api-key
type: apiKey
in: header
parameters:
table_name:
name: name
in: path
description: name of the table
required: true
schema:
type: string
responses:
invalid_request:
description: Invalid request
content:
text/plain:
schema:
type: string
not_found:
description: Not found
content:
text/plain:
schema:
type: string
unauthorized:
description: Unauthorized
content:
text/plain:
schema:
type: string
requestBodies:
arrow_stream_buffer:
description: Arrow IPC stream buffer
required: true
content:
application/vnd.apache.arrow.stream:
schema:
type: string
format: binary
paths:
/v1/table/:
get:
description: List tables, optionally, with pagination.
tags:
- Tables
summary: List Tables
operationId: listTables
parameters:
- name: limit
in: query
description: Limits the number of items to return.
schema:
type: integer
- name: page_token
in: query
description: Specifies the starting position of the next query
schema:
type: string
responses:
"200":
description: Successfully returned a list of tables in the DB
content:
application/json:
schema:
type: object
properties:
tables:
type: array
items:
type: string
page_token:
type: string
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/create/:
post:
description: Create a new table
summary: Create a new table
operationId: createTable
tags:
- Tables
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
$ref: "#/components/requestBodies/arrow_stream_buffer"
responses:
"200":
description: Table successfully created
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/query/:
post:
description: Vector Query
url: https://{db-uri}.{aws-region}.api.lancedb.com/v1/table/{name}/query/
tags:
- Data
summary: Vector Query
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
vector:
type: FixedSizeList
description: |
The targetted vector to search for. Required.
vector_column:
type: string
description: |
The column to query, it can be inferred from the schema if there is only one vector column.
prefilter:
type: boolean
description: |
Whether to prefilter the data. Optional.
k:
type: integer
description: |
The number of search results to return. Default is 10.
distance_type:
type: string
description: |
The distance metric to use for search. L2, Cosine, Dot and Hamming are supported. Default is L2.
bypass_vector_index:
type: boolean
description: |
Whether to bypass vector index. Optional.
filter:
type: string
description: |
A filter expression that specifies the rows to query. Optional.
columns:
type: array
items:
type: string
description: |
The columns to return. Optional.
nprobe:
type: integer
description: |
The number of probes to use for search. Optional.
refine_factor:
type: integer
description: |
The refine factor to use for search. Optional.
responses:
"200":
description: top k results if query is successfully executed
content:
application/json:
schema:
type: object
properties:
results:
type: array
items:
type: object
properties:
id:
type: integer
selected_col_1_to_return:
type: col_1_type
selected_col_n_to_return:
type: col_n_type
_distance:
type: float
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/insert/:
post:
description: Insert new data to the Table.
tags:
- Data
operationId: insertData
summary: Insert new data.
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
$ref: "#/components/requestBodies/arrow_stream_buffer"
responses:
"200":
description: Insert successful
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/merge_insert/:
post:
description: Create a "merge insert" operation
This operation can add rows, update rows, and remove rows all in a single
transaction. See python method `lancedb.table.Table.merge_insert` for examples.
tags:
- Data
summary: Merge Insert
operationId: mergeInsert
parameters:
- $ref: "#/components/parameters/table_name"
- name: on
in: query
description: |
The column to use as the primary key for the merge operation.
required: true
schema:
type: string
- name: when_matched_update_all
in: query
description: |
Rows that exist in both the source table (new data) and
the target table (old data) will be updated, replacing
the old row with the corresponding matching row.
required: false
schema:
type: boolean
- name: when_matched_update_all_filt
in: query
description: |
If present then only rows that satisfy the filter expression will
be updated
required: false
schema:
type: string
- name: when_not_matched_insert_all
in: query
description: |
Rows that exist only in the source table (new data) will be
inserted into the target table (old data).
required: false
schema:
type: boolean
- name: when_not_matched_by_source_delete
in: query
description: |
Rows that exist only in the target table (old data) will be
deleted. An optional condition (`when_not_matched_by_source_delete_filt`)
can be provided to limit what data is deleted.
required: false
schema:
type: boolean
- name: when_not_matched_by_source_delete_filt
in: query
description: |
The filter expression that specifies the rows to delete.
required: false
schema:
type: string
requestBody:
$ref: "#/components/requestBodies/arrow_stream_buffer"
responses:
"200":
description: Merge Insert successful
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/delete/:
post:
description: Delete rows from a table.
tags:
- Data
summary: Delete rows from a table
operationId: deleteData
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
predicate:
type: string
description: |
A filter expression that specifies the rows to delete.
responses:
"200":
description: Delete successful
"401":
$ref: "#/components/responses/unauthorized"
/v1/table/{name}/drop/:
post:
description: Drop a table
tags:
- Tables
summary: Drop a table
operationId: dropTable
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
$ref: "#/components/requestBodies/arrow_stream_buffer"
responses:
"200":
description: Drop successful
"401":
$ref: "#/components/responses/unauthorized"
/v1/table/{name}/describe/:
post:
description: Describe a table and return Table Information.
tags:
- Tables
summary: Describe a table
operationId: describeTable
parameters:
- $ref: "#/components/parameters/table_name"
responses:
"200":
description: Table information
content:
application/json:
schema:
type: object
properties:
table:
type: string
version:
type: integer
schema:
type: string
stats:
type: object
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/index/list/:
post:
description: List indexes of a table
tags:
- Tables
summary: List indexes of a table
operationId: listIndexes
parameters:
- $ref: "#/components/parameters/table_name"
responses:
"200":
description: Available list of indexes on the table.
content:
application/json:
schema:
type: object
properties:
indexes:
type: array
items:
type: object
properties:
columns:
type: array
items:
type: string
index_name:
type: string
index_uuid:
type: string
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/create_index/:
post:
description: Create vector index on a Table
tags:
- Tables
summary: Create vector index on a Table
operationId: createIndex
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
column:
type: string
metric_type:
type: string
nullable: false
description: |
The metric type to use for the index. L2, Cosine, Dot are supported.
index_type:
type: string
responses:
"200":
description: Index successfully created
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"
/v1/table/{name}/create_scalar_index/:
post:
description: Create a scalar index on a table
tags:
- Tables
summary: Create a scalar index on a table
operationId: createScalarIndex
parameters:
- $ref: "#/components/parameters/table_name"
requestBody:
required: true
content:
application/json:
schema:
type: object
properties:
column:
type: string
index_type:
type: string
required: false
responses:
"200":
description: Scalar Index successfully created
"400":
$ref: "#/components/responses/invalid_request"
"401":
$ref: "#/components/responses/unauthorized"
"404":
$ref: "#/components/responses/not_found"

View File

@@ -2,5 +2,4 @@ mkdocs==1.5.3
mkdocs-jupyter==0.24.1
mkdocs-material==9.5.3
mkdocstrings[python]==0.20.0
mkdocs-render-swagger-plugin
pydantic
pydantic

View File

@@ -1 +0,0 @@
!!swagger ../../openapi.yml!!

View File

@@ -193,13 +193,13 @@ from lancedb.pydantic import LanceModel, Vector
model = get_registry().get("huggingface").create(name='facebook/bart-base')
class Words(LanceModel):
class TextModel(LanceModel):
text: str = model.SourceField()
vector: Vector(model.ndims()) = model.VectorField()
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
table = db.create_table("greets", schema=Words)
table.add(df)
table.add()
query = "old greeting"
actual = table.search(query).limit(1).to_pydantic(Words)[0]
print(actual.text)

View File

@@ -265,108 +265,6 @@ For **read-only access**, LanceDB will need a policy such as:
}
```
#### DynamoDB Commit Store for concurrent writes
By default, S3 does not support concurrent writes. Having two or more processes
writing to the same table at the same time can lead to data corruption. This is
because S3, unlike other object stores, does not have any atomic put or copy
operation.
To enable concurrent writes, you can configure LanceDB to use a DynamoDB table
as a commit store. This table will be used to coordinate writes between
different processes. To enable this feature, you must modify your connection
URI to use the `s3+ddb` scheme and add a query parameter `ddbTableName` with the
name of the table to use.
=== "Python"
```python
import lancedb
db = await lancedb.connect_async(
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
)
```
=== "JavaScript"
```javascript
const lancedb = require("lancedb");
const db = await lancedb.connect(
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
);
```
The DynamoDB table must be created with the following schema:
- Hash key: `base_uri` (string)
- Range key: `version` (number)
You can create this programmatically with:
=== "Python"
<!-- skip-test -->
```python
import boto3
dynamodb = boto3.client("dynamodb")
table = dynamodb.create_table(
TableName=table_name,
KeySchema=[
{"AttributeName": "base_uri", "KeyType": "HASH"},
{"AttributeName": "version", "KeyType": "RANGE"},
],
AttributeDefinitions=[
{"AttributeName": "base_uri", "AttributeType": "S"},
{"AttributeName": "version", "AttributeType": "N"},
],
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
)
```
=== "JavaScript"
<!-- skip-test -->
```javascript
import {
CreateTableCommand,
DynamoDBClient,
} from "@aws-sdk/client-dynamodb";
const dynamodb = new DynamoDBClient({
region: CONFIG.awsRegion,
credentials: {
accessKeyId: CONFIG.awsAccessKeyId,
secretAccessKey: CONFIG.awsSecretAccessKey,
},
endpoint: CONFIG.awsEndpoint,
});
const command = new CreateTableCommand({
TableName: table_name,
AttributeDefinitions: [
{
AttributeName: "base_uri",
AttributeType: "S",
},
{
AttributeName: "version",
AttributeType: "N",
},
],
KeySchema: [
{ AttributeName: "base_uri", KeyType: "HASH" },
{ AttributeName: "version", KeyType: "RANGE" },
],
ProvisionedThroughput: {
ReadCapacityUnits: 1,
WriteCapacityUnits: 1,
},
});
await client.send(command);
```
#### S3-compatible stores
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify both region and endpoint:

File diff suppressed because one or more lines are too long

View File

@@ -6,7 +6,7 @@
"types": "dist/index.d.ts",
"scripts": {
"tsc": "tsc -b",
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build -p lancedb-node --message-format=json",
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build --message-format=json",
"build-release": "npm run build -- --release",
"test": "npm run tsc && mocha -recursive dist/test",
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",

View File

@@ -15,11 +15,11 @@ crate-type = ["cdylib"]
arrow-ipc.workspace = true
futures.workspace = true
lancedb = { path = "../rust/lancedb" }
napi = { version = "2.16.8", default-features = false, features = [
"napi9",
napi = { version = "2.15", default-features = false, features = [
"napi7",
"async",
] }
napi-derive = "2.16.4"
napi-derive = "2"
# Prevent dynamic linking of lzma, which comes from datafusion
lzma-sys = { version = "*", features = ["static"] }

View File

@@ -63,7 +63,6 @@ describe("Registry", () => {
return data.map(() => [1, 2, 3]);
}
}
const func = getRegistry()
.get<MockEmbeddingFunction>("mock-embedding")!
.create();

View File

@@ -14,11 +14,6 @@
/* eslint-disable @typescript-eslint/naming-convention */
import {
CreateTableCommand,
DeleteTableCommand,
DynamoDBClient,
} from "@aws-sdk/client-dynamodb";
import {
CreateKeyCommand,
KMSClient,
@@ -43,7 +38,6 @@ const CONFIG = {
awsAccessKeyId: "ACCESSKEY",
awsSecretAccessKey: "SECRETKEY",
awsEndpoint: "http://127.0.0.1:4566",
dynamodbEndpoint: "http://127.0.0.1:4566",
awsRegion: "us-east-1",
};
@@ -72,6 +66,7 @@ class S3Bucket {
} catch {
// It's fine if the bucket doesn't exist
}
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
await client.send(new CreateBucketCommand({ Bucket: name }));
return new S3Bucket(name);
}
@@ -84,27 +79,32 @@ class S3Bucket {
static async deleteBucket(client: S3Client, name: string) {
// Must delete all objects before we can delete the bucket
const objects = await client.send(
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
new ListObjectsV2Command({ Bucket: name }),
);
if (objects.Contents) {
for (const object of objects.Contents) {
await client.send(
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
new DeleteObjectCommand({ Bucket: name, Key: object.Key }),
);
}
}
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
await client.send(new DeleteBucketCommand({ Bucket: name }));
}
public async assertAllEncrypted(path: string, keyId: string) {
const client = S3Bucket.s3Client();
const objects = await client.send(
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
new ListObjectsV2Command({ Bucket: this.name, Prefix: path }),
);
if (objects.Contents) {
for (const object of objects.Contents) {
const metadata = await client.send(
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
new HeadObjectCommand({ Bucket: this.name, Key: object.Key }),
);
expect(metadata.ServerSideEncryption).toBe("aws:kms");
@@ -143,6 +143,7 @@ class KmsKey {
public async delete() {
const client = KmsKey.kmsClient();
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
await client.send(new ScheduleKeyDeletionCommand({ KeyId: this.keyId }));
}
}
@@ -223,91 +224,3 @@ maybeDescribe("storage_options", () => {
await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId);
});
});
class DynamoDBCommitTable {
name: string;
constructor(name: string) {
this.name = name;
}
static dynamoClient() {
return new DynamoDBClient({
region: CONFIG.awsRegion,
credentials: {
accessKeyId: CONFIG.awsAccessKeyId,
secretAccessKey: CONFIG.awsSecretAccessKey,
},
endpoint: CONFIG.awsEndpoint,
});
}
public static async create(name: string): Promise<DynamoDBCommitTable> {
const client = DynamoDBCommitTable.dynamoClient();
const command = new CreateTableCommand({
TableName: name,
AttributeDefinitions: [
{
AttributeName: "base_uri",
AttributeType: "S",
},
{
AttributeName: "version",
AttributeType: "N",
},
],
KeySchema: [
{ AttributeName: "base_uri", KeyType: "HASH" },
{ AttributeName: "version", KeyType: "RANGE" },
],
ProvisionedThroughput: {
ReadCapacityUnits: 1,
WriteCapacityUnits: 1,
},
});
await client.send(command);
return new DynamoDBCommitTable(name);
}
public async delete() {
const client = DynamoDBCommitTable.dynamoClient();
await client.send(new DeleteTableCommand({ TableName: this.name }));
}
}
maybeDescribe("DynamoDB Lock", () => {
let bucket: S3Bucket;
let commitTable: DynamoDBCommitTable;
beforeAll(async () => {
bucket = await S3Bucket.create("lancedb2");
commitTable = await DynamoDBCommitTable.create("commitTable");
});
afterAll(async () => {
await commitTable.delete();
await bucket.delete();
});
it("can be used to configure a DynamoDB table for commit log", async () => {
const uri = `s3+ddb://${bucket.name}/test?ddbTableName=${commitTable.name}`;
const db = await connect(uri, {
storageOptions: CONFIG,
readConsistencyInterval: 0,
});
const table = await db.createTable("test", [{ a: 1, b: 2 }]);
// 5 concurrent appends
const futs = Array.from({ length: 5 }, async () => {
// Open a table so each append has a separate table reference. Otherwise
// they will share the same table reference and the internal ReadWriteLock
// will prevent any real concurrency.
const table = await db.openTable("test");
await table.add([{ a: 2, b: 3 }]);
});
await Promise.all(futs);
const rowCount = await table.countRows();
expect(rowCount).toBe(6);
});
});

View File

@@ -1,5 +1,5 @@
{
"$schema": "https://biomejs.dev/schemas/1.8.3/schema.json",
"$schema": "https://biomejs.dev/schemas/1.7.3/schema.json",
"organizeImports": {
"enabled": true
},
@@ -100,16 +100,6 @@
"globals": []
},
"overrides": [
{
"include": ["__test__/s3_integration.test.ts"],
"linter": {
"rules": {
"style": {
"useNamingConvention": "off"
}
}
}
},
{
"include": [
"**/*.ts",

View File

@@ -35,11 +35,6 @@ export interface FunctionOptions {
[key: string]: any;
}
export interface EmbeddingFunctionConstructor<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: T["TOptions"]): T;
}
/**
* An embedding function that automatically creates vector representation for a given column.
*/
@@ -48,12 +43,6 @@ export abstract class EmbeddingFunction<
T = any,
M extends FunctionOptions = FunctionOptions,
> {
/**
* @ignore
* This is only used for associating the options type with the class for type checking
*/
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
readonly TOptions!: M;
/**
* Convert the embedding function to a JSON object
* It is used to serialize the embedding function to the schema

View File

@@ -13,29 +13,24 @@
// limitations under the License.
import type OpenAI from "openai";
import { type EmbeddingCreateParams } from "openai/resources";
import { Float, Float32 } from "../arrow";
import { EmbeddingFunction } from "./embedding_function";
import { register } from "./registry";
export type OpenAIOptions = {
apiKey: string;
model: EmbeddingCreateParams["model"];
apiKey?: string;
model?: string;
};
@register("openai")
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
string,
Partial<OpenAIOptions>
OpenAIOptions
> {
#openai: OpenAI;
#modelName: OpenAIOptions["model"];
#modelName: string;
constructor(
options: Partial<OpenAIOptions> = {
model: "text-embedding-ada-002",
},
) {
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
super();
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
if (!openAIKey) {
@@ -78,7 +73,7 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
case "text-embedding-3-small":
return 1536;
default:
throw new Error(`Unknown model: ${this.#modelName}`);
return null as never;
}
}

View File

@@ -12,15 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import {
type EmbeddingFunction,
type EmbeddingFunctionConstructor,
} from "./embedding_function";
import type { EmbeddingFunction } from "./embedding_function";
import "reflect-metadata";
import { OpenAIEmbeddingFunction } from "./openai";
export interface EmbeddingFunctionOptions {
[key: string]: unknown;
}
export interface EmbeddingFunctionFactory<
T extends EmbeddingFunction = EmbeddingFunction,
> {
new (modelOptions?: EmbeddingFunctionOptions): T;
}
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
create(options?: T["TOptions"]): T;
create(options?: EmbeddingFunctionOptions): T;
}
/**
@@ -30,7 +36,7 @@ interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
* or TextEmbeddingFunction and registering it with the registry
*/
export class EmbeddingFunctionRegistry {
#functions = new Map<string, EmbeddingFunctionConstructor>();
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
/**
* Register an embedding function
@@ -38,9 +44,7 @@ export class EmbeddingFunctionRegistry {
* @param func The function to register
* @throws Error if the function is already registered
*/
register<
T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor,
>(
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
this: EmbeddingFunctionRegistry,
alias?: string,
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
@@ -65,34 +69,18 @@ export class EmbeddingFunctionRegistry {
* Fetch an embedding function by name
* @param name The name of the function
*/
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
name: Name extends "openai" ? "openai" : string,
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
// ex:
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
//
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
// ```ts
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
// ```
): Name extends "openai"
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
: EmbeddingFunctionCreate<T> | undefined {
type Output = Name extends "openai"
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
: EmbeddingFunctionCreate<T> | undefined;
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
name: string,
): EmbeddingFunctionCreate<T> | undefined {
const factory = this.#functions.get(name);
if (!factory) {
return undefined as Output;
return undefined;
}
return {
create: function (options?: T["TOptions"]) {
return new factory(options);
create: function (options: EmbeddingFunctionOptions) {
return new factory(options) as unknown as T;
},
} as Output;
};
}
/**
@@ -116,7 +104,7 @@ export class EmbeddingFunctionRegistry {
name: string;
sourceColumn: string;
vectorColumn: string;
model: EmbeddingFunction["TOptions"];
model: EmbeddingFunctionOptions;
};
const functions = <FunctionConfig[]>(
JSON.parse(metadata.get("embedding_functions")!)

View File

@@ -55,7 +55,7 @@ export class RestfulLanceDBClient {
return axios.create({
baseURL: this.url,
headers: {
// biome-ignore lint: external API
// biome-ignore lint/style/useNamingConvention: external api
Authorization: `Bearer ${this.#apiKey}`,
},
transformResponse: decodeErrorData,

1403
nodejs/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -34,10 +34,9 @@
"devDependencies": {
"@aws-sdk/client-kms": "^3.33.0",
"@aws-sdk/client-s3": "^3.33.0",
"@aws-sdk/client-dynamodb": "^3.33.0",
"@biomejs/biome": "^1.7.3",
"@jest/globals": "^29.7.0",
"@napi-rs/cli": "^2.18.3",
"@napi-rs/cli": "^2.18.0",
"@types/jest": "^29.1.2",
"@types/tmp": "^0.2.6",
"apache-arrow-old": "npm:apache-arrow@13.0.0",
@@ -69,7 +68,7 @@
"lint-ci": "biome ci .",
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
"lint": "biome check . && biome format .",
"lint-fix": "biome check --write . && biome format --write .",
"lint-fix": "biome check --apply-unsafe . && biome format --write .",
"prepublishOnly": "napi prepublish -t npm",
"test": "jest --verbose",
"integration": "S3_TEST=1 npm run test",
@@ -77,13 +76,9 @@
"version": "napi version"
},
"dependencies": {
"apache-arrow": "^15.0.0",
"axios": "^1.7.2",
"openai": "^4.29.2",
"reflect-metadata": "^0.2.2"
},
"optionalDependencies": {
"openai": "^4.29.2"
},
"peerDependencies": {
"apache-arrow": "^15.0.0"
}
}

View File

@@ -89,7 +89,7 @@ impl Connection {
}
/// List all tables in the dataset.
#[napi(catch_unwind)]
#[napi]
pub async fn table_names(
&self,
start_after: Option<String>,
@@ -113,7 +113,7 @@ impl Connection {
/// - name: The name of the table.
/// - buf: The buffer containing the IPC file.
///
#[napi(catch_unwind)]
#[napi]
pub async fn create_table(
&self,
name: String,
@@ -141,7 +141,7 @@ impl Connection {
Ok(Table::new(tbl))
}
#[napi(catch_unwind)]
#[napi]
pub async fn create_empty_table(
&self,
name: String,
@@ -173,7 +173,7 @@ impl Connection {
Ok(Table::new(tbl))
}
#[napi(catch_unwind)]
#[napi]
pub async fn open_table(
&self,
name: String,
@@ -197,7 +197,7 @@ impl Connection {
}
/// Drop table with the name. Or raise an error if the table does not exist.
#[napi(catch_unwind)]
#[napi]
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
self.get_inner()?
.drop_table(&name)

View File

@@ -30,7 +30,7 @@ impl RecordBatchIterator {
Self { inner }
}
#[napi(catch_unwind)]
#[napi]
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
if let Some(rst) = self.inner.next().await {
let batch = rst.map_err(|e| {

View File

@@ -31,7 +31,7 @@ impl NativeMergeInsertBuilder {
this
}
#[napi(catch_unwind)]
#[napi]
pub async fn execute(&self, buf: Buffer) -> napi::Result<()> {
let data = ipc_file_to_batches(buf.to_vec())
.and_then(IntoArrow::into_arrow)

View File

@@ -62,7 +62,7 @@ impl Query {
Ok(VectorQuery { inner })
}
#[napi(catch_unwind)]
#[napi]
pub async fn execute(
&self,
max_batch_length: Option<u32>,
@@ -136,7 +136,7 @@ impl VectorQuery {
self.inner = self.inner.clone().limit(limit as usize);
}
#[napi(catch_unwind)]
#[napi]
pub async fn execute(
&self,
max_batch_length: Option<u32>,

View File

@@ -70,7 +70,7 @@ impl Table {
}
/// Return Schema as empty Arrow IPC file.
#[napi(catch_unwind)]
#[napi]
pub async fn schema(&self) -> napi::Result<Buffer> {
let schema =
self.inner_ref()?.schema().await.map_err(|e| {
@@ -86,7 +86,7 @@ impl Table {
})?))
}
#[napi(catch_unwind)]
#[napi]
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
let batches = ipc_file_to_batches(buf.to_vec())
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
@@ -108,7 +108,7 @@ impl Table {
})
}
#[napi(catch_unwind)]
#[napi]
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> {
self.inner_ref()?
.count_rows(filter)
@@ -122,7 +122,7 @@ impl Table {
})
}
#[napi(catch_unwind)]
#[napi]
pub async fn delete(&self, predicate: String) -> napi::Result<()> {
self.inner_ref()?.delete(&predicate).await.map_err(|e| {
napi::Error::from_reason(format!(
@@ -132,7 +132,7 @@ impl Table {
})
}
#[napi(catch_unwind)]
#[napi]
pub async fn create_index(
&self,
index: Option<&Index>,
@@ -151,7 +151,7 @@ impl Table {
builder.execute().await.default_error()
}
#[napi(catch_unwind)]
#[napi]
pub async fn update(
&self,
only_if: Option<String>,
@@ -167,17 +167,17 @@ impl Table {
op.execute().await.default_error()
}
#[napi(catch_unwind)]
#[napi]
pub fn query(&self) -> napi::Result<Query> {
Ok(Query::new(self.inner_ref()?.query()))
}
#[napi(catch_unwind)]
#[napi]
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
self.query()?.nearest_to(vector)
}
#[napi(catch_unwind)]
#[napi]
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
let transforms = transforms
.into_iter()
@@ -196,7 +196,7 @@ impl Table {
Ok(())
}
#[napi(catch_unwind)]
#[napi]
pub async fn alter_columns(&self, alterations: Vec<ColumnAlteration>) -> napi::Result<()> {
for alteration in &alterations {
if alteration.rename.is_none() && alteration.nullable.is_none() {
@@ -222,7 +222,7 @@ impl Table {
Ok(())
}
#[napi(catch_unwind)]
#[napi]
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> {
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
self.inner_ref()?
@@ -237,7 +237,7 @@ impl Table {
Ok(())
}
#[napi(catch_unwind)]
#[napi]
pub async fn version(&self) -> napi::Result<i64> {
self.inner_ref()?
.version()
@@ -246,7 +246,7 @@ impl Table {
.default_error()
}
#[napi(catch_unwind)]
#[napi]
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
self.inner_ref()?
.checkout(version as u64)
@@ -254,17 +254,17 @@ impl Table {
.default_error()
}
#[napi(catch_unwind)]
#[napi]
pub async fn checkout_latest(&self) -> napi::Result<()> {
self.inner_ref()?.checkout_latest().await.default_error()
}
#[napi(catch_unwind)]
#[napi]
pub async fn restore(&self) -> napi::Result<()> {
self.inner_ref()?.restore().await.default_error()
}
#[napi(catch_unwind)]
#[napi]
pub async fn optimize(&self, older_than_ms: Option<i64>) -> napi::Result<OptimizeStats> {
let inner = self.inner_ref()?;
@@ -318,7 +318,7 @@ impl Table {
})
}
#[napi(catch_unwind)]
#[napi]
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
Ok(self
.inner_ref()?
@@ -330,14 +330,14 @@ impl Table {
.collect::<Vec<_>>())
}
#[napi(catch_unwind)]
#[napi]
pub async fn index_stats(&self, index_name: String) -> napi::Result<Option<IndexStatistics>> {
let tbl = self.inner_ref()?.as_native().unwrap();
let stats = tbl.index_stats(&index_name).await.default_error()?;
Ok(stats.map(IndexStatistics::from))
}
#[napi(catch_unwind)]
#[napi]
pub fn merge_insert(&self, on: Vec<String>) -> napi::Result<NativeMergeInsertBuilder> {
let on: Vec<_> = on.iter().map(String::as_str).collect();
Ok(self.inner_ref()?.merge_insert(on.as_slice()).into())

View File

@@ -28,11 +28,12 @@ from lancedb.common import data_to_reader, validate_schema
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data, _table_path
from .table import AsyncTable, LanceTable, Table, _sanitize_data
from .util import (
fs_from_uri,
get_uri_location,
get_uri_scheme,
join_uri,
validate_table_name,
)
@@ -456,18 +457,16 @@ class LanceDBConnection(DBConnection):
If True, ignore if the table does not exist.
"""
try:
table_uri = _table_path(self.uri, name)
filesystem, path = fs_from_uri(table_uri)
filesystem.delete_dir(path)
filesystem, path = fs_from_uri(self.uri)
table_path = join_uri(path, name + ".lance")
filesystem.delete_dir(table_path)
except FileNotFoundError:
if not ignore_missing:
raise
@override
def drop_database(self):
dummy_table_uri = _table_path(self.uri, "dummy")
uri = dummy_table_uri.removesuffix("dummy.lance")
filesystem, path = fs_from_uri(uri)
filesystem, path = fs_from_uri(self.uri)
filesystem.delete_dir(path)

View File

@@ -4,7 +4,6 @@ from .colbert import ColbertReranker
from .cross_encoder import CrossEncoderReranker
from .linear_combination import LinearCombinationReranker
from .openai import OpenaiReranker
from .jina import JinaReranker
__all__ = [
"Reranker",
@@ -13,5 +12,4 @@ __all__ = [
"LinearCombinationReranker",
"OpenaiReranker",
"ColbertReranker",
"JinaReranker",
]

View File

@@ -1,103 +0,0 @@
from functools import cached_property
from typing import Union
import pyarrow as pa
from ..util import attempt_import_or_raise
from .base import Reranker
class JinaReranker(Reranker):
"""
Reranks the results using Jina reranker model.
Parameters
----------
model_name : str, default "jinaai/jina-reranker-v1-turbo-en"
The name of the reranker to use. For all models, see
https://huggingface.co/jinaai/jina-reranker-v1-turbo-en
column : str, default "text"
The name of the column to use as input to the cross encoder model.
device : str, default None
The device to use for the cross encoder model. If None, will use "cuda"
if available, otherwise "cpu".
"""
def __init__(
self,
model_name: str = "jinaai/jina-reranker-v1-turbo-en",
column: str = "text",
device: Union[str, None] = None,
return_score="relevance",
):
super().__init__(return_score)
torch = attempt_import_or_raise("torch")
self.model_name = model_name
self.column = column
self.device = device
if self.device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@cached_property
def model(self):
transformers = attempt_import_or_raise("transformers")
model = transformers.AutoModelForSequenceClassification.from_pretrained(
self.model_name, num_labels=1, trust_remote_code=True
)
return model
def _rerank(self, result_set: pa.Table, query: str):
passages = result_set[self.column].to_pylist()
cross_inp = [[query, passage] for passage in passages]
cross_scores = self.model.compute_score(cross_inp)
result_set = result_set.append_column(
"_relevance_score", pa.array(cross_scores, type=pa.float32())
)
return result_set
def rerank_hybrid(
self,
query: str,
vector_results: pa.Table,
fts_results: pa.Table,
):
combined_results = self.merge_results(vector_results, fts_results)
combined_results = self._rerank(combined_results, query)
# sort the results by _score
if self.score == "relevance":
combined_results = combined_results.drop_columns(["score", "_distance"])
elif self.score == "all":
raise NotImplementedError(
"return_score='all' not implemented for CrossEncoderReranker"
)
combined_results = combined_results.sort_by(
[("_relevance_score", "descending")]
)
return combined_results
def rerank_vector(
self,
query: str,
vector_results: pa.Table,
):
vector_results = self._rerank(vector_results, query)
if self.score == "relevance":
vector_results = vector_results.drop_columns(["_distance"])
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
return vector_results
def rerank_fts(
self,
query: str,
fts_results: pa.Table,
):
fts_results = self._rerank(fts_results, query)
if self.score == "relevance":
fts_results = fts_results.drop_columns(["score"])
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
return fts_results

View File

@@ -30,7 +30,6 @@ from typing import (
Tuple,
Union,
)
from urllib.parse import urlparse
import lance
import numpy as np
@@ -48,7 +47,6 @@ from .pydantic import LanceModel, model_to_dict
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
from .util import (
fs_from_uri,
get_uri_scheme,
inf_vector_column_query,
join_uri,
safe_import_pandas,
@@ -210,26 +208,6 @@ def _to_record_batch_generator(
yield b
def _table_path(base: str, table_name: str) -> str:
"""
Get a table path that can be used in PyArrow FS.
Removes any weird schemes (such as "s3+ddb") and drops any query params.
"""
uri = _table_uri(base, table_name)
# Parse as URL
parsed = urlparse(uri)
# If scheme is s3+ddb, convert to s3
if parsed.scheme == "s3+ddb":
parsed = parsed._replace(scheme="s3")
# Remove query parameters
return parsed._replace(query=None).geturl()
def _table_uri(base: str, table_name: str) -> str:
return join_uri(base, f"{table_name}.lance")
class Table(ABC):
"""
A Table is a collection of Records in a LanceDB Database.
@@ -930,7 +908,7 @@ class LanceTable(Table):
@classmethod
def open(cls, db, name, **kwargs):
tbl = cls(db, name, **kwargs)
fs, path = fs_from_uri(tbl._dataset_path)
fs, path = fs_from_uri(tbl._dataset_uri)
file_info = fs.get_file_info(path)
if file_info.type != pa.fs.FileType.Directory:
raise FileNotFoundError(
@@ -940,14 +918,9 @@ class LanceTable(Table):
return tbl
@cached_property
def _dataset_path(self) -> str:
# Cacheable since it's deterministic
return _table_path(self._conn.uri, self.name)
@cached_property
@property
def _dataset_uri(self) -> str:
return _table_uri(self._conn.uri, self.name)
return join_uri(self._conn.uri, f"{self.name}.lance")
@property
def _dataset(self) -> LanceDataset:
@@ -1257,10 +1230,6 @@ class LanceTable(Table):
)
def _get_fts_index_path(self):
if get_uri_scheme(self._dataset_uri) != "file":
raise NotImplementedError(
"Full-text search is not supported on object stores."
)
return join_uri(self._dataset_uri, "_indices", "tantivy")
def add(

View File

@@ -139,11 +139,8 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
# using pathlib for local paths make this windows compatible
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
return str(pathlib.Path(base, *parts))
else:
# there might be query parameters in the base URI
url = urlparse(base)
new_path = "/".join([p.rstrip("/") for p in [url.path, *parts]])
return url._replace(path=new_path).geturl()
# for remote paths, just use os.path.join
return "/".join([p.rstrip("/") for p in [base, *parts]])
def attempt_import_or_raise(module: str, mitigation=None):

View File

@@ -1,3 +1,5 @@
import os
import lancedb
import numpy as np
import pytest
@@ -9,7 +11,6 @@ from lancedb.rerankers import (
ColbertReranker,
CrossEncoderReranker,
OpenaiReranker,
JinaReranker,
)
from lancedb.table import LanceTable
@@ -118,18 +119,136 @@ def test_linear_combination(tmp_path):
)
@pytest.mark.slow
@pytest.mark.parametrize(
"reranker",
[
ColbertReranker(),
OpenaiReranker(),
CohereReranker(),
CrossEncoderReranker(),
JinaReranker(),
],
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
)
def test_colbert_reranker(tmp_path, reranker):
def test_cohere_reranker(tmp_path):
pytest.importorskip("cohere")
reranker = CohereReranker()
table, schema = get_test_table(tmp_path)
# Hybrid search setting
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=CohereReranker())
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
query = "Our father who art in heaven"
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_cross_encoder_reranker(tmp_path):
pytest.importorskip("sentence_transformers")
reranker = CrossEncoderReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=reranker)
.to_pydantic(schema)
)
assert result1 == result2
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query), query_type="hybrid")
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
def test_colbert_reranker(tmp_path):
pytest.importorskip("transformers")
reranker = ColbertReranker()
table, schema = get_test_table(tmp_path)
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
@@ -186,3 +305,67 @@ def test_colbert_reranker(tmp_path, reranker):
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
@pytest.mark.skipif(
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
)
def test_openai_reranker(tmp_path):
pytest.importorskip("openai")
table, schema = get_test_table(tmp_path)
reranker = OpenaiReranker()
result1 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(normalize="score", reranker=reranker)
.to_pydantic(schema)
)
result2 = (
table.search("Our father who art in heaven", query_type="hybrid")
.rerank(reranker=OpenaiReranker())
.to_pydantic(schema)
)
assert result1 == result2
# test explicit hybrid query
query = "Our father who art in heaven"
query_vector = table.to_pandas()["vector"][0]
result = (
table.search((query_vector, query))
.limit(30)
.rerank(reranker=reranker)
.to_arrow()
)
assert len(result) == 30
err = (
"The _relevance_score column of the results returned by the reranker "
"represents the relevance of the result to the query & should "
"be descending."
)
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
# Vector search setting
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
assert len(result) == 30
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
result_explicit = (
table.search(query_vector)
.rerank(reranker=reranker, query_string=query)
.limit(30)
.to_arrow()
)
assert len(result_explicit) == 30
with pytest.raises(
ValueError
): # This raises an error because vector query is provided without reanking query
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
# FTS search setting
result = (
table.search(query, query_type="fts")
.rerank(reranker=reranker)
.limit(30)
.to_arrow()
)
assert len(result) > 0
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err

View File

@@ -13,8 +13,6 @@
import asyncio
import copy
from datetime import timedelta
import threading
import pytest
import pyarrow as pa
@@ -27,7 +25,6 @@ CONFIG = {
"aws_access_key_id": "ACCESSKEY",
"aws_secret_access_key": "SECRETKEY",
"aws_endpoint": "http://localhost:4566",
"dynamodb_endpoint": "http://localhost:4566",
"aws_region": "us-east-1",
}
@@ -159,104 +156,3 @@ def test_s3_sse(s3_bucket: str, kms_key: str):
validate_objects_encrypted(s3_bucket, path, kms_key)
asyncio.run(test())
@pytest.fixture(scope="module")
def commit_table():
ddb = get_boto3_client("dynamodb", endpoint_url=CONFIG["dynamodb_endpoint"])
table_name = "lance-integtest"
try:
ddb.delete_table(TableName=table_name)
except ddb.exceptions.ResourceNotFoundException:
pass
ddb.create_table(
TableName=table_name,
KeySchema=[
{"AttributeName": "base_uri", "KeyType": "HASH"},
{"AttributeName": "version", "KeyType": "RANGE"},
],
AttributeDefinitions=[
{"AttributeName": "base_uri", "AttributeType": "S"},
{"AttributeName": "version", "AttributeType": "N"},
],
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
)
yield table_name
ddb.delete_table(TableName=table_name)
@pytest.mark.s3_test
def test_s3_dynamodb(s3_bucket: str, commit_table: str):
storage_options = copy.copy(CONFIG)
uri = f"s3+ddb://{s3_bucket}/test?ddbTableName={commit_table}"
data = pa.table({"x": [1, 2, 3]})
async def test():
db = await lancedb.connect_async(
uri,
storage_options=storage_options,
read_consistency_interval=timedelta(0),
)
table = await db.create_table("test", data)
# Five concurrent writers
async def insert():
# independent table refs for true concurrent writes.
table = await db.open_table("test")
await table.add(data, mode="append")
tasks = [insert() for _ in range(5)]
await asyncio.gather(*tasks)
row_count = await table.count_rows()
assert row_count == 3 * 6
asyncio.run(test())
@pytest.mark.s3_test
def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
# Sync API doesn't support storage_options, so we have to provide as env vars
for key, value in CONFIG.items():
monkeypatch.setenv(key.upper(), value)
uri = f"s3+ddb://{s3_bucket}/test2?ddbTableName={commit_table}"
data = pa.table({"x": ["a", "b", "c"]})
db = lancedb.connect(
uri,
read_consistency_interval=timedelta(0),
)
table = db.create_table("test_ddb_sync", data)
# Five concurrent writers
def insert():
table = db.open_table("test_ddb_sync")
table.add(data, mode="append")
threads = []
for _ in range(5):
thread = threading.Thread(target=insert)
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
row_count = table.count_rows()
assert row_count == 3 * 6
# FTS indices should error since they are not supported yet.
with pytest.raises(
NotImplementedError, match="Full-text search is not supported on object stores."
):
table.create_fts_index("x")
# make sure list tables still works
assert db.table_names() == ["test_ddb_sync"]
db.drop_table("test_ddb_sync")
assert db.table_names() == []
db.drop_database()

View File

@@ -55,11 +55,10 @@ walkdir = "2"
# For s3 integration tests (dev deps aren't allowed to be optional atm)
# We pin these because the content-length check breaks with localstack
# https://github.com/smithy-lang/smithy-rs/releases/tag/release-2024-05-21
aws-sdk-dynamodb = { version = "=1.23.0" }
aws-sdk-s3 = { version = "=1.23.0" }
aws-sdk-kms = { version = "=1.21.0" }
aws-config = { version = "1.0" }
aws-smithy-runtime = { version = "=1.3.1" }
aws-smithy-runtime = { version = "=1.3.0" }
[features]
default = []

View File

@@ -25,9 +25,7 @@ const CONFIG: &[(&str, &str)] = &[
("access_key_id", "ACCESS_KEY"),
("secret_access_key", "SECRET_KEY"),
("endpoint", "http://127.0.0.1:4566"),
("dynamodb_endpoint", "http://127.0.0.1:4566"),
("allow_http", "true"),
("region", "us-east-1"),
];
async fn aws_config() -> SdkConfig {
@@ -290,126 +288,3 @@ async fn test_encryption() -> Result<()> {
Ok(())
}
struct DynamoDBCommitTable(String);
impl DynamoDBCommitTable {
async fn new(name: &str) -> Self {
let config = aws_config().await;
let client = aws_sdk_dynamodb::Client::new(&config);
// In case it wasn't deleted earlier
Self::delete_table(client.clone(), name).await;
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
use aws_sdk_dynamodb::types::*;
client
.create_table()
.table_name(name)
.attribute_definitions(
AttributeDefinition::builder()
.attribute_name("base_uri")
.attribute_type(ScalarAttributeType::S)
.build()
.unwrap(),
)
.attribute_definitions(
AttributeDefinition::builder()
.attribute_name("version")
.attribute_type(ScalarAttributeType::N)
.build()
.unwrap(),
)
.key_schema(
KeySchemaElement::builder()
.attribute_name("base_uri")
.key_type(KeyType::Hash)
.build()
.unwrap(),
)
.key_schema(
KeySchemaElement::builder()
.attribute_name("version")
.key_type(KeyType::Range)
.build()
.unwrap(),
)
.provisioned_throughput(
ProvisionedThroughput::builder()
.read_capacity_units(1)
.write_capacity_units(1)
.build()
.unwrap(),
)
.send()
.await
.unwrap();
Self(name.to_string())
}
async fn delete_table(client: aws_sdk_dynamodb::Client, name: &str) {
match client
.delete_table()
.table_name(name)
.send()
.await
.map_err(|err| err.into_service_error())
{
Ok(_) => {}
Err(e) if e.is_resource_not_found_exception() => {}
Err(e) => panic!("Failed to delete table: {}", e),
};
}
}
impl Drop for DynamoDBCommitTable {
fn drop(&mut self) {
let table_name = self.0.clone();
tokio::task::spawn(async move {
let config = aws_config().await;
let client = aws_sdk_dynamodb::Client::new(&config);
Self::delete_table(client, &table_name).await;
});
}
}
#[tokio::test]
async fn test_concurrent_dynamodb_commit() {
// test concurrent commit on dynamodb
let bucket = S3Bucket::new("test-dynamodb").await;
let table = DynamoDBCommitTable::new("test_table").await;
let uri = format!("s3+ddb://{}?ddbTableName={}", bucket.0, table.0);
let db = lancedb::connect(&uri)
.storage_options(CONFIG.iter().cloned())
.execute()
.await
.unwrap();
let data = test_data();
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
let table = db.create_table("test_table", data).execute().await.unwrap();
let data = test_data();
let mut tasks = vec![];
for _ in 0..5 {
let table = db.open_table("test_table").execute().await.unwrap();
let data = data.clone();
tasks.push(tokio::spawn(async move {
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
table.add(data).execute().await.unwrap();
}));
}
for task in tasks {
task.await.unwrap();
}
table.checkout_latest().await.unwrap();
let row_count = table.count_rows(None).await.unwrap();
assert_eq!(row_count, 18);
}