Merge branch 'main' of https://github.com/lancedb/lancedb into docs_march

This commit is contained in:
ayush chaurasia
2024-03-23 20:35:03 +05:30
21 changed files with 534 additions and 138 deletions

View File

@@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"]
categories = ["database-implementations"]
[workspace.dependencies]
lance = { "version" = "=0.10.4", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.4" }
lance-linalg = { "version" = "=0.10.4" }
lance-testing = { "version" = "=0.10.4" }
lance = { "version" = "=0.10.5", "features" = ["dynamodb"] }
lance-index = { "version" = "=0.10.5" }
lance-linalg = { "version" = "=0.10.5" }
lance-testing = { "version" = "=0.10.5" }
# Note that this one does not include pyarrow
arrow = { version = "50.0", optional = false }
arrow-array = "50.0"
@@ -39,3 +39,5 @@ pin-project = "1.0.7"
snafu = "0.7.4"
url = "2"
num-traits = "0.2"
regex = "1.10"
lazy_static = "1"

View File

@@ -1,11 +1,79 @@
document.addEventListener("DOMContentLoaded", function () {
var script = document.createElement("script");
script.src = "https://widget.kapa.ai/kapa-widget.bundle.js";
script.setAttribute("data-website-id", "c5881fae-cec0-490b-b45e-d83d131d4f25");
script.setAttribute("data-project-name", "LanceDB");
script.setAttribute("data-project-color", "#000000");
script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/108903835?s=200&v=4");
script.setAttribute("data-modal-example-questions","Help me create an IVF_PQ index,How do I do an exhaustive search?,How do I create a LanceDB table?,Can I use my own embedding function?");
script.async = true;
document.head.appendChild(script);
});
// Creates an SVG robot icon (from Lucide)
function robotSVG() {
var svg = document.createElementNS("http://www.w3.org/2000/svg", "svg");
svg.setAttribute("width", "24");
svg.setAttribute("height", "24");
svg.setAttribute("viewBox", "0 0 24 24");
svg.setAttribute("fill", "none");
svg.setAttribute("stroke", "currentColor");
svg.setAttribute("stroke-width", "2");
svg.setAttribute("stroke-linecap", "round");
svg.setAttribute("stroke-linejoin", "round");
svg.setAttribute("class", "lucide lucide-bot-message-square");
var path1 = document.createElementNS("http://www.w3.org/2000/svg", "path");
path1.setAttribute("d", "M12 6V2H8");
svg.appendChild(path1);
var path2 = document.createElementNS("http://www.w3.org/2000/svg", "path");
path2.setAttribute("d", "m8 18-4 4V8a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2Z");
svg.appendChild(path2);
var path3 = document.createElementNS("http://www.w3.org/2000/svg", "path");
path3.setAttribute("d", "M2 12h2");
svg.appendChild(path3);
var path4 = document.createElementNS("http://www.w3.org/2000/svg", "path");
path4.setAttribute("d", "M9 11v2");
svg.appendChild(path4);
var path5 = document.createElementNS("http://www.w3.org/2000/svg", "path");
path5.setAttribute("d", "M15 11v2");
svg.appendChild(path5);
var path6 = document.createElementNS("http://www.w3.org/2000/svg", "path");
path6.setAttribute("d", "M20 12h2");
svg.appendChild(path6);
return svg
}
// Creates the Fluidic Chatbot buttom
function fluidicButton() {
var btn = document.createElement("a");
btn.href = "https://asklancedb.com";
btn.target = "_blank";
btn.style.position = "fixed";
btn.style.fontWeight = "bold";
btn.style.fontSize = ".8rem";
btn.style.right = "10px";
btn.style.bottom = "10px";
btn.style.width = "80px";
btn.style.height = "80px";
btn.style.background = "linear-gradient(135deg, #7C5EFF 0%, #625eff 100%)";
btn.style.color = "white";
btn.style.borderRadius = "5px";
btn.style.display = "flex";
btn.style.flexDirection = "column";
btn.style.justifyContent = "center";
btn.style.alignItems = "center";
btn.style.zIndex = "1000";
btn.style.opacity = "0";
btn.style.boxShadow = "0 0 0 rgba(0, 0, 0, 0)";
btn.style.transition = "opacity 0.2s ease-in, box-shadow 0.2s ease-in";
setTimeout(function() {
btn.style.opacity = "1";
btn.style.boxShadow = "0 0 .2rem #0000001a,0 .2rem .4rem #0003"
}, 0);
return btn
}
document.addEventListener("DOMContentLoaded", function() {
var btn = fluidicButton()
btn.appendChild(robotSVG());
var text = document.createTextNode("Ask AI");
btn.appendChild(text);
document.body.appendChild(btn);
});

View File

@@ -24,6 +24,7 @@ import { RemoteConnection } from './remote'
import { Query } from './query'
import { isEmbeddingFunction } from './embedding/embedding_function'
import { type Literal, toSQL } from './util'
import { type HttpMiddleware } from './middleware'
const {
databaseNew,
@@ -302,6 +303,18 @@ export interface Connection {
* @param name The name of the table to drop.
*/
dropTable(name: string): Promise<void>
/**
* Instrument the behavior of this Connection with middleware.
*
* The middleware will be called in the order they are added.
*
* Currently this functionality is only supported for remote Connections.
*
* @param {HttpMiddleware} - Middleware which will instrument the Connection.
* @returns - this Connection instrumented by the passed middleware
*/
withMiddleware(middleware: HttpMiddleware): Connection
}
/**
@@ -541,6 +554,18 @@ export interface Table<T = number[]> {
* names (e.g. "a").
*/
dropColumns(columnNames: string[]): Promise<void>
/**
* Instrument the behavior of this Table with middleware.
*
* The middleware will be called in the order they are added.
*
* Currently this functionality is only supported for remote tables.
*
* @param {HttpMiddleware} - Middleware which will instrument the Table.
* @returns - this Table instrumented by the passed middleware
*/
withMiddleware(middleware: HttpMiddleware): Table<T>
}
/**
@@ -795,6 +820,10 @@ export class LocalConnection implements Connection {
async dropTable (name: string): Promise<void> {
await databaseDropTable.call(this._db, name)
}
withMiddleware (middleware: HttpMiddleware): Connection {
return this
}
}
export class LocalTable<T = number[]> implements Table<T> {
@@ -1105,6 +1134,10 @@ export class LocalTable<T = number[]> implements Table<T> {
async dropColumns (columnNames: string[]): Promise<void> {
return tableDropColumns.call(this._tbl, columnNames)
}
withMiddleware (middleware: HttpMiddleware): Table<T> {
return this
}
}
export interface CleanupStats {

58
node/src/middleware.ts Normal file
View File

@@ -0,0 +1,58 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* Middleware for Remote LanceDB Connection or Table
*/
export interface HttpMiddleware {
/**
* A callback that can be used to instrument the behavior of http requests to remote
* tables. It can be used to add headers, modify the request, or even short-circuit
* the request and return a response without making the request to the remote endpoint.
* It can also be used to modify the response from the remote endpoint.
*
* @param {RemoteResponse} res - Request to the remote endpoint
* @param {onRemoteRequestNext} next - Callback to advance the middleware chain
*/
onRemoteRequest(
req: RemoteRequest,
next: (req: RemoteRequest) => Promise<RemoteResponse>,
): Promise<RemoteResponse>
};
export enum Method {
GET,
POST
}
/**
* A LanceDB Remote HTTP Request
*/
export interface RemoteRequest {
uri: string
method: Method
headers: Map<string, string>
params?: Map<string, string>
body?: any
}
/**
* A LanceDB Remote HTTP Response
*/
export interface RemoteResponse {
status: number
statusText: string
headers: Map<string, string>
body: () => Promise<any>
}

View File

@@ -12,13 +12,101 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import axios, { type AxiosResponse } from 'axios'
import axios, { type AxiosResponse, type ResponseType } from 'axios'
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
import { type RemoteResponse, type RemoteRequest, Method } from '../middleware'
interface HttpLancedbClientMiddleware {
onRemoteRequest(
req: RemoteRequest,
next: (req: RemoteRequest) => Promise<RemoteResponse>,
): Promise<RemoteResponse>
}
/**
* Invoke the middleware chain and at the end call the remote endpoint
*/
async function callWithMiddlewares (
req: RemoteRequest,
middlewares: HttpLancedbClientMiddleware[],
opts?: MiddlewareInvocationOptions
): Promise<RemoteResponse> {
async function call (
i: number,
req: RemoteRequest
): Promise<RemoteResponse> {
// if we have reached the end of the middleware chain, make the request
if (i > middlewares.length) {
const headers = Object.fromEntries(req.headers.entries())
const params = Object.fromEntries(req.params?.entries() ?? [])
const timeout = 10000
let res
if (req.method === Method.POST) {
res = await axios.post(
req.uri,
req.body,
{
headers,
params,
timeout,
responseType: opts?.responseType
}
)
} else {
res = await axios.get(
req.uri,
{
headers,
params,
timeout
}
)
}
return toLanceRes(res)
}
// call next middleware in chain
return await middlewares[i - 1].onRemoteRequest(
req,
async (req) => {
return await call(i + 1, req)
}
)
}
return await call(1, req)
}
interface MiddlewareInvocationOptions {
responseType?: ResponseType
}
/**
* Marshall the library response into a LanceDB response
*/
function toLanceRes (res: AxiosResponse): RemoteResponse {
const headers = new Map()
for (const h in res.headers) {
headers.set(h, res.headers[h])
}
return {
status: res.status,
statusText: res.statusText,
headers,
body: async () => {
return res.data
}
}
}
export class HttpLancedbClient {
private readonly _url: string
private readonly _apiKey: () => string
private readonly _middlewares: HttpLancedbClientMiddleware[]
public constructor (
url: string,
@@ -27,6 +115,7 @@ export class HttpLancedbClient {
) {
this._url = url
this._apiKey = () => apiKey
this._middlewares = []
}
get uri (): string {
@@ -43,74 +132,61 @@ export class HttpLancedbClient {
columns?: string[],
filter?: string
): Promise<ArrowTable<any>> {
const response = await axios.post(
`${this._url}/v1/table/${tableName}/query/`,
{
vector,
k,
nprobes,
refineFactor,
columns,
filter,
prefilter
},
{
headers: {
'Content-Type': 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
responseType: 'arraybuffer',
timeout: 10000
}
).catch((err) => {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
if (response.status !== 200) {
const errorData = new TextDecoder().decode(response.data)
throw new Error(
`Server Error, status: ${response.status as number}, ` +
`message: ${response.statusText as string}: ${errorData}`
)
}
const table = tableFromIPC(response.data)
const result = await this.post(
`/v1/table/${tableName}/query/`,
{
vector,
k,
nprobes,
refineFactor,
columns,
filter,
prefilter
},
undefined,
undefined,
'arraybuffer'
)
const table = tableFromIPC(await result.body())
return table
}
/**
* Sent GET request.
*/
public async get (path: string, params?: Record<string, string | number>): Promise<AxiosResponse> {
const response = await axios.get(
`${this._url}${path}`,
{
headers: {
'Content-Type': 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
params,
timeout: 10000
}
).catch((err) => {
public async get (path: string, params?: Record<string, string>): Promise<RemoteResponse> {
const req = {
uri: `${this._url}${path}`,
method: Method.GET,
headers: new Map(Object.entries({
'Content-Type': 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
})),
params: new Map(Object.entries(params ?? {}))
}
let response
try {
response = await callWithMiddlewares(req, this._middlewares)
return response
} catch (err: any) {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
response = toLanceRes(err.response)
}
if (response.status !== 200) {
const errorData = new TextDecoder().decode(response.data)
const errorData = new TextDecoder().decode(await response.body())
throw new Error(
`Server Error, status: ${response.status as number}, ` +
`message: ${response.statusText as string}: ${errorData}`
`Server Error, status: ${response.status}, ` +
`message: ${response.statusText}: ${errorData}`
)
}
return response
}
@@ -120,35 +196,65 @@ export class HttpLancedbClient {
public async post (
path: string,
data?: any,
params?: Record<string, string | number>,
content?: string | undefined
): Promise<AxiosResponse> {
const response = await axios.post(
`${this._url}${path}`,
data,
{
headers: {
'Content-Type': content ?? 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
params,
timeout: 30000
}
).catch((err) => {
params?: Record<string, string>,
content?: string | undefined,
responseType?: ResponseType | undefined
): Promise<RemoteResponse> {
const req = {
uri: `${this._url}${path}`,
method: Method.POST,
headers: new Map(Object.entries({
'Content-Type': content ?? 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
})),
params: new Map(Object.entries(params ?? {})),
body: data
}
let response
try {
response = await callWithMiddlewares(req, this._middlewares, { responseType })
// return response
} catch (err: any) {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
response = toLanceRes(err.response)
}
if (response.status !== 200) {
const errorData = new TextDecoder().decode(response.data)
const errorData = new TextDecoder().decode(await response.body())
throw new Error(
`Server Error, status: ${response.status as number}, ` +
`message: ${response.statusText as string}: ${errorData}`
`Server Error, status: ${response.status}, ` +
`message: ${response.statusText}: ${errorData}`
)
}
return response
}
/**
* Instrument this client with middleware
* @param mw - The middleware that instruments the client
* @returns - an instance of this client instrumented with the middleware
*/
public withMiddleware (mw: HttpLancedbClientMiddleware): HttpLancedbClient {
const wrapped = this.clone()
wrapped._middlewares.push(mw)
return wrapped
}
/**
* Make a clone of this client
*/
private clone (): HttpLancedbClient {
const clone = new HttpLancedbClient(this._url, this._apiKey(), this._dbName)
for (const mw of this._middlewares) {
clone._middlewares.push(mw)
}
return clone
}
}

View File

@@ -39,12 +39,13 @@ import {
fromTableToStreamBuffer
} from '../arrow'
import { toSQL } from '../util'
import { type HttpMiddleware } from '../middleware'
/**
* Remote connection.
*/
export class RemoteConnection implements Connection {
private readonly _client: HttpLancedbClient
private _client: HttpLancedbClient
private readonly _dbName: string
constructor (opts: ConnectionOptions) {
@@ -84,10 +85,11 @@ export class RemoteConnection implements Connection {
limit: number = 10
): Promise<string[]> {
const response = await this._client.get('/v1/table/', {
limit,
limit: `${limit}`,
page_token: pageToken
})
return response.data.tables
const body = await response.body()
return body.tables
}
async openTable (name: string): Promise<Table>
@@ -163,7 +165,7 @@ export class RemoteConnection implements Connection {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
@@ -177,6 +179,17 @@ export class RemoteConnection implements Connection {
async dropTable (name: string): Promise<void> {
await this._client.post(`/v1/table/${name}/drop/`)
}
withMiddleware (middleware: HttpMiddleware): Connection {
const wrapped = this.clone()
wrapped._client = wrapped._client.withMiddleware(middleware)
return wrapped
}
private clone (): RemoteConnection {
const clone: RemoteConnection = Object.create(RemoteConnection.prototype)
return Object.assign(clone, this)
}
}
export class RemoteQuery<T = number[]> extends Query<T> {
@@ -229,7 +242,7 @@ export class RemoteQuery<T = number[]> extends Query<T> {
// we are using extend until we have next next version release
// Table and Connection has both been refactored to interfaces
export class RemoteTable<T = number[]> implements Table<T> {
private readonly _client: HttpLancedbClient
private _client: HttpLancedbClient
private readonly _embeddings?: EmbeddingFunction<T>
private readonly _name: string
@@ -256,15 +269,15 @@ export class RemoteTable<T = number[]> implements Table<T> {
get schema (): Promise<any> {
return this._client
.post(`/v1/table/${this._name}/describe/`)
.then((res) => {
.then(async (res) => {
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
return res.data?.schema
return (await res.body())?.schema
})
}
@@ -320,7 +333,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
}
@@ -346,7 +359,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
return tbl.numRows
@@ -372,7 +385,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
return tbl.numRows
@@ -415,7 +428,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
}
@@ -436,14 +449,14 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
}
async countRows (): Promise<number> {
const result = await this._client.post(`/v1/table/${this._name}/describe/`)
return result.data?.stats?.num_rows
return (await result.body())?.stats?.num_rows
}
async delete (filter: string): Promise<void> {
@@ -476,7 +489,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
const results = await this._client.post(
`/v1/table/${this._name}/index/list/`
)
return results.data.indexes?.map((index: any) => ({
return (await results.body()).indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
uuid: index.index_uuid
@@ -487,9 +500,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
const results = await this._client.post(
`/v1/table/${this._name}/index/${indexUuid}/stats/`
)
const body = await results.body()
return {
numIndexedRows: results.data.num_indexed_rows,
numUnindexedRows: results.data.num_unindexed_rows
numIndexedRows: body?.num_indexed_rows,
numUnindexedRows: body?.num_unindexed_rows
}
}
@@ -504,4 +518,15 @@ export class RemoteTable<T = number[]> implements Table<T> {
async dropColumns (columnNames: string[]): Promise<void> {
throw new Error('Drop columns is not yet supported in LanceDB Cloud.')
}
withMiddleware(middleware: HttpMiddleware): Table<T> {
const wrapped = this.clone()
wrapped._client = wrapped._client.withMiddleware(middleware)
return wrapped
}
private clone (): RemoteTable<T> {
const clone: RemoteTable<T> = Object.create(RemoteTable.prototype)
return Object.assign(clone, this)
}
}

View File

@@ -240,7 +240,7 @@ describe("When creating an index", () => {
)
.column("vec")
.toArrow(),
).rejects.toThrow(/.*does not match the dimension.*/);
).rejects.toThrow(/.* query dim=64, expected vector dim=32.*/);
const query64 = Array(64)
.fill(1)

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.6.4
current_version = 0.6.5
commit = True
message = [python] Bump version: {current_version} → {new_version}
tag = True

View File

@@ -1,9 +1,9 @@
[project]
name = "lancedb"
version = "0.6.4"
version = "0.6.5"
dependencies = [
"deprecation",
"pylance==0.10.4",
"pylance==0.10.5",
"ratelimiter~=1.0",
"retry>=0.9.2",
"tqdm>=4.27.0",
@@ -94,13 +94,11 @@ lancedb = "lancedb.cli.cli:cli"
requires = ["maturin>=1.4"]
build-backend = "maturin"
[tool.ruff.lint]
select = ["F", "E", "W", "I", "G", "TCH", "PERF"]
[tool.pytest.ini_options]
addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py"
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"asyncio",

View File

@@ -31,7 +31,13 @@ from lancedb.utils.events import register_event
from ._lancedb import connect as lancedb_connect
from .pydantic import LanceModel
from .table import AsyncTable, LanceTable, Table, _sanitize_data
from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri
from .util import (
fs_from_uri,
get_uri_location,
get_uri_scheme,
join_uri,
validate_table_name,
)
if TYPE_CHECKING:
from datetime import timedelta
@@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection):
"""
if mode.lower() not in ["create", "overwrite"]:
raise ValueError("mode must be either 'create' or 'overwrite'")
validate_table_name(name)
tbl = LanceTable.create(
self,

View File

@@ -26,6 +26,7 @@ from ..db import DBConnection
from ..embeddings import EmbeddingFunctionConfig
from ..pydantic import LanceModel
from ..table import Table, _sanitize_data
from ..util import validate_table_name
from .arrow import to_ipc_binary
from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient
from .errors import LanceDBClientError
@@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection):
LanceTable(table4)
"""
validate_table_name(name)
if data is None and schema is None:
raise ValueError("Either data or schema must be provided.")
if embedding_functions is not None:

View File

@@ -25,6 +25,8 @@ import numpy as np
import pyarrow as pa
import pyarrow.fs as pa_fs
from ._lancedb import validate_table_name as native_validate_table_name
def safe_import_adlfs():
try:
@@ -286,3 +288,8 @@ def deprecated(func):
return func(*args, **kwargs)
return new_func
def validate_table_name(name: str):
"""Verify the table name is valid."""
native_validate_table_name(name)

View File

@@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path):
.to_arrow()
)
assert table.num_rows == 1
def test_create_table_with_invalid_names(tmp_path):
db = lancedb.connect(uri=tmp_path)
data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)]
with pytest.raises(ValueError):
db.create_table("foo/bar", data)
with pytest.raises(ValueError):
db.create_table("foo bar", data)
with pytest.raises(ValueError):
db.create_table("foo$$bar", data)
db.create_table("foo.bar", data)

View File

@@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<VectorQuery>()?;
m.add_class::<RecordBatchStream>()?;
m.add_function(wrap_pyfunction!(connect, m)?)?;
m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}

View File

@@ -3,7 +3,7 @@ use std::sync::Mutex;
use lancedb::DistanceType;
use pyo3::{
exceptions::{PyRuntimeError, PyValueError},
PyResult,
pyfunction, PyResult,
};
/// A wrapper around a rust builder
@@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef<str>) -> PyResult<DistanceT
))),
}
}
#[pyfunction]
pub(crate) fn validate_table_name(table_name: &str) -> PyResult<()> {
lancedb::utils::validate_table_name(table_name)
.map_err(|e| PyValueError::new_err(e.to_string()))
}

View File

@@ -22,6 +22,7 @@ chrono = { workspace = true }
object_store = { workspace = true }
snafu = { workspace = true }
half = { workspace = true }
lazy_static.workspace = true
lance = { workspace = true }
lance-index = { workspace = true }
lance-linalg = { workspace = true }
@@ -34,11 +35,10 @@ bytes = "1"
futures.workspace = true
num-traits.workspace = true
url.workspace = true
regex.workspace = true
serde = { version = "^1" }
serde_json = { version = "1" }
# For remote feature
reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true }
[dev-dependencies]

View File

@@ -31,6 +31,7 @@ use crate::arrow::IntoArrow;
use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result};
use crate::io::object_store::MirroringObjectStoreWrapper;
use crate::table::{NativeTable, WriteOptions};
use crate::utils::validate_table_name;
use crate::Table;
pub const LANCE_FILE_EXTENSION: &str = "lance";
@@ -675,13 +676,18 @@ impl Database {
/// Get the URI of a table in the database.
fn table_uri(&self, name: &str) -> Result<String> {
validate_table_name(name)?;
let path = Path::new(&self.uri);
let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION));
let mut uri = table_uri
.as_path()
.to_str()
.context(InvalidTableNameSnafu { name })?
.context(InvalidTableNameSnafu {
name,
reason: "Name is not valid URL",
})?
.to_string();
// If there are query string set on the connection, propagate to lance

View File

@@ -20,8 +20,8 @@ use snafu::Snafu;
#[derive(Debug, Snafu)]
#[snafu(visibility(pub(crate)))]
pub enum Error {
#[snafu(display("Invalid table name: {name}"))]
InvalidTableName { name: String },
#[snafu(display("Invalid table name (\"{name}\"): {reason}"))]
InvalidTableName { name: String, reason: String },
#[snafu(display("Invalid input, {message}"))]
InvalidInput { message: String },
#[snafu(display("Table '{name}' was not found"))]

View File

@@ -230,9 +230,9 @@ pub enum DistanceType {
impl From<DistanceType> for LanceDistanceType {
fn from(value: DistanceType) -> Self {
match value {
DistanceType::L2 => LanceDistanceType::L2,
DistanceType::Cosine => LanceDistanceType::Cosine,
DistanceType::Dot => LanceDistanceType::Dot,
DistanceType::L2 => Self::L2,
DistanceType::Cosine => Self::Cosine,
DistanceType::Dot => Self::Dot,
}
}
}
@@ -240,9 +240,9 @@ impl From<DistanceType> for LanceDistanceType {
impl From<LanceDistanceType> for DistanceType {
fn from(value: LanceDistanceType) -> Self {
match value {
LanceDistanceType::L2 => DistanceType::L2,
LanceDistanceType::Cosine => DistanceType::Cosine,
LanceDistanceType::Dot => DistanceType::Dot,
LanceDistanceType::L2 => Self::L2,
LanceDistanceType::Cosine => Self::Cosine,
LanceDistanceType::Dot => Self::Dot,
}
}
}
@@ -251,7 +251,7 @@ impl<'a> TryFrom<&'a str> for DistanceType {
type Error = <LanceDistanceType as TryFrom<&'a str>>::Error;
fn try_from(value: &str) -> std::prelude::v1::Result<Self, Self::Error> {
LanceDistanceType::try_from(value).map(DistanceType::from)
LanceDistanceType::try_from(value).map(Self::from)
}
}

View File

@@ -854,6 +854,7 @@ impl NativeTable {
.to_str()
.ok_or(Error::InvalidTableName {
name: uri.to_string(),
reason: "Table name is not valid URL".to_string(),
})?;
Ok(name.to_string())
}
@@ -1185,15 +1186,26 @@ impl NativeTable {
let field = ds_ref.schema().field(&column).ok_or(Error::Schema {
message: format!("Column {} not found in dataset schema", column),
})?;
if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32)
{
return Err(Error::Schema {
message: format!(
"Vector column '{}' does not match the dimension of the query vector: dim={}",
column,
query_vector.len(),
),
});
if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() {
if !f.data_type().is_floating() {
return Err(Error::InvalidInput {
message: format!(
"The data type of the vector column '{}' is not a floating point type",
column
),
});
}
if dim != query_vector.len() as i32 {
return Err(Error::InvalidInput {
message: format!(
"The dimension of the query vector does not match with the dimension of the vector column '{}':
query dim={}, expected vector dim={}",
column,
query_vector.len(),
dim,
),
});
}
}
let query_vector = query_vector.as_primitive::<Float32Type>();
scanner.nearest(

View File

@@ -1,12 +1,30 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::sync::Arc;
use arrow_schema::Schema;
use lance::dataset::{ReadParams, WriteParams};
use lance::io::{ObjectStoreParams, WrappingObjectStore};
use lazy_static::lazy_static;
use crate::error::{Error, Result};
lazy_static! {
static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap();
}
pub trait PatchStoreParam {
fn patch_with_store_wrapper(
self,
@@ -64,6 +82,25 @@ impl PatchReadParam for ReadParams {
}
}
/// Validate table name.
pub fn validate_table_name(name: &str) -> Result<()> {
if name.is_empty() {
return Err(Error::InvalidTableName {
name: name.to_string(),
reason: "Table names cannot be empty strings".to_string(),
});
}
if !TABLE_NAME_REGEX.is_match(name) {
return Err(Error::InvalidTableName {
name: name.to_string(),
reason:
"Table names can only contain alphanumeric characters, underscores, hyphens, and periods"
.to_string(),
});
}
Ok(())
}
/// Find one default column to create index.
pub(crate) fn default_vector_column(schema: &Schema, dim: Option<i32>) -> Result<String> {
// Try to find one fixed size list array column.
@@ -145,4 +182,20 @@ mod tests {
.to_string()
.contains("More than one"));
}
#[test]
fn test_validate_table_name() {
assert!(validate_table_name("my_table").is_ok());
assert!(validate_table_name("my_table_1").is_ok());
assert!(validate_table_name("123mytable").is_ok());
assert!(validate_table_name("_12345table").is_ok());
assert!(validate_table_name("table.12345").is_ok());
assert!(validate_table_name("table.._dot_..12345").is_ok());
assert!(validate_table_name("").is_err());
assert!(validate_table_name("my_table!").is_err());
assert!(validate_table_name("my/table").is_err());
assert!(validate_table_name("my@table").is_err());
assert!(validate_table_name("name with space").is_err());
}
}