mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-05 11:22:58 +00:00
Compare commits
16 Commits
python-v0.
...
ayush/jina
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2bb7b2d2e7 | ||
|
|
020a437230 | ||
|
|
34f1aeb84c | ||
|
|
5c3a88b6b2 | ||
|
|
e780b2f51c | ||
|
|
b8a1719174 | ||
|
|
ccded130ed | ||
|
|
48f8d1b3b7 | ||
|
|
865ed99881 | ||
|
|
d6485f1215 | ||
|
|
79a1667753 | ||
|
|
a866b78a31 | ||
|
|
c7d37b3e6e | ||
|
|
4b71552b73 | ||
|
|
5ce5f64da3 | ||
|
|
c582b0fc63 |
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.5.2-final.1"
|
current_version = "0.6.0"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ runs:
|
|||||||
args: ${{ inputs.args }}
|
args: ${{ inputs.args }}
|
||||||
docker-options: "-e PIP_EXTRA_INDEX_URL=https://pypi.fury.io/lancedb/"
|
docker-options: "-e PIP_EXTRA_INDEX_URL=https://pypi.fury.io/lancedb/"
|
||||||
working-directory: python
|
working-directory: python
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v3
|
||||||
with:
|
with:
|
||||||
name: windows-wheels
|
name: windows-wheels
|
||||||
path: python\target\wheels
|
path: python\target\wheels
|
||||||
|
|||||||
4
.github/workflows/docs_test.yml
vendored
4
.github/workflows/docs_test.yml
vendored
@@ -24,7 +24,7 @@ env:
|
|||||||
jobs:
|
jobs:
|
||||||
test-python:
|
test-python:
|
||||||
name: Test doc python code
|
name: Test doc python code
|
||||||
runs-on: "buildjet-8vcpu-ubuntu-2204"
|
runs-on: "warp-ubuntu-latest-x64-4x"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -56,7 +56,7 @@ jobs:
|
|||||||
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
|
for d in *; do cd "$d"; echo "$d".py; python "$d".py; cd ..; done
|
||||||
test-node:
|
test-node:
|
||||||
name: Test doc nodejs code
|
name: Test doc nodejs code
|
||||||
runs-on: "buildjet-8vcpu-ubuntu-2204"
|
runs-on: "warp-ubuntu-latest-x64-4x"
|
||||||
timeout-minutes: 60
|
timeout-minutes: 60
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: local-biome-check
|
- id: local-biome-check
|
||||||
name: biome check
|
name: biome check
|
||||||
entry: npx @biomejs/biome@1.7.3 check --config-path nodejs/biome.json nodejs/
|
entry: npx @biomejs/biome@1.8.3 check --config-path nodejs/biome.json nodejs/
|
||||||
language: system
|
language: system
|
||||||
types: [text]
|
types: [text]
|
||||||
files: "nodejs/.*"
|
files: "nodejs/.*"
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ arrow-schema = "51.0"
|
|||||||
arrow-arith = "51.0"
|
arrow-arith = "51.0"
|
||||||
arrow-cast = "51.0"
|
arrow-cast = "51.0"
|
||||||
async-trait = "0"
|
async-trait = "0"
|
||||||
chrono = "=0.4.39"
|
chrono = "0.4.35"
|
||||||
datafusion-physical-plan = "37.1"
|
datafusion-physical-plan = "37.1"
|
||||||
half = { "version" = "=2.4.1", default-features = false, features = [
|
half = { "version" = "=2.4.1", default-features = false, features = [
|
||||||
"num-traits",
|
"num-traits",
|
||||||
|
|||||||
@@ -57,6 +57,8 @@ plugins:
|
|||||||
- https://arrow.apache.org/docs/objects.inv
|
- https://arrow.apache.org/docs/objects.inv
|
||||||
- https://pandas.pydata.org/docs/objects.inv
|
- https://pandas.pydata.org/docs/objects.inv
|
||||||
- mkdocs-jupyter
|
- mkdocs-jupyter
|
||||||
|
- render_swagger:
|
||||||
|
allow_arbitrary_locations : true
|
||||||
|
|
||||||
markdown_extensions:
|
markdown_extensions:
|
||||||
- admonition
|
- admonition
|
||||||
@@ -158,6 +160,7 @@ nav:
|
|||||||
- API reference:
|
- API reference:
|
||||||
- 🐍 Python: python/saas-python.md
|
- 🐍 Python: python/saas-python.md
|
||||||
- 👾 JavaScript: javascript/modules.md
|
- 👾 JavaScript: javascript/modules.md
|
||||||
|
- REST API: cloud/rest.md
|
||||||
|
|
||||||
- Quick start: basic.md
|
- Quick start: basic.md
|
||||||
- Concepts:
|
- Concepts:
|
||||||
@@ -228,6 +231,7 @@ nav:
|
|||||||
- API reference:
|
- API reference:
|
||||||
- 🐍 Python: python/saas-python.md
|
- 🐍 Python: python/saas-python.md
|
||||||
- 👾 JavaScript: javascript/modules.md
|
- 👾 JavaScript: javascript/modules.md
|
||||||
|
- REST API: cloud/rest.md
|
||||||
|
|
||||||
extra_css:
|
extra_css:
|
||||||
- styles/global.css
|
- styles/global.css
|
||||||
|
|||||||
479
docs/openapi.yml
Normal file
479
docs/openapi.yml
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
openapi: 3.1.0
|
||||||
|
info:
|
||||||
|
version: 1.0.0
|
||||||
|
title: LanceDB Cloud API
|
||||||
|
description: |
|
||||||
|
LanceDB Cloud API is a RESTful API that allows users to access and modify data stored in LanceDB Cloud.
|
||||||
|
Table actions are considered temporary resource creations and all use POST method.
|
||||||
|
contact:
|
||||||
|
name: LanceDB support
|
||||||
|
url: https://lancedb.com
|
||||||
|
email: contact@lancedb.com
|
||||||
|
|
||||||
|
servers:
|
||||||
|
- url: https://{db}.{region}.api.lancedb.com
|
||||||
|
description: LanceDB Cloud REST endpoint.
|
||||||
|
variables:
|
||||||
|
db:
|
||||||
|
default: ""
|
||||||
|
description: the name of DB
|
||||||
|
region:
|
||||||
|
default: "us-east-1"
|
||||||
|
description: the service region of the DB
|
||||||
|
|
||||||
|
security:
|
||||||
|
- key_auth: []
|
||||||
|
|
||||||
|
components:
|
||||||
|
securitySchemes:
|
||||||
|
key_auth:
|
||||||
|
name: x-api-key
|
||||||
|
type: apiKey
|
||||||
|
in: header
|
||||||
|
parameters:
|
||||||
|
table_name:
|
||||||
|
name: name
|
||||||
|
in: path
|
||||||
|
description: name of the table
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
invalid_request:
|
||||||
|
description: Invalid request
|
||||||
|
content:
|
||||||
|
text/plain:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
not_found:
|
||||||
|
description: Not found
|
||||||
|
content:
|
||||||
|
text/plain:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
unauthorized:
|
||||||
|
description: Unauthorized
|
||||||
|
content:
|
||||||
|
text/plain:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBodies:
|
||||||
|
arrow_stream_buffer:
|
||||||
|
description: Arrow IPC stream buffer
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/vnd.apache.arrow.stream:
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
format: binary
|
||||||
|
|
||||||
|
paths:
|
||||||
|
/v1/table/:
|
||||||
|
get:
|
||||||
|
description: List tables, optionally, with pagination.
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: List Tables
|
||||||
|
operationId: listTables
|
||||||
|
parameters:
|
||||||
|
- name: limit
|
||||||
|
in: query
|
||||||
|
description: Limits the number of items to return.
|
||||||
|
schema:
|
||||||
|
type: integer
|
||||||
|
- name: page_token
|
||||||
|
in: query
|
||||||
|
description: Specifies the starting position of the next query
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Successfully returned a list of tables in the DB
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
tables:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
page_token:
|
||||||
|
type: string
|
||||||
|
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
|
||||||
|
/v1/table/{name}/create/:
|
||||||
|
post:
|
||||||
|
description: Create a new table
|
||||||
|
summary: Create a new table
|
||||||
|
operationId: createTable
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Table successfully created
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
|
||||||
|
/v1/table/{name}/query/:
|
||||||
|
post:
|
||||||
|
description: Vector Query
|
||||||
|
url: https://{db-uri}.{aws-region}.api.lancedb.com/v1/table/{name}/query/
|
||||||
|
tags:
|
||||||
|
- Data
|
||||||
|
summary: Vector Query
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
vector:
|
||||||
|
type: FixedSizeList
|
||||||
|
description: |
|
||||||
|
The targetted vector to search for. Required.
|
||||||
|
vector_column:
|
||||||
|
type: string
|
||||||
|
description: |
|
||||||
|
The column to query, it can be inferred from the schema if there is only one vector column.
|
||||||
|
prefilter:
|
||||||
|
type: boolean
|
||||||
|
description: |
|
||||||
|
Whether to prefilter the data. Optional.
|
||||||
|
k:
|
||||||
|
type: integer
|
||||||
|
description: |
|
||||||
|
The number of search results to return. Default is 10.
|
||||||
|
distance_type:
|
||||||
|
type: string
|
||||||
|
description: |
|
||||||
|
The distance metric to use for search. L2, Cosine, Dot and Hamming are supported. Default is L2.
|
||||||
|
bypass_vector_index:
|
||||||
|
type: boolean
|
||||||
|
description: |
|
||||||
|
Whether to bypass vector index. Optional.
|
||||||
|
filter:
|
||||||
|
type: string
|
||||||
|
description: |
|
||||||
|
A filter expression that specifies the rows to query. Optional.
|
||||||
|
columns:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
description: |
|
||||||
|
The columns to return. Optional.
|
||||||
|
nprobe:
|
||||||
|
type: integer
|
||||||
|
description: |
|
||||||
|
The number of probes to use for search. Optional.
|
||||||
|
refine_factor:
|
||||||
|
type: integer
|
||||||
|
description: |
|
||||||
|
The refine factor to use for search. Optional.
|
||||||
|
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: top k results if query is successfully executed
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
results:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
type: integer
|
||||||
|
selected_col_1_to_return:
|
||||||
|
type: col_1_type
|
||||||
|
selected_col_n_to_return:
|
||||||
|
type: col_n_type
|
||||||
|
_distance:
|
||||||
|
type: float
|
||||||
|
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
|
||||||
|
/v1/table/{name}/insert/:
|
||||||
|
post:
|
||||||
|
description: Insert new data to the Table.
|
||||||
|
tags:
|
||||||
|
- Data
|
||||||
|
operationId: insertData
|
||||||
|
summary: Insert new data.
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Insert successful
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
/v1/table/{name}/merge_insert/:
|
||||||
|
post:
|
||||||
|
description: Create a "merge insert" operation
|
||||||
|
This operation can add rows, update rows, and remove rows all in a single
|
||||||
|
transaction. See python method `lancedb.table.Table.merge_insert` for examples.
|
||||||
|
tags:
|
||||||
|
- Data
|
||||||
|
summary: Merge Insert
|
||||||
|
operationId: mergeInsert
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
- name: on
|
||||||
|
in: query
|
||||||
|
description: |
|
||||||
|
The column to use as the primary key for the merge operation.
|
||||||
|
required: true
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: when_matched_update_all
|
||||||
|
in: query
|
||||||
|
description: |
|
||||||
|
Rows that exist in both the source table (new data) and
|
||||||
|
the target table (old data) will be updated, replacing
|
||||||
|
the old row with the corresponding matching row.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: boolean
|
||||||
|
- name: when_matched_update_all_filt
|
||||||
|
in: query
|
||||||
|
description: |
|
||||||
|
If present then only rows that satisfy the filter expression will
|
||||||
|
be updated
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
- name: when_not_matched_insert_all
|
||||||
|
in: query
|
||||||
|
description: |
|
||||||
|
Rows that exist only in the source table (new data) will be
|
||||||
|
inserted into the target table (old data).
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: boolean
|
||||||
|
- name: when_not_matched_by_source_delete
|
||||||
|
in: query
|
||||||
|
description: |
|
||||||
|
Rows that exist only in the target table (old data) will be
|
||||||
|
deleted. An optional condition (`when_not_matched_by_source_delete_filt`)
|
||||||
|
can be provided to limit what data is deleted.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: boolean
|
||||||
|
- name: when_not_matched_by_source_delete_filt
|
||||||
|
in: query
|
||||||
|
description: |
|
||||||
|
The filter expression that specifies the rows to delete.
|
||||||
|
required: false
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
requestBody:
|
||||||
|
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Merge Insert successful
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
/v1/table/{name}/delete/:
|
||||||
|
post:
|
||||||
|
description: Delete rows from a table.
|
||||||
|
tags:
|
||||||
|
- Data
|
||||||
|
summary: Delete rows from a table
|
||||||
|
operationId: deleteData
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
predicate:
|
||||||
|
type: string
|
||||||
|
description: |
|
||||||
|
A filter expression that specifies the rows to delete.
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Delete successful
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
/v1/table/{name}/drop/:
|
||||||
|
post:
|
||||||
|
description: Drop a table
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: Drop a table
|
||||||
|
operationId: dropTable
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
$ref: "#/components/requestBodies/arrow_stream_buffer"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Drop successful
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
|
||||||
|
/v1/table/{name}/describe/:
|
||||||
|
post:
|
||||||
|
description: Describe a table and return Table Information.
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: Describe a table
|
||||||
|
operationId: describeTable
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Table information
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
table:
|
||||||
|
type: string
|
||||||
|
version:
|
||||||
|
type: integer
|
||||||
|
schema:
|
||||||
|
type: string
|
||||||
|
stats:
|
||||||
|
type: object
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
|
||||||
|
/v1/table/{name}/index/list/:
|
||||||
|
post:
|
||||||
|
description: List indexes of a table
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: List indexes of a table
|
||||||
|
operationId: listIndexes
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Available list of indexes on the table.
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
indexes:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
columns:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
index_name:
|
||||||
|
type: string
|
||||||
|
index_uuid:
|
||||||
|
type: string
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
/v1/table/{name}/create_index/:
|
||||||
|
post:
|
||||||
|
description: Create vector index on a Table
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: Create vector index on a Table
|
||||||
|
operationId: createIndex
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
column:
|
||||||
|
type: string
|
||||||
|
metric_type:
|
||||||
|
type: string
|
||||||
|
nullable: false
|
||||||
|
description: |
|
||||||
|
The metric type to use for the index. L2, Cosine, Dot are supported.
|
||||||
|
index_type:
|
||||||
|
type: string
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Index successfully created
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
|
/v1/table/{name}/create_scalar_index/:
|
||||||
|
post:
|
||||||
|
description: Create a scalar index on a table
|
||||||
|
tags:
|
||||||
|
- Tables
|
||||||
|
summary: Create a scalar index on a table
|
||||||
|
operationId: createScalarIndex
|
||||||
|
parameters:
|
||||||
|
- $ref: "#/components/parameters/table_name"
|
||||||
|
requestBody:
|
||||||
|
required: true
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
column:
|
||||||
|
type: string
|
||||||
|
index_type:
|
||||||
|
type: string
|
||||||
|
required: false
|
||||||
|
responses:
|
||||||
|
"200":
|
||||||
|
description: Scalar Index successfully created
|
||||||
|
"400":
|
||||||
|
$ref: "#/components/responses/invalid_request"
|
||||||
|
"401":
|
||||||
|
$ref: "#/components/responses/unauthorized"
|
||||||
|
"404":
|
||||||
|
$ref: "#/components/responses/not_found"
|
||||||
@@ -2,4 +2,5 @@ mkdocs==1.5.3
|
|||||||
mkdocs-jupyter==0.24.1
|
mkdocs-jupyter==0.24.1
|
||||||
mkdocs-material==9.5.3
|
mkdocs-material==9.5.3
|
||||||
mkdocstrings[python]==0.20.0
|
mkdocstrings[python]==0.20.0
|
||||||
pydantic
|
mkdocs-render-swagger-plugin
|
||||||
|
pydantic
|
||||||
|
|||||||
1
docs/src/cloud/rest.md
Normal file
1
docs/src/cloud/rest.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
!!swagger ../../openapi.yml!!
|
||||||
@@ -193,13 +193,13 @@ from lancedb.pydantic import LanceModel, Vector
|
|||||||
|
|
||||||
model = get_registry().get("huggingface").create(name='facebook/bart-base')
|
model = get_registry().get("huggingface").create(name='facebook/bart-base')
|
||||||
|
|
||||||
class TextModel(LanceModel):
|
class Words(LanceModel):
|
||||||
text: str = model.SourceField()
|
text: str = model.SourceField()
|
||||||
vector: Vector(model.ndims()) = model.VectorField()
|
vector: Vector(model.ndims()) = model.VectorField()
|
||||||
|
|
||||||
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
|
df = pd.DataFrame({"text": ["hi hello sayonara", "goodbye world"]})
|
||||||
table = db.create_table("greets", schema=Words)
|
table = db.create_table("greets", schema=Words)
|
||||||
table.add()
|
table.add(df)
|
||||||
query = "old greeting"
|
query = "old greeting"
|
||||||
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
actual = table.search(query).limit(1).to_pydantic(Words)[0]
|
||||||
print(actual.text)
|
print(actual.text)
|
||||||
|
|||||||
@@ -265,6 +265,108 @@ For **read-only access**, LanceDB will need a policy such as:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### DynamoDB Commit Store for concurrent writes
|
||||||
|
|
||||||
|
By default, S3 does not support concurrent writes. Having two or more processes
|
||||||
|
writing to the same table at the same time can lead to data corruption. This is
|
||||||
|
because S3, unlike other object stores, does not have any atomic put or copy
|
||||||
|
operation.
|
||||||
|
|
||||||
|
To enable concurrent writes, you can configure LanceDB to use a DynamoDB table
|
||||||
|
as a commit store. This table will be used to coordinate writes between
|
||||||
|
different processes. To enable this feature, you must modify your connection
|
||||||
|
URI to use the `s3+ddb` scheme and add a query parameter `ddbTableName` with the
|
||||||
|
name of the table to use.
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
```python
|
||||||
|
import lancedb
|
||||||
|
db = await lancedb.connect_async(
|
||||||
|
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "JavaScript"
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
const lancedb = require("lancedb");
|
||||||
|
|
||||||
|
const db = await lancedb.connect(
|
||||||
|
"s3+ddb://bucket/path?ddbTableName=my-dynamodb-table",
|
||||||
|
);
|
||||||
|
```
|
||||||
|
|
||||||
|
The DynamoDB table must be created with the following schema:
|
||||||
|
|
||||||
|
- Hash key: `base_uri` (string)
|
||||||
|
- Range key: `version` (number)
|
||||||
|
|
||||||
|
You can create this programmatically with:
|
||||||
|
|
||||||
|
=== "Python"
|
||||||
|
|
||||||
|
<!-- skip-test -->
|
||||||
|
```python
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
dynamodb = boto3.client("dynamodb")
|
||||||
|
table = dynamodb.create_table(
|
||||||
|
TableName=table_name,
|
||||||
|
KeySchema=[
|
||||||
|
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
||||||
|
{"AttributeName": "version", "KeyType": "RANGE"},
|
||||||
|
],
|
||||||
|
AttributeDefinitions=[
|
||||||
|
{"AttributeName": "base_uri", "AttributeType": "S"},
|
||||||
|
{"AttributeName": "version", "AttributeType": "N"},
|
||||||
|
],
|
||||||
|
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
=== "JavaScript"
|
||||||
|
|
||||||
|
<!-- skip-test -->
|
||||||
|
```javascript
|
||||||
|
import {
|
||||||
|
CreateTableCommand,
|
||||||
|
DynamoDBClient,
|
||||||
|
} from "@aws-sdk/client-dynamodb";
|
||||||
|
|
||||||
|
const dynamodb = new DynamoDBClient({
|
||||||
|
region: CONFIG.awsRegion,
|
||||||
|
credentials: {
|
||||||
|
accessKeyId: CONFIG.awsAccessKeyId,
|
||||||
|
secretAccessKey: CONFIG.awsSecretAccessKey,
|
||||||
|
},
|
||||||
|
endpoint: CONFIG.awsEndpoint,
|
||||||
|
});
|
||||||
|
const command = new CreateTableCommand({
|
||||||
|
TableName: table_name,
|
||||||
|
AttributeDefinitions: [
|
||||||
|
{
|
||||||
|
AttributeName: "base_uri",
|
||||||
|
AttributeType: "S",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
AttributeName: "version",
|
||||||
|
AttributeType: "N",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
KeySchema: [
|
||||||
|
{ AttributeName: "base_uri", KeyType: "HASH" },
|
||||||
|
{ AttributeName: "version", KeyType: "RANGE" },
|
||||||
|
],
|
||||||
|
ProvisionedThroughput: {
|
||||||
|
ReadCapacityUnits: 1,
|
||||||
|
WriteCapacityUnits: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
await client.send(command);
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### S3-compatible stores
|
#### S3-compatible stores
|
||||||
|
|
||||||
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify both region and endpoint:
|
LanceDB can also connect to S3-compatible stores, such as MinIO. To do so, you must specify both region and endpoint:
|
||||||
|
|||||||
@@ -116,21 +116,21 @@ This guide will show how to create tables, insert data into them, and update the
|
|||||||
|
|
||||||
### From a Polars DataFrame
|
### From a Polars DataFrame
|
||||||
|
|
||||||
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
LanceDB supports [Polars](https://pola.rs/), a modern, fast DataFrame library
|
||||||
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
written in Rust. Just like in Pandas, the Polars integration is enabled by PyArrow
|
||||||
under the hood. A deeper integration between LanceDB Tables and Polars DataFrames
|
under the hood. A deeper integration between LanceDB Tables and Polars DataFrames
|
||||||
is on the way.
|
is on the way.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import polars as pl
|
import polars as pl
|
||||||
|
|
||||||
data = pl.DataFrame({
|
data = pl.DataFrame({
|
||||||
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
"vector": [[3.1, 4.1], [5.9, 26.5]],
|
||||||
"item": ["foo", "bar"],
|
"item": ["foo", "bar"],
|
||||||
"price": [10.0, 20.0]
|
"price": [10.0, 20.0]
|
||||||
})
|
})
|
||||||
table = db.create_table("pl_table", data=data)
|
table = db.create_table("pl_table", data=data)
|
||||||
```
|
```
|
||||||
|
|
||||||
### From an Arrow Table
|
### From an Arrow Table
|
||||||
=== "Python"
|
=== "Python"
|
||||||
|
|||||||
1481
docs/src/notebooks/lancedb_reranking.ipynb
Normal file
1481
docs/src/notebooks/lancedb_reranking.ipynb
Normal file
File diff suppressed because one or more lines are too long
4
node/package-lock.json
generated
4
node/package-lock.json
generated
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.5.2",
|
"version": "0.6.0",
|
||||||
"lockfileVersion": 3,
|
"lockfileVersion": 3,
|
||||||
"requires": true,
|
"requires": true,
|
||||||
"packages": {
|
"packages": {
|
||||||
"": {
|
"": {
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.5.2",
|
"version": "0.6.0",
|
||||||
"cpu": [
|
"cpu": [
|
||||||
"x64",
|
"x64",
|
||||||
"arm64"
|
"arm64"
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
{
|
{
|
||||||
"name": "vectordb",
|
"name": "vectordb",
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"description": " Serverless, low-latency vector database for AI applications",
|
"description": " Serverless, low-latency vector database for AI applications",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"types": "dist/index.d.ts",
|
"types": "dist/index.d.ts",
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"tsc": "tsc -b",
|
"tsc": "tsc -b",
|
||||||
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build --message-format=json",
|
"build": "npm run tsc && cargo-cp-artifact --artifact cdylib lancedb_node index.node -- cargo build -p lancedb-node --message-format=json",
|
||||||
"build-release": "npm run build -- --release",
|
"build-release": "npm run build -- --release",
|
||||||
"test": "npm run tsc && mocha -recursive dist/test",
|
"test": "npm run tsc && mocha -recursive dist/test",
|
||||||
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
"integration-test": "npm run tsc && mocha -recursive dist/integration_test",
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ crate-type = ["cdylib"]
|
|||||||
arrow-ipc.workspace = true
|
arrow-ipc.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
lancedb = { path = "../rust/lancedb" }
|
lancedb = { path = "../rust/lancedb" }
|
||||||
napi = { version = "2.15", default-features = false, features = [
|
napi = { version = "2.16.8", default-features = false, features = [
|
||||||
"napi7",
|
"napi9",
|
||||||
"async",
|
"async",
|
||||||
] }
|
] }
|
||||||
napi-derive = "2"
|
napi-derive = "2.16.4"
|
||||||
|
|
||||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||||
lzma-sys = { version = "*", features = ["static"] }
|
lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ describe("Registry", () => {
|
|||||||
return data.map(() => [1, 2, 3]);
|
return data.map(() => [1, 2, 3]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const func = getRegistry()
|
const func = getRegistry()
|
||||||
.get<MockEmbeddingFunction>("mock-embedding")!
|
.get<MockEmbeddingFunction>("mock-embedding")!
|
||||||
.create();
|
.create();
|
||||||
|
|||||||
@@ -14,6 +14,11 @@
|
|||||||
|
|
||||||
/* eslint-disable @typescript-eslint/naming-convention */
|
/* eslint-disable @typescript-eslint/naming-convention */
|
||||||
|
|
||||||
|
import {
|
||||||
|
CreateTableCommand,
|
||||||
|
DeleteTableCommand,
|
||||||
|
DynamoDBClient,
|
||||||
|
} from "@aws-sdk/client-dynamodb";
|
||||||
import {
|
import {
|
||||||
CreateKeyCommand,
|
CreateKeyCommand,
|
||||||
KMSClient,
|
KMSClient,
|
||||||
@@ -38,6 +43,7 @@ const CONFIG = {
|
|||||||
awsAccessKeyId: "ACCESSKEY",
|
awsAccessKeyId: "ACCESSKEY",
|
||||||
awsSecretAccessKey: "SECRETKEY",
|
awsSecretAccessKey: "SECRETKEY",
|
||||||
awsEndpoint: "http://127.0.0.1:4566",
|
awsEndpoint: "http://127.0.0.1:4566",
|
||||||
|
dynamodbEndpoint: "http://127.0.0.1:4566",
|
||||||
awsRegion: "us-east-1",
|
awsRegion: "us-east-1",
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -66,7 +72,6 @@ class S3Bucket {
|
|||||||
} catch {
|
} catch {
|
||||||
// It's fine if the bucket doesn't exist
|
// It's fine if the bucket doesn't exist
|
||||||
}
|
}
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
await client.send(new CreateBucketCommand({ Bucket: name }));
|
await client.send(new CreateBucketCommand({ Bucket: name }));
|
||||||
return new S3Bucket(name);
|
return new S3Bucket(name);
|
||||||
}
|
}
|
||||||
@@ -79,32 +84,27 @@ class S3Bucket {
|
|||||||
static async deleteBucket(client: S3Client, name: string) {
|
static async deleteBucket(client: S3Client, name: string) {
|
||||||
// Must delete all objects before we can delete the bucket
|
// Must delete all objects before we can delete the bucket
|
||||||
const objects = await client.send(
|
const objects = await client.send(
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
new ListObjectsV2Command({ Bucket: name }),
|
new ListObjectsV2Command({ Bucket: name }),
|
||||||
);
|
);
|
||||||
if (objects.Contents) {
|
if (objects.Contents) {
|
||||||
for (const object of objects.Contents) {
|
for (const object of objects.Contents) {
|
||||||
await client.send(
|
await client.send(
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
new DeleteObjectCommand({ Bucket: name, Key: object.Key }),
|
new DeleteObjectCommand({ Bucket: name, Key: object.Key }),
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
await client.send(new DeleteBucketCommand({ Bucket: name }));
|
await client.send(new DeleteBucketCommand({ Bucket: name }));
|
||||||
}
|
}
|
||||||
|
|
||||||
public async assertAllEncrypted(path: string, keyId: string) {
|
public async assertAllEncrypted(path: string, keyId: string) {
|
||||||
const client = S3Bucket.s3Client();
|
const client = S3Bucket.s3Client();
|
||||||
const objects = await client.send(
|
const objects = await client.send(
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
new ListObjectsV2Command({ Bucket: this.name, Prefix: path }),
|
new ListObjectsV2Command({ Bucket: this.name, Prefix: path }),
|
||||||
);
|
);
|
||||||
if (objects.Contents) {
|
if (objects.Contents) {
|
||||||
for (const object of objects.Contents) {
|
for (const object of objects.Contents) {
|
||||||
const metadata = await client.send(
|
const metadata = await client.send(
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
new HeadObjectCommand({ Bucket: this.name, Key: object.Key }),
|
new HeadObjectCommand({ Bucket: this.name, Key: object.Key }),
|
||||||
);
|
);
|
||||||
expect(metadata.ServerSideEncryption).toBe("aws:kms");
|
expect(metadata.ServerSideEncryption).toBe("aws:kms");
|
||||||
@@ -143,7 +143,6 @@ class KmsKey {
|
|||||||
|
|
||||||
public async delete() {
|
public async delete() {
|
||||||
const client = KmsKey.kmsClient();
|
const client = KmsKey.kmsClient();
|
||||||
// biome-ignore lint/style/useNamingConvention: we dont control s3's api
|
|
||||||
await client.send(new ScheduleKeyDeletionCommand({ KeyId: this.keyId }));
|
await client.send(new ScheduleKeyDeletionCommand({ KeyId: this.keyId }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -224,3 +223,91 @@ maybeDescribe("storage_options", () => {
|
|||||||
await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId);
|
await bucket.assertAllEncrypted("test/table2.lance", kmsKey.keyId);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
class DynamoDBCommitTable {
|
||||||
|
name: string;
|
||||||
|
constructor(name: string) {
|
||||||
|
this.name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
static dynamoClient() {
|
||||||
|
return new DynamoDBClient({
|
||||||
|
region: CONFIG.awsRegion,
|
||||||
|
credentials: {
|
||||||
|
accessKeyId: CONFIG.awsAccessKeyId,
|
||||||
|
secretAccessKey: CONFIG.awsSecretAccessKey,
|
||||||
|
},
|
||||||
|
endpoint: CONFIG.awsEndpoint,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
public static async create(name: string): Promise<DynamoDBCommitTable> {
|
||||||
|
const client = DynamoDBCommitTable.dynamoClient();
|
||||||
|
const command = new CreateTableCommand({
|
||||||
|
TableName: name,
|
||||||
|
AttributeDefinitions: [
|
||||||
|
{
|
||||||
|
AttributeName: "base_uri",
|
||||||
|
AttributeType: "S",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
AttributeName: "version",
|
||||||
|
AttributeType: "N",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
KeySchema: [
|
||||||
|
{ AttributeName: "base_uri", KeyType: "HASH" },
|
||||||
|
{ AttributeName: "version", KeyType: "RANGE" },
|
||||||
|
],
|
||||||
|
ProvisionedThroughput: {
|
||||||
|
ReadCapacityUnits: 1,
|
||||||
|
WriteCapacityUnits: 1,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
await client.send(command);
|
||||||
|
return new DynamoDBCommitTable(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async delete() {
|
||||||
|
const client = DynamoDBCommitTable.dynamoClient();
|
||||||
|
await client.send(new DeleteTableCommand({ TableName: this.name }));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maybeDescribe("DynamoDB Lock", () => {
|
||||||
|
let bucket: S3Bucket;
|
||||||
|
let commitTable: DynamoDBCommitTable;
|
||||||
|
|
||||||
|
beforeAll(async () => {
|
||||||
|
bucket = await S3Bucket.create("lancedb2");
|
||||||
|
commitTable = await DynamoDBCommitTable.create("commitTable");
|
||||||
|
});
|
||||||
|
|
||||||
|
afterAll(async () => {
|
||||||
|
await commitTable.delete();
|
||||||
|
await bucket.delete();
|
||||||
|
});
|
||||||
|
|
||||||
|
it("can be used to configure a DynamoDB table for commit log", async () => {
|
||||||
|
const uri = `s3+ddb://${bucket.name}/test?ddbTableName=${commitTable.name}`;
|
||||||
|
const db = await connect(uri, {
|
||||||
|
storageOptions: CONFIG,
|
||||||
|
readConsistencyInterval: 0,
|
||||||
|
});
|
||||||
|
|
||||||
|
const table = await db.createTable("test", [{ a: 1, b: 2 }]);
|
||||||
|
|
||||||
|
// 5 concurrent appends
|
||||||
|
const futs = Array.from({ length: 5 }, async () => {
|
||||||
|
// Open a table so each append has a separate table reference. Otherwise
|
||||||
|
// they will share the same table reference and the internal ReadWriteLock
|
||||||
|
// will prevent any real concurrency.
|
||||||
|
const table = await db.openTable("test");
|
||||||
|
await table.add([{ a: 2, b: 3 }]);
|
||||||
|
});
|
||||||
|
await Promise.all(futs);
|
||||||
|
|
||||||
|
const rowCount = await table.countRows();
|
||||||
|
expect(rowCount).toBe(6);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|||||||
@@ -39,7 +39,9 @@ describe.each([arrow, arrowOld])("Given a table", (arrow: any) => {
|
|||||||
let tmpDir: tmp.DirResult;
|
let tmpDir: tmp.DirResult;
|
||||||
let table: Table;
|
let table: Table;
|
||||||
|
|
||||||
const schema = new arrow.Schema([
|
const schema:
|
||||||
|
| import("apache-arrow").Schema
|
||||||
|
| import("apache-arrow-old").Schema = new arrow.Schema([
|
||||||
new arrow.Field("id", new arrow.Float64(), true),
|
new arrow.Field("id", new arrow.Float64(), true),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
@@ -315,7 +317,7 @@ describe("When creating an index", () => {
|
|||||||
.query()
|
.query()
|
||||||
.limit(2)
|
.limit(2)
|
||||||
.nearestTo(queryVec)
|
.nearestTo(queryVec)
|
||||||
.distanceType("DoT")
|
.distanceType("dot")
|
||||||
.toArrow();
|
.toArrow();
|
||||||
expect(rst.numRows).toBe(2);
|
expect(rst.numRows).toBe(2);
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"$schema": "https://biomejs.dev/schemas/1.7.3/schema.json",
|
"$schema": "https://biomejs.dev/schemas/1.8.3/schema.json",
|
||||||
"organizeImports": {
|
"organizeImports": {
|
||||||
"enabled": true
|
"enabled": true
|
||||||
},
|
},
|
||||||
@@ -100,6 +100,16 @@
|
|||||||
"globals": []
|
"globals": []
|
||||||
},
|
},
|
||||||
"overrides": [
|
"overrides": [
|
||||||
|
{
|
||||||
|
"include": ["__test__/s3_integration.test.ts"],
|
||||||
|
"linter": {
|
||||||
|
"rules": {
|
||||||
|
"style": {
|
||||||
|
"useNamingConvention": "off"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"include": [
|
"include": [
|
||||||
"**/*.ts",
|
"**/*.ts",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
import {
|
import {
|
||||||
Table as ArrowTable,
|
Table as ArrowTable,
|
||||||
Binary,
|
Binary,
|
||||||
|
BufferType,
|
||||||
DataType,
|
DataType,
|
||||||
Field,
|
Field,
|
||||||
FixedSizeBinary,
|
FixedSizeBinary,
|
||||||
@@ -37,14 +38,68 @@ import {
|
|||||||
type makeTable,
|
type makeTable,
|
||||||
vectorFromArray,
|
vectorFromArray,
|
||||||
} from "apache-arrow";
|
} from "apache-arrow";
|
||||||
|
import { Buffers } from "apache-arrow/data";
|
||||||
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
import { type EmbeddingFunction } from "./embedding/embedding_function";
|
||||||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||||
import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
|
import {
|
||||||
|
sanitizeField,
|
||||||
|
sanitizeSchema,
|
||||||
|
sanitizeTable,
|
||||||
|
sanitizeType,
|
||||||
|
} from "./sanitize";
|
||||||
export * from "apache-arrow";
|
export * from "apache-arrow";
|
||||||
|
export type SchemaLike =
|
||||||
|
| Schema
|
||||||
|
| {
|
||||||
|
fields: FieldLike[];
|
||||||
|
metadata: Map<string, string>;
|
||||||
|
get names(): unknown[];
|
||||||
|
};
|
||||||
|
export type FieldLike =
|
||||||
|
| Field
|
||||||
|
| {
|
||||||
|
type: string;
|
||||||
|
name: string;
|
||||||
|
nullable?: boolean;
|
||||||
|
metadata?: Map<string, string>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type DataLike =
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
| import("apache-arrow").Data<Struct<any>>
|
||||||
|
| {
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
type: any;
|
||||||
|
length: number;
|
||||||
|
offset: number;
|
||||||
|
stride: number;
|
||||||
|
nullable: boolean;
|
||||||
|
children: DataLike[];
|
||||||
|
get nullCount(): number;
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
values: Buffers<any>[BufferType.DATA];
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
typeIds: Buffers<any>[BufferType.TYPE];
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
nullBitmap: Buffers<any>[BufferType.VALIDITY];
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
valueOffsets: Buffers<any>[BufferType.OFFSET];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type RecordBatchLike =
|
||||||
|
| RecordBatch
|
||||||
|
| {
|
||||||
|
schema: SchemaLike;
|
||||||
|
data: DataLike;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type TableLike =
|
||||||
|
| ArrowTable
|
||||||
|
| { schema: SchemaLike; batches: RecordBatchLike[] };
|
||||||
|
|
||||||
export type IntoVector = Float32Array | Float64Array | number[];
|
export type IntoVector = Float32Array | Float64Array | number[];
|
||||||
|
|
||||||
export function isArrowTable(value: object): value is ArrowTable {
|
export function isArrowTable(value: object): value is TableLike {
|
||||||
if (value instanceof ArrowTable) return true;
|
if (value instanceof ArrowTable) return true;
|
||||||
return "schema" in value && "batches" in value;
|
return "schema" in value && "batches" in value;
|
||||||
}
|
}
|
||||||
@@ -135,7 +190,7 @@ export function isFixedSizeList(value: unknown): value is FixedSizeList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Data type accepted by NodeJS SDK */
|
/** Data type accepted by NodeJS SDK */
|
||||||
export type Data = Record<string, unknown>[] | ArrowTable;
|
export type Data = Record<string, unknown>[] | TableLike;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Options to control how a column should be converted to a vector array
|
* Options to control how a column should be converted to a vector array
|
||||||
@@ -162,7 +217,7 @@ export class MakeArrowTableOptions {
|
|||||||
* The schema must be specified if there are no records (e.g. to make
|
* The schema must be specified if there are no records (e.g. to make
|
||||||
* an empty table)
|
* an empty table)
|
||||||
*/
|
*/
|
||||||
schema?: Schema;
|
schema?: SchemaLike;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Mapping from vector column name to expected type
|
* Mapping from vector column name to expected type
|
||||||
@@ -310,7 +365,7 @@ export function makeArrowTable(
|
|||||||
if (opt.schema !== undefined && opt.schema !== null) {
|
if (opt.schema !== undefined && opt.schema !== null) {
|
||||||
opt.schema = sanitizeSchema(opt.schema);
|
opt.schema = sanitizeSchema(opt.schema);
|
||||||
opt.schema = validateSchemaEmbeddings(
|
opt.schema = validateSchemaEmbeddings(
|
||||||
opt.schema,
|
opt.schema as Schema,
|
||||||
data,
|
data,
|
||||||
options?.embeddingFunction,
|
options?.embeddingFunction,
|
||||||
);
|
);
|
||||||
@@ -394,7 +449,7 @@ export function makeArrowTable(
|
|||||||
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
// `new ArrowTable(schema, batches)` which does not do any schema inference
|
||||||
const firstTable = new ArrowTable(columns);
|
const firstTable = new ArrowTable(columns);
|
||||||
const batchesFixed = firstTable.batches.map(
|
const batchesFixed = firstTable.batches.map(
|
||||||
(batch) => new RecordBatch(opt.schema!, batch.data),
|
(batch) => new RecordBatch(opt.schema as Schema, batch.data),
|
||||||
);
|
);
|
||||||
let schema: Schema;
|
let schema: Schema;
|
||||||
if (metadata !== undefined) {
|
if (metadata !== undefined) {
|
||||||
@@ -407,9 +462,9 @@ export function makeArrowTable(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
schema = new Schema(opt.schema.fields, schemaMetadata);
|
schema = new Schema(opt.schema.fields as Field[], schemaMetadata);
|
||||||
} else {
|
} else {
|
||||||
schema = opt.schema;
|
schema = opt.schema as Schema;
|
||||||
}
|
}
|
||||||
return new ArrowTable(schema, batchesFixed);
|
return new ArrowTable(schema, batchesFixed);
|
||||||
}
|
}
|
||||||
@@ -425,7 +480,7 @@ export function makeArrowTable(
|
|||||||
* Create an empty Arrow table with the provided schema
|
* Create an empty Arrow table with the provided schema
|
||||||
*/
|
*/
|
||||||
export function makeEmptyTable(
|
export function makeEmptyTable(
|
||||||
schema: Schema,
|
schema: SchemaLike,
|
||||||
metadata?: Map<string, string>,
|
metadata?: Map<string, string>,
|
||||||
): ArrowTable {
|
): ArrowTable {
|
||||||
return makeArrowTable([], { schema }, metadata);
|
return makeArrowTable([], { schema }, metadata);
|
||||||
@@ -563,17 +618,16 @@ async function applyEmbeddingsFromMetadata(
|
|||||||
async function applyEmbeddings<T>(
|
async function applyEmbeddings<T>(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunctionConfig,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: SchemaLike,
|
||||||
): Promise<ArrowTable> {
|
): Promise<ArrowTable> {
|
||||||
if (schema?.metadata.has("embedding_functions")) {
|
|
||||||
return applyEmbeddingsFromMetadata(table, schema!);
|
|
||||||
} else if (embeddings == null || embeddings === undefined) {
|
|
||||||
return table;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (schema !== undefined && schema !== null) {
|
if (schema !== undefined && schema !== null) {
|
||||||
schema = sanitizeSchema(schema);
|
schema = sanitizeSchema(schema);
|
||||||
}
|
}
|
||||||
|
if (schema?.metadata.has("embedding_functions")) {
|
||||||
|
return applyEmbeddingsFromMetadata(table, schema! as Schema);
|
||||||
|
} else if (embeddings == null || embeddings === undefined) {
|
||||||
|
return table;
|
||||||
|
}
|
||||||
|
|
||||||
// Convert from ArrowTable to Record<String, Vector>
|
// Convert from ArrowTable to Record<String, Vector>
|
||||||
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
const colEntries = [...Array(table.numCols).keys()].map((_, idx) => {
|
||||||
@@ -650,7 +704,7 @@ async function applyEmbeddings<T>(
|
|||||||
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
|
`When using embedding functions and specifying a schema the schema should include the embedding column but the column ${destColumn} was missing`,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return alignTable(newTable, schema);
|
return alignTable(newTable, schema as Schema);
|
||||||
}
|
}
|
||||||
return newTable;
|
return newTable;
|
||||||
}
|
}
|
||||||
@@ -744,7 +798,7 @@ export async function fromRecordsToStreamBuffer(
|
|||||||
export async function fromTableToBuffer(
|
export async function fromTableToBuffer(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunctionConfig,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: SchemaLike,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
if (schema !== undefined && schema !== null) {
|
if (schema !== undefined && schema !== null) {
|
||||||
schema = sanitizeSchema(schema);
|
schema = sanitizeSchema(schema);
|
||||||
@@ -771,7 +825,7 @@ export async function fromDataToBuffer(
|
|||||||
schema = sanitizeSchema(schema);
|
schema = sanitizeSchema(schema);
|
||||||
}
|
}
|
||||||
if (isArrowTable(data)) {
|
if (isArrowTable(data)) {
|
||||||
return fromTableToBuffer(data, embeddings, schema);
|
return fromTableToBuffer(sanitizeTable(data), embeddings, schema);
|
||||||
} else {
|
} else {
|
||||||
const table = await convertToTable(data, embeddings, { schema });
|
const table = await convertToTable(data, embeddings, { schema });
|
||||||
return fromTableToBuffer(table);
|
return fromTableToBuffer(table);
|
||||||
@@ -789,7 +843,7 @@ export async function fromDataToBuffer(
|
|||||||
export async function fromTableToStreamBuffer(
|
export async function fromTableToStreamBuffer(
|
||||||
table: ArrowTable,
|
table: ArrowTable,
|
||||||
embeddings?: EmbeddingFunctionConfig,
|
embeddings?: EmbeddingFunctionConfig,
|
||||||
schema?: Schema,
|
schema?: SchemaLike,
|
||||||
): Promise<Buffer> {
|
): Promise<Buffer> {
|
||||||
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
|
const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
|
||||||
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings);
|
const writer = RecordBatchStreamWriter.writeAll(tableWithEmbeddings);
|
||||||
@@ -854,7 +908,6 @@ function validateSchemaEmbeddings(
|
|||||||
for (let field of schema.fields) {
|
for (let field of schema.fields) {
|
||||||
if (isFixedSizeList(field.type)) {
|
if (isFixedSizeList(field.type)) {
|
||||||
field = sanitizeField(field);
|
field = sanitizeField(field);
|
||||||
|
|
||||||
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
|
if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
|
||||||
if (schema.metadata.has("embedding_functions")) {
|
if (schema.metadata.has("embedding_functions")) {
|
||||||
const embeddings = JSON.parse(
|
const embeddings = JSON.parse(
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import { Table as ArrowTable, Data, Schema } from "./arrow";
|
import { Data, Schema, SchemaLike, TableLike } from "./arrow";
|
||||||
import { fromTableToBuffer, makeEmptyTable } from "./arrow";
|
import { fromTableToBuffer, makeEmptyTable } from "./arrow";
|
||||||
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
|
||||||
import { Connection as LanceDbConnection } from "./native";
|
import { Connection as LanceDbConnection } from "./native";
|
||||||
@@ -50,7 +50,7 @@ export interface CreateTableOptions {
|
|||||||
* The default is true while the new format is in beta
|
* The default is true while the new format is in beta
|
||||||
*/
|
*/
|
||||||
useLegacyFormat?: boolean;
|
useLegacyFormat?: boolean;
|
||||||
schema?: Schema;
|
schema?: SchemaLike;
|
||||||
embeddingFunction?: EmbeddingFunctionConfig;
|
embeddingFunction?: EmbeddingFunctionConfig;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,12 +167,12 @@ export abstract class Connection {
|
|||||||
/**
|
/**
|
||||||
* Creates a new Table and initialize it with new data.
|
* Creates a new Table and initialize it with new data.
|
||||||
* @param {string} name - The name of the table.
|
* @param {string} name - The name of the table.
|
||||||
* @param {Record<string, unknown>[] | ArrowTable} data - Non-empty Array of Records
|
* @param {Record<string, unknown>[] | TableLike} data - Non-empty Array of Records
|
||||||
* to be inserted into the table
|
* to be inserted into the table
|
||||||
*/
|
*/
|
||||||
abstract createTable(
|
abstract createTable(
|
||||||
name: string,
|
name: string,
|
||||||
data: Record<string, unknown>[] | ArrowTable,
|
data: Record<string, unknown>[] | TableLike,
|
||||||
options?: Partial<CreateTableOptions>,
|
options?: Partial<CreateTableOptions>,
|
||||||
): Promise<Table>;
|
): Promise<Table>;
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ export abstract class Connection {
|
|||||||
*/
|
*/
|
||||||
abstract createEmptyTable(
|
abstract createEmptyTable(
|
||||||
name: string,
|
name: string,
|
||||||
schema: Schema,
|
schema: import("./arrow").SchemaLike,
|
||||||
options?: Partial<CreateTableOptions>,
|
options?: Partial<CreateTableOptions>,
|
||||||
): Promise<Table>;
|
): Promise<Table>;
|
||||||
|
|
||||||
@@ -235,7 +235,7 @@ export class LocalConnection extends Connection {
|
|||||||
nameOrOptions:
|
nameOrOptions:
|
||||||
| string
|
| string
|
||||||
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
|
| ({ name: string; data: Data } & Partial<CreateTableOptions>),
|
||||||
data?: Record<string, unknown>[] | ArrowTable,
|
data?: Record<string, unknown>[] | TableLike,
|
||||||
options?: Partial<CreateTableOptions>,
|
options?: Partial<CreateTableOptions>,
|
||||||
): Promise<Table> {
|
): Promise<Table> {
|
||||||
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
|
if (typeof nameOrOptions !== "string" && "name" in nameOrOptions) {
|
||||||
@@ -259,7 +259,7 @@ export class LocalConnection extends Connection {
|
|||||||
|
|
||||||
async createEmptyTable(
|
async createEmptyTable(
|
||||||
name: string,
|
name: string,
|
||||||
schema: Schema,
|
schema: import("./arrow").SchemaLike,
|
||||||
options?: Partial<CreateTableOptions>,
|
options?: Partial<CreateTableOptions>,
|
||||||
): Promise<Table> {
|
): Promise<Table> {
|
||||||
let mode: string = options?.mode ?? "create";
|
let mode: string = options?.mode ?? "create";
|
||||||
|
|||||||
@@ -35,6 +35,11 @@ export interface FunctionOptions {
|
|||||||
[key: string]: any;
|
[key: string]: any;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface EmbeddingFunctionConstructor<
|
||||||
|
T extends EmbeddingFunction = EmbeddingFunction,
|
||||||
|
> {
|
||||||
|
new (modelOptions?: T["TOptions"]): T;
|
||||||
|
}
|
||||||
/**
|
/**
|
||||||
* An embedding function that automatically creates vector representation for a given column.
|
* An embedding function that automatically creates vector representation for a given column.
|
||||||
*/
|
*/
|
||||||
@@ -43,6 +48,12 @@ export abstract class EmbeddingFunction<
|
|||||||
T = any,
|
T = any,
|
||||||
M extends FunctionOptions = FunctionOptions,
|
M extends FunctionOptions = FunctionOptions,
|
||||||
> {
|
> {
|
||||||
|
/**
|
||||||
|
* @ignore
|
||||||
|
* This is only used for associating the options type with the class for type checking
|
||||||
|
*/
|
||||||
|
// biome-ignore lint/style/useNamingConvention: we want to keep the name as it is
|
||||||
|
readonly TOptions!: M;
|
||||||
/**
|
/**
|
||||||
* Convert the embedding function to a JSON object
|
* Convert the embedding function to a JSON object
|
||||||
* It is used to serialize the embedding function to the schema
|
* It is used to serialize the embedding function to the schema
|
||||||
|
|||||||
@@ -13,24 +13,29 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import type OpenAI from "openai";
|
import type OpenAI from "openai";
|
||||||
|
import { type EmbeddingCreateParams } from "openai/resources";
|
||||||
import { Float, Float32 } from "../arrow";
|
import { Float, Float32 } from "../arrow";
|
||||||
import { EmbeddingFunction } from "./embedding_function";
|
import { EmbeddingFunction } from "./embedding_function";
|
||||||
import { register } from "./registry";
|
import { register } from "./registry";
|
||||||
|
|
||||||
export type OpenAIOptions = {
|
export type OpenAIOptions = {
|
||||||
apiKey?: string;
|
apiKey: string;
|
||||||
model?: string;
|
model: EmbeddingCreateParams["model"];
|
||||||
};
|
};
|
||||||
|
|
||||||
@register("openai")
|
@register("openai")
|
||||||
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
||||||
string,
|
string,
|
||||||
OpenAIOptions
|
Partial<OpenAIOptions>
|
||||||
> {
|
> {
|
||||||
#openai: OpenAI;
|
#openai: OpenAI;
|
||||||
#modelName: string;
|
#modelName: OpenAIOptions["model"];
|
||||||
|
|
||||||
constructor(options: OpenAIOptions = { model: "text-embedding-ada-002" }) {
|
constructor(
|
||||||
|
options: Partial<OpenAIOptions> = {
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
) {
|
||||||
super();
|
super();
|
||||||
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
const openAIKey = options?.apiKey ?? process.env.OPENAI_API_KEY;
|
||||||
if (!openAIKey) {
|
if (!openAIKey) {
|
||||||
@@ -73,7 +78,7 @@ export class OpenAIEmbeddingFunction extends EmbeddingFunction<
|
|||||||
case "text-embedding-3-small":
|
case "text-embedding-3-small":
|
||||||
return 1536;
|
return 1536;
|
||||||
default:
|
default:
|
||||||
return null as never;
|
throw new Error(`Unknown model: ${this.#modelName}`);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,21 +12,15 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
import type { EmbeddingFunction } from "./embedding_function";
|
import {
|
||||||
|
type EmbeddingFunction,
|
||||||
|
type EmbeddingFunctionConstructor,
|
||||||
|
} from "./embedding_function";
|
||||||
import "reflect-metadata";
|
import "reflect-metadata";
|
||||||
|
import { OpenAIEmbeddingFunction } from "./openai";
|
||||||
export interface EmbeddingFunctionOptions {
|
|
||||||
[key: string]: unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface EmbeddingFunctionFactory<
|
|
||||||
T extends EmbeddingFunction = EmbeddingFunction,
|
|
||||||
> {
|
|
||||||
new (modelOptions?: EmbeddingFunctionOptions): T;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
||||||
create(options?: EmbeddingFunctionOptions): T;
|
create(options?: T["TOptions"]): T;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -36,7 +30,7 @@ interface EmbeddingFunctionCreate<T extends EmbeddingFunction> {
|
|||||||
* or TextEmbeddingFunction and registering it with the registry
|
* or TextEmbeddingFunction and registering it with the registry
|
||||||
*/
|
*/
|
||||||
export class EmbeddingFunctionRegistry {
|
export class EmbeddingFunctionRegistry {
|
||||||
#functions: Map<string, EmbeddingFunctionFactory> = new Map();
|
#functions = new Map<string, EmbeddingFunctionConstructor>();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Register an embedding function
|
* Register an embedding function
|
||||||
@@ -44,7 +38,9 @@ export class EmbeddingFunctionRegistry {
|
|||||||
* @param func The function to register
|
* @param func The function to register
|
||||||
* @throws Error if the function is already registered
|
* @throws Error if the function is already registered
|
||||||
*/
|
*/
|
||||||
register<T extends EmbeddingFunctionFactory = EmbeddingFunctionFactory>(
|
register<
|
||||||
|
T extends EmbeddingFunctionConstructor = EmbeddingFunctionConstructor,
|
||||||
|
>(
|
||||||
this: EmbeddingFunctionRegistry,
|
this: EmbeddingFunctionRegistry,
|
||||||
alias?: string,
|
alias?: string,
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
@@ -69,18 +65,34 @@ export class EmbeddingFunctionRegistry {
|
|||||||
* Fetch an embedding function by name
|
* Fetch an embedding function by name
|
||||||
* @param name The name of the function
|
* @param name The name of the function
|
||||||
*/
|
*/
|
||||||
get<T extends EmbeddingFunction<unknown> = EmbeddingFunction>(
|
get<T extends EmbeddingFunction<unknown>, Name extends string = "">(
|
||||||
name: string,
|
name: Name extends "openai" ? "openai" : string,
|
||||||
): EmbeddingFunctionCreate<T> | undefined {
|
//This makes it so that you can use string constants as "types", or use an explicitly supplied type
|
||||||
|
// ex:
|
||||||
|
// `registry.get("openai") -> EmbeddingFunctionCreate<OpenAIEmbeddingFunction>`
|
||||||
|
// `registry.get<MyCustomEmbeddingFunction>("my_func") -> EmbeddingFunctionCreate<MyCustomEmbeddingFunction> | undefined`
|
||||||
|
//
|
||||||
|
// the reason this is important is that we always know our built in functions are defined so the user isnt forced to do a non null/undefined
|
||||||
|
// ```ts
|
||||||
|
// const openai: OpenAIEmbeddingFunction = registry.get("openai").create()
|
||||||
|
// ```
|
||||||
|
): Name extends "openai"
|
||||||
|
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||||
|
: EmbeddingFunctionCreate<T> | undefined {
|
||||||
|
type Output = Name extends "openai"
|
||||||
|
? EmbeddingFunctionCreate<OpenAIEmbeddingFunction>
|
||||||
|
: EmbeddingFunctionCreate<T> | undefined;
|
||||||
|
|
||||||
const factory = this.#functions.get(name);
|
const factory = this.#functions.get(name);
|
||||||
if (!factory) {
|
if (!factory) {
|
||||||
return undefined;
|
return undefined as Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
create: function (options: EmbeddingFunctionOptions) {
|
create: function (options?: T["TOptions"]) {
|
||||||
return new factory(options) as unknown as T;
|
return new factory(options);
|
||||||
},
|
},
|
||||||
};
|
} as Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -104,7 +116,7 @@ export class EmbeddingFunctionRegistry {
|
|||||||
name: string;
|
name: string;
|
||||||
sourceColumn: string;
|
sourceColumn: string;
|
||||||
vectorColumn: string;
|
vectorColumn: string;
|
||||||
model: EmbeddingFunctionOptions;
|
model: EmbeddingFunction["TOptions"];
|
||||||
};
|
};
|
||||||
const functions = <FunctionConfig[]>(
|
const functions = <FunctionConfig[]>(
|
||||||
JSON.parse(metadata.get("embedding_functions")!)
|
JSON.parse(metadata.get("embedding_functions")!)
|
||||||
|
|||||||
@@ -300,7 +300,9 @@ export class VectorQuery extends QueryBase<NativeVectorQuery, VectorQuery> {
|
|||||||
*
|
*
|
||||||
* By default "l2" is used.
|
* By default "l2" is used.
|
||||||
*/
|
*/
|
||||||
distanceType(distanceType: string): VectorQuery {
|
distanceType(
|
||||||
|
distanceType: Required<IvfPqOptions>["distanceType"],
|
||||||
|
): VectorQuery {
|
||||||
this.inner.distanceType(distanceType);
|
this.inner.distanceType(distanceType);
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ export class RestfulLanceDBClient {
|
|||||||
return axios.create({
|
return axios.create({
|
||||||
baseURL: this.url,
|
baseURL: this.url,
|
||||||
headers: {
|
headers: {
|
||||||
// biome-ignore lint/style/useNamingConvention: external api
|
// biome-ignore lint: external API
|
||||||
Authorization: `Bearer ${this.#apiKey}`,
|
Authorization: `Bearer ${this.#apiKey}`,
|
||||||
},
|
},
|
||||||
transformResponse: decodeErrorData,
|
transformResponse: decodeErrorData,
|
||||||
|
|||||||
@@ -1,5 +1,10 @@
|
|||||||
import { Schema } from "apache-arrow";
|
import { Schema } from "apache-arrow";
|
||||||
import { Data, fromTableToStreamBuffer, makeEmptyTable } from "../arrow";
|
import {
|
||||||
|
Data,
|
||||||
|
SchemaLike,
|
||||||
|
fromTableToStreamBuffer,
|
||||||
|
makeEmptyTable,
|
||||||
|
} from "../arrow";
|
||||||
import {
|
import {
|
||||||
Connection,
|
Connection,
|
||||||
CreateTableOptions,
|
CreateTableOptions,
|
||||||
@@ -156,7 +161,7 @@ export class RemoteConnection extends Connection {
|
|||||||
|
|
||||||
async createEmptyTable(
|
async createEmptyTable(
|
||||||
name: string,
|
name: string,
|
||||||
schema: Schema,
|
schema: SchemaLike,
|
||||||
options?: Partial<CreateTableOptions> | undefined,
|
options?: Partial<CreateTableOptions> | undefined,
|
||||||
): Promise<Table> {
|
): Promise<Table> {
|
||||||
if (options?.mode) {
|
if (options?.mode) {
|
||||||
|
|||||||
@@ -20,10 +20,12 @@
|
|||||||
// comes from the exact same library instance. This is not always the case
|
// comes from the exact same library instance. This is not always the case
|
||||||
// and so we must sanitize the input to ensure that it is compatible.
|
// and so we must sanitize the input to ensure that it is compatible.
|
||||||
|
|
||||||
|
import { BufferType, Data } from "apache-arrow";
|
||||||
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
|
import type { IntBitWidth, TKeys, TimeBitWidth } from "apache-arrow/type";
|
||||||
import {
|
import {
|
||||||
Binary,
|
Binary,
|
||||||
Bool,
|
Bool,
|
||||||
|
DataLike,
|
||||||
DataType,
|
DataType,
|
||||||
DateDay,
|
DateDay,
|
||||||
DateMillisecond,
|
DateMillisecond,
|
||||||
@@ -56,9 +58,14 @@ import {
|
|||||||
Map_,
|
Map_,
|
||||||
Null,
|
Null,
|
||||||
type Precision,
|
type Precision,
|
||||||
|
RecordBatch,
|
||||||
|
RecordBatchLike,
|
||||||
Schema,
|
Schema,
|
||||||
|
SchemaLike,
|
||||||
SparseUnion,
|
SparseUnion,
|
||||||
Struct,
|
Struct,
|
||||||
|
Table,
|
||||||
|
TableLike,
|
||||||
Time,
|
Time,
|
||||||
TimeMicrosecond,
|
TimeMicrosecond,
|
||||||
TimeMillisecond,
|
TimeMillisecond,
|
||||||
@@ -488,7 +495,7 @@ export function sanitizeField(fieldLike: unknown): Field {
|
|||||||
* instance because they might be using a different instance of apache-arrow
|
* instance because they might be using a different instance of apache-arrow
|
||||||
* than lancedb is using.
|
* than lancedb is using.
|
||||||
*/
|
*/
|
||||||
export function sanitizeSchema(schemaLike: unknown): Schema {
|
export function sanitizeSchema(schemaLike: SchemaLike): Schema {
|
||||||
if (schemaLike instanceof Schema) {
|
if (schemaLike instanceof Schema) {
|
||||||
return schemaLike;
|
return schemaLike;
|
||||||
}
|
}
|
||||||
@@ -514,3 +521,68 @@ export function sanitizeSchema(schemaLike: unknown): Schema {
|
|||||||
);
|
);
|
||||||
return new Schema(sanitizedFields, metadata);
|
return new Schema(sanitizedFields, metadata);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function sanitizeTable(tableLike: TableLike): Table {
|
||||||
|
if (tableLike instanceof Table) {
|
||||||
|
return tableLike;
|
||||||
|
}
|
||||||
|
if (typeof tableLike !== "object" || tableLike === null) {
|
||||||
|
throw Error("Expected a Table but object was null/undefined");
|
||||||
|
}
|
||||||
|
if (!("schema" in tableLike)) {
|
||||||
|
throw Error(
|
||||||
|
"The table passed in does not appear to be a table (no 'schema' property)",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!("batches" in tableLike)) {
|
||||||
|
throw Error(
|
||||||
|
"The table passed in does not appear to be a table (no 'columns' property)",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const schema = sanitizeSchema(tableLike.schema);
|
||||||
|
|
||||||
|
const batches = tableLike.batches.map(sanitizeRecordBatch);
|
||||||
|
return new Table(schema, batches);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sanitizeRecordBatch(batchLike: RecordBatchLike): RecordBatch {
|
||||||
|
if (batchLike instanceof RecordBatch) {
|
||||||
|
return batchLike;
|
||||||
|
}
|
||||||
|
if (typeof batchLike !== "object" || batchLike === null) {
|
||||||
|
throw Error("Expected a RecordBatch but object was null/undefined");
|
||||||
|
}
|
||||||
|
if (!("schema" in batchLike)) {
|
||||||
|
throw Error(
|
||||||
|
"The record batch passed in does not appear to be a record batch (no 'schema' property)",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if (!("data" in batchLike)) {
|
||||||
|
throw Error(
|
||||||
|
"The record batch passed in does not appear to be a record batch (no 'data' property)",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const schema = sanitizeSchema(batchLike.schema);
|
||||||
|
const data = sanitizeData(batchLike.data);
|
||||||
|
return new RecordBatch(schema, data);
|
||||||
|
}
|
||||||
|
function sanitizeData(
|
||||||
|
dataLike: DataLike,
|
||||||
|
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
||||||
|
): import("apache-arrow").Data<Struct<any>> {
|
||||||
|
if (dataLike instanceof Data) {
|
||||||
|
return dataLike;
|
||||||
|
}
|
||||||
|
return new Data(
|
||||||
|
dataLike.type,
|
||||||
|
dataLike.offset,
|
||||||
|
dataLike.length,
|
||||||
|
dataLike.nullCount,
|
||||||
|
{
|
||||||
|
[BufferType.OFFSET]: dataLike.valueOffsets,
|
||||||
|
[BufferType.DATA]: dataLike.values,
|
||||||
|
[BufferType.VALIDITY]: dataLike.nullBitmap,
|
||||||
|
[BufferType.TYPE]: dataLike.typeIds,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import {
|
|||||||
Data,
|
Data,
|
||||||
IntoVector,
|
IntoVector,
|
||||||
Schema,
|
Schema,
|
||||||
|
TableLike,
|
||||||
fromDataToBuffer,
|
fromDataToBuffer,
|
||||||
fromTableToBuffer,
|
fromTableToBuffer,
|
||||||
fromTableToStreamBuffer,
|
fromTableToStreamBuffer,
|
||||||
@@ -38,6 +39,8 @@ import {
|
|||||||
Table as _NativeTable,
|
Table as _NativeTable,
|
||||||
} from "./native";
|
} from "./native";
|
||||||
import { Query, VectorQuery } from "./query";
|
import { Query, VectorQuery } from "./query";
|
||||||
|
import { sanitizeTable } from "./sanitize";
|
||||||
|
export { IndexConfig } from "./native";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Options for adding data to a table.
|
* Options for adding data to a table.
|
||||||
@@ -381,8 +384,7 @@ export abstract class Table {
|
|||||||
abstract indexStats(name: string): Promise<IndexStatistics | undefined>;
|
abstract indexStats(name: string): Promise<IndexStatistics | undefined>;
|
||||||
|
|
||||||
static async parseTableData(
|
static async parseTableData(
|
||||||
// biome-ignore lint/suspicious/noExplicitAny: <explanation>
|
data: Record<string, unknown>[] | TableLike,
|
||||||
data: Record<string, unknown>[] | ArrowTable<any>,
|
|
||||||
options?: Partial<CreateTableOptions>,
|
options?: Partial<CreateTableOptions>,
|
||||||
streaming = false,
|
streaming = false,
|
||||||
) {
|
) {
|
||||||
@@ -395,9 +397,9 @@ export abstract class Table {
|
|||||||
|
|
||||||
let table: ArrowTable;
|
let table: ArrowTable;
|
||||||
if (isArrowTable(data)) {
|
if (isArrowTable(data)) {
|
||||||
table = data;
|
table = sanitizeTable(data);
|
||||||
} else {
|
} else {
|
||||||
table = makeArrowTable(data, options);
|
table = makeArrowTable(data as Record<string, unknown>[], options);
|
||||||
}
|
}
|
||||||
if (streaming) {
|
if (streaming) {
|
||||||
const buf = await fromTableToStreamBuffer(
|
const buf = await fromTableToStreamBuffer(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-arm64",
|
"name": "@lancedb/lancedb-darwin-arm64",
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.darwin-arm64.node",
|
"main": "lancedb.darwin-arm64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-darwin-x64",
|
"name": "@lancedb/lancedb-darwin-x64",
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"os": ["darwin"],
|
"os": ["darwin"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.darwin-x64.node",
|
"main": "lancedb.darwin-x64.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
"name": "@lancedb/lancedb-linux-arm64-gnu",
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["arm64"],
|
"cpu": ["arm64"],
|
||||||
"main": "lancedb.linux-arm64-gnu.node",
|
"main": "lancedb.linux-arm64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-linux-x64-gnu",
|
"name": "@lancedb/lancedb-linux-x64-gnu",
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"os": ["linux"],
|
"os": ["linux"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.linux-x64-gnu.node",
|
"main": "lancedb.linux-x64-gnu.node",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@lancedb/lancedb-win32-x64-msvc",
|
"name": "@lancedb/lancedb-win32-x64-msvc",
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"os": ["win32"],
|
"os": ["win32"],
|
||||||
"cpu": ["x64"],
|
"cpu": ["x64"],
|
||||||
"main": "lancedb.win32-x64-msvc.node",
|
"main": "lancedb.win32-x64-msvc.node",
|
||||||
|
|||||||
1403
nodejs/package-lock.json
generated
1403
nodejs/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -10,7 +10,7 @@
|
|||||||
"vector database",
|
"vector database",
|
||||||
"ann"
|
"ann"
|
||||||
],
|
],
|
||||||
"version": "0.5.2-final.1",
|
"version": "0.6.0",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"exports": {
|
"exports": {
|
||||||
".": "./dist/index.js",
|
".": "./dist/index.js",
|
||||||
@@ -34,9 +34,10 @@
|
|||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@aws-sdk/client-kms": "^3.33.0",
|
"@aws-sdk/client-kms": "^3.33.0",
|
||||||
"@aws-sdk/client-s3": "^3.33.0",
|
"@aws-sdk/client-s3": "^3.33.0",
|
||||||
|
"@aws-sdk/client-dynamodb": "^3.33.0",
|
||||||
"@biomejs/biome": "^1.7.3",
|
"@biomejs/biome": "^1.7.3",
|
||||||
"@jest/globals": "^29.7.0",
|
"@jest/globals": "^29.7.0",
|
||||||
"@napi-rs/cli": "^2.18.0",
|
"@napi-rs/cli": "^2.18.3",
|
||||||
"@types/jest": "^29.1.2",
|
"@types/jest": "^29.1.2",
|
||||||
"@types/tmp": "^0.2.6",
|
"@types/tmp": "^0.2.6",
|
||||||
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
"apache-arrow-old": "npm:apache-arrow@13.0.0",
|
||||||
@@ -68,7 +69,7 @@
|
|||||||
"lint-ci": "biome ci .",
|
"lint-ci": "biome ci .",
|
||||||
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
"docs": "typedoc --plugin typedoc-plugin-markdown --out ../docs/src/js lancedb/index.ts",
|
||||||
"lint": "biome check . && biome format .",
|
"lint": "biome check . && biome format .",
|
||||||
"lint-fix": "biome check --apply-unsafe . && biome format --write .",
|
"lint-fix": "biome check --write . && biome format --write .",
|
||||||
"prepublishOnly": "napi prepublish -t npm",
|
"prepublishOnly": "napi prepublish -t npm",
|
||||||
"test": "jest --verbose",
|
"test": "jest --verbose",
|
||||||
"integration": "S3_TEST=1 npm run test",
|
"integration": "S3_TEST=1 npm run test",
|
||||||
@@ -76,9 +77,13 @@
|
|||||||
"version": "napi version"
|
"version": "napi version"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"apache-arrow": "^15.0.0",
|
|
||||||
"axios": "^1.7.2",
|
"axios": "^1.7.2",
|
||||||
"openai": "^4.29.2",
|
|
||||||
"reflect-metadata": "^0.2.2"
|
"reflect-metadata": "^0.2.2"
|
||||||
|
},
|
||||||
|
"optionalDependencies": {
|
||||||
|
"openai": "^4.29.2"
|
||||||
|
},
|
||||||
|
"peerDependencies": {
|
||||||
|
"apache-arrow": "^15.0.0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// List all tables in the dataset.
|
/// List all tables in the dataset.
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn table_names(
|
pub async fn table_names(
|
||||||
&self,
|
&self,
|
||||||
start_after: Option<String>,
|
start_after: Option<String>,
|
||||||
@@ -113,7 +113,7 @@ impl Connection {
|
|||||||
/// - name: The name of the table.
|
/// - name: The name of the table.
|
||||||
/// - buf: The buffer containing the IPC file.
|
/// - buf: The buffer containing the IPC file.
|
||||||
///
|
///
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn create_table(
|
pub async fn create_table(
|
||||||
&self,
|
&self,
|
||||||
name: String,
|
name: String,
|
||||||
@@ -141,7 +141,7 @@ impl Connection {
|
|||||||
Ok(Table::new(tbl))
|
Ok(Table::new(tbl))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn create_empty_table(
|
pub async fn create_empty_table(
|
||||||
&self,
|
&self,
|
||||||
name: String,
|
name: String,
|
||||||
@@ -173,7 +173,7 @@ impl Connection {
|
|||||||
Ok(Table::new(tbl))
|
Ok(Table::new(tbl))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn open_table(
|
pub async fn open_table(
|
||||||
&self,
|
&self,
|
||||||
name: String,
|
name: String,
|
||||||
@@ -197,7 +197,7 @@ impl Connection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Drop table with the name. Or raise an error if the table does not exist.
|
/// Drop table with the name. Or raise an error if the table does not exist.
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
pub async fn drop_table(&self, name: String) -> napi::Result<()> {
|
||||||
self.get_inner()?
|
self.get_inner()?
|
||||||
.drop_table(&name)
|
.drop_table(&name)
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ impl RecordBatchIterator {
|
|||||||
Self { inner }
|
Self { inner }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
|
pub async unsafe fn next(&mut self) -> napi::Result<Option<Buffer>> {
|
||||||
if let Some(rst) = self.inner.next().await {
|
if let Some(rst) = self.inner.next().await {
|
||||||
let batch = rst.map_err(|e| {
|
let batch = rst.map_err(|e| {
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ impl NativeMergeInsertBuilder {
|
|||||||
this
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(&self, buf: Buffer) -> napi::Result<()> {
|
pub async fn execute(&self, buf: Buffer) -> napi::Result<()> {
|
||||||
let data = ipc_file_to_batches(buf.to_vec())
|
let data = ipc_file_to_batches(buf.to_vec())
|
||||||
.and_then(IntoArrow::into_arrow)
|
.and_then(IntoArrow::into_arrow)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ impl Query {
|
|||||||
Ok(VectorQuery { inner })
|
Ok(VectorQuery { inner })
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
max_batch_length: Option<u32>,
|
max_batch_length: Option<u32>,
|
||||||
@@ -136,7 +136,7 @@ impl VectorQuery {
|
|||||||
self.inner = self.inner.clone().limit(limit as usize);
|
self.inner = self.inner.clone().limit(limit as usize);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn execute(
|
pub async fn execute(
|
||||||
&self,
|
&self,
|
||||||
max_batch_length: Option<u32>,
|
max_batch_length: Option<u32>,
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ impl Table {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Return Schema as empty Arrow IPC file.
|
/// Return Schema as empty Arrow IPC file.
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn schema(&self) -> napi::Result<Buffer> {
|
pub async fn schema(&self) -> napi::Result<Buffer> {
|
||||||
let schema =
|
let schema =
|
||||||
self.inner_ref()?.schema().await.map_err(|e| {
|
self.inner_ref()?.schema().await.map_err(|e| {
|
||||||
@@ -86,7 +86,7 @@ impl Table {
|
|||||||
})?))
|
})?))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
|
pub async fn add(&self, buf: Buffer, mode: String) -> napi::Result<()> {
|
||||||
let batches = ipc_file_to_batches(buf.to_vec())
|
let batches = ipc_file_to_batches(buf.to_vec())
|
||||||
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
.map_err(|e| napi::Error::from_reason(format!("Failed to read IPC file: {}", e)))?;
|
||||||
@@ -108,7 +108,7 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> {
|
pub async fn count_rows(&self, filter: Option<String>) -> napi::Result<i64> {
|
||||||
self.inner_ref()?
|
self.inner_ref()?
|
||||||
.count_rows(filter)
|
.count_rows(filter)
|
||||||
@@ -122,7 +122,7 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn delete(&self, predicate: String) -> napi::Result<()> {
|
pub async fn delete(&self, predicate: String) -> napi::Result<()> {
|
||||||
self.inner_ref()?.delete(&predicate).await.map_err(|e| {
|
self.inner_ref()?.delete(&predicate).await.map_err(|e| {
|
||||||
napi::Error::from_reason(format!(
|
napi::Error::from_reason(format!(
|
||||||
@@ -132,7 +132,7 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn create_index(
|
pub async fn create_index(
|
||||||
&self,
|
&self,
|
||||||
index: Option<&Index>,
|
index: Option<&Index>,
|
||||||
@@ -151,7 +151,7 @@ impl Table {
|
|||||||
builder.execute().await.default_error()
|
builder.execute().await.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn update(
|
pub async fn update(
|
||||||
&self,
|
&self,
|
||||||
only_if: Option<String>,
|
only_if: Option<String>,
|
||||||
@@ -167,17 +167,17 @@ impl Table {
|
|||||||
op.execute().await.default_error()
|
op.execute().await.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub fn query(&self) -> napi::Result<Query> {
|
pub fn query(&self) -> napi::Result<Query> {
|
||||||
Ok(Query::new(self.inner_ref()?.query()))
|
Ok(Query::new(self.inner_ref()?.query()))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
|
pub fn vector_search(&self, vector: Float32Array) -> napi::Result<VectorQuery> {
|
||||||
self.query()?.nearest_to(vector)
|
self.query()?.nearest_to(vector)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
|
pub async fn add_columns(&self, transforms: Vec<AddColumnsSql>) -> napi::Result<()> {
|
||||||
let transforms = transforms
|
let transforms = transforms
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -196,7 +196,7 @@ impl Table {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn alter_columns(&self, alterations: Vec<ColumnAlteration>) -> napi::Result<()> {
|
pub async fn alter_columns(&self, alterations: Vec<ColumnAlteration>) -> napi::Result<()> {
|
||||||
for alteration in &alterations {
|
for alteration in &alterations {
|
||||||
if alteration.rename.is_none() && alteration.nullable.is_none() {
|
if alteration.rename.is_none() && alteration.nullable.is_none() {
|
||||||
@@ -222,7 +222,7 @@ impl Table {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> {
|
pub async fn drop_columns(&self, columns: Vec<String>) -> napi::Result<()> {
|
||||||
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
|
let col_refs = columns.iter().map(String::as_str).collect::<Vec<_>>();
|
||||||
self.inner_ref()?
|
self.inner_ref()?
|
||||||
@@ -237,7 +237,7 @@ impl Table {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn version(&self) -> napi::Result<i64> {
|
pub async fn version(&self) -> napi::Result<i64> {
|
||||||
self.inner_ref()?
|
self.inner_ref()?
|
||||||
.version()
|
.version()
|
||||||
@@ -246,7 +246,7 @@ impl Table {
|
|||||||
.default_error()
|
.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
|
pub async fn checkout(&self, version: i64) -> napi::Result<()> {
|
||||||
self.inner_ref()?
|
self.inner_ref()?
|
||||||
.checkout(version as u64)
|
.checkout(version as u64)
|
||||||
@@ -254,17 +254,17 @@ impl Table {
|
|||||||
.default_error()
|
.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn checkout_latest(&self) -> napi::Result<()> {
|
pub async fn checkout_latest(&self) -> napi::Result<()> {
|
||||||
self.inner_ref()?.checkout_latest().await.default_error()
|
self.inner_ref()?.checkout_latest().await.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn restore(&self) -> napi::Result<()> {
|
pub async fn restore(&self) -> napi::Result<()> {
|
||||||
self.inner_ref()?.restore().await.default_error()
|
self.inner_ref()?.restore().await.default_error()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn optimize(&self, older_than_ms: Option<i64>) -> napi::Result<OptimizeStats> {
|
pub async fn optimize(&self, older_than_ms: Option<i64>) -> napi::Result<OptimizeStats> {
|
||||||
let inner = self.inner_ref()?;
|
let inner = self.inner_ref()?;
|
||||||
|
|
||||||
@@ -318,7 +318,7 @@ impl Table {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
|
pub async fn list_indices(&self) -> napi::Result<Vec<IndexConfig>> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.inner_ref()?
|
.inner_ref()?
|
||||||
@@ -330,14 +330,14 @@ impl Table {
|
|||||||
.collect::<Vec<_>>())
|
.collect::<Vec<_>>())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub async fn index_stats(&self, index_name: String) -> napi::Result<Option<IndexStatistics>> {
|
pub async fn index_stats(&self, index_name: String) -> napi::Result<Option<IndexStatistics>> {
|
||||||
let tbl = self.inner_ref()?.as_native().unwrap();
|
let tbl = self.inner_ref()?.as_native().unwrap();
|
||||||
let stats = tbl.index_stats(&index_name).await.default_error()?;
|
let stats = tbl.index_stats(&index_name).await.default_error()?;
|
||||||
Ok(stats.map(IndexStatistics::from))
|
Ok(stats.map(IndexStatistics::from))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[napi]
|
#[napi(catch_unwind)]
|
||||||
pub fn merge_insert(&self, on: Vec<String>) -> napi::Result<NativeMergeInsertBuilder> {
|
pub fn merge_insert(&self, on: Vec<String>) -> napi::Result<NativeMergeInsertBuilder> {
|
||||||
let on: Vec<_> = on.iter().map(String::as_str).collect();
|
let on: Vec<_> = on.iter().map(String::as_str).collect();
|
||||||
Ok(self.inner_ref()?.merge_insert(on.as_slice()).into())
|
Ok(self.inner_ref()?.merge_insert(on.as_slice()).into())
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[tool.bumpversion]
|
[tool.bumpversion]
|
||||||
current_version = "0.9.0-beta.8"
|
current_version = "0.9.0"
|
||||||
parse = """(?x)
|
parse = """(?x)
|
||||||
(?P<major>0|[1-9]\\d*)\\.
|
(?P<major>0|[1-9]\\d*)\\.
|
||||||
(?P<minor>0|[1-9]\\d*)\\.
|
(?P<minor>0|[1-9]\\d*)\\.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-python"
|
name = "lancedb-python"
|
||||||
version = "0.9.0-beta.8"
|
version = "0.9.0"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "Python bindings for LanceDB"
|
description = "Python bindings for LanceDB"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -19,8 +19,6 @@ lancedb = { path = "../rust/lancedb" }
|
|||||||
env_logger = "0.10"
|
env_logger = "0.10"
|
||||||
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
pyo3 = { version = "0.20", features = ["extension-module", "abi3-py38"] }
|
||||||
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
pyo3-asyncio = { version = "0.20", features = ["attributes", "tokio-runtime"] }
|
||||||
base64ct = "=1.6.0" # workaround for https://github.com/RustCrypto/formats/issues/1684
|
|
||||||
chrono = "=0.4.39"
|
|
||||||
|
|
||||||
# Prevent dynamic linking of lzma, which comes from datafusion
|
# Prevent dynamic linking of lzma, which comes from datafusion
|
||||||
lzma-sys = { version = "*", features = ["static"] }
|
lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ dependencies = [
|
|||||||
"packaging",
|
"packaging",
|
||||||
"cachetools",
|
"cachetools",
|
||||||
"overrides>=0.7",
|
"overrides>=0.7",
|
||||||
"urllib3==1.26.19"
|
|
||||||
]
|
]
|
||||||
description = "lancedb"
|
description = "lancedb"
|
||||||
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
authors = [{ name = "LanceDB Devs", email = "dev@lancedb.com" }]
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ def connect(
|
|||||||
host_override: Optional[str] = None,
|
host_override: Optional[str] = None,
|
||||||
read_consistency_interval: Optional[timedelta] = None,
|
read_consistency_interval: Optional[timedelta] = None,
|
||||||
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
request_thread_pool: Optional[Union[int, ThreadPoolExecutor]] = None,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> DBConnection:
|
) -> DBConnection:
|
||||||
"""Connect to a LanceDB database.
|
"""Connect to a LanceDB database.
|
||||||
@@ -71,9 +70,6 @@ def connect(
|
|||||||
executor will be used for making requests. This is for LanceDB Cloud
|
executor will be used for making requests. This is for LanceDB Cloud
|
||||||
only and is only used when making batch requests (i.e., passing in
|
only and is only used when making batch requests (i.e., passing in
|
||||||
multiple queries to the search method at once).
|
multiple queries to the search method at once).
|
||||||
storage_options: dict, optional
|
|
||||||
Additional options for the storage backend. See available options at
|
|
||||||
https://lancedb.github.io/lancedb/guides/storage/
|
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
--------
|
--------
|
||||||
@@ -109,16 +105,12 @@ def connect(
|
|||||||
region,
|
region,
|
||||||
host_override,
|
host_override,
|
||||||
request_thread_pool=request_thread_pool,
|
request_thread_pool=request_thread_pool,
|
||||||
storage_options=storage_options,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if kwargs:
|
if kwargs:
|
||||||
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
raise ValueError(f"Unknown keyword arguments: {kwargs}")
|
||||||
return LanceDBConnection(
|
return LanceDBConnection(uri, read_consistency_interval=read_consistency_interval)
|
||||||
uri,
|
|
||||||
read_consistency_interval=read_consistency_interval,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def connect_async(
|
async def connect_async(
|
||||||
|
|||||||
@@ -28,12 +28,11 @@ from lancedb.common import data_to_reader, validate_schema
|
|||||||
|
|
||||||
from ._lancedb import connect as lancedb_connect
|
from ._lancedb import connect as lancedb_connect
|
||||||
from .pydantic import LanceModel
|
from .pydantic import LanceModel
|
||||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data
|
from .table import AsyncTable, LanceTable, Table, _sanitize_data, _table_path
|
||||||
from .util import (
|
from .util import (
|
||||||
fs_from_uri,
|
fs_from_uri,
|
||||||
get_uri_location,
|
get_uri_location,
|
||||||
get_uri_scheme,
|
get_uri_scheme,
|
||||||
join_uri,
|
|
||||||
validate_table_name,
|
validate_table_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -457,16 +456,18 @@ class LanceDBConnection(DBConnection):
|
|||||||
If True, ignore if the table does not exist.
|
If True, ignore if the table does not exist.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
filesystem, path = fs_from_uri(self.uri)
|
table_uri = _table_path(self.uri, name)
|
||||||
table_path = join_uri(path, name + ".lance")
|
filesystem, path = fs_from_uri(table_uri)
|
||||||
filesystem.delete_dir(table_path)
|
filesystem.delete_dir(path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
if not ignore_missing:
|
if not ignore_missing:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def drop_database(self):
|
def drop_database(self):
|
||||||
filesystem, path = fs_from_uri(self.uri)
|
dummy_table_uri = _table_path(self.uri, "dummy")
|
||||||
|
uri = dummy_table_uri.removesuffix("dummy.lance")
|
||||||
|
filesystem, path = fs_from_uri(uri)
|
||||||
filesystem.delete_dir(path)
|
filesystem.delete_dir(path)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -117,8 +117,6 @@ class Query(pydantic.BaseModel):
|
|||||||
|
|
||||||
with_row_id: bool = False
|
with_row_id: bool = False
|
||||||
|
|
||||||
fast_search: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class LanceQueryBuilder(ABC):
|
class LanceQueryBuilder(ABC):
|
||||||
"""An abstract query builder. Subclasses are defined for vector search,
|
"""An abstract query builder. Subclasses are defined for vector search,
|
||||||
@@ -127,14 +125,12 @@ class LanceQueryBuilder(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(
|
def create(
|
||||||
cls,
|
cls,
|
||||||
table: "Table",
|
table: "Table",
|
||||||
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
query: Optional[Union[np.ndarray, str, "PIL.Image.Image", Tuple]],
|
||||||
query_type: str,
|
query_type: str,
|
||||||
vector_column_name: str,
|
vector_column_name: str,
|
||||||
ordering_field_name: Optional[str] = None,
|
ordering_field_name: str = None,
|
||||||
fts_columns: Union[str, List[str]] = [],
|
|
||||||
fast_search: bool = False,
|
|
||||||
) -> LanceQueryBuilder:
|
) -> LanceQueryBuilder:
|
||||||
"""
|
"""
|
||||||
Create a query builder based on the given query and query type.
|
Create a query builder based on the given query and query type.
|
||||||
@@ -151,19 +147,14 @@ class LanceQueryBuilder(ABC):
|
|||||||
If "auto", the query type is inferred based on the query.
|
If "auto", the query type is inferred based on the query.
|
||||||
vector_column_name: str
|
vector_column_name: str
|
||||||
The name of the vector column to use for vector search.
|
The name of the vector column to use for vector search.
|
||||||
fast_search: bool
|
|
||||||
Skip flat search of unindexed data.
|
|
||||||
"""
|
"""
|
||||||
# Check hybrid search first as it supports empty query pattern
|
|
||||||
if query_type == "hybrid":
|
|
||||||
# hybrid fts and vector query
|
|
||||||
return LanceHybridQueryBuilder(
|
|
||||||
table, query, vector_column_name, fts_columns=fts_columns
|
|
||||||
)
|
|
||||||
|
|
||||||
if query is None:
|
if query is None:
|
||||||
return LanceEmptyQueryBuilder(table)
|
return LanceEmptyQueryBuilder(table)
|
||||||
|
|
||||||
|
if query_type == "hybrid":
|
||||||
|
# hybrid fts and vector query
|
||||||
|
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||||
|
|
||||||
# remember the string query for reranking purpose
|
# remember the string query for reranking purpose
|
||||||
str_query = query if isinstance(query, str) else None
|
str_query = query if isinstance(query, str) else None
|
||||||
|
|
||||||
@@ -174,17 +165,12 @@ class LanceQueryBuilder(ABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if query_type == "hybrid":
|
if query_type == "hybrid":
|
||||||
return LanceHybridQueryBuilder(
|
return LanceHybridQueryBuilder(table, query, vector_column_name)
|
||||||
table, query, vector_column_name, fts_columns=fts_columns
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(query, str):
|
if isinstance(query, str):
|
||||||
# fts
|
# fts
|
||||||
return LanceFtsQueryBuilder(
|
return LanceFtsQueryBuilder(
|
||||||
table,
|
table, query, ordering_field_name=ordering_field_name
|
||||||
query,
|
|
||||||
ordering_field_name=ordering_field_name,
|
|
||||||
fts_columns=fts_columns,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(query, list):
|
if isinstance(query, list):
|
||||||
@@ -194,9 +180,7 @@ class LanceQueryBuilder(ABC):
|
|||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported query type: {type(query)}")
|
raise TypeError(f"Unsupported query type: {type(query)}")
|
||||||
|
|
||||||
return LanceVectorQueryBuilder(
|
return LanceVectorQueryBuilder(table, query, vector_column_name, str_query)
|
||||||
table, query, vector_column_name, str_query, fast_search
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
def _resolve_query(cls, table, query, query_type, vector_column_name):
|
||||||
@@ -212,6 +196,8 @@ class LanceQueryBuilder(ABC):
|
|||||||
elif query_type == "auto":
|
elif query_type == "auto":
|
||||||
if isinstance(query, (list, np.ndarray)):
|
if isinstance(query, (list, np.ndarray)):
|
||||||
return query, "vector"
|
return query, "vector"
|
||||||
|
if isinstance(query, tuple):
|
||||||
|
return query, "hybrid"
|
||||||
else:
|
else:
|
||||||
conf = table.embedding_functions.get(vector_column_name)
|
conf = table.embedding_functions.get(vector_column_name)
|
||||||
if conf is not None:
|
if conf is not None:
|
||||||
@@ -238,14 +224,9 @@ class LanceQueryBuilder(ABC):
|
|||||||
def __init__(self, table: "Table"):
|
def __init__(self, table: "Table"):
|
||||||
self._table = table
|
self._table = table
|
||||||
self._limit = 10
|
self._limit = 10
|
||||||
self._offset = 0
|
|
||||||
self._columns = None
|
self._columns = None
|
||||||
self._where = None
|
self._where = None
|
||||||
self._prefilter = False
|
|
||||||
self._with_row_id = False
|
self._with_row_id = False
|
||||||
self._vector = None
|
|
||||||
self._text = None
|
|
||||||
self._ef = None
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
@deprecation.deprecated(
|
||||||
deprecated_in="0.3.1",
|
deprecated_in="0.3.1",
|
||||||
@@ -356,13 +337,11 @@ class LanceQueryBuilder(ABC):
|
|||||||
----------
|
----------
|
||||||
limit: int
|
limit: int
|
||||||
The maximum number of results to return.
|
The maximum number of results to return.
|
||||||
The default query limit is 10 results.
|
By default the query is limited to the first 10.
|
||||||
For ANN/KNN queries, you must specify a limit.
|
Call this method and pass 0, a negative value,
|
||||||
Entering 0, a negative number, or None will reset
|
or None to remove the limit.
|
||||||
the limit to the default value of 10.
|
*WARNING* if you have a large dataset, removing
|
||||||
*WARNING* if you have a large dataset, setting
|
the limit can potentially result in reading a
|
||||||
the limit to a large number, e.g. the table size,
|
|
||||||
can potentially result in reading a
|
|
||||||
large amount of data into memory and cause
|
large amount of data into memory and cause
|
||||||
out of memory issues.
|
out of memory issues.
|
||||||
|
|
||||||
@@ -372,33 +351,11 @@ class LanceQueryBuilder(ABC):
|
|||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
if limit is None or limit <= 0:
|
if limit is None or limit <= 0:
|
||||||
if isinstance(self, LanceVectorQueryBuilder):
|
self._limit = None
|
||||||
raise ValueError("Limit is required for ANN/KNN queries")
|
|
||||||
else:
|
|
||||||
self._limit = None
|
|
||||||
else:
|
else:
|
||||||
self._limit = limit
|
self._limit = limit
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def offset(self, offset: int) -> LanceQueryBuilder:
|
|
||||||
"""Set the offset for the results.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
offset: int
|
|
||||||
The offset to start fetching results from.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LanceQueryBuilder
|
|
||||||
The LanceQueryBuilder object.
|
|
||||||
"""
|
|
||||||
if offset is None or offset <= 0:
|
|
||||||
self._offset = 0
|
|
||||||
else:
|
|
||||||
self._offset = offset
|
|
||||||
return self
|
|
||||||
|
|
||||||
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
|
def select(self, columns: Union[list[str], dict[str, str]]) -> LanceQueryBuilder:
|
||||||
"""Set the columns to return.
|
"""Set the columns to return.
|
||||||
|
|
||||||
@@ -460,80 +417,6 @@ class LanceQueryBuilder(ABC):
|
|||||||
self._with_row_id = with_row_id
|
self._with_row_id = with_row_id
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def explain_plan(self, verbose: Optional[bool] = False) -> str:
|
|
||||||
"""Return the execution plan for this query.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> import lancedb
|
|
||||||
>>> db = lancedb.connect("./.lancedb")
|
|
||||||
>>> table = db.create_table("my_table", [{"vector": [99, 99]}])
|
|
||||||
>>> query = [100, 100]
|
|
||||||
>>> plan = table.search(query).explain_plan(True)
|
|
||||||
>>> print(plan) # doctest: +ELLIPSIS, +NORMALIZE_WHITESPACE
|
|
||||||
ProjectionExec: expr=[vector@0 as vector, _distance@2 as _distance]
|
|
||||||
GlobalLimitExec: skip=0, fetch=10
|
|
||||||
FilterExec: _distance@2 IS NOT NULL
|
|
||||||
SortExec: TopK(fetch=10), expr=[_distance@2 ASC NULLS LAST], preserve_partitioning=[false]
|
|
||||||
KNNVectorDistance: metric=l2
|
|
||||||
LanceScan: uri=..., projection=[vector], row_id=true, row_addr=false, ordered=false
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
verbose : bool, default False
|
|
||||||
Use a verbose output format.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
plan : str
|
|
||||||
""" # noqa: E501
|
|
||||||
ds = self._table.to_lance()
|
|
||||||
return ds.scanner(
|
|
||||||
nearest={
|
|
||||||
"column": self._vector_column,
|
|
||||||
"q": self._query,
|
|
||||||
"k": self._limit,
|
|
||||||
"metric": self._metric,
|
|
||||||
"nprobes": self._nprobes,
|
|
||||||
"refine_factor": self._refine_factor,
|
|
||||||
},
|
|
||||||
prefilter=self._prefilter,
|
|
||||||
filter=self._str_query,
|
|
||||||
limit=self._limit,
|
|
||||||
with_row_id=self._with_row_id,
|
|
||||||
offset=self._offset,
|
|
||||||
).explain_plan(verbose)
|
|
||||||
|
|
||||||
def vector(self, vector: Union[np.ndarray, list]) -> LanceQueryBuilder:
|
|
||||||
"""Set the vector to search for.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
vector: np.ndarray or list
|
|
||||||
The vector to search for.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LanceQueryBuilder
|
|
||||||
The LanceQueryBuilder object.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def text(self, text: str) -> LanceQueryBuilder:
|
|
||||||
"""Set the text to search for.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
text: str
|
|
||||||
The text to search for.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LanceQueryBuilder
|
|
||||||
The LanceQueryBuilder object.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
class LanceVectorQueryBuilder(LanceQueryBuilder):
|
||||||
"""
|
"""
|
||||||
@@ -557,12 +440,11 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
table: "Table",
|
table: "Table",
|
||||||
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
query: Union[np.ndarray, list, "PIL.Image.Image"],
|
||||||
vector_column: str,
|
vector_column: str,
|
||||||
str_query: Optional[str] = None,
|
str_query: Optional[str] = None,
|
||||||
fast_search: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(table)
|
super().__init__(table)
|
||||||
self._query = query
|
self._query = query
|
||||||
@@ -573,14 +455,13 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._prefilter = False
|
self._prefilter = False
|
||||||
self._reranker = None
|
self._reranker = None
|
||||||
self._str_query = str_query
|
self._str_query = str_query
|
||||||
self._fast_search = fast_search
|
|
||||||
|
|
||||||
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceVectorQueryBuilder:
|
def metric(self, metric: Literal["L2", "cosine"]) -> LanceVectorQueryBuilder:
|
||||||
"""Set the distance metric to use.
|
"""Set the distance metric to use.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
metric: "L2" or "cosine" or "dot"
|
metric: "L2" or "cosine"
|
||||||
The distance metric to use. By default "L2" is used.
|
The distance metric to use. By default "L2" is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -588,7 +469,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
LanceVectorQueryBuilder
|
LanceVectorQueryBuilder
|
||||||
The LanceQueryBuilder object.
|
The LanceQueryBuilder object.
|
||||||
"""
|
"""
|
||||||
self._metric = metric.lower()
|
self._metric = metric
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
|
||||||
@@ -613,28 +494,6 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
self._nprobes = nprobes
|
self._nprobes = nprobes
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def ef(self, ef: int) -> LanceVectorQueryBuilder:
|
|
||||||
"""Set the number of candidates to consider during search.
|
|
||||||
|
|
||||||
Higher values will yield better recall (more likely to find vectors if
|
|
||||||
they exist) at the expense of latency.
|
|
||||||
|
|
||||||
This only applies to the HNSW-related index.
|
|
||||||
The default value is 1.5 * limit.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
ef: int
|
|
||||||
The number of candidates to consider during search.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LanceVectorQueryBuilder
|
|
||||||
The LanceQueryBuilder object.
|
|
||||||
"""
|
|
||||||
self._ef = ef
|
|
||||||
return self
|
|
||||||
|
|
||||||
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
|
||||||
"""Set the refine factor to use, increasing the number of vectors sampled.
|
"""Set the refine factor to use, increasing the number of vectors sampled.
|
||||||
|
|
||||||
@@ -695,11 +554,15 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
refine_factor=self._refine_factor,
|
refine_factor=self._refine_factor,
|
||||||
vector_column=self._vector_column,
|
vector_column=self._vector_column,
|
||||||
with_row_id=self._with_row_id,
|
with_row_id=self._with_row_id,
|
||||||
offset=self._offset,
|
|
||||||
fast_search=self._fast_search,
|
|
||||||
ef=self._ef,
|
|
||||||
)
|
)
|
||||||
result_set = self._table._execute_query(query, batch_size)
|
result_set = self._table._execute_query(query, batch_size)
|
||||||
|
if self._reranker is not None:
|
||||||
|
rs_table = result_set.read_all()
|
||||||
|
result_set = self._reranker.rerank_vector(self._str_query, rs_table)
|
||||||
|
# convert result_set back to RecordBatchReader
|
||||||
|
result_set = pa.RecordBatchReader.from_batches(
|
||||||
|
result_set.schema, result_set.to_batches()
|
||||||
|
)
|
||||||
|
|
||||||
return result_set
|
return result_set
|
||||||
|
|
||||||
@@ -728,7 +591,7 @@ class LanceVectorQueryBuilder(LanceQueryBuilder):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
def rerank(
|
def rerank(
|
||||||
self, reranker: Reranker, query_string: Optional[str] = None
|
self, reranker: Reranker, query_string: Optional[str] = None
|
||||||
) -> LanceVectorQueryBuilder:
|
) -> LanceVectorQueryBuilder:
|
||||||
"""Rerank the results using the specified reranker.
|
"""Rerank the results using the specified reranker.
|
||||||
|
|
||||||
@@ -893,34 +756,12 @@ class LanceFtsQueryBuilder(LanceQueryBuilder):
|
|||||||
|
|
||||||
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
class LanceEmptyQueryBuilder(LanceQueryBuilder):
|
||||||
def to_arrow(self) -> pa.Table:
|
def to_arrow(self) -> pa.Table:
|
||||||
return self.to_batches().read_all()
|
ds = self._table.to_lance()
|
||||||
|
return ds.to_table(
|
||||||
def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReader:
|
|
||||||
query = Query(
|
|
||||||
columns=self._columns,
|
columns=self._columns,
|
||||||
filter=self._where,
|
filter=self._where,
|
||||||
k=self._limit or 10,
|
limit=self._limit,
|
||||||
with_row_id=self._with_row_id,
|
|
||||||
vector=[],
|
|
||||||
# not actually respected in remote query
|
|
||||||
offset=self._offset or 0,
|
|
||||||
)
|
)
|
||||||
return self._table._execute_query(query)
|
|
||||||
|
|
||||||
def rerank(self, reranker: Reranker) -> LanceEmptyQueryBuilder:
|
|
||||||
"""Rerank the results using the specified reranker.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
reranker: Reranker
|
|
||||||
The reranker to use.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LanceEmptyQueryBuilder
|
|
||||||
The LanceQueryBuilder object.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError("Reranking is not yet supported.")
|
|
||||||
|
|
||||||
|
|
||||||
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
class LanceHybridQueryBuilder(LanceQueryBuilder):
|
||||||
|
|||||||
@@ -55,13 +55,11 @@ class RestfulLanceDBClient:
|
|||||||
region: str
|
region: str
|
||||||
api_key: Credential
|
api_key: Credential
|
||||||
host_override: Optional[str] = attrs.field(default=None)
|
host_override: Optional[str] = attrs.field(default=None)
|
||||||
db_prefix: Optional[str] = attrs.field(default=None)
|
|
||||||
|
|
||||||
closed: bool = attrs.field(default=False, init=False)
|
closed: bool = attrs.field(default=False, init=False)
|
||||||
|
|
||||||
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
connection_timeout: float = attrs.field(default=120.0, kw_only=True)
|
||||||
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
read_timeout: float = attrs.field(default=300.0, kw_only=True)
|
||||||
storage_options: Optional[Dict[str, str]] = attrs.field(default=None, kw_only=True)
|
|
||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def session(self) -> requests.Session:
|
def session(self) -> requests.Session:
|
||||||
@@ -94,18 +92,6 @@ class RestfulLanceDBClient:
|
|||||||
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
headers["Host"] = f"{self.db_name}.{self.region}.api.lancedb.com"
|
||||||
if self.host_override:
|
if self.host_override:
|
||||||
headers["x-lancedb-database"] = self.db_name
|
headers["x-lancedb-database"] = self.db_name
|
||||||
if self.storage_options:
|
|
||||||
if self.storage_options.get("account_name") is not None:
|
|
||||||
headers["x-azure-storage-account-name"] = self.storage_options[
|
|
||||||
"account_name"
|
|
||||||
]
|
|
||||||
if self.storage_options.get("azure_storage_account_name") is not None:
|
|
||||||
headers["x-azure-storage-account-name"] = self.storage_options[
|
|
||||||
"azure_storage_account_name"
|
|
||||||
]
|
|
||||||
if self.db_prefix:
|
|
||||||
headers["x-lancedb-database-prefix"] = self.db_prefix
|
|
||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -172,7 +158,6 @@ class RestfulLanceDBClient:
|
|||||||
headers["content-type"] = content_type
|
headers["content-type"] = content_type
|
||||||
if request_id is not None:
|
if request_id is not None:
|
||||||
headers["x-request-id"] = request_id
|
headers["x-request-id"] = request_id
|
||||||
|
|
||||||
with self.session.post(
|
with self.session.post(
|
||||||
urljoin(self.url, uri),
|
urljoin(self.url, uri),
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@@ -260,6 +245,7 @@ def retry_adapter(options: Dict[str, Any]) -> HTTPAdapter:
|
|||||||
connect=connect_retries,
|
connect=connect_retries,
|
||||||
read=read_retries,
|
read=read_retries,
|
||||||
backoff_factor=backoff_factor,
|
backoff_factor=backoff_factor,
|
||||||
|
backoff_jitter=backoff_jitter,
|
||||||
status_forcelist=statuses,
|
status_forcelist=statuses,
|
||||||
allowed_methods=methods,
|
allowed_methods=methods,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import inspect
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Dict, Iterable, List, Optional, Union
|
from typing import Iterable, List, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
@@ -44,25 +44,20 @@ class RemoteDBConnection(DBConnection):
|
|||||||
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
request_thread_pool: Optional[ThreadPoolExecutor] = None,
|
||||||
connection_timeout: float = 120.0,
|
connection_timeout: float = 120.0,
|
||||||
read_timeout: float = 300.0,
|
read_timeout: float = 300.0,
|
||||||
storage_options: Optional[Dict[str, str]] = None,
|
|
||||||
):
|
):
|
||||||
"""Connect to a remote LanceDB database."""
|
"""Connect to a remote LanceDB database."""
|
||||||
parsed = urlparse(db_url)
|
parsed = urlparse(db_url)
|
||||||
if parsed.scheme != "db":
|
if parsed.scheme != "db":
|
||||||
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
raise ValueError(f"Invalid scheme: {parsed.scheme}, only accepts db://")
|
||||||
self.db_name = parsed.netloc
|
self.db_name = parsed.netloc
|
||||||
prefix = parsed.path.lstrip("/")
|
|
||||||
self.db_prefix = None if not prefix else prefix
|
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self._client = RestfulLanceDBClient(
|
self._client = RestfulLanceDBClient(
|
||||||
self.db_name,
|
self.db_name,
|
||||||
region,
|
region,
|
||||||
api_key,
|
api_key,
|
||||||
host_override,
|
host_override,
|
||||||
self.db_prefix,
|
|
||||||
connection_timeout=connection_timeout,
|
connection_timeout=connection_timeout,
|
||||||
read_timeout=read_timeout,
|
read_timeout=read_timeout,
|
||||||
storage_options=storage_options,
|
|
||||||
)
|
)
|
||||||
self._request_thread_pool = request_thread_pool
|
self._request_thread_pool = request_thread_pool
|
||||||
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
self._table_cache = TTLCache(maxsize=10000, ttl=300)
|
||||||
|
|||||||
@@ -15,14 +15,13 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Dict, Iterable, Optional, Union, Literal
|
from typing import Dict, Iterable, Optional, Union
|
||||||
|
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
from lance import json_to_schema
|
from lance import json_to_schema
|
||||||
|
|
||||||
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
from lancedb.common import DATA, VEC, VECTOR_COLUMN_NAME
|
||||||
from lancedb.merge import LanceMergeInsertBuilder
|
from lancedb.merge import LanceMergeInsertBuilder
|
||||||
from lancedb.query import LanceQueryBuilder
|
|
||||||
|
|
||||||
from ..query import LanceVectorQueryBuilder
|
from ..query import LanceVectorQueryBuilder
|
||||||
from ..table import Query, Table, _sanitize_data
|
from ..table import Query, Table, _sanitize_data
|
||||||
@@ -82,7 +81,6 @@ class RemoteTable(Table):
|
|||||||
def create_scalar_index(
|
def create_scalar_index(
|
||||||
self,
|
self,
|
||||||
column: str,
|
column: str,
|
||||||
index_type: Literal["BTREE", "BITMAP", "LABEL_LIST", "scalar"] = "scalar",
|
|
||||||
):
|
):
|
||||||
"""Creates a scalar index
|
"""Creates a scalar index
|
||||||
Parameters
|
Parameters
|
||||||
@@ -91,6 +89,8 @@ class RemoteTable(Table):
|
|||||||
The column to be indexed. Must be a boolean, integer, float,
|
The column to be indexed. Must be a boolean, integer, float,
|
||||||
or string column.
|
or string column.
|
||||||
"""
|
"""
|
||||||
|
index_type = "scalar"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"column": column,
|
"column": column,
|
||||||
"index_type": index_type,
|
"index_type": index_type,
|
||||||
@@ -228,21 +228,10 @@ class RemoteTable(Table):
|
|||||||
content_type=ARROW_STREAM_CONTENT_TYPE,
|
content_type=ARROW_STREAM_CONTENT_TYPE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def query(
|
|
||||||
self,
|
|
||||||
query: Union[VEC, str] = None,
|
|
||||||
query_type: str = "vector",
|
|
||||||
vector_column_name: Optional[str] = None,
|
|
||||||
fast_search: bool = False,
|
|
||||||
) -> LanceVectorQueryBuilder:
|
|
||||||
return self.search(query, query_type, vector_column_name, fast_search)
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
query: Union[VEC, str] = None,
|
query: Union[VEC, str],
|
||||||
query_type: str = "vector",
|
|
||||||
vector_column_name: Optional[str] = None,
|
vector_column_name: Optional[str] = None,
|
||||||
fast_search: bool = False,
|
|
||||||
) -> LanceVectorQueryBuilder:
|
) -> LanceVectorQueryBuilder:
|
||||||
"""Create a search query to find the nearest neighbors
|
"""Create a search query to find the nearest neighbors
|
||||||
of the given query vector. We currently support [vector search][search]
|
of the given query vector. We currently support [vector search][search]
|
||||||
@@ -289,11 +278,6 @@ class RemoteTable(Table):
|
|||||||
- If the table has multiple vector columns then the *vector_column_name*
|
- If the table has multiple vector columns then the *vector_column_name*
|
||||||
needs to be specified. Otherwise, an error is raised.
|
needs to be specified. Otherwise, an error is raised.
|
||||||
|
|
||||||
fast_search: bool, optional
|
|
||||||
Skip a flat search of unindexed data. This may improve
|
|
||||||
search performance but search results will not include unindexed data.
|
|
||||||
|
|
||||||
- *default False*.
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
LanceQueryBuilder
|
LanceQueryBuilder
|
||||||
@@ -309,14 +293,7 @@ class RemoteTable(Table):
|
|||||||
"""
|
"""
|
||||||
if vector_column_name is None:
|
if vector_column_name is None:
|
||||||
vector_column_name = inf_vector_column_query(self.schema)
|
vector_column_name = inf_vector_column_query(self.schema)
|
||||||
|
return LanceVectorQueryBuilder(self, query, vector_column_name)
|
||||||
return LanceQueryBuilder.create(
|
|
||||||
self,
|
|
||||||
query,
|
|
||||||
query_type,
|
|
||||||
vector_column_name=vector_column_name,
|
|
||||||
fast_search=fast_search,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _execute_query(
|
def _execute_query(
|
||||||
self, query: Query, batch_size: Optional[int] = None
|
self, query: Query, batch_size: Optional[int] = None
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from .colbert import ColbertReranker
|
|||||||
from .cross_encoder import CrossEncoderReranker
|
from .cross_encoder import CrossEncoderReranker
|
||||||
from .linear_combination import LinearCombinationReranker
|
from .linear_combination import LinearCombinationReranker
|
||||||
from .openai import OpenaiReranker
|
from .openai import OpenaiReranker
|
||||||
|
from .jina import JinaReranker
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Reranker",
|
"Reranker",
|
||||||
@@ -12,4 +13,5 @@ __all__ = [
|
|||||||
"LinearCombinationReranker",
|
"LinearCombinationReranker",
|
||||||
"OpenaiReranker",
|
"OpenaiReranker",
|
||||||
"ColbertReranker",
|
"ColbertReranker",
|
||||||
|
"JinaReranker",
|
||||||
]
|
]
|
||||||
|
|||||||
103
python/python/lancedb/rerankers/jina.py
Normal file
103
python/python/lancedb/rerankers/jina.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
from functools import cached_property
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import pyarrow as pa
|
||||||
|
|
||||||
|
from ..util import attempt_import_or_raise
|
||||||
|
from .base import Reranker
|
||||||
|
|
||||||
|
|
||||||
|
class JinaReranker(Reranker):
|
||||||
|
"""
|
||||||
|
Reranks the results using Jina reranker model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_name : str, default "jinaai/jina-reranker-v1-turbo-en"
|
||||||
|
The name of the reranker to use. For all models, see
|
||||||
|
https://huggingface.co/jinaai/jina-reranker-v1-turbo-en
|
||||||
|
column : str, default "text"
|
||||||
|
The name of the column to use as input to the cross encoder model.
|
||||||
|
device : str, default None
|
||||||
|
The device to use for the cross encoder model. If None, will use "cuda"
|
||||||
|
if available, otherwise "cpu".
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = "jinaai/jina-reranker-v1-turbo-en",
|
||||||
|
column: str = "text",
|
||||||
|
device: Union[str, None] = None,
|
||||||
|
return_score="relevance",
|
||||||
|
):
|
||||||
|
super().__init__(return_score)
|
||||||
|
torch = attempt_import_or_raise("torch")
|
||||||
|
self.model_name = model_name
|
||||||
|
self.column = column
|
||||||
|
self.device = device
|
||||||
|
if self.device is None:
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def model(self):
|
||||||
|
transformers = attempt_import_or_raise("transformers")
|
||||||
|
model = transformers.AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
self.model_name, num_labels=1, trust_remote_code=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def _rerank(self, result_set: pa.Table, query: str):
|
||||||
|
passages = result_set[self.column].to_pylist()
|
||||||
|
cross_inp = [[query, passage] for passage in passages]
|
||||||
|
cross_scores = self.model.compute_score(cross_inp)
|
||||||
|
result_set = result_set.append_column(
|
||||||
|
"_relevance_score", pa.array(cross_scores, type=pa.float32())
|
||||||
|
)
|
||||||
|
|
||||||
|
return result_set
|
||||||
|
|
||||||
|
def rerank_hybrid(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
combined_results = self.merge_results(vector_results, fts_results)
|
||||||
|
combined_results = self._rerank(combined_results, query)
|
||||||
|
# sort the results by _score
|
||||||
|
if self.score == "relevance":
|
||||||
|
combined_results = combined_results.drop_columns(["score", "_distance"])
|
||||||
|
elif self.score == "all":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"return_score='all' not implemented for CrossEncoderReranker"
|
||||||
|
)
|
||||||
|
combined_results = combined_results.sort_by(
|
||||||
|
[("_relevance_score", "descending")]
|
||||||
|
)
|
||||||
|
|
||||||
|
return combined_results
|
||||||
|
|
||||||
|
def rerank_vector(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
vector_results: pa.Table,
|
||||||
|
):
|
||||||
|
vector_results = self._rerank(vector_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
vector_results = vector_results.drop_columns(["_distance"])
|
||||||
|
|
||||||
|
vector_results = vector_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
return vector_results
|
||||||
|
|
||||||
|
def rerank_fts(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
fts_results: pa.Table,
|
||||||
|
):
|
||||||
|
fts_results = self._rerank(fts_results, query)
|
||||||
|
if self.score == "relevance":
|
||||||
|
fts_results = fts_results.drop_columns(["score"])
|
||||||
|
|
||||||
|
fts_results = fts_results.sort_by([("_relevance_score", "descending")])
|
||||||
|
return fts_results
|
||||||
@@ -30,6 +30,7 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import lance
|
import lance
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -47,6 +48,7 @@ from .pydantic import LanceModel, model_to_dict
|
|||||||
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
|
from .query import AsyncQuery, AsyncVectorQuery, LanceQueryBuilder, Query
|
||||||
from .util import (
|
from .util import (
|
||||||
fs_from_uri,
|
fs_from_uri,
|
||||||
|
get_uri_scheme,
|
||||||
inf_vector_column_query,
|
inf_vector_column_query,
|
||||||
join_uri,
|
join_uri,
|
||||||
safe_import_pandas,
|
safe_import_pandas,
|
||||||
@@ -208,6 +210,26 @@ def _to_record_batch_generator(
|
|||||||
yield b
|
yield b
|
||||||
|
|
||||||
|
|
||||||
|
def _table_path(base: str, table_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a table path that can be used in PyArrow FS.
|
||||||
|
|
||||||
|
Removes any weird schemes (such as "s3+ddb") and drops any query params.
|
||||||
|
"""
|
||||||
|
uri = _table_uri(base, table_name)
|
||||||
|
# Parse as URL
|
||||||
|
parsed = urlparse(uri)
|
||||||
|
# If scheme is s3+ddb, convert to s3
|
||||||
|
if parsed.scheme == "s3+ddb":
|
||||||
|
parsed = parsed._replace(scheme="s3")
|
||||||
|
# Remove query parameters
|
||||||
|
return parsed._replace(query=None).geturl()
|
||||||
|
|
||||||
|
|
||||||
|
def _table_uri(base: str, table_name: str) -> str:
|
||||||
|
return join_uri(base, f"{table_name}.lance")
|
||||||
|
|
||||||
|
|
||||||
class Table(ABC):
|
class Table(ABC):
|
||||||
"""
|
"""
|
||||||
A Table is a collection of Records in a LanceDB Database.
|
A Table is a collection of Records in a LanceDB Database.
|
||||||
@@ -908,7 +930,7 @@ class LanceTable(Table):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def open(cls, db, name, **kwargs):
|
def open(cls, db, name, **kwargs):
|
||||||
tbl = cls(db, name, **kwargs)
|
tbl = cls(db, name, **kwargs)
|
||||||
fs, path = fs_from_uri(tbl._dataset_uri)
|
fs, path = fs_from_uri(tbl._dataset_path)
|
||||||
file_info = fs.get_file_info(path)
|
file_info = fs.get_file_info(path)
|
||||||
if file_info.type != pa.fs.FileType.Directory:
|
if file_info.type != pa.fs.FileType.Directory:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
@@ -918,9 +940,14 @@ class LanceTable(Table):
|
|||||||
|
|
||||||
return tbl
|
return tbl
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
|
def _dataset_path(self) -> str:
|
||||||
|
# Cacheable since it's deterministic
|
||||||
|
return _table_path(self._conn.uri, self.name)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
def _dataset_uri(self) -> str:
|
def _dataset_uri(self) -> str:
|
||||||
return join_uri(self._conn.uri, f"{self.name}.lance")
|
return _table_uri(self._conn.uri, self.name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _dataset(self) -> LanceDataset:
|
def _dataset(self) -> LanceDataset:
|
||||||
@@ -1230,6 +1257,10 @@ class LanceTable(Table):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_fts_index_path(self):
|
def _get_fts_index_path(self):
|
||||||
|
if get_uri_scheme(self._dataset_uri) != "file":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Full-text search is not supported on object stores."
|
||||||
|
)
|
||||||
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
return join_uri(self._dataset_uri, "_indices", "tantivy")
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
|
|||||||
@@ -139,8 +139,11 @@ def join_uri(base: Union[str, pathlib.Path], *parts: str) -> str:
|
|||||||
# using pathlib for local paths make this windows compatible
|
# using pathlib for local paths make this windows compatible
|
||||||
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
|
# `get_uri_scheme` returns `file` for windows drive names (e.g. `c:\path`)
|
||||||
return str(pathlib.Path(base, *parts))
|
return str(pathlib.Path(base, *parts))
|
||||||
# for remote paths, just use os.path.join
|
else:
|
||||||
return "/".join([p.rstrip("/") for p in [base, *parts]])
|
# there might be query parameters in the base URI
|
||||||
|
url = urlparse(base)
|
||||||
|
new_path = "/".join([p.rstrip("/") for p in [url.path, *parts]])
|
||||||
|
return url._replace(path=new_path).geturl()
|
||||||
|
|
||||||
|
|
||||||
def attempt_import_or_raise(module: str, mitigation=None):
|
def attempt_import_or_raise(module: str, mitigation=None):
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class FakeLanceDBClient:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
def query(self, table_name: str, query: VectorQuery) -> VectorQueryResult:
|
||||||
print(f"{query=}")
|
|
||||||
assert table_name == "test"
|
assert table_name == "test"
|
||||||
t = pa.schema([]).empty_table()
|
t = pa.schema([]).empty_table()
|
||||||
return VectorQueryResult(t)
|
return VectorQueryResult(t)
|
||||||
@@ -40,21 +39,3 @@ def test_remote_db():
|
|||||||
table = conn["test"]
|
table = conn["test"]
|
||||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
||||||
table.search([1.0, 2.0]).to_pandas()
|
table.search([1.0, 2.0]).to_pandas()
|
||||||
|
|
||||||
|
|
||||||
def test_empty_query_with_filter():
|
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
|
||||||
setattr(conn, "_client", FakeLanceDBClient())
|
|
||||||
|
|
||||||
table = conn["test"]
|
|
||||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
|
||||||
print(table.query().select(["vector"]).where("foo == bar").to_arrow())
|
|
||||||
|
|
||||||
|
|
||||||
def test_fast_search_query_with_filter():
|
|
||||||
conn = lancedb.connect("db://client-will-be-injected", api_key="fake")
|
|
||||||
setattr(conn, "_client", FakeLanceDBClient())
|
|
||||||
|
|
||||||
table = conn["test"]
|
|
||||||
table.schema = pa.schema([pa.field("vector", pa.list_(pa.float32(), 2))])
|
|
||||||
print(table.query([0, 0], fast_search=True).select(["vector"]).where("foo == bar").to_arrow())
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import lancedb
|
import lancedb
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@@ -11,6 +9,7 @@ from lancedb.rerankers import (
|
|||||||
ColbertReranker,
|
ColbertReranker,
|
||||||
CrossEncoderReranker,
|
CrossEncoderReranker,
|
||||||
OpenaiReranker,
|
OpenaiReranker,
|
||||||
|
JinaReranker,
|
||||||
)
|
)
|
||||||
from lancedb.table import LanceTable
|
from lancedb.table import LanceTable
|
||||||
|
|
||||||
@@ -119,136 +118,18 @@ def test_linear_combination(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.slow
|
||||||
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
|
@pytest.mark.parametrize(
|
||||||
|
"reranker",
|
||||||
|
[
|
||||||
|
ColbertReranker(),
|
||||||
|
OpenaiReranker(),
|
||||||
|
CohereReranker(),
|
||||||
|
CrossEncoderReranker(),
|
||||||
|
JinaReranker(),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_cohere_reranker(tmp_path):
|
def test_colbert_reranker(tmp_path, reranker):
|
||||||
pytest.importorskip("cohere")
|
|
||||||
reranker = CohereReranker()
|
|
||||||
table, schema = get_test_table(tmp_path)
|
|
||||||
# Hybrid search setting
|
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=CohereReranker())
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
def test_cross_encoder_reranker(tmp_path):
|
|
||||||
pytest.importorskip("sentence_transformers")
|
|
||||||
reranker = CrossEncoderReranker()
|
|
||||||
table, schema = get_test_table(tmp_path)
|
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query), query_type="hybrid")
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
|
|
||||||
def test_colbert_reranker(tmp_path):
|
|
||||||
pytest.importorskip("transformers")
|
|
||||||
reranker = ColbertReranker()
|
|
||||||
table, schema = get_test_table(tmp_path)
|
table, schema = get_test_table(tmp_path)
|
||||||
result1 = (
|
result1 = (
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
table.search("Our father who art in heaven", query_type="hybrid")
|
||||||
@@ -305,67 +186,3 @@ def test_colbert_reranker(tmp_path):
|
|||||||
)
|
)
|
||||||
assert len(result) > 0
|
assert len(result) > 0
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.environ.get("OPENAI_API_KEY") is None, reason="OPENAI_API_KEY not set"
|
|
||||||
)
|
|
||||||
def test_openai_reranker(tmp_path):
|
|
||||||
pytest.importorskip("openai")
|
|
||||||
table, schema = get_test_table(tmp_path)
|
|
||||||
reranker = OpenaiReranker()
|
|
||||||
result1 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(normalize="score", reranker=reranker)
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
result2 = (
|
|
||||||
table.search("Our father who art in heaven", query_type="hybrid")
|
|
||||||
.rerank(reranker=OpenaiReranker())
|
|
||||||
.to_pydantic(schema)
|
|
||||||
)
|
|
||||||
assert result1 == result2
|
|
||||||
|
|
||||||
# test explicit hybrid query
|
|
||||||
query = "Our father who art in heaven"
|
|
||||||
query_vector = table.to_pandas()["vector"][0]
|
|
||||||
result = (
|
|
||||||
table.search((query_vector, query))
|
|
||||||
.limit(30)
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(result) == 30
|
|
||||||
|
|
||||||
err = (
|
|
||||||
"The _relevance_score column of the results returned by the reranker "
|
|
||||||
"represents the relevance of the result to the query & should "
|
|
||||||
"be descending."
|
|
||||||
)
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|
||||||
# Vector search setting
|
|
||||||
result = table.search(query).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
assert len(result) == 30
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
result_explicit = (
|
|
||||||
table.search(query_vector)
|
|
||||||
.rerank(reranker=reranker, query_string=query)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result_explicit) == 30
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError
|
|
||||||
): # This raises an error because vector query is provided without reanking query
|
|
||||||
table.search(query_vector).rerank(reranker=reranker).limit(30).to_arrow()
|
|
||||||
# FTS search setting
|
|
||||||
result = (
|
|
||||||
table.search(query, query_type="fts")
|
|
||||||
.rerank(reranker=reranker)
|
|
||||||
.limit(30)
|
|
||||||
.to_arrow()
|
|
||||||
)
|
|
||||||
assert len(result) > 0
|
|
||||||
assert np.all(np.diff(result.column("_relevance_score").to_numpy()) <= 0), err
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import copy
|
import copy
|
||||||
|
from datetime import timedelta
|
||||||
|
import threading
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pyarrow as pa
|
import pyarrow as pa
|
||||||
@@ -25,6 +27,7 @@ CONFIG = {
|
|||||||
"aws_access_key_id": "ACCESSKEY",
|
"aws_access_key_id": "ACCESSKEY",
|
||||||
"aws_secret_access_key": "SECRETKEY",
|
"aws_secret_access_key": "SECRETKEY",
|
||||||
"aws_endpoint": "http://localhost:4566",
|
"aws_endpoint": "http://localhost:4566",
|
||||||
|
"dynamodb_endpoint": "http://localhost:4566",
|
||||||
"aws_region": "us-east-1",
|
"aws_region": "us-east-1",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,3 +159,104 @@ def test_s3_sse(s3_bucket: str, kms_key: str):
|
|||||||
validate_objects_encrypted(s3_bucket, path, kms_key)
|
validate_objects_encrypted(s3_bucket, path, kms_key)
|
||||||
|
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def commit_table():
|
||||||
|
ddb = get_boto3_client("dynamodb", endpoint_url=CONFIG["dynamodb_endpoint"])
|
||||||
|
table_name = "lance-integtest"
|
||||||
|
try:
|
||||||
|
ddb.delete_table(TableName=table_name)
|
||||||
|
except ddb.exceptions.ResourceNotFoundException:
|
||||||
|
pass
|
||||||
|
ddb.create_table(
|
||||||
|
TableName=table_name,
|
||||||
|
KeySchema=[
|
||||||
|
{"AttributeName": "base_uri", "KeyType": "HASH"},
|
||||||
|
{"AttributeName": "version", "KeyType": "RANGE"},
|
||||||
|
],
|
||||||
|
AttributeDefinitions=[
|
||||||
|
{"AttributeName": "base_uri", "AttributeType": "S"},
|
||||||
|
{"AttributeName": "version", "AttributeType": "N"},
|
||||||
|
],
|
||||||
|
ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1},
|
||||||
|
)
|
||||||
|
yield table_name
|
||||||
|
ddb.delete_table(TableName=table_name)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.s3_test
|
||||||
|
def test_s3_dynamodb(s3_bucket: str, commit_table: str):
|
||||||
|
storage_options = copy.copy(CONFIG)
|
||||||
|
|
||||||
|
uri = f"s3+ddb://{s3_bucket}/test?ddbTableName={commit_table}"
|
||||||
|
data = pa.table({"x": [1, 2, 3]})
|
||||||
|
|
||||||
|
async def test():
|
||||||
|
db = await lancedb.connect_async(
|
||||||
|
uri,
|
||||||
|
storage_options=storage_options,
|
||||||
|
read_consistency_interval=timedelta(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
table = await db.create_table("test", data)
|
||||||
|
|
||||||
|
# Five concurrent writers
|
||||||
|
async def insert():
|
||||||
|
# independent table refs for true concurrent writes.
|
||||||
|
table = await db.open_table("test")
|
||||||
|
await table.add(data, mode="append")
|
||||||
|
|
||||||
|
tasks = [insert() for _ in range(5)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
row_count = await table.count_rows()
|
||||||
|
assert row_count == 3 * 6
|
||||||
|
|
||||||
|
asyncio.run(test())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.s3_test
|
||||||
|
def test_s3_dynamodb_sync(s3_bucket: str, commit_table: str, monkeypatch):
|
||||||
|
# Sync API doesn't support storage_options, so we have to provide as env vars
|
||||||
|
for key, value in CONFIG.items():
|
||||||
|
monkeypatch.setenv(key.upper(), value)
|
||||||
|
|
||||||
|
uri = f"s3+ddb://{s3_bucket}/test2?ddbTableName={commit_table}"
|
||||||
|
data = pa.table({"x": ["a", "b", "c"]})
|
||||||
|
|
||||||
|
db = lancedb.connect(
|
||||||
|
uri,
|
||||||
|
read_consistency_interval=timedelta(0),
|
||||||
|
)
|
||||||
|
|
||||||
|
table = db.create_table("test_ddb_sync", data)
|
||||||
|
|
||||||
|
# Five concurrent writers
|
||||||
|
def insert():
|
||||||
|
table = db.open_table("test_ddb_sync")
|
||||||
|
table.add(data, mode="append")
|
||||||
|
|
||||||
|
threads = []
|
||||||
|
for _ in range(5):
|
||||||
|
thread = threading.Thread(target=insert)
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
for thread in threads:
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
row_count = table.count_rows()
|
||||||
|
assert row_count == 3 * 6
|
||||||
|
|
||||||
|
# FTS indices should error since they are not supported yet.
|
||||||
|
with pytest.raises(
|
||||||
|
NotImplementedError, match="Full-text search is not supported on object stores."
|
||||||
|
):
|
||||||
|
table.create_fts_index("x")
|
||||||
|
|
||||||
|
# make sure list tables still works
|
||||||
|
assert db.table_names() == ["test_ddb_sync"]
|
||||||
|
db.drop_table("test_ddb_sync")
|
||||||
|
assert db.table_names() == []
|
||||||
|
db.drop_database()
|
||||||
|
|||||||
@@ -735,7 +735,7 @@ def test_create_scalar_index(db):
|
|||||||
indices = table.to_lance().list_indices()
|
indices = table.to_lance().list_indices()
|
||||||
assert len(indices) == 1
|
assert len(indices) == 1
|
||||||
scalar_index = indices[0]
|
scalar_index = indices[0]
|
||||||
assert scalar_index["type"] == "BTree"
|
assert scalar_index["type"] == "Scalar"
|
||||||
|
|
||||||
# Confirm that prefiltering still works with the scalar index column
|
# Confirm that prefiltering still works with the scalar index column
|
||||||
results = table.search().where("x = 'c'").to_arrow()
|
results = table.search().where("x = 'c'").to_arrow()
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb-node"
|
name = "lancedb-node"
|
||||||
version = "0.5.2-final.1"
|
version = "0.6.0"
|
||||||
description = "Serverless, low-latency vector database for AI applications"
|
description = "Serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "lancedb"
|
name = "lancedb"
|
||||||
version = "0.5.2-final.1"
|
version = "0.6.0"
|
||||||
edition.workspace = true
|
edition.workspace = true
|
||||||
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
description = "LanceDB: A serverless, low-latency vector database for AI applications"
|
||||||
license.workspace = true
|
license.workspace = true
|
||||||
@@ -55,10 +55,11 @@ walkdir = "2"
|
|||||||
# For s3 integration tests (dev deps aren't allowed to be optional atm)
|
# For s3 integration tests (dev deps aren't allowed to be optional atm)
|
||||||
# We pin these because the content-length check breaks with localstack
|
# We pin these because the content-length check breaks with localstack
|
||||||
# https://github.com/smithy-lang/smithy-rs/releases/tag/release-2024-05-21
|
# https://github.com/smithy-lang/smithy-rs/releases/tag/release-2024-05-21
|
||||||
|
aws-sdk-dynamodb = { version = "=1.23.0" }
|
||||||
aws-sdk-s3 = { version = "=1.23.0" }
|
aws-sdk-s3 = { version = "=1.23.0" }
|
||||||
aws-sdk-kms = { version = "=1.21.0" }
|
aws-sdk-kms = { version = "=1.21.0" }
|
||||||
aws-config = { version = "1.0" }
|
aws-config = { version = "1.0" }
|
||||||
aws-smithy-runtime = { version = "=1.3.0" }
|
aws-smithy-runtime = { version = "=1.3.1" }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|||||||
@@ -6,3 +6,12 @@
|
|||||||
LanceDB Rust SDK, a serverless vector database.
|
LanceDB Rust SDK, a serverless vector database.
|
||||||
|
|
||||||
Read more at: https://lancedb.com/
|
Read more at: https://lancedb.com/
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> A transitive dependency of `lancedb` is `lzma-sys`, which uses dynamic linking
|
||||||
|
> by default. If you want to statically link `lzma-sys`, you should activate it's
|
||||||
|
> `static` feature by adding the following to your dependencies:
|
||||||
|
>
|
||||||
|
> ```toml
|
||||||
|
> lzma-sys = { version = "*", features = ["static"] }
|
||||||
|
> ```
|
||||||
|
|||||||
@@ -1889,6 +1889,7 @@ impl TableInternal for NativeTable {
|
|||||||
}
|
}
|
||||||
columns.push(field.name.clone());
|
columns.push(field.name.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
let index_type = if is_vector {
|
let index_type = if is_vector {
|
||||||
crate::index::IndexType::IvfPq
|
crate::index::IndexType::IvfPq
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ const CONFIG: &[(&str, &str)] = &[
|
|||||||
("access_key_id", "ACCESS_KEY"),
|
("access_key_id", "ACCESS_KEY"),
|
||||||
("secret_access_key", "SECRET_KEY"),
|
("secret_access_key", "SECRET_KEY"),
|
||||||
("endpoint", "http://127.0.0.1:4566"),
|
("endpoint", "http://127.0.0.1:4566"),
|
||||||
|
("dynamodb_endpoint", "http://127.0.0.1:4566"),
|
||||||
("allow_http", "true"),
|
("allow_http", "true"),
|
||||||
|
("region", "us-east-1"),
|
||||||
];
|
];
|
||||||
|
|
||||||
async fn aws_config() -> SdkConfig {
|
async fn aws_config() -> SdkConfig {
|
||||||
@@ -288,3 +290,126 @@ async fn test_encryption() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DynamoDBCommitTable(String);
|
||||||
|
|
||||||
|
impl DynamoDBCommitTable {
|
||||||
|
async fn new(name: &str) -> Self {
|
||||||
|
let config = aws_config().await;
|
||||||
|
let client = aws_sdk_dynamodb::Client::new(&config);
|
||||||
|
|
||||||
|
// In case it wasn't deleted earlier
|
||||||
|
Self::delete_table(client.clone(), name).await;
|
||||||
|
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
|
||||||
|
|
||||||
|
use aws_sdk_dynamodb::types::*;
|
||||||
|
|
||||||
|
client
|
||||||
|
.create_table()
|
||||||
|
.table_name(name)
|
||||||
|
.attribute_definitions(
|
||||||
|
AttributeDefinition::builder()
|
||||||
|
.attribute_name("base_uri")
|
||||||
|
.attribute_type(ScalarAttributeType::S)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.attribute_definitions(
|
||||||
|
AttributeDefinition::builder()
|
||||||
|
.attribute_name("version")
|
||||||
|
.attribute_type(ScalarAttributeType::N)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.key_schema(
|
||||||
|
KeySchemaElement::builder()
|
||||||
|
.attribute_name("base_uri")
|
||||||
|
.key_type(KeyType::Hash)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.key_schema(
|
||||||
|
KeySchemaElement::builder()
|
||||||
|
.attribute_name("version")
|
||||||
|
.key_type(KeyType::Range)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.provisioned_throughput(
|
||||||
|
ProvisionedThroughput::builder()
|
||||||
|
.read_capacity_units(1)
|
||||||
|
.write_capacity_units(1)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
Self(name.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_table(client: aws_sdk_dynamodb::Client, name: &str) {
|
||||||
|
match client
|
||||||
|
.delete_table()
|
||||||
|
.table_name(name)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|err| err.into_service_error())
|
||||||
|
{
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) if e.is_resource_not_found_exception() => {}
|
||||||
|
Err(e) => panic!("Failed to delete table: {}", e),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for DynamoDBCommitTable {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let table_name = self.0.clone();
|
||||||
|
tokio::task::spawn(async move {
|
||||||
|
let config = aws_config().await;
|
||||||
|
let client = aws_sdk_dynamodb::Client::new(&config);
|
||||||
|
Self::delete_table(client, &table_name).await;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_concurrent_dynamodb_commit() {
|
||||||
|
// test concurrent commit on dynamodb
|
||||||
|
let bucket = S3Bucket::new("test-dynamodb").await;
|
||||||
|
let table = DynamoDBCommitTable::new("test_table").await;
|
||||||
|
|
||||||
|
let uri = format!("s3+ddb://{}?ddbTableName={}", bucket.0, table.0);
|
||||||
|
let db = lancedb::connect(&uri)
|
||||||
|
.storage_options(CONFIG.iter().cloned())
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let data = test_data();
|
||||||
|
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||||
|
|
||||||
|
let table = db.create_table("test_table", data).execute().await.unwrap();
|
||||||
|
|
||||||
|
let data = test_data();
|
||||||
|
|
||||||
|
let mut tasks = vec![];
|
||||||
|
for _ in 0..5 {
|
||||||
|
let table = db.open_table("test_table").execute().await.unwrap();
|
||||||
|
let data = data.clone();
|
||||||
|
tasks.push(tokio::spawn(async move {
|
||||||
|
let data = RecordBatchIterator::new(vec![Ok(data.clone())], data.schema());
|
||||||
|
table.add(data).execute().await.unwrap();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for task in tasks {
|
||||||
|
task.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
table.checkout_latest().await.unwrap();
|
||||||
|
let row_count = table.count_rows(None).await.unwrap();
|
||||||
|
assert_eq!(row_count, 18);
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user