mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-07 12:22:59 +00:00
Merge branch 'main' of https://github.com/lancedb/lancedb into docs_march
This commit is contained in:
10
Cargo.toml
10
Cargo.toml
@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
|
||||
categories = ["database-implementations"]
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.10.4", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.4" }
|
||||
lance-linalg = { "version" = "=0.10.4" }
|
||||
lance-testing = { "version" = "=0.10.4" }
|
||||
lance = { "version" = "=0.10.5", "features" = ["dynamodb"] }
|
||||
lance-index = { "version" = "=0.10.5" }
|
||||
lance-linalg = { "version" = "=0.10.5" }
|
||||
lance-testing = { "version" = "=0.10.5" }
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "50.0", optional = false }
|
||||
arrow-array = "50.0"
|
||||
@@ -39,3 +39,5 @@ pin-project = "1.0.7"
|
||||
snafu = "0.7.4"
|
||||
url = "2"
|
||||
num-traits = "0.2"
|
||||
regex = "1.10"
|
||||
lazy_static = "1"
|
||||
|
||||
@@ -1,11 +1,79 @@
|
||||
document.addEventListener("DOMContentLoaded", function () {
|
||||
var script = document.createElement("script");
|
||||
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
|
||||
script.setAttribute("data-website-id", "c5881fae-cec0-490b-b45e-d83d131d4f25");
|
||||
script.setAttribute("data-project-name", "LanceDB");
|
||||
script.setAttribute("data-project-color", "#000000");
|
||||
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/108903835?s=200&v=4");
|
||||
script.setAttribute("data-modal-example-questions","Help me create an IVF_PQ index,How do I do an exhaustive search?,How do I create a LanceDB table?,Can I use my own embedding function?");
|
||||
script.async = true;
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
// Creates an SVG robot icon (from Lucide)
|
||||
function robotSVG() {
|
||||
var svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
|
||||
svg.setAttribute("width", "24");
|
||||
svg.setAttribute("height", "24");
|
||||
svg.setAttribute("viewBox", "0 0 24 24");
|
||||
svg.setAttribute("fill", "none");
|
||||
svg.setAttribute("stroke", "currentColor");
|
||||
svg.setAttribute("stroke-width", "2");
|
||||
svg.setAttribute("stroke-linecap", "round");
|
||||
svg.setAttribute("stroke-linejoin", "round");
|
||||
svg.setAttribute("class", "lucide lucide-bot-message-square");
|
||||
|
||||
var path1 = document.createElementNS("http://www.w3.org/2000/svg", "path");
|
||||
path1.setAttribute("d", "M12 6V2H8");
|
||||
svg.appendChild(path1);
|
||||
|
||||
var path2 = document.createElementNS("http://www.w3.org/2000/svg", "path");
|
||||
path2.setAttribute("d", "m8 18-4 4V8a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2Z");
|
||||
svg.appendChild(path2);
|
||||
|
||||
var path3 = document.createElementNS("http://www.w3.org/2000/svg", "path");
|
||||
path3.setAttribute("d", "M2 12h2");
|
||||
svg.appendChild(path3);
|
||||
|
||||
var path4 = document.createElementNS("http://www.w3.org/2000/svg", "path");
|
||||
path4.setAttribute("d", "M9 11v2");
|
||||
svg.appendChild(path4);
|
||||
|
||||
var path5 = document.createElementNS("http://www.w3.org/2000/svg", "path");
|
||||
path5.setAttribute("d", "M15 11v2");
|
||||
svg.appendChild(path5);
|
||||
|
||||
var path6 = document.createElementNS("http://www.w3.org/2000/svg", "path");
|
||||
path6.setAttribute("d", "M20 12h2");
|
||||
svg.appendChild(path6);
|
||||
|
||||
return svg
|
||||
}
|
||||
|
||||
// Creates the Fluidic Chatbot buttom
|
||||
function fluidicButton() {
|
||||
var btn = document.createElement("a");
|
||||
btn.href = "https://asklancedb.com";
|
||||
btn.target = "_blank";
|
||||
btn.style.position = "fixed";
|
||||
btn.style.fontWeight = "bold";
|
||||
btn.style.fontSize = ".8rem";
|
||||
btn.style.right = "10px";
|
||||
btn.style.bottom = "10px";
|
||||
btn.style.width = "80px";
|
||||
btn.style.height = "80px";
|
||||
btn.style.background = "linear-gradient(135deg, #7C5EFF 0%, #625eff 100%)";
|
||||
btn.style.color = "white";
|
||||
btn.style.borderRadius = "5px";
|
||||
btn.style.display = "flex";
|
||||
btn.style.flexDirection = "column";
|
||||
btn.style.justifyContent = "center";
|
||||
btn.style.alignItems = "center";
|
||||
btn.style.zIndex = "1000";
|
||||
btn.style.opacity = "0";
|
||||
btn.style.boxShadow = "0 0 0 rgba(0, 0, 0, 0)";
|
||||
btn.style.transition = "opacity 0.2s ease-in, box-shadow 0.2s ease-in";
|
||||
|
||||
setTimeout(function() {
|
||||
btn.style.opacity = "1";
|
||||
btn.style.boxShadow = "0 0 .2rem #0000001a,0 .2rem .4rem #0003"
|
||||
}, 0);
|
||||
|
||||
return btn
|
||||
}
|
||||
|
||||
document.addEventListener("DOMContentLoaded", function() {
|
||||
var btn = fluidicButton()
|
||||
btn.appendChild(robotSVG());
|
||||
var text = document.createTextNode("Ask AI");
|
||||
btn.appendChild(text);
|
||||
document.body.appendChild(btn);
|
||||
});
|
||||
|
||||
@@ -24,6 +24,7 @@ import { RemoteConnection } from './remote'
|
||||
import { Query } from './query'
|
||||
import { isEmbeddingFunction } from './embedding/embedding_function'
|
||||
import { type Literal, toSQL } from './util'
|
||||
import { type HttpMiddleware } from './middleware'
|
||||
|
||||
const {
|
||||
databaseNew,
|
||||
@@ -302,6 +303,18 @@ export interface Connection {
|
||||
* @param name The name of the table to drop.
|
||||
*/
|
||||
dropTable(name: string): Promise<void>
|
||||
|
||||
/**
|
||||
* Instrument the behavior of this Connection with middleware.
|
||||
*
|
||||
* The middleware will be called in the order they are added.
|
||||
*
|
||||
* Currently this functionality is only supported for remote Connections.
|
||||
*
|
||||
* @param {HttpMiddleware} - Middleware which will instrument the Connection.
|
||||
* @returns - this Connection instrumented by the passed middleware
|
||||
*/
|
||||
withMiddleware(middleware: HttpMiddleware): Connection
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -541,6 +554,18 @@ export interface Table<T = number[]> {
|
||||
* names (e.g. "a").
|
||||
*/
|
||||
dropColumns(columnNames: string[]): Promise<void>
|
||||
|
||||
/**
|
||||
* Instrument the behavior of this Table with middleware.
|
||||
*
|
||||
* The middleware will be called in the order they are added.
|
||||
*
|
||||
* Currently this functionality is only supported for remote tables.
|
||||
*
|
||||
* @param {HttpMiddleware} - Middleware which will instrument the Table.
|
||||
* @returns - this Table instrumented by the passed middleware
|
||||
*/
|
||||
withMiddleware(middleware: HttpMiddleware): Table<T>
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -795,6 +820,10 @@ export class LocalConnection implements Connection {
|
||||
async dropTable (name: string): Promise<void> {
|
||||
await databaseDropTable.call(this._db, name)
|
||||
}
|
||||
|
||||
withMiddleware (middleware: HttpMiddleware): Connection {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
export class LocalTable<T = number[]> implements Table<T> {
|
||||
@@ -1105,6 +1134,10 @@ export class LocalTable<T = number[]> implements Table<T> {
|
||||
async dropColumns (columnNames: string[]): Promise<void> {
|
||||
return tableDropColumns.call(this._tbl, columnNames)
|
||||
}
|
||||
|
||||
withMiddleware (middleware: HttpMiddleware): Table<T> {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
export interface CleanupStats {
|
||||
|
||||
58
node/src/middleware.ts
Normal file
58
node/src/middleware.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/**
|
||||
* Middleware for Remote LanceDB Connection or Table
|
||||
*/
|
||||
export interface HttpMiddleware {
|
||||
/**
|
||||
* A callback that can be used to instrument the behavior of http requests to remote
|
||||
* tables. It can be used to add headers, modify the request, or even short-circuit
|
||||
* the request and return a response without making the request to the remote endpoint.
|
||||
* It can also be used to modify the response from the remote endpoint.
|
||||
*
|
||||
* @param {RemoteResponse} res - Request to the remote endpoint
|
||||
* @param {onRemoteRequestNext} next - Callback to advance the middleware chain
|
||||
*/
|
||||
onRemoteRequest(
|
||||
req: RemoteRequest,
|
||||
next: (req: RemoteRequest) => Promise<RemoteResponse>,
|
||||
): Promise<RemoteResponse>
|
||||
};
|
||||
|
||||
export enum Method {
|
||||
GET,
|
||||
POST
|
||||
}
|
||||
|
||||
/**
|
||||
* A LanceDB Remote HTTP Request
|
||||
*/
|
||||
export interface RemoteRequest {
|
||||
uri: string
|
||||
method: Method
|
||||
headers: Map<string, string>
|
||||
params?: Map<string, string>
|
||||
body?: any
|
||||
}
|
||||
|
||||
/**
|
||||
* A LanceDB Remote HTTP Response
|
||||
*/
|
||||
export interface RemoteResponse {
|
||||
status: number
|
||||
statusText: string
|
||||
headers: Map<string, string>
|
||||
body: () => Promise<any>
|
||||
}
|
||||
@@ -12,13 +12,101 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
import axios, { type AxiosResponse } from 'axios'
|
||||
import axios, { type AxiosResponse, type ResponseType } from 'axios'
|
||||
|
||||
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
|
||||
|
||||
import { type RemoteResponse, type RemoteRequest, Method } from '../middleware'
|
||||
|
||||
interface HttpLancedbClientMiddleware {
|
||||
onRemoteRequest(
|
||||
req: RemoteRequest,
|
||||
next: (req: RemoteRequest) => Promise<RemoteResponse>,
|
||||
): Promise<RemoteResponse>
|
||||
}
|
||||
|
||||
/**
|
||||
* Invoke the middleware chain and at the end call the remote endpoint
|
||||
*/
|
||||
async function callWithMiddlewares (
|
||||
req: RemoteRequest,
|
||||
middlewares: HttpLancedbClientMiddleware[],
|
||||
opts?: MiddlewareInvocationOptions
|
||||
): Promise<RemoteResponse> {
|
||||
async function call (
|
||||
i: number,
|
||||
req: RemoteRequest
|
||||
): Promise<RemoteResponse> {
|
||||
// if we have reached the end of the middleware chain, make the request
|
||||
if (i > middlewares.length) {
|
||||
const headers = Object.fromEntries(req.headers.entries())
|
||||
const params = Object.fromEntries(req.params?.entries() ?? [])
|
||||
const timeout = 10000
|
||||
let res
|
||||
if (req.method === Method.POST) {
|
||||
res = await axios.post(
|
||||
req.uri,
|
||||
req.body,
|
||||
{
|
||||
headers,
|
||||
params,
|
||||
timeout,
|
||||
responseType: opts?.responseType
|
||||
}
|
||||
)
|
||||
} else {
|
||||
res = await axios.get(
|
||||
req.uri,
|
||||
{
|
||||
headers,
|
||||
params,
|
||||
timeout
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
return toLanceRes(res)
|
||||
}
|
||||
|
||||
// call next middleware in chain
|
||||
return await middlewares[i - 1].onRemoteRequest(
|
||||
req,
|
||||
async (req) => {
|
||||
return await call(i + 1, req)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
return await call(1, req)
|
||||
}
|
||||
|
||||
interface MiddlewareInvocationOptions {
|
||||
responseType?: ResponseType
|
||||
}
|
||||
|
||||
/**
|
||||
* Marshall the library response into a LanceDB response
|
||||
*/
|
||||
function toLanceRes (res: AxiosResponse): RemoteResponse {
|
||||
const headers = new Map()
|
||||
for (const h in res.headers) {
|
||||
headers.set(h, res.headers[h])
|
||||
}
|
||||
|
||||
return {
|
||||
status: res.status,
|
||||
statusText: res.statusText,
|
||||
headers,
|
||||
body: async () => {
|
||||
return res.data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class HttpLancedbClient {
|
||||
private readonly _url: string
|
||||
private readonly _apiKey: () => string
|
||||
private readonly _middlewares: HttpLancedbClientMiddleware[]
|
||||
|
||||
public constructor (
|
||||
url: string,
|
||||
@@ -27,6 +115,7 @@ export class HttpLancedbClient {
|
||||
) {
|
||||
this._url = url
|
||||
this._apiKey = () => apiKey
|
||||
this._middlewares = []
|
||||
}
|
||||
|
||||
get uri (): string {
|
||||
@@ -43,74 +132,61 @@ export class HttpLancedbClient {
|
||||
columns?: string[],
|
||||
filter?: string
|
||||
): Promise<ArrowTable<any>> {
|
||||
const response = await axios.post(
|
||||
`${this._url}/v1/table/${tableName}/query/`,
|
||||
{
|
||||
vector,
|
||||
k,
|
||||
nprobes,
|
||||
refineFactor,
|
||||
columns,
|
||||
filter,
|
||||
prefilter
|
||||
},
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
},
|
||||
responseType: 'arraybuffer',
|
||||
timeout: 10000
|
||||
}
|
||||
).catch((err) => {
|
||||
console.error('error: ', err)
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
return err.response
|
||||
})
|
||||
if (response.status !== 200) {
|
||||
const errorData = new TextDecoder().decode(response.data)
|
||||
throw new Error(
|
||||
`Server Error, status: ${response.status as number}, ` +
|
||||
`message: ${response.statusText as string}: ${errorData}`
|
||||
)
|
||||
}
|
||||
|
||||
const table = tableFromIPC(response.data)
|
||||
const result = await this.post(
|
||||
`/v1/table/${tableName}/query/`,
|
||||
{
|
||||
vector,
|
||||
k,
|
||||
nprobes,
|
||||
refineFactor,
|
||||
columns,
|
||||
filter,
|
||||
prefilter
|
||||
},
|
||||
undefined,
|
||||
undefined,
|
||||
'arraybuffer'
|
||||
)
|
||||
const table = tableFromIPC(await result.body())
|
||||
return table
|
||||
}
|
||||
|
||||
/**
|
||||
* Sent GET request.
|
||||
*/
|
||||
public async get (path: string, params?: Record<string, string | number>): Promise<AxiosResponse> {
|
||||
const response = await axios.get(
|
||||
`${this._url}${path}`,
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
},
|
||||
params,
|
||||
timeout: 10000
|
||||
}
|
||||
).catch((err) => {
|
||||
public async get (path: string, params?: Record<string, string>): Promise<RemoteResponse> {
|
||||
const req = {
|
||||
uri: `${this._url}${path}`,
|
||||
method: Method.GET,
|
||||
headers: new Map(Object.entries({
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
})),
|
||||
params: new Map(Object.entries(params ?? {}))
|
||||
}
|
||||
|
||||
let response
|
||||
try {
|
||||
response = await callWithMiddlewares(req, this._middlewares)
|
||||
return response
|
||||
} catch (err: any) {
|
||||
console.error('error: ', err)
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
return err.response
|
||||
})
|
||||
|
||||
response = toLanceRes(err.response)
|
||||
}
|
||||
|
||||
if (response.status !== 200) {
|
||||
const errorData = new TextDecoder().decode(response.data)
|
||||
const errorData = new TextDecoder().decode(await response.body())
|
||||
throw new Error(
|
||||
`Server Error, status: ${response.status as number}, ` +
|
||||
`message: ${response.statusText as string}: ${errorData}`
|
||||
`Server Error, status: ${response.status}, ` +
|
||||
`message: ${response.statusText}: ${errorData}`
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
@@ -120,35 +196,65 @@ export class HttpLancedbClient {
|
||||
public async post (
|
||||
path: string,
|
||||
data?: any,
|
||||
params?: Record<string, string | number>,
|
||||
content?: string | undefined
|
||||
): Promise<AxiosResponse> {
|
||||
const response = await axios.post(
|
||||
`${this._url}${path}`,
|
||||
data,
|
||||
{
|
||||
headers: {
|
||||
'Content-Type': content ?? 'application/json',
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
},
|
||||
params,
|
||||
timeout: 30000
|
||||
}
|
||||
).catch((err) => {
|
||||
params?: Record<string, string>,
|
||||
content?: string | undefined,
|
||||
responseType?: ResponseType | undefined
|
||||
): Promise<RemoteResponse> {
|
||||
const req = {
|
||||
uri: `${this._url}${path}`,
|
||||
method: Method.POST,
|
||||
headers: new Map(Object.entries({
|
||||
'Content-Type': content ?? 'application/json',
|
||||
'x-api-key': this._apiKey(),
|
||||
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
|
||||
})),
|
||||
params: new Map(Object.entries(params ?? {})),
|
||||
body: data
|
||||
}
|
||||
|
||||
let response
|
||||
try {
|
||||
response = await callWithMiddlewares(req, this._middlewares, { responseType })
|
||||
|
||||
// return response
|
||||
} catch (err: any) {
|
||||
console.error('error: ', err)
|
||||
if (err.response === undefined) {
|
||||
throw new Error(`Network Error: ${err.message as string}`)
|
||||
}
|
||||
return err.response
|
||||
})
|
||||
response = toLanceRes(err.response)
|
||||
}
|
||||
|
||||
if (response.status !== 200) {
|
||||
const errorData = new TextDecoder().decode(response.data)
|
||||
const errorData = new TextDecoder().decode(await response.body())
|
||||
throw new Error(
|
||||
`Server Error, status: ${response.status as number}, ` +
|
||||
`message: ${response.statusText as string}: ${errorData}`
|
||||
`Server Error, status: ${response.status}, ` +
|
||||
`message: ${response.statusText}: ${errorData}`
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
/**
|
||||
* Instrument this client with middleware
|
||||
* @param mw - The middleware that instruments the client
|
||||
* @returns - an instance of this client instrumented with the middleware
|
||||
*/
|
||||
public withMiddleware (mw: HttpLancedbClientMiddleware): HttpLancedbClient {
|
||||
const wrapped = this.clone()
|
||||
wrapped._middlewares.push(mw)
|
||||
return wrapped
|
||||
}
|
||||
|
||||
/**
|
||||
* Make a clone of this client
|
||||
*/
|
||||
private clone (): HttpLancedbClient {
|
||||
const clone = new HttpLancedbClient(this._url, this._apiKey(), this._dbName)
|
||||
for (const mw of this._middlewares) {
|
||||
clone._middlewares.push(mw)
|
||||
}
|
||||
return clone
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,12 +39,13 @@ import {
|
||||
fromTableToStreamBuffer
|
||||
} from '../arrow'
|
||||
import { toSQL } from '../util'
|
||||
import { type HttpMiddleware } from '../middleware'
|
||||
|
||||
/**
|
||||
* Remote connection.
|
||||
*/
|
||||
export class RemoteConnection implements Connection {
|
||||
private readonly _client: HttpLancedbClient
|
||||
private _client: HttpLancedbClient
|
||||
private readonly _dbName: string
|
||||
|
||||
constructor (opts: ConnectionOptions) {
|
||||
@@ -84,10 +85,11 @@ export class RemoteConnection implements Connection {
|
||||
limit: number = 10
|
||||
): Promise<string[]> {
|
||||
const response = await this._client.get('/v1/table/', {
|
||||
limit,
|
||||
limit: `${limit}`,
|
||||
page_token: pageToken
|
||||
})
|
||||
return response.data.tables
|
||||
const body = await response.body()
|
||||
return body.tables
|
||||
}
|
||||
|
||||
async openTable (name: string): Promise<Table>
|
||||
@@ -163,7 +165,7 @@ export class RemoteConnection implements Connection {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
|
||||
@@ -177,6 +179,17 @@ export class RemoteConnection implements Connection {
|
||||
async dropTable (name: string): Promise<void> {
|
||||
await this._client.post(`/v1/table/${name}/drop/`)
|
||||
}
|
||||
|
||||
withMiddleware (middleware: HttpMiddleware): Connection {
|
||||
const wrapped = this.clone()
|
||||
wrapped._client = wrapped._client.withMiddleware(middleware)
|
||||
return wrapped
|
||||
}
|
||||
|
||||
private clone (): RemoteConnection {
|
||||
const clone: RemoteConnection = Object.create(RemoteConnection.prototype)
|
||||
return Object.assign(clone, this)
|
||||
}
|
||||
}
|
||||
|
||||
export class RemoteQuery<T = number[]> extends Query<T> {
|
||||
@@ -229,7 +242,7 @@ export class RemoteQuery<T = number[]> extends Query<T> {
|
||||
// we are using extend until we have next next version release
|
||||
// Table and Connection has both been refactored to interfaces
|
||||
export class RemoteTable<T = number[]> implements Table<T> {
|
||||
private readonly _client: HttpLancedbClient
|
||||
private _client: HttpLancedbClient
|
||||
private readonly _embeddings?: EmbeddingFunction<T>
|
||||
private readonly _name: string
|
||||
|
||||
@@ -256,15 +269,15 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
get schema (): Promise<any> {
|
||||
return this._client
|
||||
.post(`/v1/table/${this._name}/describe/`)
|
||||
.then((res) => {
|
||||
.then(async (res) => {
|
||||
if (res.status !== 200) {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
return res.data?.schema
|
||||
return (await res.body())?.schema
|
||||
})
|
||||
}
|
||||
|
||||
@@ -320,7 +333,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -346,7 +359,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
return tbl.numRows
|
||||
@@ -372,7 +385,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
return tbl.numRows
|
||||
@@ -415,7 +428,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -436,14 +449,14 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
throw new Error(
|
||||
`Server Error, status: ${res.status}, ` +
|
||||
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
|
||||
`message: ${res.statusText}: ${res.data}`
|
||||
`message: ${res.statusText}: ${await res.body()}`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async countRows (): Promise<number> {
|
||||
const result = await this._client.post(`/v1/table/${this._name}/describe/`)
|
||||
return result.data?.stats?.num_rows
|
||||
return (await result.body())?.stats?.num_rows
|
||||
}
|
||||
|
||||
async delete (filter: string): Promise<void> {
|
||||
@@ -476,7 +489,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
const results = await this._client.post(
|
||||
`/v1/table/${this._name}/index/list/`
|
||||
)
|
||||
return results.data.indexes?.map((index: any) => ({
|
||||
return (await results.body()).indexes?.map((index: any) => ({
|
||||
columns: index.columns,
|
||||
name: index.index_name,
|
||||
uuid: index.index_uuid
|
||||
@@ -487,9 +500,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
const results = await this._client.post(
|
||||
`/v1/table/${this._name}/index/${indexUuid}/stats/`
|
||||
)
|
||||
const body = await results.body()
|
||||
return {
|
||||
numIndexedRows: results.data.num_indexed_rows,
|
||||
numUnindexedRows: results.data.num_unindexed_rows
|
||||
numIndexedRows: body?.num_indexed_rows,
|
||||
numUnindexedRows: body?.num_unindexed_rows
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,4 +518,15 @@ export class RemoteTable<T = number[]> implements Table<T> {
|
||||
async dropColumns (columnNames: string[]): Promise<void> {
|
||||
throw new Error('Drop columns is not yet supported in LanceDB Cloud.')
|
||||
}
|
||||
|
||||
withMiddleware(middleware: HttpMiddleware): Table<T> {
|
||||
const wrapped = this.clone()
|
||||
wrapped._client = wrapped._client.withMiddleware(middleware)
|
||||
return wrapped
|
||||
}
|
||||
|
||||
private clone (): RemoteTable<T> {
|
||||
const clone: RemoteTable<T> = Object.create(RemoteTable.prototype)
|
||||
return Object.assign(clone, this)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,7 +240,7 @@ describe("When creating an index", () => {
|
||||
)
|
||||
.column("vec")
|
||||
.toArrow(),
|
||||
).rejects.toThrow(/.*does not match the dimension.*/);
|
||||
).rejects.toThrow(/.* query dim=64, expected vector dim=32.*/);
|
||||
|
||||
const query64 = Array(64)
|
||||
.fill(1)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.6.4
|
||||
current_version = 0.6.5
|
||||
commit = True
|
||||
message = [python] Bump version: {current_version} → {new_version}
|
||||
tag = True
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
[project]
|
||||
name = "lancedb"
|
||||
version = "0.6.4"
|
||||
version = "0.6.5"
|
||||
dependencies = [
|
||||
"deprecation",
|
||||
"pylance==0.10.4",
|
||||
"pylance==0.10.5",
|
||||
"ratelimiter~=1.0",
|
||||
"retry>=0.9.2",
|
||||
"tqdm>=4.27.0",
|
||||
@@ -94,13 +94,11 @@ lancedb = "lancedb.cli.cli:cli"
|
||||
requires = ["maturin>=1.4"]
|
||||
build-backend = "maturin"
|
||||
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
|
||||
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"asyncio",
|
||||
|
||||
@@ -31,7 +31,13 @@ from lancedb.utils.events import register_event
|
||||
from ._lancedb import connect as lancedb_connect
|
||||
from .pydantic import LanceModel
|
||||
from .table import AsyncTable, LanceTable, Table, _sanitize_data
|
||||
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
|
||||
from .util import (
|
||||
fs_from_uri,
|
||||
get_uri_location,
|
||||
get_uri_scheme,
|
||||
join_uri,
|
||||
validate_table_name,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
@@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection):
|
||||
"""
|
||||
if mode.lower() not in ["create", "overwrite"]:
|
||||
raise ValueError("mode must be either 'create' or 'overwrite'")
|
||||
validate_table_name(name)
|
||||
|
||||
tbl = LanceTable.create(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,7 @@ from ..db import DBConnection
|
||||
from ..embeddings import EmbeddingFunctionConfig
|
||||
from ..pydantic import LanceModel
|
||||
from ..table import Table, _sanitize_data
|
||||
from ..util import validate_table_name
|
||||
from .arrow import to_ipc_binary
|
||||
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
|
||||
from .errors import LanceDBClientError
|
||||
@@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection):
|
||||
LanceTable(table4)
|
||||
|
||||
"""
|
||||
validate_table_name(name)
|
||||
if data is None and schema is None:
|
||||
raise ValueError("Either data or schema must be provided.")
|
||||
if embedding_functions is not None:
|
||||
|
||||
@@ -25,6 +25,8 @@ import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.fs as pa_fs
|
||||
|
||||
from ._lancedb import validate_table_name as native_validate_table_name
|
||||
|
||||
|
||||
def safe_import_adlfs():
|
||||
try:
|
||||
@@ -286,3 +288,8 @@ def deprecated(func):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
def validate_table_name(name: str):
|
||||
"""Verify the table name is valid."""
|
||||
native_validate_table_name(name)
|
||||
|
||||
@@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path):
|
||||
.to_arrow()
|
||||
)
|
||||
assert table.num_rows == 1
|
||||
|
||||
|
||||
def test_create_table_with_invalid_names(tmp_path):
|
||||
db = lancedb.connect(uri=tmp_path)
|
||||
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo/bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo bar", data)
|
||||
with pytest.raises(ValueError):
|
||||
db.create_table("foo$$bar", data)
|
||||
db.create_table("foo.bar", data)
|
||||
|
||||
@@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<VectorQuery>()?;
|
||||
m.add_class::<RecordBatchStream>()?;
|
||||
m.add_function(wrap_pyfunction!(connect, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Mutex;
|
||||
use lancedb::DistanceType;
|
||||
use pyo3::{
|
||||
exceptions::{PyRuntimeError, PyValueError},
|
||||
PyResult,
|
||||
pyfunction, PyResult,
|
||||
};
|
||||
|
||||
/// A wrapper around a rust builder
|
||||
@@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub(crate) fn validate_table_name(table_name: &str) -> PyResult<()> {
|
||||
lancedb::utils::validate_table_name(table_name)
|
||||
.map_err(|e| PyValueError::new_err(e.to_string()))
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ chrono = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
snafu = { workspace = true }
|
||||
half = { workspace = true }
|
||||
lazy_static.workspace = true
|
||||
lance = { workspace = true }
|
||||
lance-index = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
@@ -34,11 +35,10 @@ bytes = "1"
|
||||
futures.workspace = true
|
||||
num-traits.workspace = true
|
||||
url.workspace = true
|
||||
regex.workspace = true
|
||||
serde = { version = "^1" }
|
||||
serde_json = { version = "1" }
|
||||
|
||||
# For remote feature
|
||||
|
||||
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
@@ -31,6 +31,7 @@ use crate::arrow::IntoArrow;
|
||||
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
|
||||
use crate::io::object_store::MirroringObjectStoreWrapper;
|
||||
use crate::table::{NativeTable, WriteOptions};
|
||||
use crate::utils::validate_table_name;
|
||||
use crate::Table;
|
||||
|
||||
pub const LANCE_FILE_EXTENSION: &str = "lance";
|
||||
@@ -675,13 +676,18 @@ impl Database {
|
||||
|
||||
/// Get the URI of a table in the database.
|
||||
fn table_uri(&self, name: &str) -> Result<String> {
|
||||
validate_table_name(name)?;
|
||||
|
||||
let path = Path::new(&self.uri);
|
||||
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
|
||||
|
||||
let mut uri = table_uri
|
||||
.as_path()
|
||||
.to_str()
|
||||
.context(InvalidTableNameSnafu { name })?
|
||||
.context(InvalidTableNameSnafu {
|
||||
name,
|
||||
reason: "Name is not valid URL",
|
||||
})?
|
||||
.to_string();
|
||||
|
||||
// If there are query string set on the connection, propagate to lance
|
||||
|
||||
@@ -20,8 +20,8 @@ use snafu::Snafu;
|
||||
#[derive(Debug, Snafu)]
|
||||
#[snafu(visibility(pub(crate)))]
|
||||
pub enum Error {
|
||||
#[snafu(display("Invalid table name: {name}"))]
|
||||
InvalidTableName { name: String },
|
||||
#[snafu(display("Invalid table name (\"{name}\"): {reason}"))]
|
||||
InvalidTableName { name: String, reason: String },
|
||||
#[snafu(display("Invalid input, {message}"))]
|
||||
InvalidInput { message: String },
|
||||
#[snafu(display("Table '{name}' was not found"))]
|
||||
|
||||
@@ -230,9 +230,9 @@ pub enum DistanceType {
|
||||
impl From<DistanceType> for LanceDistanceType {
|
||||
fn from(value: DistanceType) -> Self {
|
||||
match value {
|
||||
DistanceType::L2 => LanceDistanceType::L2,
|
||||
DistanceType::Cosine => LanceDistanceType::Cosine,
|
||||
DistanceType::Dot => LanceDistanceType::Dot,
|
||||
DistanceType::L2 => Self::L2,
|
||||
DistanceType::Cosine => Self::Cosine,
|
||||
DistanceType::Dot => Self::Dot,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -240,9 +240,9 @@ impl From<DistanceType> for LanceDistanceType {
|
||||
impl From<LanceDistanceType> for DistanceType {
|
||||
fn from(value: LanceDistanceType) -> Self {
|
||||
match value {
|
||||
LanceDistanceType::L2 => DistanceType::L2,
|
||||
LanceDistanceType::Cosine => DistanceType::Cosine,
|
||||
LanceDistanceType::Dot => DistanceType::Dot,
|
||||
LanceDistanceType::L2 => Self::L2,
|
||||
LanceDistanceType::Cosine => Self::Cosine,
|
||||
LanceDistanceType::Dot => Self::Dot,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,7 +251,7 @@ impl<'a> TryFrom<&'a str> for DistanceType {
|
||||
type Error = <LanceDistanceType as TryFrom<&'a str>>::Error;
|
||||
|
||||
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
|
||||
LanceDistanceType::try_from(value).map(DistanceType::from)
|
||||
LanceDistanceType::try_from(value).map(Self::from)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -854,6 +854,7 @@ impl NativeTable {
|
||||
.to_str()
|
||||
.ok_or(Error::InvalidTableName {
|
||||
name: uri.to_string(),
|
||||
reason: "Table name is not valid URL".to_string(),
|
||||
})?;
|
||||
Ok(name.to_string())
|
||||
}
|
||||
@@ -1185,15 +1186,26 @@ impl NativeTable {
|
||||
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
|
||||
message: format!("Column {} not found in dataset schema", column),
|
||||
})?;
|
||||
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
|
||||
{
|
||||
return Err(Error::Schema {
|
||||
message: format!(
|
||||
"Vector column '{}' does not match the dimension of the query vector: dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
),
|
||||
});
|
||||
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
|
||||
if !f.data_type().is_floating() {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The data type of the vector column '{}' is not a floating point type",
|
||||
column
|
||||
),
|
||||
});
|
||||
}
|
||||
if dim != query_vector.len() as i32 {
|
||||
return Err(Error::InvalidInput {
|
||||
message: format!(
|
||||
"The dimension of the query vector does not match with the dimension of the vector column '{}':
|
||||
query dim={}, expected vector dim={}",
|
||||
column,
|
||||
query_vector.len(),
|
||||
dim,
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
let query_vector = query_vector.as_primitive::<Float32Type>();
|
||||
scanner.nearest(
|
||||
|
||||
@@ -1,12 +1,30 @@
|
||||
// Copyright 2024 LanceDB Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::Schema;
|
||||
|
||||
use lance::dataset::{ReadParams, WriteParams};
|
||||
use lance::io::{ObjectStoreParams, WrappingObjectStore};
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
lazy_static! {
|
||||
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
|
||||
}
|
||||
|
||||
pub trait PatchStoreParam {
|
||||
fn patch_with_store_wrapper(
|
||||
self,
|
||||
@@ -64,6 +82,25 @@ impl PatchReadParam for ReadParams {
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate table name.
|
||||
pub fn validate_table_name(name: &str) -> Result<()> {
|
||||
if name.is_empty() {
|
||||
return Err(Error::InvalidTableName {
|
||||
name: name.to_string(),
|
||||
reason: "Table names cannot be empty strings".to_string(),
|
||||
});
|
||||
}
|
||||
if !TABLE_NAME_REGEX.is_match(name) {
|
||||
return Err(Error::InvalidTableName {
|
||||
name: name.to_string(),
|
||||
reason:
|
||||
"Table names can only contain alphanumeric characters, underscores, hyphens, and periods"
|
||||
.to_string(),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find one default column to create index.
|
||||
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
|
||||
// Try to find one fixed size list array column.
|
||||
@@ -145,4 +182,20 @@ mod tests {
|
||||
.to_string()
|
||||
.contains("More than one"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_table_name() {
|
||||
assert!(validate_table_name("my_table").is_ok());
|
||||
assert!(validate_table_name("my_table_1").is_ok());
|
||||
assert!(validate_table_name("123mytable").is_ok());
|
||||
assert!(validate_table_name("_12345table").is_ok());
|
||||
assert!(validate_table_name("table.12345").is_ok());
|
||||
assert!(validate_table_name("table.._dot_..12345").is_ok());
|
||||
|
||||
assert!(validate_table_name("").is_err());
|
||||
assert!(validate_table_name("my_table!").is_err());
|
||||
assert!(validate_table_name("my/table").is_err());
|
||||
assert!(validate_table_name("my@table").is_err());
|
||||
assert!(validate_table_name("name with space").is_err());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user