mirror of
https://github.com/lancedb/lancedb.git
synced 2025-12-25 14:29:56 +00:00
Compare commits
15 Commits
ayush/jina
...
python-v0.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a3ef68306 | ||
|
|
43952e01d7 | ||
|
|
495c335831 | ||
|
|
77707db543 | ||
|
|
d6d7ad3b06 | ||
|
|
e58d64c286 | ||
|
|
76cbd18c46 | ||
|
|
4abb38ac70 | ||
|
|
cc7bc5011d | ||
|
|
8193183304 | ||
|
|
cf28b58b7d | ||
|
|
e3b7ee47b9 | ||
|
|
97c9c906e4 | ||
|
|
358f86b9c6 | ||
|
|
5489e215a3 |
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.6.0"
|
||||
current_version = "0.5.2"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
4
.github/workflows/docs_test.yml
vendored
4
.github/workflows/docs_test.yml
vendored
@@ -24,7 +24,7 @@ env:
|
||||
jobs:
|
||||
test-python:
|
||||
name: Test doc python code
|
||||
runs-on: "warp-ubuntu-latest-x64-4x"
|
||||
runs-on: "buildjet-8vcpu-ubuntu-2204"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -56,7 +56,7 @@ jobs:
|
||||
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
|
||||
test-node:
|
||||
name: Test doc nodejs code
|
||||
runs-on: "warp-ubuntu-latest-x64-4x"
|
||||
runs-on: "buildjet-8vcpu-ubuntu-2204"
|
||||
timeout-minutes: 60
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
@@ -14,7 +14,7 @@ repos:
|
||||
hooks:
|
||||
- id: local-biome-check
|
||||
name: biome check
|
||||
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
|
||||
entry: npx @biomejs/biome@1.7.3 check --config-path nodejs/biome.json nodejs/
|
||||
language: system
|
||||
types: [text]
|
||||
files: "nodejs/.*"
|
||||
|
||||
@@ -57,8 +57,6 @@ plugins:
|
||||
- https://arrow.apache.org/docs/objects.inv
|
||||
- https://pandas.pydata.org/docs/objects.inv
|
||||
- mkdocs-jupyter
|
||||
- render_swagger:
|
||||
allow_arbitrary_locations : true
|
||||
|
||||
markdown_extensions:
|
||||
- admonition
|
||||
@@ -160,7 +158,6 @@ nav:
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- REST API: cloud/rest.md
|
||||
|
||||
- Quick start: basic.md
|
||||
- Concepts:
|
||||
@@ -231,7 +228,6 @@ nav:
|
||||
- API reference:
|
||||
- 🐍 Python: python/saas-python.md
|
||||
- 👾 JavaScript: javascript/modules.md
|
||||
- REST API: cloud/rest.md
|
||||
|
||||
extra_css:
|
||||
- styles/global.css
|
||||
|
||||
479
docs/openapi.yml
479
docs/openapi.yml
@@ -1,479 +0,0 @@
|
||||
openapi: 3.1.0
|
||||
info:
|
||||
version: 1.0.0
|
||||
title: LanceDB Cloud API
|
||||
description: |
|
||||
LanceDB Cloud API is a RESTful API that allows users to access and modify data stored in LanceDB Cloud.
|
||||
Table actions are considered temporary resource creations and all use POST method.
|
||||
contact:
|
||||
name: LanceDB support
|
||||
url: https://lancedb.com
|
||||
email: contact@lancedb.com
|
||||
|
||||
servers:
|
||||
- url: https://{db}.{region}.api.lancedb.com
|
||||
description: LanceDB Cloud REST endpoint.
|
||||
variables:
|
||||
db:
|
||||
default: ""
|
||||
description: the name of DB
|
||||
region:
|
||||
default: "us-east-1"
|
||||
description: the service region of the DB
|
||||
|
||||
security:
|
||||
- key_auth: []
|
||||
|
||||
components:
|
||||
securitySchemes:
|
||||
key_auth:
|
||||
name: x-api-key
|
||||
type: apiKey
|
||||
in: header
|
||||
parameters:
|
||||
table_name:
|
||||
name: name
|
||||
in: path
|
||||
description: name of the table
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
invalid_request:
|
||||
description: Invalid request
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
not_found:
|
||||
description: Not found
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
unauthorized:
|
||||
description: Unauthorized
|
||||
content:
|
||||
text/plain:
|
||||
schema:
|
||||
type: string
|
||||
requestBodies:
|
||||
arrow_stream_buffer:
|
||||
description: Arrow IPC stream buffer
|
||||
required: true
|
||||
content:
|
||||
application/vnd.apache.arrow.stream:
|
||||
schema:
|
||||
type: string
|
||||
format: binary
|
||||
|
||||
paths:
|
||||
/v1/table/:
|
||||
get:
|
||||
description: List tables, optionally, with pagination.
|
||||
tags:
|
||||
- Tables
|
||||
summary: List Tables
|
||||
operationId: listTables
|
||||
parameters:
|
||||
- name: limit
|
||||
in: query
|
||||
description: Limits the number of items to return.
|
||||
schema:
|
||||
type: integer
|
||||
- name: page_token
|
||||
in: query
|
||||
description: Specifies the starting position of the next query
|
||||
schema:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: Successfully returned a list of tables in the DB
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
tables:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
page_token:
|
||||
type: string
|
||||
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
|
||||
/v1/table/{name}/create/:
|
||||
post:
|
||||
description: Create a new table
|
||||
summary: Create a new table
|
||||
operationId: createTable
|
||||
tags:
|
||||
- Tables
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||
responses:
|
||||
"200":
|
||||
description: Table successfully created
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
|
||||
/v1/table/{name}/query/:
|
||||
post:
|
||||
description: Vector Query
|
||||
url: https://{db-uri}.{aws-region}.api.lancedb.com/v1/table/{name}/query/
|
||||
tags:
|
||||
- Data
|
||||
summary: Vector Query
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
vector:
|
||||
type: FixedSizeList
|
||||
description: |
|
||||
The targetted vector to search for. Required.
|
||||
vector_column:
|
||||
type: string
|
||||
description: |
|
||||
The column to query, it can be inferred from the schema if there is only one vector column.
|
||||
prefilter:
|
||||
type: boolean
|
||||
description: |
|
||||
Whether to prefilter the data. Optional.
|
||||
k:
|
||||
type: integer
|
||||
description: |
|
||||
The number of search results to return. Default is 10.
|
||||
distance_type:
|
||||
type: string
|
||||
description: |
|
||||
The distance metric to use for search. L2, Cosine, Dot and Hamming are supported. Default is L2.
|
||||
bypass_vector_index:
|
||||
type: boolean
|
||||
description: |
|
||||
Whether to bypass vector index. Optional.
|
||||
filter:
|
||||
type: string
|
||||
description: |
|
||||
A filter expression that specifies the rows to query. Optional.
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: |
|
||||
The columns to return. Optional.
|
||||
nprobe:
|
||||
type: integer
|
||||
description: |
|
||||
The number of probes to use for search. Optional.
|
||||
refine_factor:
|
||||
type: integer
|
||||
description: |
|
||||
The refine factor to use for search. Optional.
|
||||
|
||||
responses:
|
||||
"200":
|
||||
description: top k results if query is successfully executed
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
results:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
type: integer
|
||||
selected_col_1_to_return:
|
||||
type: col_1_type
|
||||
selected_col_n_to_return:
|
||||
type: col_n_type
|
||||
_distance:
|
||||
type: float
|
||||
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
|
||||
/v1/table/{name}/insert/:
|
||||
post:
|
||||
description: Insert new data to the Table.
|
||||
tags:
|
||||
- Data
|
||||
operationId: insertData
|
||||
summary: Insert new data.
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||
responses:
|
||||
"200":
|
||||
description: Insert successful
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
/v1/table/{name}/merge_insert/:
|
||||
post:
|
||||
description: Create a "merge insert" operation
|
||||
This operation can add rows, update rows, and remove rows all in a single
|
||||
transaction. See python method `lancedb.table.Table.merge_insert` for examples.
|
||||
tags:
|
||||
- Data
|
||||
summary: Merge Insert
|
||||
operationId: mergeInsert
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
- name: on
|
||||
in: query
|
||||
description: |
|
||||
The column to use as the primary key for the merge operation.
|
||||
required: true
|
||||
schema:
|
||||
type: string
|
||||
- name: when_matched_update_all
|
||||
in: query
|
||||
description: |
|
||||
Rows that exist in both the source table (new data) and
|
||||
the target table (old data) will be updated, replacing
|
||||
the old row with the corresponding matching row.
|
||||
required: false
|
||||
schema:
|
||||
type: boolean
|
||||
- name: when_matched_update_all_filt
|
||||
in: query
|
||||
description: |
|
||||
If present then only rows that satisfy the filter expression will
|
||||
be updated
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
- name: when_not_matched_insert_all
|
||||
in: query
|
||||
description: |
|
||||
Rows that exist only in the source table (new data) will be
|
||||
inserted into the target table (old data).
|
||||
required: false
|
||||
schema:
|
||||
type: boolean
|
||||
- name: when_not_matched_by_source_delete
|
||||
in: query
|
||||
description: |
|
||||
Rows that exist only in the target table (old data) will be
|
||||
deleted. An optional condition (`when_not_matched_by_source_delete_filt`)
|
||||
can be provided to limit what data is deleted.
|
||||
required: false
|
||||
schema:
|
||||
type: boolean
|
||||
- name: when_not_matched_by_source_delete_filt
|
||||
in: query
|
||||
description: |
|
||||
The filter expression that specifies the rows to delete.
|
||||
required: false
|
||||
schema:
|
||||
type: string
|
||||
requestBody:
|
||||
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||
responses:
|
||||
"200":
|
||||
description: Merge Insert successful
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
/v1/table/{name}/delete/:
|
||||
post:
|
||||
description: Delete rows from a table.
|
||||
tags:
|
||||
- Data
|
||||
summary: Delete rows from a table
|
||||
operationId: deleteData
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
predicate:
|
||||
type: string
|
||||
description: |
|
||||
A filter expression that specifies the rows to delete.
|
||||
responses:
|
||||
"200":
|
||||
description: Delete successful
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
/v1/table/{name}/drop/:
|
||||
post:
|
||||
description: Drop a table
|
||||
tags:
|
||||
- Tables
|
||||
summary: Drop a table
|
||||
operationId: dropTable
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||
responses:
|
||||
"200":
|
||||
description: Drop successful
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
|
||||
/v1/table/{name}/describe/:
|
||||
post:
|
||||
description: Describe a table and return Table Information.
|
||||
tags:
|
||||
- Tables
|
||||
summary: Describe a table
|
||||
operationId: describeTable
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
responses:
|
||||
"200":
|
||||
description: Table information
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
table:
|
||||
type: string
|
||||
version:
|
||||
type: integer
|
||||
schema:
|
||||
type: string
|
||||
stats:
|
||||
type: object
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
|
||||
/v1/table/{name}/index/list/:
|
||||
post:
|
||||
description: List indexes of a table
|
||||
tags:
|
||||
- Tables
|
||||
summary: List indexes of a table
|
||||
operationId: listIndexes
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
responses:
|
||||
"200":
|
||||
description: Available list of indexes on the table.
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
indexes:
|
||||
type: array
|
||||
items:
|
||||
type: object
|
||||
properties:
|
||||
columns:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
index_name:
|
||||
type: string
|
||||
index_uuid:
|
||||
type: string
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
/v1/table/{name}/create_index/:
|
||||
post:
|
||||
description: Create vector index on a Table
|
||||
tags:
|
||||
- Tables
|
||||
summary: Create vector index on a Table
|
||||
operationId: createIndex
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
metric_type:
|
||||
type: string
|
||||
nullable: false
|
||||
description: |
|
||||
The metric type to use for the index. L2, Cosine, Dot are supported.
|
||||
index_type:
|
||||
type: string
|
||||
responses:
|
||||
"200":
|
||||
description: Index successfully created
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
/v1/table/{name}/create_scalar_index/:
|
||||
post:
|
||||
description: Create a scalar index on a table
|
||||
tags:
|
||||
- Tables
|
||||
summary: Create a scalar index on a table
|
||||
operationId: createScalarIndex
|
||||
parameters:
|
||||
- $ref: "#/components/parameters/table_name"
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
column:
|
||||
type: string
|
||||
index_type:
|
||||
type: string
|
||||
required: false
|
||||
responses:
|
||||
"200":
|
||||
description: Scalar Index successfully created
|
||||
"400":
|
||||
$ref: "#/components/responses/invalid_request"
|
||||
"401":
|
||||
$ref: "#/components/responses/unauthorized"
|
||||
"404":
|
||||
$ref: "#/components/responses/not_found"
|
||||
@@ -2,5 +2,4 @@ mkdocs==1.5.3
|
||||
mkdocs-jupyter==0.24.1
|
||||
mkdocs-material==9.5.3
|
||||
mkdocstrings[python]==0.20.0
|
||||
mkdocs-render-swagger-plugin
|
||||
pydantic
|
||||
pydantic
|
||||
@@ -1 +0,0 @@
|
||||
!!swagger ../../openapi.yml!!
|
||||
@@ -193,13 +193,13 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
model = get_registry().get("huggingface").create(name='facebook/bart-base')
|
||||
|
||||
class Words(LanceModel):
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
vector: Vector(model.ndims()) = model.VectorField()
|
||||
|
||||
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
|
||||
table = db.create_table("greets", schema=Words)
|
||||
table.add(df)
|
||||
table.add()
|
||||
query = "old greeting"
|
||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||
print(actual.text)
|
||||
|
||||
@@ -265,108 +265,6 @@ For **read-only access**, LanceDB will need a policy such as:
|
||||
}
|
||||
```
|
||||
|
||||
#### DynamoDB Commit Store for concurrent writes
|
||||
|
||||
By default, S3 does not support concurrent writes. Having two or more processes
|
||||
writing to the same table at the same time can lead to data corruption. This is
|
||||
because S3, unlike other object stores, does not have any atomic put or copy
|
||||
operation.
|
||||
|
||||
To enable concurrent writes, you can configure LanceDB to use a DynamoDB table
|
||||
as a commit store. This table will be used to coordinate writes between
|
||||
different processes. To enable this feature, you must modify your connection
|
||||
URI to use the `s3+ddb` scheme and add a query parameter `ddbTableName` with the
|
||||
name of the table to use.
|
||||
|
||||
=== "Python"
|
||||
|
||||
```python
|
||||
import lancedb
|
||||
db = await lancedb.connect_async(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
)
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
```javascript
|
||||
const lancedb = require("lancedb");
|
||||
|
||||
const db = await lancedb.connect(
|
||||
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||
);
|
||||
```
|
||||
|
||||
The DynamoDB table must be created with the following schema:
|
||||
|
||||
- Hash key: `base_uri` (string)
|
||||
- Range key: `version` (number)
|
||||
|
||||
You can create this programmatically with:
|
||||
|
||||
=== "Python"
|
||||
|
||||
<!-- skip-test -->
|
||||
```python
|
||||
import boto3
|
||||
|
||||
dynamodb = boto3.client("dynamodb")
|
||||
table = dynamodb.create_table(
|
||||
TableName=table_name,
|
||||
KeySchema=[
|
||||
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
||||
{"AttributeName": "version", "KeyType": "RANGE"},
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{"AttributeName": "base_uri", "AttributeType": "S"},
|
||||
{"AttributeName": "version", "AttributeType": "N"},
|
||||
],
|
||||
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
||||
)
|
||||
```
|
||||
|
||||
=== "JavaScript"
|
||||
|
||||
<!-- skip-test -->
|
||||
```javascript
|
||||
import {
|
||||
CreateTableCommand,
|
||||
DynamoDBClient,
|
||||
} from "@aws-sdk/client-dynamodb";
|
||||
|
||||
const dynamodb = new DynamoDBClient({
|
||||
region: CONFIG.awsRegion,
|
||||
credentials: {
|
||||
accessKeyId: CONFIG.awsAccessKeyId,
|
||||
secretAccessKey: CONFIG.awsSecretAccessKey,
|
||||
},
|
||||
endpoint: CONFIG.awsEndpoint,
|
||||
});
|
||||
const command = new CreateTableCommand({
|
||||
TableName: table_name,
|
||||
AttributeDefinitions: [
|
||||
{
|
||||
AttributeName: "base_uri",
|
||||
AttributeType: "S",
|
||||
},
|
||||
{
|
||||
AttributeName: "version",
|
||||
AttributeType: "N",
|
||||
},
|
||||
],
|
||||
KeySchema: [
|
||||
{ AttributeName: "base_uri", KeyType: "HASH" },
|
||||
{ AttributeName: "version", KeyType: "RANGE" },
|
||||
],
|
||||
ProvisionedThroughput: {
|
||||
ReadCapacityUnits: 1,
|
||||
WriteCapacityUnits: 1,
|
||||
},
|
||||
});
|
||||
await client.send(command);
|
||||
```
|
||||
|
||||
|
||||
#### S3-compatible stores
|
||||
|
||||
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify both region and endpoint:
|
||||
|
||||
@@ -116,21 +116,21 @@ This guide will show how to create tables, insert data into them, and update the
|
||||
|
||||
### From a Polars DataFrame
|
||||
|
||||
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
||||
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
||||
under the hood. A deeper integration between LanceDB Tables and Polars DataFrames
|
||||
is on the way.
|
||||
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
||||
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
||||
under the hood. A deeper integration between LanceDB Tables and Polars DataFrames
|
||||
is on the way.
|
||||
|
||||
```python
|
||||
import polars as pl
|
||||
```python
|
||||
import polars as pl
|
||||
|
||||
data = pl.DataFrame({
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0]
|
||||
})
|
||||
table = db.create_table("pl_table", data=data)
|
||||
```
|
||||
data = pl.DataFrame({
|
||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||
"item": ["foo", "bar"],
|
||||
"price": [10.0, 20.0]
|
||||
})
|
||||
table = db.create_table("pl_table", data=data)
|
||||
```
|
||||
|
||||
### From an Arrow Table
|
||||
=== "Python"
|
||||
|
||||
File diff suppressed because one or more lines are too long
4
node/package-lock.json
generated
4
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "vectordb",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"cpu": [
|
||||
"x64",
|
||||
"arm64"
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "vectordb",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"description": " Serverless, low-latency vector database for AI applications",
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"scripts": {
|
||||
"tsc": "tsc -b",
|
||||
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build -p lancedb-node --message-format=json",
|
||||
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build --message-format=json",
|
||||
"build-release": "npm run build -- --release",
|
||||
"test": "npm run tsc && mocha -recursive dist/test",
|
||||
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
||||
|
||||
@@ -15,11 +15,11 @@ crate-type = ["cdylib"]
|
||||
arrow-ipc.workspace = true
|
||||
futures.workspace = true
|
||||
lancedb = { path = "../rust/lancedb" }
|
||||
napi = { version = "2.16.8", default-features = false, features = [
|
||||
"napi9",
|
||||
napi = { version = "2.15", default-features = false, features = [
|
||||
"napi7",
|
||||
"async",
|
||||
] }
|
||||
napi-derive = "2.16.4"
|
||||
napi-derive = "2"
|
||||
|
||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||
lzma-sys = { version = "*", features = ["static"] }
|
||||
|
||||
@@ -63,7 +63,6 @@ describe("Registry", () => {
|
||||
return data.map(() => [1, 2, 3]);
|
||||
}
|
||||
}
|
||||
|
||||
const func = getRegistry()
|
||||
.get<MockEmbeddingFunction>("mock-embedding")!
|
||||
.create();
|
||||
|
||||
@@ -14,11 +14,6 @@
|
||||
|
||||
/* eslint-disable @typescript-eslint/naming-convention */
|
||||
|
||||
import {
|
||||
CreateTableCommand,
|
||||
DeleteTableCommand,
|
||||
DynamoDBClient,
|
||||
} from "@aws-sdk/client-dynamodb";
|
||||
import {
|
||||
CreateKeyCommand,
|
||||
KMSClient,
|
||||
@@ -43,7 +38,6 @@ const CONFIG = {
|
||||
awsAccessKeyId: "ACCESSKEY",
|
||||
awsSecretAccessKey: "SECRETKEY",
|
||||
awsEndpoint: "http://127.0.0.1:4566",
|
||||
dynamodbEndpoint: "http://127.0.0.1:4566",
|
||||
awsRegion: "us-east-1",
|
||||
};
|
||||
|
||||
@@ -72,6 +66,7 @@ class S3Bucket {
|
||||
} catch {
|
||||
// It's fine if the bucket doesn't exist
|
||||
}
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
await client.send(new CreateBucketCommand({ Bucket: name }));
|
||||
return new S3Bucket(name);
|
||||
}
|
||||
@@ -84,27 +79,32 @@ class S3Bucket {
|
||||
static async deleteBucket(client: S3Client, name: string) {
|
||||
// Must delete all objects before we can delete the bucket
|
||||
const objects = await client.send(
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
new ListObjectsV2Command({ Bucket: name }),
|
||||
);
|
||||
if (objects.Contents) {
|
||||
for (const object of objects.Contents) {
|
||||
await client.send(
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
new DeleteObjectCommand({ Bucket: name, Key: object.Key }),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
await client.send(new DeleteBucketCommand({ Bucket: name }));
|
||||
}
|
||||
|
||||
public async assertAllEncrypted(path: string, keyId: string) {
|
||||
const client = S3Bucket.s3Client();
|
||||
const objects = await client.send(
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
new ListObjectsV2Command({ Bucket: this.name, Prefix: path }),
|
||||
);
|
||||
if (objects.Contents) {
|
||||
for (const object of objects.Contents) {
|
||||
const metadata = await client.send(
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
new HeadObjectCommand({ Bucket: this.name, Key: object.Key }),
|
||||
);
|
||||
expect(metadata.ServerSideEncryption).toBe("aws:kms");
|
||||
@@ -143,6 +143,7 @@ class KmsKey {
|
||||
|
||||
public async delete() {
|
||||
const client = KmsKey.kmsClient();
|
||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
||||
await client.send(new ScheduleKeyDeletionCommand({ KeyId: this.keyId }));
|
||||
}
|
||||
}
|
||||
@@ -223,91 +224,3 @@ maybeDescribe("storage_options", () => {
|
||||
await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId);
|
||||
});
|
||||
});
|
||||
|
||||
class DynamoDBCommitTable {
|
||||
name: string;
|
||||
constructor(name: string) {
|
||||
this.name = name;
|
||||
}
|
||||
|
||||
static dynamoClient() {
|
||||
return new DynamoDBClient({
|
||||
region: CONFIG.awsRegion,
|
||||
credentials: {
|
||||
accessKeyId: CONFIG.awsAccessKeyId,
|
||||
secretAccessKey: CONFIG.awsSecretAccessKey,
|
||||
},
|
||||
endpoint: CONFIG.awsEndpoint,
|
||||
});
|
||||
}
|
||||
|
||||
public static async create(name: string): Promise<DynamoDBCommitTable> {
|
||||
const client = DynamoDBCommitTable.dynamoClient();
|
||||
const command = new CreateTableCommand({
|
||||
TableName: name,
|
||||
AttributeDefinitions: [
|
||||
{
|
||||
AttributeName: "base_uri",
|
||||
AttributeType: "S",
|
||||
},
|
||||
{
|
||||
AttributeName: "version",
|
||||
AttributeType: "N",
|
||||
},
|
||||
],
|
||||
KeySchema: [
|
||||
{ AttributeName: "base_uri", KeyType: "HASH" },
|
||||
{ AttributeName: "version", KeyType: "RANGE" },
|
||||
],
|
||||
ProvisionedThroughput: {
|
||||
ReadCapacityUnits: 1,
|
||||
WriteCapacityUnits: 1,
|
||||
},
|
||||
});
|
||||
await client.send(command);
|
||||
return new DynamoDBCommitTable(name);
|
||||
}
|
||||
|
||||
public async delete() {
|
||||
const client = DynamoDBCommitTable.dynamoClient();
|
||||
await client.send(new DeleteTableCommand({ TableName: this.name }));
|
||||
}
|
||||
}
|
||||
|
||||
maybeDescribe("DynamoDB Lock", () => {
|
||||
let bucket: S3Bucket;
|
||||
let commitTable: DynamoDBCommitTable;
|
||||
|
||||
beforeAll(async () => {
|
||||
bucket = await S3Bucket.create("lancedb2");
|
||||
commitTable = await DynamoDBCommitTable.create("commitTable");
|
||||
});
|
||||
|
||||
afterAll(async () => {
|
||||
await commitTable.delete();
|
||||
await bucket.delete();
|
||||
});
|
||||
|
||||
it("can be used to configure a DynamoDB table for commit log", async () => {
|
||||
const uri = `s3+ddb://${bucket.name}/test?ddbTableName=${commitTable.name}`;
|
||||
const db = await connect(uri, {
|
||||
storageOptions: CONFIG,
|
||||
readConsistencyInterval: 0,
|
||||
});
|
||||
|
||||
const table = await db.createTable("test", [{ a: 1, b: 2 }]);
|
||||
|
||||
// 5 concurrent appends
|
||||
const futs = Array.from({ length: 5 }, async () => {
|
||||
// Open a table so each append has a separate table reference. Otherwise
|
||||
// they will share the same table reference and the internal ReadWriteLock
|
||||
// will prevent any real concurrency.
|
||||
const table = await db.openTable("test");
|
||||
await table.add([{ a: 2, b: 3 }]);
|
||||
});
|
||||
await Promise.all(futs);
|
||||
|
||||
const rowCount = await table.countRows();
|
||||
expect(rowCount).toBe(6);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -39,9 +39,7 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
|
||||
let tmpDir: tmp.DirResult;
|
||||
let table: Table;
|
||||
|
||||
const schema:
|
||||
| import("apache-arrow").Schema
|
||||
| import("apache-arrow-old").Schema = new arrow.Schema([
|
||||
const schema = new arrow.Schema([
|
||||
new arrow.Field("id", new arrow.Float64(), true),
|
||||
]);
|
||||
|
||||
@@ -317,7 +315,7 @@ describe("When creating an index", () => {
|
||||
.query()
|
||||
.limit(2)
|
||||
.nearestTo(queryVec)
|
||||
.distanceType("dot")
|
||||
.distanceType("DoT")
|
||||
.toArrow();
|
||||
expect(rst.numRows).toBe(2);
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"$schema": "https://biomejs.dev/schemas/1.8.3/schema.json",
|
||||
"$schema": "https://biomejs.dev/schemas/1.7.3/schema.json",
|
||||
"organizeImports": {
|
||||
"enabled": true
|
||||
},
|
||||
@@ -100,16 +100,6 @@
|
||||
"globals": []
|
||||
},
|
||||
"overrides": [
|
||||
{
|
||||
"include": ["__test__/s3_integration.test.ts"],
|
||||
"linter": {
|
||||
"rules": {
|
||||
"style": {
|
||||
"useNamingConvention": "off"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"include": [
|
||||
"**/*.ts",
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import {
|
||||
Table as ArrowTable,
|
||||
Binary,
|
||||
BufferType,
|
||||
DataType,
|
||||
Field,
|
||||
FixedSizeBinary,
|
||||
@@ -38,68 +37,14 @@ import {
|
||||
type makeTable,
|
||||
vectorFromArray,
|
||||
} from "apache-arrow";
|
||||
import { Buffers } from "apache-arrow/data";
|
||||
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
||||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||
import {
|
||||
sanitizeField,
|
||||
sanitizeSchema,
|
||||
sanitizeTable,
|
||||
sanitizeType,
|
||||
} from "./sanitize";
|
||||
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
|
||||
export * from "apache-arrow";
|
||||
export type SchemaLike =
|
||||
| Schema
|
||||
| {
|
||||
fields: FieldLike[];
|
||||
metadata: Map<string, string>;
|
||||
get names(): unknown[];
|
||||
};
|
||||
export type FieldLike =
|
||||
| Field
|
||||
| {
|
||||
type: string;
|
||||
name: string;
|
||||
nullable?: boolean;
|
||||
metadata?: Map<string, string>;
|
||||
};
|
||||
|
||||
export type DataLike =
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
| import("apache-arrow").Data<Struct<any>>
|
||||
| {
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
type: any;
|
||||
length: number;
|
||||
offset: number;
|
||||
stride: number;
|
||||
nullable: boolean;
|
||||
children: DataLike[];
|
||||
get nullCount(): number;
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
values: Buffers<any>[BufferType.DATA];
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
typeIds: Buffers<any>[BufferType.TYPE];
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
nullBitmap: Buffers<any>[BufferType.VALIDITY];
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
valueOffsets: Buffers<any>[BufferType.OFFSET];
|
||||
};
|
||||
|
||||
export type RecordBatchLike =
|
||||
| RecordBatch
|
||||
| {
|
||||
schema: SchemaLike;
|
||||
data: DataLike;
|
||||
};
|
||||
|
||||
export type TableLike =
|
||||
| ArrowTable
|
||||
| { schema: SchemaLike; batches: RecordBatchLike[] };
|
||||
|
||||
export type IntoVector = Float32Array | Float64Array | number[];
|
||||
|
||||
export function isArrowTable(value: object): value is TableLike {
|
||||
export function isArrowTable(value: object): value is ArrowTable {
|
||||
if (value instanceof ArrowTable) return true;
|
||||
return "schema" in value && "batches" in value;
|
||||
}
|
||||
@@ -190,7 +135,7 @@ export function isFixedSizeList(value: unknown): value is FixedSizeList {
|
||||
}
|
||||
|
||||
/** Data type accepted by NodeJS SDK */
|
||||
export type Data = Record<string, unknown>[] | TableLike;
|
||||
export type Data = Record<string, unknown>[] | ArrowTable;
|
||||
|
||||
/*
|
||||
* Options to control how a column should be converted to a vector array
|
||||
@@ -217,7 +162,7 @@ export class MakeArrowTableOptions {
|
||||
* The schema must be specified if there are no records (e.g. to make
|
||||
* an empty table)
|
||||
*/
|
||||
schema?: SchemaLike;
|
||||
schema?: Schema;
|
||||
|
||||
/*
|
||||
* Mapping from vector column name to expected type
|
||||
@@ -365,7 +310,7 @@ export function makeArrowTable(
|
||||
if (opt.schema !== undefined && opt.schema !== null) {
|
||||
opt.schema = sanitizeSchema(opt.schema);
|
||||
opt.schema = validateSchemaEmbeddings(
|
||||
opt.schema as Schema,
|
||||
opt.schema,
|
||||
data,
|
||||
options?.embeddingFunction,
|
||||
);
|
||||
@@ -449,7 +394,7 @@ export function makeArrowTable(
|
||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||
const firstTable = new ArrowTable(columns);
|
||||
const batchesFixed = firstTable.batches.map(
|
||||
(batch) => new RecordBatch(opt.schema as Schema, batch.data),
|
||||
(batch) => new RecordBatch(opt.schema!, batch.data),
|
||||
);
|
||||
let schema: Schema;
|
||||
if (metadata !== undefined) {
|
||||
@@ -462,9 +407,9 @@ export function makeArrowTable(
|
||||
}
|
||||
}
|
||||
|
||||
schema = new Schema(opt.schema.fields as Field[], schemaMetadata);
|
||||
schema = new Schema(opt.schema.fields, schemaMetadata);
|
||||
} else {
|
||||
schema = opt.schema as Schema;
|
||||
schema = opt.schema;
|
||||
}
|
||||
return new ArrowTable(schema, batchesFixed);
|
||||
}
|
||||
@@ -480,7 +425,7 @@ export function makeArrowTable(
|
||||
* Create an empty Arrow table with the provided schema
|
||||
*/
|
||||
export function makeEmptyTable(
|
||||
schema: SchemaLike,
|
||||
schema: Schema,
|
||||
metadata?: Map<string, string>,
|
||||
): ArrowTable {
|
||||
return makeArrowTable([], { schema }, metadata);
|
||||
@@ -618,17 +563,18 @@ async function applyEmbeddingsFromMetadata(
|
||||
async function applyEmbeddings<T>(
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunctionConfig,
|
||||
schema?: SchemaLike,
|
||||
schema?: Schema,
|
||||
): Promise<ArrowTable> {
|
||||
if (schema !== undefined && schema !== null) {
|
||||
schema = sanitizeSchema(schema);
|
||||
}
|
||||
if (schema?.metadata.has("embedding_functions")) {
|
||||
return applyEmbeddingsFromMetadata(table, schema! as Schema);
|
||||
return applyEmbeddingsFromMetadata(table, schema!);
|
||||
} else if (embeddings == null || embeddings === undefined) {
|
||||
return table;
|
||||
}
|
||||
|
||||
if (schema !== undefined && schema !== null) {
|
||||
schema = sanitizeSchema(schema);
|
||||
}
|
||||
|
||||
// Convert from ArrowTable to Record<String, Vector>
|
||||
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
||||
const name = table.schema.fields[idx].name;
|
||||
@@ -704,7 +650,7 @@ async function applyEmbeddings<T>(
|
||||
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
|
||||
);
|
||||
}
|
||||
return alignTable(newTable, schema as Schema);
|
||||
return alignTable(newTable, schema);
|
||||
}
|
||||
return newTable;
|
||||
}
|
||||
@@ -798,7 +744,7 @@ export async function fromRecordsToStreamBuffer(
|
||||
export async function fromTableToBuffer(
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunctionConfig,
|
||||
schema?: SchemaLike,
|
||||
schema?: Schema,
|
||||
): Promise<Buffer> {
|
||||
if (schema !== undefined && schema !== null) {
|
||||
schema = sanitizeSchema(schema);
|
||||
@@ -825,7 +771,7 @@ export async function fromDataToBuffer(
|
||||
schema = sanitizeSchema(schema);
|
||||
}
|
||||
if (isArrowTable(data)) {
|
||||
return fromTableToBuffer(sanitizeTable(data), embeddings, schema);
|
||||
return fromTableToBuffer(data, embeddings, schema);
|
||||
} else {
|
||||
const table = await convertToTable(data, embeddings, { schema });
|
||||
return fromTableToBuffer(table);
|
||||
@@ -843,7 +789,7 @@ export async function fromDataToBuffer(
|
||||
export async function fromTableToStreamBuffer(
|
||||
table: ArrowTable,
|
||||
embeddings?: EmbeddingFunctionConfig,
|
||||
schema?: SchemaLike,
|
||||
schema?: Schema,
|
||||
): Promise<Buffer> {
|
||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
|
||||
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings);
|
||||
@@ -908,6 +854,7 @@ function validateSchemaEmbeddings(
|
||||
for (let field of schema.fields) {
|
||||
if (isFixedSizeList(field.type)) {
|
||||
field = sanitizeField(field);
|
||||
|
||||
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
|
||||
if (schema.metadata.has("embedding_functions")) {
|
||||
const embeddings = JSON.parse(
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import { Data, Schema, SchemaLike, TableLike } from "./arrow";
|
||||
import { Table as ArrowTable, Data, Schema } from "./arrow";
|
||||
import { fromTableToBuffer, makeEmptyTable } from "./arrow";
|
||||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||
import { Connection as LanceDbConnection } from "./native";
|
||||
@@ -50,7 +50,7 @@ export interface CreateTableOptions {
|
||||
* The default is true while the new format is in beta
|
||||
*/
|
||||
useLegacyFormat?: boolean;
|
||||
schema?: SchemaLike;
|
||||
schema?: Schema;
|
||||
embeddingFunction?: EmbeddingFunctionConfig;
|
||||
}
|
||||
|
||||
@@ -167,12 +167,12 @@ export abstract class Connection {
|
||||
/**
|
||||
* Creates a new Table and initialize it with new data.
|
||||
* @param {string} name - The name of the table.
|
||||
* @param {Record<string, unknown>[] | TableLike} data - Non-empty Array of Records
|
||||
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
|
||||
* to be inserted into the table
|
||||
*/
|
||||
abstract createTable(
|
||||
name: string,
|
||||
data: Record<string, unknown>[] | TableLike,
|
||||
data: Record<string, unknown>[] | ArrowTable,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table>;
|
||||
|
||||
@@ -183,7 +183,7 @@ export abstract class Connection {
|
||||
*/
|
||||
abstract createEmptyTable(
|
||||
name: string,
|
||||
schema: import("./arrow").SchemaLike,
|
||||
schema: Schema,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table>;
|
||||
|
||||
@@ -235,7 +235,7 @@ export class LocalConnection extends Connection {
|
||||
nameOrOptions:
|
||||
| string
|
||||
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
|
||||
data?: Record<string, unknown>[] | TableLike,
|
||||
data?: Record<string, unknown>[] | ArrowTable,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table> {
|
||||
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
|
||||
@@ -259,7 +259,7 @@ export class LocalConnection extends Connection {
|
||||
|
||||
async createEmptyTable(
|
||||
name: string,
|
||||
schema: import("./arrow").SchemaLike,
|
||||
schema: Schema,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
): Promise<Table> {
|
||||
let mode: string = options?.mode ?? "create";
|
||||
|
||||
@@ -35,11 +35,6 @@ export interface FunctionOptions {
|
||||
[key: string]: any;
|
||||
}
|
||||
|
||||
export interface EmbeddingFunctionConstructor<
|
||||
T extends EmbeddingFunction = EmbeddingFunction,
|
||||
> {
|
||||
new (modelOptions?: T["TOptions"]): T;
|
||||
}
|
||||
/**
|
||||
* An embedding function that automatically creates vector representation for a given column.
|
||||
*/
|
||||
@@ -48,12 +43,6 @@ export abstract class EmbeddingFunction<
|
||||
T = any,
|
||||
M extends FunctionOptions = FunctionOptions,
|
||||
> {
|
||||
/**
|
||||
* @ignore
|
||||
* This is only used for associating the options type with the class for type checking
|
||||
*/
|
||||
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
|
||||
readonly TOptions!: M;
|
||||
/**
|
||||
* Convert the embedding function to a JSON object
|
||||
* It is used to serialize the embedding function to the schema
|
||||
|
||||
@@ -13,29 +13,24 @@
|
||||
// limitations under the License.
|
||||
|
||||
import type OpenAI from "openai";
|
||||
import { type EmbeddingCreateParams } from "openai/resources";
|
||||
import { Float, Float32 } from "../arrow";
|
||||
import { EmbeddingFunction } from "./embedding_function";
|
||||
import { register } from "./registry";
|
||||
|
||||
export type OpenAIOptions = {
|
||||
apiKey: string;
|
||||
model: EmbeddingCreateParams["model"];
|
||||
apiKey?: string;
|
||||
model?: string;
|
||||
};
|
||||
|
||||
@register("openai")
|
||||
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||
string,
|
||||
Partial<OpenAIOptions>
|
||||
OpenAIOptions
|
||||
> {
|
||||
#openai: OpenAI;
|
||||
#modelName: OpenAIOptions["model"];
|
||||
#modelName: string;
|
||||
|
||||
constructor(
|
||||
options: Partial<OpenAIOptions> = {
|
||||
model: "text-embedding-ada-002",
|
||||
},
|
||||
) {
|
||||
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
|
||||
super();
|
||||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
||||
if (!openAIKey) {
|
||||
@@ -78,7 +73,7 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||
case "text-embedding-3-small":
|
||||
return 1536;
|
||||
default:
|
||||
throw new Error(`Unknown model: ${this.#modelName}`);
|
||||
return null as never;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,15 +12,21 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import {
|
||||
type EmbeddingFunction,
|
||||
type EmbeddingFunctionConstructor,
|
||||
} from "./embedding_function";
|
||||
import type { EmbeddingFunction } from "./embedding_function";
|
||||
import "reflect-metadata";
|
||||
import { OpenAIEmbeddingFunction } from "./openai";
|
||||
|
||||
export interface EmbeddingFunctionOptions {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface EmbeddingFunctionFactory<
|
||||
T extends EmbeddingFunction = EmbeddingFunction,
|
||||
> {
|
||||
new (modelOptions?: EmbeddingFunctionOptions): T;
|
||||
}
|
||||
|
||||
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||
create(options?: T["TOptions"]): T;
|
||||
create(options?: EmbeddingFunctionOptions): T;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -30,7 +36,7 @@ interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||
* or TextEmbeddingFunction and registering it with the registry
|
||||
*/
|
||||
export class EmbeddingFunctionRegistry {
|
||||
#functions = new Map<string, EmbeddingFunctionConstructor>();
|
||||
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
|
||||
|
||||
/**
|
||||
* Register an embedding function
|
||||
@@ -38,9 +44,7 @@ export class EmbeddingFunctionRegistry {
|
||||
* @param func The function to register
|
||||
* @throws Error if the function is already registered
|
||||
*/
|
||||
register<
|
||||
T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor,
|
||||
>(
|
||||
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
|
||||
this: EmbeddingFunctionRegistry,
|
||||
alias?: string,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
@@ -65,34 +69,18 @@ export class EmbeddingFunctionRegistry {
|
||||
* Fetch an embedding function by name
|
||||
* @param name The name of the function
|
||||
*/
|
||||
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
|
||||
name: Name extends "openai" ? "openai" : string,
|
||||
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
|
||||
// ex:
|
||||
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
|
||||
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
|
||||
//
|
||||
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
|
||||
// ```ts
|
||||
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
|
||||
// ```
|
||||
): Name extends "openai"
|
||||
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||
: EmbeddingFunctionCreate<T> | undefined {
|
||||
type Output = Name extends "openai"
|
||||
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||
: EmbeddingFunctionCreate<T> | undefined;
|
||||
|
||||
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
|
||||
name: string,
|
||||
): EmbeddingFunctionCreate<T> | undefined {
|
||||
const factory = this.#functions.get(name);
|
||||
if (!factory) {
|
||||
return undefined as Output;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
create: function (options?: T["TOptions"]) {
|
||||
return new factory(options);
|
||||
create: function (options: EmbeddingFunctionOptions) {
|
||||
return new factory(options) as unknown as T;
|
||||
},
|
||||
} as Output;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -116,7 +104,7 @@ export class EmbeddingFunctionRegistry {
|
||||
name: string;
|
||||
sourceColumn: string;
|
||||
vectorColumn: string;
|
||||
model: EmbeddingFunction["TOptions"];
|
||||
model: EmbeddingFunctionOptions;
|
||||
};
|
||||
const functions = <FunctionConfig[]>(
|
||||
JSON.parse(metadata.get("embedding_functions")!)
|
||||
|
||||
@@ -300,9 +300,7 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
||||
*
|
||||
* By default "l2" is used.
|
||||
*/
|
||||
distanceType(
|
||||
distanceType: Required<IvfPqOptions>["distanceType"],
|
||||
): VectorQuery {
|
||||
distanceType(distanceType: string): VectorQuery {
|
||||
this.inner.distanceType(distanceType);
|
||||
return this;
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ export class RestfulLanceDBClient {
|
||||
return axios.create({
|
||||
baseURL: this.url,
|
||||
headers: {
|
||||
// biome-ignore lint: external API
|
||||
// biome-ignore lint/style/useNamingConvention: external api
|
||||
Authorization: `Bearer ${this.#apiKey}`,
|
||||
},
|
||||
transformResponse: decodeErrorData,
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
import { Schema } from "apache-arrow";
|
||||
import {
|
||||
Data,
|
||||
SchemaLike,
|
||||
fromTableToStreamBuffer,
|
||||
makeEmptyTable,
|
||||
} from "../arrow";
|
||||
import { Data, fromTableToStreamBuffer, makeEmptyTable } from "../arrow";
|
||||
import {
|
||||
Connection,
|
||||
CreateTableOptions,
|
||||
@@ -161,7 +156,7 @@ export class RemoteConnection extends Connection {
|
||||
|
||||
async createEmptyTable(
|
||||
name: string,
|
||||
schema: SchemaLike,
|
||||
schema: Schema,
|
||||
options?: Partial<CreateTableOptions> | undefined,
|
||||
): Promise<Table> {
|
||||
if (options?.mode) {
|
||||
|
||||
@@ -20,12 +20,10 @@
|
||||
// comes from the exact same library instance. This is not always the case
|
||||
// and so we must sanitize the input to ensure that it is compatible.
|
||||
|
||||
import { BufferType, Data } from "apache-arrow";
|
||||
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
|
||||
import {
|
||||
Binary,
|
||||
Bool,
|
||||
DataLike,
|
||||
DataType,
|
||||
DateDay,
|
||||
DateMillisecond,
|
||||
@@ -58,14 +56,9 @@ import {
|
||||
Map_,
|
||||
Null,
|
||||
type Precision,
|
||||
RecordBatch,
|
||||
RecordBatchLike,
|
||||
Schema,
|
||||
SchemaLike,
|
||||
SparseUnion,
|
||||
Struct,
|
||||
Table,
|
||||
TableLike,
|
||||
Time,
|
||||
TimeMicrosecond,
|
||||
TimeMillisecond,
|
||||
@@ -495,7 +488,7 @@ export function sanitizeField(fieldLike: unknown): Field {
|
||||
* instance because they might be using a different instance of apache-arrow
|
||||
* than lancedb is using.
|
||||
*/
|
||||
export function sanitizeSchema(schemaLike: SchemaLike): Schema {
|
||||
export function sanitizeSchema(schemaLike: unknown): Schema {
|
||||
if (schemaLike instanceof Schema) {
|
||||
return schemaLike;
|
||||
}
|
||||
@@ -521,68 +514,3 @@ export function sanitizeSchema(schemaLike: SchemaLike): Schema {
|
||||
);
|
||||
return new Schema(sanitizedFields, metadata);
|
||||
}
|
||||
|
||||
export function sanitizeTable(tableLike: TableLike): Table {
|
||||
if (tableLike instanceof Table) {
|
||||
return tableLike;
|
||||
}
|
||||
if (typeof tableLike !== "object" || tableLike === null) {
|
||||
throw Error("Expected a Table but object was null/undefined");
|
||||
}
|
||||
if (!("schema" in tableLike)) {
|
||||
throw Error(
|
||||
"The table passed in does not appear to be a table (no 'schema' property)",
|
||||
);
|
||||
}
|
||||
if (!("batches" in tableLike)) {
|
||||
throw Error(
|
||||
"The table passed in does not appear to be a table (no 'columns' property)",
|
||||
);
|
||||
}
|
||||
const schema = sanitizeSchema(tableLike.schema);
|
||||
|
||||
const batches = tableLike.batches.map(sanitizeRecordBatch);
|
||||
return new Table(schema, batches);
|
||||
}
|
||||
|
||||
function sanitizeRecordBatch(batchLike: RecordBatchLike): RecordBatch {
|
||||
if (batchLike instanceof RecordBatch) {
|
||||
return batchLike;
|
||||
}
|
||||
if (typeof batchLike !== "object" || batchLike === null) {
|
||||
throw Error("Expected a RecordBatch but object was null/undefined");
|
||||
}
|
||||
if (!("schema" in batchLike)) {
|
||||
throw Error(
|
||||
"The record batch passed in does not appear to be a record batch (no 'schema' property)",
|
||||
);
|
||||
}
|
||||
if (!("data" in batchLike)) {
|
||||
throw Error(
|
||||
"The record batch passed in does not appear to be a record batch (no 'data' property)",
|
||||
);
|
||||
}
|
||||
const schema = sanitizeSchema(batchLike.schema);
|
||||
const data = sanitizeData(batchLike.data);
|
||||
return new RecordBatch(schema, data);
|
||||
}
|
||||
function sanitizeData(
|
||||
dataLike: DataLike,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
): import("apache-arrow").Data<Struct<any>> {
|
||||
if (dataLike instanceof Data) {
|
||||
return dataLike;
|
||||
}
|
||||
return new Data(
|
||||
dataLike.type,
|
||||
dataLike.offset,
|
||||
dataLike.length,
|
||||
dataLike.nullCount,
|
||||
{
|
||||
[BufferType.OFFSET]: dataLike.valueOffsets,
|
||||
[BufferType.DATA]: dataLike.values,
|
||||
[BufferType.VALIDITY]: dataLike.nullBitmap,
|
||||
[BufferType.TYPE]: dataLike.typeIds,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import {
|
||||
Data,
|
||||
IntoVector,
|
||||
Schema,
|
||||
TableLike,
|
||||
fromDataToBuffer,
|
||||
fromTableToBuffer,
|
||||
fromTableToStreamBuffer,
|
||||
@@ -39,8 +38,6 @@ import {
|
||||
Table as _NativeTable,
|
||||
} from "./native";
|
||||
import { Query, VectorQuery } from "./query";
|
||||
import { sanitizeTable } from "./sanitize";
|
||||
export { IndexConfig } from "./native";
|
||||
|
||||
/**
|
||||
* Options for adding data to a table.
|
||||
@@ -384,7 +381,8 @@ export abstract class Table {
|
||||
abstract indexStats(name: string): Promise<IndexStatistics | undefined>;
|
||||
|
||||
static async parseTableData(
|
||||
data: Record<string, unknown>[] | TableLike,
|
||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||
data: Record<string, unknown>[] | ArrowTable<any>,
|
||||
options?: Partial<CreateTableOptions>,
|
||||
streaming = false,
|
||||
) {
|
||||
@@ -397,9 +395,9 @@ export abstract class Table {
|
||||
|
||||
let table: ArrowTable;
|
||||
if (isArrowTable(data)) {
|
||||
table = sanitizeTable(data);
|
||||
table = data;
|
||||
} else {
|
||||
table = makeArrowTable(data as Record<string, unknown>[], options);
|
||||
table = makeArrowTable(data, options);
|
||||
}
|
||||
if (streaming) {
|
||||
const buf = await fromTableToStreamBuffer(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-arm64",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.darwin-arm64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-darwin-x64",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"os": ["darwin"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.darwin-x64.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["arm64"],
|
||||
"main": "lancedb.linux-arm64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"os": ["linux"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.linux-x64-gnu.node",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"os": ["win32"],
|
||||
"cpu": ["x64"],
|
||||
"main": "lancedb.win32-x64-msvc.node",
|
||||
|
||||
1403
nodejs/package-lock.json
generated
1403
nodejs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@
|
||||
"vector database",
|
||||
"ann"
|
||||
],
|
||||
"version": "0.6.0",
|
||||
"version": "0.5.2",
|
||||
"main": "dist/index.js",
|
||||
"exports": {
|
||||
".": "./dist/index.js",
|
||||
@@ -34,10 +34,9 @@
|
||||
"devDependencies": {
|
||||
"@aws-sdk/client-kms": "^3.33.0",
|
||||
"@aws-sdk/client-s3": "^3.33.0",
|
||||
"@aws-sdk/client-dynamodb": "^3.33.0",
|
||||
"@biomejs/biome": "^1.7.3",
|
||||
"@jest/globals": "^29.7.0",
|
||||
"@napi-rs/cli": "^2.18.3",
|
||||
"@napi-rs/cli": "^2.18.0",
|
||||
"@types/jest": "^29.1.2",
|
||||
"@types/tmp": "^0.2.6",
|
||||
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
||||
@@ -69,7 +68,7 @@
|
||||
"lint-ci": "biome ci .",
|
||||
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
||||
"lint": "biome check . && biome format .",
|
||||
"lint-fix": "biome check --write . && biome format --write .",
|
||||
"lint-fix": "biome check --apply-unsafe . && biome format --write .",
|
||||
"prepublishOnly": "napi prepublish -t npm",
|
||||
"test": "jest --verbose",
|
||||
"integration": "S3_TEST=1 npm run test",
|
||||
@@ -77,13 +76,9 @@
|
||||
"version": "napi version"
|
||||
},
|
||||
"dependencies": {
|
||||
"apache-arrow": "^15.0.0",
|
||||
"axios": "^1.7.2",
|
||||
"openai": "^4.29.2",
|
||||
"reflect-metadata": "^0.2.2"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"openai": "^4.29.2"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"apache-arrow": "^15.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,7 +89,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
/// List all tables in the dataset.
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn table_names(
|
||||
&self,
|
||||
start_after: Option<String>,
|
||||
@@ -113,7 +113,7 @@ impl Connection {
|
||||
/// - name: The name of the table.
|
||||
/// - buf: The buffer containing the IPC file.
|
||||
///
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn create_table(
|
||||
&self,
|
||||
name: String,
|
||||
@@ -141,7 +141,7 @@ impl Connection {
|
||||
Ok(Table::new(tbl))
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn create_empty_table(
|
||||
&self,
|
||||
name: String,
|
||||
@@ -173,7 +173,7 @@ impl Connection {
|
||||
Ok(Table::new(tbl))
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn open_table(
|
||||
&self,
|
||||
name: String,
|
||||
@@ -197,7 +197,7 @@ impl Connection {
|
||||
}
|
||||
|
||||
/// Drop table with the name. Or raise an error if the table does not exist.
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
||||
self.get_inner()?
|
||||
.drop_table(&name)
|
||||
|
||||
@@ -30,7 +30,7 @@ impl RecordBatchIterator {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
|
||||
if let Some(rst) = self.inner.next().await {
|
||||
let batch = rst.map_err(|e| {
|
||||
|
||||
@@ -31,7 +31,7 @@ impl NativeMergeInsertBuilder {
|
||||
this
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn execute(&self, buf: Buffer) -> napi::Result<()> {
|
||||
let data = ipc_file_to_batches(buf.to_vec())
|
||||
.and_then(IntoArrow::into_arrow)
|
||||
|
||||
@@ -62,7 +62,7 @@ impl Query {
|
||||
Ok(VectorQuery { inner })
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
max_batch_length: Option<u32>,
|
||||
@@ -136,7 +136,7 @@ impl VectorQuery {
|
||||
self.inner = self.inner.clone().limit(limit as usize);
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn execute(
|
||||
&self,
|
||||
max_batch_length: Option<u32>,
|
||||
|
||||
@@ -70,7 +70,7 @@ impl Table {
|
||||
}
|
||||
|
||||
/// Return Schema as empty Arrow IPC file.
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||
let schema =
|
||||
self.inner_ref()?.schema().await.map_err(|e| {
|
||||
@@ -86,7 +86,7 @@ impl Table {
|
||||
})?))
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
|
||||
let batches = ipc_file_to_batches(buf.to_vec())
|
||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||
@@ -108,7 +108,7 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> {
|
||||
self.inner_ref()?
|
||||
.count_rows(filter)
|
||||
@@ -122,7 +122,7 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn delete(&self, predicate: String) -> napi::Result<()> {
|
||||
self.inner_ref()?.delete(&predicate).await.map_err(|e| {
|
||||
napi::Error::from_reason(format!(
|
||||
@@ -132,7 +132,7 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn create_index(
|
||||
&self,
|
||||
index: Option<&Index>,
|
||||
@@ -151,7 +151,7 @@ impl Table {
|
||||
builder.execute().await.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn update(
|
||||
&self,
|
||||
only_if: Option<String>,
|
||||
@@ -167,17 +167,17 @@ impl Table {
|
||||
op.execute().await.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub fn query(&self) -> napi::Result<Query> {
|
||||
Ok(Query::new(self.inner_ref()?.query()))
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
|
||||
self.query()?.nearest_to(vector)
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
|
||||
let transforms = transforms
|
||||
.into_iter()
|
||||
@@ -196,7 +196,7 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn alter_columns(&self, alterations: Vec<ColumnAlteration>) -> napi::Result<()> {
|
||||
for alteration in &alterations {
|
||||
if alteration.rename.is_none() && alteration.nullable.is_none() {
|
||||
@@ -222,7 +222,7 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> {
|
||||
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
|
||||
self.inner_ref()?
|
||||
@@ -237,7 +237,7 @@ impl Table {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn version(&self) -> napi::Result<i64> {
|
||||
self.inner_ref()?
|
||||
.version()
|
||||
@@ -246,7 +246,7 @@ impl Table {
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
|
||||
self.inner_ref()?
|
||||
.checkout(version as u64)
|
||||
@@ -254,17 +254,17 @@ impl Table {
|
||||
.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn checkout_latest(&self) -> napi::Result<()> {
|
||||
self.inner_ref()?.checkout_latest().await.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn restore(&self) -> napi::Result<()> {
|
||||
self.inner_ref()?.restore().await.default_error()
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn optimize(&self, older_than_ms: Option<i64>) -> napi::Result<OptimizeStats> {
|
||||
let inner = self.inner_ref()?;
|
||||
|
||||
@@ -318,7 +318,7 @@ impl Table {
|
||||
})
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
|
||||
Ok(self
|
||||
.inner_ref()?
|
||||
@@ -330,14 +330,14 @@ impl Table {
|
||||
.collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub async fn index_stats(&self, index_name: String) -> napi::Result<Option<IndexStatistics>> {
|
||||
let tbl = self.inner_ref()?.as_native().unwrap();
|
||||
let stats = tbl.index_stats(&index_name).await.default_error()?;
|
||||
Ok(stats.map(IndexStatistics::from))
|
||||
}
|
||||
|
||||
#[napi(catch_unwind)]
|
||||
#[napi]
|
||||
pub fn merge_insert(&self, on: Vec<String>) -> napi::Result<NativeMergeInsertBuilder> {
|
||||
let on: Vec<_> = on.iter().map(String::as_str).collect();
|
||||
Ok(self.inner_ref()?.merge_insert(on.as_slice()).into())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[tool.bumpversion]
|
||||
current_version = "0.9.0"
|
||||
current_version = "0.9.0-beta.4"
|
||||
parse = """(?x)
|
||||
(?P<major>0|[1-9]\\d*)\\.
|
||||
(?P<minor>0|[1-9]\\d*)\\.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-python"
|
||||
version = "0.9.0"
|
||||
version = "0.9.0-beta.4"
|
||||
edition.workspace = true
|
||||
description = "Python bindings for LanceDB"
|
||||
license.workspace = true
|
||||
|
||||
@@ -13,6 +13,7 @@ dependencies = [
|
||||
"packaging",
|
||||
"cachetools",
|
||||
"overrides>=0.7",
|
||||
"urllib3==1.26.19"
|
||||
]
|
||||
description = "lancedb"
|
||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||
|
||||
@@ -35,6 +35,7 @@ def connect(
|
||||
host_override: Optional[str] = None,
|
||||
read_consistency_interval: Optional[timedelta] = None,
|
||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> DBConnection:
|
||||
"""Connect to a LanceDB database.
|
||||
@@ -70,6 +71,9 @@ def connect(
|
||||
executor will be used for making requests. This is for LanceDB Cloud
|
||||
only and is only used when making batch requests (i.e., passing in
|
||||
multiple queries to the search method at once).
|
||||
storage_options: dict, optional
|
||||
Additional options for the storage backend. See available options at
|
||||
https://lancedb.github.io/lancedb/guides/storage/
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -105,12 +109,16 @@ def connect(
|
||||
region,
|
||||
host_override,
|
||||
request_thread_pool=request_thread_pool,
|
||||
storage_options=storage_options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||
return LanceDBConnection(
|
||||
uri,
|
||||
read_consistency_interval=read_consistency_interval,
|
||||
)
|
||||
|
||||
|
||||
async def connect_async(
|
||||
|
||||
@@ -28,11 +28,12 @@ from lancedb.common import data_to_reader, validate_schema
|
||||
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .pydantic import LanceModel
|
||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data, _table_path
|
||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_location,
|
||||
get_uri_scheme,
|
||||
join_uri,
|
||||
validate_table_name,
|
||||
)
|
||||
|
||||
@@ -456,18 +457,16 @@ class LanceDBConnection(DBConnection):
|
||||
If True, ignore if the table does not exist.
|
||||
"""
|
||||
try:
|
||||
table_uri = _table_path(self.uri, name)
|
||||
filesystem, path = fs_from_uri(table_uri)
|
||||
filesystem.delete_dir(path)
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
table_path = join_uri(path, name + ".lance")
|
||||
filesystem.delete_dir(table_path)
|
||||
except FileNotFoundError:
|
||||
if not ignore_missing:
|
||||
raise
|
||||
|
||||
@override
|
||||
def drop_database(self):
|
||||
dummy_table_uri = _table_path(self.uri, "dummy")
|
||||
uri = dummy_table_uri.removesuffix("dummy.lance")
|
||||
filesystem, path = fs_from_uri(uri)
|
||||
filesystem, path = fs_from_uri(self.uri)
|
||||
filesystem.delete_dir(path)
|
||||
|
||||
|
||||
|
||||
@@ -117,6 +117,8 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
with_row_id: bool = False
|
||||
|
||||
fast_search: bool = False
|
||||
|
||||
|
||||
class LanceQueryBuilder(ABC):
|
||||
"""An abstract query builder. Subclasses are defined for vector search,
|
||||
@@ -125,12 +127,14 @@ class LanceQueryBuilder(ABC):
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: str = None,
|
||||
cls,
|
||||
table: "Table",
|
||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||
query_type: str,
|
||||
vector_column_name: str,
|
||||
ordering_field_name: Optional[str] = None,
|
||||
fts_columns: Union[str, List[str]] = [],
|
||||
fast_search: bool = False,
|
||||
) -> LanceQueryBuilder:
|
||||
"""
|
||||
Create a query builder based on the given query and query type.
|
||||
@@ -147,13 +151,18 @@ class LanceQueryBuilder(ABC):
|
||||
If "auto", the query type is inferred based on the query.
|
||||
vector_column_name: str
|
||||
The name of the vector column to use for vector search.
|
||||
fast_search: bool
|
||||
Skip flat search of unindexed data.
|
||||
"""
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# Check hybrid search first as it supports empty query pattern
|
||||
if query_type == "hybrid":
|
||||
# hybrid fts and vector query
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if query is None:
|
||||
return LanceEmptyQueryBuilder(table)
|
||||
|
||||
# remember the string query for reranking purpose
|
||||
str_query = query if isinstance(query, str) else None
|
||||
@@ -165,12 +174,17 @@ class LanceQueryBuilder(ABC):
|
||||
)
|
||||
|
||||
if query_type == "hybrid":
|
||||
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||
return LanceHybridQueryBuilder(
|
||||
table, query, vector_column_name, fts_columns=fts_columns
|
||||
)
|
||||
|
||||
if isinstance(query, str):
|
||||
# fts
|
||||
return LanceFtsQueryBuilder(
|
||||
table, query, ordering_field_name=ordering_field_name
|
||||
table,
|
||||
query,
|
||||
ordering_field_name=ordering_field_name,
|
||||
fts_columns=fts_columns,
|
||||
)
|
||||
|
||||
if isinstance(query, list):
|
||||
@@ -180,7 +194,9 @@ class LanceQueryBuilder(ABC):
|
||||
else:
|
||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||
|
||||
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query)
|
||||
return LanceVectorQueryBuilder(
|
||||
table, query, vector_column_name, str_query, fast_search
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||
@@ -196,8 +212,6 @@ class LanceQueryBuilder(ABC):
|
||||
elif query_type == "auto":
|
||||
if isinstance(query, (list, np.ndarray)):
|
||||
return query, "vector"
|
||||
if isinstance(query, tuple):
|
||||
return query, "hybrid"
|
||||
else:
|
||||
conf = table.embedding_functions.get(vector_column_name)
|
||||
if conf is not None:
|
||||
@@ -224,9 +238,14 @@ class LanceQueryBuilder(ABC):
|
||||
def __init__(self, table: "Table"):
|
||||
self._table = table
|
||||
self._limit = 10
|
||||
self._offset = 0
|
||||
self._columns = None
|
||||
self._where = None
|
||||
self._prefilter = False
|
||||
self._with_row_id = False
|
||||
self._vector = None
|
||||
self._text = None
|
||||
self._ef = None
|
||||
|
||||
@deprecation.deprecated(
|
||||
deprecated_in="0.3.1",
|
||||
@@ -337,11 +356,13 @@ class LanceQueryBuilder(ABC):
|
||||
----------
|
||||
limit: int
|
||||
The maximum number of results to return.
|
||||
By default the query is limited to the first 10.
|
||||
Call this method and pass 0, a negative value,
|
||||
or None to remove the limit.
|
||||
*WARNING* if you have a large dataset, removing
|
||||
the limit can potentially result in reading a
|
||||
The default query limit is 10 results.
|
||||
For ANN/KNN queries, you must specify a limit.
|
||||
Entering 0, a negative number, or None will reset
|
||||
the limit to the default value of 10.
|
||||
*WARNING* if you have a large dataset, setting
|
||||
the limit to a large number, e.g. the table size,
|
||||
can potentially result in reading a
|
||||
large amount of data into memory and cause
|
||||
out of memory issues.
|
||||
|
||||
@@ -351,11 +372,33 @@ class LanceQueryBuilder(ABC):
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
if limit is None or limit <= 0:
|
||||
self._limit = None
|
||||
if isinstance(self, LanceVectorQueryBuilder):
|
||||
raise ValueError("Limit is required for ANN/KNN queries")
|
||||
else:
|
||||
self._limit = None
|
||||
else:
|
||||
self._limit = limit
|
||||
return self
|
||||
|
||||
def offset(self, offset: int) -> LanceQueryBuilder:
|
||||
"""Set the offset for the results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
offset: int
|
||||
The offset to start fetching results from.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
if offset is None or offset <= 0:
|
||||
self._offset = 0
|
||||
else:
|
||||
self._offset = offset
|
||||
return self
|
||||
|
||||
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
|
||||
"""Set the columns to return.
|
||||
|
||||
@@ -417,6 +460,80 @@ class LanceQueryBuilder(ABC):
|
||||
self._with_row_id = with_row_id
|
||||
return self
|
||||
|
||||
def explain_plan(self, verbose: Optional[bool] = False) -> str:
|
||||
"""Return the execution plan for this query.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import lancedb
|
||||
>>> db = lancedb.connect("./.lancedb")
|
||||
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
|
||||
>>> query = [100, 100]
|
||||
>>> plan = table.search(query).explain_plan(True)
|
||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
||||
GlobalLimitExec: skip=0, fetch=10
|
||||
FilterExec: _distance@2 IS NOT NULL
|
||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
||||
KNNVectorDistance: metric=l2
|
||||
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
|
||||
|
||||
Parameters
|
||||
----------
|
||||
verbose : bool, default False
|
||||
Use a verbose output format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
plan : str
|
||||
""" # noqa: E501
|
||||
ds = self._table.to_lance()
|
||||
return ds.scanner(
|
||||
nearest={
|
||||
"column": self._vector_column,
|
||||
"q": self._query,
|
||||
"k": self._limit,
|
||||
"metric": self._metric,
|
||||
"nprobes": self._nprobes,
|
||||
"refine_factor": self._refine_factor,
|
||||
},
|
||||
prefilter=self._prefilter,
|
||||
filter=self._str_query,
|
||||
limit=self._limit,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
).explain_plan(verbose)
|
||||
|
||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
||||
"""Set the vector to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector: np.ndarray or list
|
||||
The vector to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def text(self, text: str) -> LanceQueryBuilder:
|
||||
"""Set the text to search for.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
text: str
|
||||
The text to search for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
@@ -440,11 +557,12 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
str_query: Optional[str] = None,
|
||||
self,
|
||||
table: "Table",
|
||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||
vector_column: str,
|
||||
str_query: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
):
|
||||
super().__init__(table)
|
||||
self._query = query
|
||||
@@ -455,13 +573,14 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._prefilter = False
|
||||
self._reranker = None
|
||||
self._str_query = str_query
|
||||
self._fast_search = fast_search
|
||||
|
||||
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
|
||||
"""Set the distance metric to use.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
metric: "L2" or "cosine"
|
||||
metric: "L2" or "cosine" or "dot"
|
||||
The distance metric to use. By default "L2" is used.
|
||||
|
||||
Returns
|
||||
@@ -469,7 +588,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._metric = metric
|
||||
self._metric = metric.lower()
|
||||
return self
|
||||
|
||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||
@@ -494,6 +613,28 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
self._nprobes = nprobes
|
||||
return self
|
||||
|
||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the number of candidates to consider during search.
|
||||
|
||||
Higher values will yield better recall (more likely to find vectors if
|
||||
they exist) at the expense of latency.
|
||||
|
||||
This only applies to the HNSW-related index.
|
||||
The default value is 1.5 * limit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ef: int
|
||||
The number of candidates to consider during search.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceVectorQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
self._ef = ef
|
||||
return self
|
||||
|
||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||
|
||||
@@ -554,15 +695,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
refine_factor=self._refine_factor,
|
||||
vector_column=self._vector_column,
|
||||
with_row_id=self._with_row_id,
|
||||
offset=self._offset,
|
||||
fast_search=self._fast_search,
|
||||
ef=self._ef,
|
||||
)
|
||||
result_set = self._table._execute_query(query, batch_size)
|
||||
if self._reranker is not None:
|
||||
rs_table = result_set.read_all()
|
||||
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||
# convert result_set back to RecordBatchReader
|
||||
result_set = pa.RecordBatchReader.from_batches(
|
||||
result_set.schema, result_set.to_batches()
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
@@ -591,7 +728,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||
return self
|
||||
|
||||
def rerank(
|
||||
self, reranker: Reranker, query_string: Optional[str] = None
|
||||
self, reranker: Reranker, query_string: Optional[str] = None
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
@@ -756,12 +893,34 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||
def to_arrow(self) -> pa.Table:
|
||||
ds = self._table.to_lance()
|
||||
return ds.to_table(
|
||||
return self.to_batches().read_all()
|
||||
|
||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
||||
query = Query(
|
||||
columns=self._columns,
|
||||
filter=self._where,
|
||||
limit=self._limit,
|
||||
k=self._limit or 10,
|
||||
with_row_id=self._with_row_id,
|
||||
vector=[],
|
||||
# not actually respected in remote query
|
||||
offset=self._offset or 0,
|
||||
)
|
||||
return self._table._execute_query(query)
|
||||
|
||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
||||
"""Rerank the results using the specified reranker.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reranker: Reranker
|
||||
The reranker to use.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LanceEmptyQueryBuilder
|
||||
The LanceQueryBuilder object.
|
||||
"""
|
||||
raise NotImplementedError("Reranking is not yet supported.")
|
||||
|
||||
|
||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||
|
||||
@@ -55,11 +55,13 @@ class RestfulLanceDBClient:
|
||||
region: str
|
||||
api_key: Credential
|
||||
host_override: Optional[str] = attrs.field(default=None)
|
||||
db_prefix: Optional[str] = attrs.field(default=None)
|
||||
|
||||
closed: bool = attrs.field(default=False, init=False)
|
||||
|
||||
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
||||
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
||||
storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True)
|
||||
|
||||
@functools.cached_property
|
||||
def session(self) -> requests.Session:
|
||||
@@ -92,6 +94,18 @@ class RestfulLanceDBClient:
|
||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||
if self.host_override:
|
||||
headers["x-lancedb-database"] = self.db_name
|
||||
if self.storage_options:
|
||||
if self.storage_options.get("account_name") is not None:
|
||||
headers["x-azure-storage-account-name"] = self.storage_options[
|
||||
"account_name"
|
||||
]
|
||||
if self.storage_options.get("azure_storage_account_name") is not None:
|
||||
headers["x-azure-storage-account-name"] = self.storage_options[
|
||||
"azure_storage_account_name"
|
||||
]
|
||||
if self.db_prefix:
|
||||
headers["x-lancedb-database-prefix"] = self.db_prefix
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
@@ -158,6 +172,7 @@ class RestfulLanceDBClient:
|
||||
headers["content-type"] = content_type
|
||||
if request_id is not None:
|
||||
headers["x-request-id"] = request_id
|
||||
|
||||
with self.session.post(
|
||||
urljoin(self.url, uri),
|
||||
headers=headers,
|
||||
@@ -245,7 +260,6 @@ def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
||||
connect=connect_retries,
|
||||
read=read_retries,
|
||||
backoff_factor=backoff_factor,
|
||||
backoff_jitter=backoff_jitter,
|
||||
status_forcelist=statuses,
|
||||
allowed_methods=methods,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ import inspect
|
||||
import logging
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from cachetools import TTLCache
|
||||
@@ -44,20 +44,25 @@ class RemoteDBConnection(DBConnection):
|
||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||
connection_timeout: float = 120.0,
|
||||
read_timeout: float = 300.0,
|
||||
storage_options: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Connect to a remote LanceDB database."""
|
||||
parsed = urlparse(db_url)
|
||||
if parsed.scheme != "db":
|
||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||
self.db_name = parsed.netloc
|
||||
prefix = parsed.path.lstrip("/")
|
||||
self.db_prefix = None if not prefix else prefix
|
||||
self.api_key = api_key
|
||||
self._client = RestfulLanceDBClient(
|
||||
self.db_name,
|
||||
region,
|
||||
api_key,
|
||||
host_override,
|
||||
self.db_prefix,
|
||||
connection_timeout=connection_timeout,
|
||||
read_timeout=read_timeout,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
self._request_thread_pool = request_thread_pool
|
||||
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
||||
|
||||
@@ -22,6 +22,7 @@ from lance import json_to_schema
|
||||
|
||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||
from lancedb.merge import LanceMergeInsertBuilder
|
||||
from lancedb.query import LanceQueryBuilder
|
||||
|
||||
from ..query import LanceVectorQueryBuilder
|
||||
from ..table import Query, Table, _sanitize_data
|
||||
@@ -228,10 +229,21 @@ class RemoteTable(Table):
|
||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
query: Union[VEC, str] = None,
|
||||
query_type: str = "vector",
|
||||
vector_column_name: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
return self.search(query, query_type, vector_column_name, fast_search)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: Union[VEC, str],
|
||||
query: Union[VEC, str] = None,
|
||||
query_type: str = "vector",
|
||||
vector_column_name: Optional[str] = None,
|
||||
fast_search: bool = False,
|
||||
) -> LanceVectorQueryBuilder:
|
||||
"""Create a search query to find the nearest neighbors
|
||||
of the given query vector. We currently support [vector search][search]
|
||||
@@ -278,6 +290,11 @@ class RemoteTable(Table):
|
||||
- If the table has multiple vector columns then the *vector_column_name*
|
||||
needs to be specified. Otherwise, an error is raised.
|
||||
|
||||
fast_search: bool, optional
|
||||
Skip a flat search of unindexed data. This may improve
|
||||
search performance but search results will not include unindexed data.
|
||||
|
||||
- *default False*.
|
||||
Returns
|
||||
-------
|
||||
LanceQueryBuilder
|
||||
@@ -293,7 +310,14 @@ class RemoteTable(Table):
|
||||
"""
|
||||
if vector_column_name is None:
|
||||
vector_column_name = inf_vector_column_query(self.schema)
|
||||
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||
|
||||
return LanceQueryBuilder.create(
|
||||
self,
|
||||
query,
|
||||
query_type,
|
||||
vector_column_name=vector_column_name,
|
||||
fast_search=fast_search,
|
||||
)
|
||||
|
||||
def _execute_query(
|
||||
self, query: Query, batch_size: Optional[int] = None
|
||||
|
||||
@@ -4,7 +4,6 @@ from .colbert import ColbertReranker
|
||||
from .cross_encoder import CrossEncoderReranker
|
||||
from .linear_combination import LinearCombinationReranker
|
||||
from .openai import OpenaiReranker
|
||||
from .jina import JinaReranker
|
||||
|
||||
__all__ = [
|
||||
"Reranker",
|
||||
@@ -13,5 +12,4 @@ __all__ = [
|
||||
"LinearCombinationReranker",
|
||||
"OpenaiReranker",
|
||||
"ColbertReranker",
|
||||
"JinaReranker",
|
||||
]
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from ..util import attempt_import_or_raise
|
||||
from .base import Reranker
|
||||
|
||||
|
||||
class JinaReranker(Reranker):
|
||||
"""
|
||||
Reranks the results using Jina reranker model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_name : str, default "jinaai/jina-reranker-v1-turbo-en"
|
||||
The name of the reranker to use. For all models, see
|
||||
https://huggingface.co/jinaai/jina-reranker-v1-turbo-en
|
||||
column : str, default "text"
|
||||
The name of the column to use as input to the cross encoder model.
|
||||
device : str, default None
|
||||
The device to use for the cross encoder model. If None, will use "cuda"
|
||||
if available, otherwise "cpu".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "jinaai/jina-reranker-v1-turbo-en",
|
||||
column: str = "text",
|
||||
device: Union[str, None] = None,
|
||||
return_score="relevance",
|
||||
):
|
||||
super().__init__(return_score)
|
||||
torch = attempt_import_or_raise("torch")
|
||||
self.model_name = model_name
|
||||
self.column = column
|
||||
self.device = device
|
||||
if self.device is None:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
transformers = attempt_import_or_raise("transformers")
|
||||
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||
self.model_name, num_labels=1, trust_remote_code=True
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def _rerank(self, result_set: pa.Table, query: str):
|
||||
passages = result_set[self.column].to_pylist()
|
||||
cross_inp = [[query, passage] for passage in passages]
|
||||
cross_scores = self.model.compute_score(cross_inp)
|
||||
result_set = result_set.append_column(
|
||||
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||
)
|
||||
|
||||
return result_set
|
||||
|
||||
def rerank_hybrid(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
combined_results = self.merge_results(vector_results, fts_results)
|
||||
combined_results = self._rerank(combined_results, query)
|
||||
# sort the results by _score
|
||||
if self.score == "relevance":
|
||||
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||
elif self.score == "all":
|
||||
raise NotImplementedError(
|
||||
"return_score='all' not implemented for CrossEncoderReranker"
|
||||
)
|
||||
combined_results = combined_results.sort_by(
|
||||
[("_relevance_score", "descending")]
|
||||
)
|
||||
|
||||
return combined_results
|
||||
|
||||
def rerank_vector(
|
||||
self,
|
||||
query: str,
|
||||
vector_results: pa.Table,
|
||||
):
|
||||
vector_results = self._rerank(vector_results, query)
|
||||
if self.score == "relevance":
|
||||
vector_results = vector_results.drop_columns(["_distance"])
|
||||
|
||||
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||
return vector_results
|
||||
|
||||
def rerank_fts(
|
||||
self,
|
||||
query: str,
|
||||
fts_results: pa.Table,
|
||||
):
|
||||
fts_results = self._rerank(fts_results, query)
|
||||
if self.score == "relevance":
|
||||
fts_results = fts_results.drop_columns(["score"])
|
||||
|
||||
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||
return fts_results
|
||||
@@ -30,7 +30,6 @@ from typing import (
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import lance
|
||||
import numpy as np
|
||||
@@ -48,7 +47,6 @@ from .pydantic import LanceModel, model_to_dict
|
||||
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_scheme,
|
||||
inf_vector_column_query,
|
||||
join_uri,
|
||||
safe_import_pandas,
|
||||
@@ -210,26 +208,6 @@ def _to_record_batch_generator(
|
||||
yield b
|
||||
|
||||
|
||||
def _table_path(base: str, table_name: str) -> str:
|
||||
"""
|
||||
Get a table path that can be used in PyArrow FS.
|
||||
|
||||
Removes any weird schemes (such as "s3+ddb") and drops any query params.
|
||||
"""
|
||||
uri = _table_uri(base, table_name)
|
||||
# Parse as URL
|
||||
parsed = urlparse(uri)
|
||||
# If scheme is s3+ddb, convert to s3
|
||||
if parsed.scheme == "s3+ddb":
|
||||
parsed = parsed._replace(scheme="s3")
|
||||
# Remove query parameters
|
||||
return parsed._replace(query=None).geturl()
|
||||
|
||||
|
||||
def _table_uri(base: str, table_name: str) -> str:
|
||||
return join_uri(base, f"{table_name}.lance")
|
||||
|
||||
|
||||
class Table(ABC):
|
||||
"""
|
||||
A Table is a collection of Records in a LanceDB Database.
|
||||
@@ -930,7 +908,7 @@ class LanceTable(Table):
|
||||
@classmethod
|
||||
def open(cls, db, name, **kwargs):
|
||||
tbl = cls(db, name, **kwargs)
|
||||
fs, path = fs_from_uri(tbl._dataset_path)
|
||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
||||
file_info = fs.get_file_info(path)
|
||||
if file_info.type != pa.fs.FileType.Directory:
|
||||
raise FileNotFoundError(
|
||||
@@ -940,14 +918,9 @@ class LanceTable(Table):
|
||||
|
||||
return tbl
|
||||
|
||||
@cached_property
|
||||
def _dataset_path(self) -> str:
|
||||
# Cacheable since it's deterministic
|
||||
return _table_path(self._conn.uri, self.name)
|
||||
|
||||
@cached_property
|
||||
@property
|
||||
def _dataset_uri(self) -> str:
|
||||
return _table_uri(self._conn.uri, self.name)
|
||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
||||
|
||||
@property
|
||||
def _dataset(self) -> LanceDataset:
|
||||
@@ -1257,10 +1230,6 @@ class LanceTable(Table):
|
||||
)
|
||||
|
||||
def _get_fts_index_path(self):
|
||||
if get_uri_scheme(self._dataset_uri) != "file":
|
||||
raise NotImplementedError(
|
||||
"Full-text search is not supported on object stores."
|
||||
)
|
||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||
|
||||
def add(
|
||||
|
||||
@@ -139,11 +139,8 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
||||
# using pathlib for local paths make this windows compatible
|
||||
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
|
||||
return str(pathlib.Path(base, *parts))
|
||||
else:
|
||||
# there might be query parameters in the base URI
|
||||
url = urlparse(base)
|
||||
new_path = "/".join([p.rstrip("/") for p in [url.path, *parts]])
|
||||
return url._replace(path=new_path).geturl()
|
||||
# for remote paths, just use os.path.join
|
||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
||||
|
||||
|
||||
def attempt_import_or_raise(module: str, mitigation=None):
|
||||
|
||||
@@ -21,6 +21,7 @@ class FakeLanceDBClient:
|
||||
pass
|
||||
|
||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||
print(f"{query=}")
|
||||
assert table_name == "test"
|
||||
t = pa.schema([]).empty_table()
|
||||
return VectorQueryResult(t)
|
||||
@@ -39,3 +40,21 @@ def test_remote_db():
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
table.search([1.0, 2.0]).to_pandas()
|
||||
|
||||
|
||||
def test_empty_query_with_filter():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
print(table.query().select(["vector"]).where("foo == bar").to_arrow())
|
||||
|
||||
|
||||
def test_fast_search_query_with_filter():
|
||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
||||
setattr(conn, "_client", FakeLanceDBClient())
|
||||
|
||||
table = conn["test"]
|
||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||
print(table.query([0, 0], fast_search=True).select(["vector"]).where("foo == bar").to_arrow())
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import lancedb
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -9,7 +11,6 @@ from lancedb.rerankers import (
|
||||
ColbertReranker,
|
||||
CrossEncoderReranker,
|
||||
OpenaiReranker,
|
||||
JinaReranker,
|
||||
)
|
||||
from lancedb.table import LanceTable
|
||||
|
||||
@@ -118,18 +119,136 @@ def test_linear_combination(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"reranker",
|
||||
[
|
||||
ColbertReranker(),
|
||||
OpenaiReranker(),
|
||||
CohereReranker(),
|
||||
CrossEncoderReranker(),
|
||||
JinaReranker(),
|
||||
],
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
||||
)
|
||||
def test_colbert_reranker(tmp_path, reranker):
|
||||
def test_cohere_reranker(tmp_path):
|
||||
pytest.importorskip("cohere")
|
||||
reranker = CohereReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
# Hybrid search setting
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=CohereReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
query = "Our father who art in heaven"
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_cross_encoder_reranker(tmp_path):
|
||||
pytest.importorskip("sentence_transformers")
|
||||
reranker = CrossEncoderReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query), query_type="hybrid")
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
def test_colbert_reranker(tmp_path):
|
||||
pytest.importorskip("transformers")
|
||||
reranker = ColbertReranker()
|
||||
table, schema = get_test_table(tmp_path)
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
@@ -186,3 +305,67 @@ def test_colbert_reranker(tmp_path, reranker):
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
||||
)
|
||||
def test_openai_reranker(tmp_path):
|
||||
pytest.importorskip("openai")
|
||||
table, schema = get_test_table(tmp_path)
|
||||
reranker = OpenaiReranker()
|
||||
result1 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(normalize="score", reranker=reranker)
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
result2 = (
|
||||
table.search("Our father who art in heaven", query_type="hybrid")
|
||||
.rerank(reranker=OpenaiReranker())
|
||||
.to_pydantic(schema)
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
# test explicit hybrid query
|
||||
query = "Our father who art in heaven"
|
||||
query_vector = table.to_pandas()["vector"][0]
|
||||
result = (
|
||||
table.search((query_vector, query))
|
||||
.limit(30)
|
||||
.rerank(reranker=reranker)
|
||||
.to_arrow()
|
||||
)
|
||||
|
||||
assert len(result) == 30
|
||||
|
||||
err = (
|
||||
"The _relevance_score column of the results returned by the reranker "
|
||||
"represents the relevance of the result to the query & should "
|
||||
"be descending."
|
||||
)
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
# Vector search setting
|
||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
assert len(result) == 30
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
result_explicit = (
|
||||
table.search(query_vector)
|
||||
.rerank(reranker=reranker, query_string=query)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result_explicit) == 30
|
||||
with pytest.raises(
|
||||
ValueError
|
||||
): # This raises an error because vector query is provided without reanking query
|
||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
||||
# FTS search setting
|
||||
result = (
|
||||
table.search(query, query_type="fts")
|
||||
.rerank(reranker=reranker)
|
||||
.limit(30)
|
||||
.to_arrow()
|
||||
)
|
||||
assert len(result) > 0
|
||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||
|
||||
@@ -13,8 +13,6 @@
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
from datetime import timedelta
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
import pyarrow as pa
|
||||
@@ -27,7 +25,6 @@ CONFIG = {
|
||||
"aws_access_key_id": "ACCESSKEY",
|
||||
"aws_secret_access_key": "SECRETKEY",
|
||||
"aws_endpoint": "http://localhost:4566",
|
||||
"dynamodb_endpoint": "http://localhost:4566",
|
||||
"aws_region": "us-east-1",
|
||||
}
|
||||
|
||||
@@ -159,104 +156,3 @@ def test_s3_sse(s3_bucket: str, kms_key: str):
|
||||
validate_objects_encrypted(s3_bucket, path, kms_key)
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def commit_table():
|
||||
ddb = get_boto3_client("dynamodb", endpoint_url=CONFIG["dynamodb_endpoint"])
|
||||
table_name = "lance-integtest"
|
||||
try:
|
||||
ddb.delete_table(TableName=table_name)
|
||||
except ddb.exceptions.ResourceNotFoundException:
|
||||
pass
|
||||
ddb.create_table(
|
||||
TableName=table_name,
|
||||
KeySchema=[
|
||||
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
||||
{"AttributeName": "version", "KeyType": "RANGE"},
|
||||
],
|
||||
AttributeDefinitions=[
|
||||
{"AttributeName": "base_uri", "AttributeType": "S"},
|
||||
{"AttributeName": "version", "AttributeType": "N"},
|
||||
],
|
||||
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
||||
)
|
||||
yield table_name
|
||||
ddb.delete_table(TableName=table_name)
|
||||
|
||||
|
||||
@pytest.mark.s3_test
|
||||
def test_s3_dynamodb(s3_bucket: str, commit_table: str):
|
||||
storage_options = copy.copy(CONFIG)
|
||||
|
||||
uri = f"s3+ddb://{s3_bucket}/test?ddbTableName={commit_table}"
|
||||
data = pa.table({"x": [1, 2, 3]})
|
||||
|
||||
async def test():
|
||||
db = await lancedb.connect_async(
|
||||
uri,
|
||||
storage_options=storage_options,
|
||||
read_consistency_interval=timedelta(0),
|
||||
)
|
||||
|
||||
table = await db.create_table("test", data)
|
||||
|
||||
# Five concurrent writers
|
||||
async def insert():
|
||||
# independent table refs for true concurrent writes.
|
||||
table = await db.open_table("test")
|
||||
await table.add(data, mode="append")
|
||||
|
||||
tasks = [insert() for _ in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
row_count = await table.count_rows()
|
||||
assert row_count == 3 * 6
|
||||
|
||||
asyncio.run(test())
|
||||
|
||||
|
||||
@pytest.mark.s3_test
|
||||
def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
|
||||
# Sync API doesn't support storage_options, so we have to provide as env vars
|
||||
for key, value in CONFIG.items():
|
||||
monkeypatch.setenv(key.upper(), value)
|
||||
|
||||
uri = f"s3+ddb://{s3_bucket}/test2?ddbTableName={commit_table}"
|
||||
data = pa.table({"x": ["a", "b", "c"]})
|
||||
|
||||
db = lancedb.connect(
|
||||
uri,
|
||||
read_consistency_interval=timedelta(0),
|
||||
)
|
||||
|
||||
table = db.create_table("test_ddb_sync", data)
|
||||
|
||||
# Five concurrent writers
|
||||
def insert():
|
||||
table = db.open_table("test_ddb_sync")
|
||||
table.add(data, mode="append")
|
||||
|
||||
threads = []
|
||||
for _ in range(5):
|
||||
thread = threading.Thread(target=insert)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
row_count = table.count_rows()
|
||||
assert row_count == 3 * 6
|
||||
|
||||
# FTS indices should error since they are not supported yet.
|
||||
with pytest.raises(
|
||||
NotImplementedError, match="Full-text search is not supported on object stores."
|
||||
):
|
||||
table.create_fts_index("x")
|
||||
|
||||
# make sure list tables still works
|
||||
assert db.table_names() == ["test_ddb_sync"]
|
||||
db.drop_table("test_ddb_sync")
|
||||
assert db.table_names() == []
|
||||
db.drop_database()
|
||||
|
||||
@@ -735,7 +735,7 @@ def test_create_scalar_index(db):
|
||||
indices = table.to_lance().list_indices()
|
||||
assert len(indices) == 1
|
||||
scalar_index = indices[0]
|
||||
assert scalar_index["type"] == "Scalar"
|
||||
assert scalar_index["type"] == "BTree"
|
||||
|
||||
# Confirm that prefiltering still works with the scalar index column
|
||||
results = table.search().where("x = 'c'").to_arrow()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb-node"
|
||||
version = "0.6.0"
|
||||
version = "0.5.2"
|
||||
description = "Serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
edition.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "lancedb"
|
||||
version = "0.6.0"
|
||||
version = "0.5.2"
|
||||
edition.workspace = true
|
||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||
license.workspace = true
|
||||
@@ -55,11 +55,10 @@ walkdir = "2"
|
||||
# For s3 integration tests (dev deps aren't allowed to be optional atm)
|
||||
# We pin these because the content-length check breaks with localstack
|
||||
# https://github.com/smithy-lang/smithy-rs/releases/tag/release-2024-05-21
|
||||
aws-sdk-dynamodb = { version = "=1.23.0" }
|
||||
aws-sdk-s3 = { version = "=1.23.0" }
|
||||
aws-sdk-kms = { version = "=1.21.0" }
|
||||
aws-config = { version = "1.0" }
|
||||
aws-smithy-runtime = { version = "=1.3.1" }
|
||||
aws-smithy-runtime = { version = "=1.3.0" }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
|
||||
@@ -6,12 +6,3 @@
|
||||
LanceDB Rust SDK, a serverless vector database.
|
||||
|
||||
Read more at: https://lancedb.com/
|
||||
|
||||
> [!TIP]
|
||||
> A transitive dependency of `lancedb` is `lzma-sys`, which uses dynamic linking
|
||||
> by default. If you want to statically link `lzma-sys`, you should activate it's
|
||||
> `static` feature by adding the following to your dependencies:
|
||||
>
|
||||
> ```toml
|
||||
> lzma-sys = { version = "*", features = ["static"] }
|
||||
> ```
|
||||
|
||||
@@ -1889,7 +1889,6 @@ impl TableInternal for NativeTable {
|
||||
}
|
||||
columns.push(field.name.clone());
|
||||
}
|
||||
|
||||
let index_type = if is_vector {
|
||||
crate::index::IndexType::IvfPq
|
||||
} else {
|
||||
|
||||
@@ -25,9 +25,7 @@ const CONFIG: &[(&str, &str)] = &[
|
||||
("access_key_id", "ACCESS_KEY"),
|
||||
("secret_access_key", "SECRET_KEY"),
|
||||
("endpoint", "http://127.0.0.1:4566"),
|
||||
("dynamodb_endpoint", "http://127.0.0.1:4566"),
|
||||
("allow_http", "true"),
|
||||
("region", "us-east-1"),
|
||||
];
|
||||
|
||||
async fn aws_config() -> SdkConfig {
|
||||
@@ -290,126 +288,3 @@ async fn test_encryption() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct DynamoDBCommitTable(String);
|
||||
|
||||
impl DynamoDBCommitTable {
|
||||
async fn new(name: &str) -> Self {
|
||||
let config = aws_config().await;
|
||||
let client = aws_sdk_dynamodb::Client::new(&config);
|
||||
|
||||
// In case it wasn't deleted earlier
|
||||
Self::delete_table(client.clone(), name).await;
|
||||
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
|
||||
|
||||
use aws_sdk_dynamodb::types::*;
|
||||
|
||||
client
|
||||
.create_table()
|
||||
.table_name(name)
|
||||
.attribute_definitions(
|
||||
AttributeDefinition::builder()
|
||||
.attribute_name("base_uri")
|
||||
.attribute_type(ScalarAttributeType::S)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.attribute_definitions(
|
||||
AttributeDefinition::builder()
|
||||
.attribute_name("version")
|
||||
.attribute_type(ScalarAttributeType::N)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.key_schema(
|
||||
KeySchemaElement::builder()
|
||||
.attribute_name("base_uri")
|
||||
.key_type(KeyType::Hash)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.key_schema(
|
||||
KeySchemaElement::builder()
|
||||
.attribute_name("version")
|
||||
.key_type(KeyType::Range)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.provisioned_throughput(
|
||||
ProvisionedThroughput::builder()
|
||||
.read_capacity_units(1)
|
||||
.write_capacity_units(1)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
Self(name.to_string())
|
||||
}
|
||||
|
||||
async fn delete_table(client: aws_sdk_dynamodb::Client, name: &str) {
|
||||
match client
|
||||
.delete_table()
|
||||
.table_name(name)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| err.into_service_error())
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(e) if e.is_resource_not_found_exception() => {}
|
||||
Err(e) => panic!("Failed to delete table: {}", e),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for DynamoDBCommitTable {
|
||||
fn drop(&mut self) {
|
||||
let table_name = self.0.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let config = aws_config().await;
|
||||
let client = aws_sdk_dynamodb::Client::new(&config);
|
||||
Self::delete_table(client, &table_name).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_dynamodb_commit() {
|
||||
// test concurrent commit on dynamodb
|
||||
let bucket = S3Bucket::new("test-dynamodb").await;
|
||||
let table = DynamoDBCommitTable::new("test_table").await;
|
||||
|
||||
let uri = format!("s3+ddb://{}?ddbTableName={}", bucket.0, table.0);
|
||||
let db = lancedb::connect(&uri)
|
||||
.storage_options(CONFIG.iter().cloned())
|
||||
.execute()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let data = test_data();
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
|
||||
let table = db.create_table("test_table", data).execute().await.unwrap();
|
||||
|
||||
let data = test_data();
|
||||
|
||||
let mut tasks = vec![];
|
||||
for _ in 0..5 {
|
||||
let table = db.open_table("test_table").execute().await.unwrap();
|
||||
let data = data.clone();
|
||||
tasks.push(tokio::spawn(async move {
|
||||
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||
table.add(data).execute().await.unwrap();
|
||||
}));
|
||||
}
|
||||
|
||||
for task in tasks {
|
||||
task.await.unwrap();
|
||||
}
|
||||
|
||||
table.checkout_latest().await.unwrap();
|
||||
let row_count = table.count_rows(None).await.unwrap();
|
||||
assert_eq!(row_count, 18);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user