diff --git a/Cargo.toml b/Cargo.toml index 854650cd..d08e9ae8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,10 @@ keywords = ["lancedb", "lance", "database", "vector", "search"] categories = ["database-implementations"] [workspace.dependencies] -lance = { "version" = "=0.10.4", "features" = ["dynamodb"] } -lance-index = { "version" = "=0.10.4" } -lance-linalg = { "version" = "=0.10.4" } -lance-testing = { "version" = "=0.10.4" } +lance = { "version" = "=0.10.5", "features" = ["dynamodb"] } +lance-index = { "version" = "=0.10.5" } +lance-linalg = { "version" = "=0.10.5" } +lance-testing = { "version" = "=0.10.5" } # Note that this one does not include pyarrow arrow = { version = "50.0", optional = false } arrow-array = "50.0" @@ -39,3 +39,5 @@ pin-project = "1.0.7" snafu = "0.7.4" url = "2" num-traits = "0.2" +regex = "1.10" +lazy_static = "1" diff --git a/docs/src/extra_js/init_ask_ai_widget.js b/docs/src/extra_js/init_ask_ai_widget.js index 366e5071..02e7ce37 100644 --- a/docs/src/extra_js/init_ask_ai_widget.js +++ b/docs/src/extra_js/init_ask_ai_widget.js @@ -1,11 +1,79 @@ -document.addEventListener("DOMContentLoaded", function () { - var script = document.createElement("script"); - script.src = "https://widget.kapa.ai/kapa-widget.bundle.js"; - script.setAttribute("data-website-id", "c5881fae-cec0-490b-b45e-d83d131d4f25"); - script.setAttribute("data-project-name", "LanceDB"); - script.setAttribute("data-project-color", "#000000"); - script.setAttribute("data-project-logo", "https://avatars.githubusercontent.com/u/108903835?s=200&v=4"); - script.setAttribute("data-modal-example-questions","Help me create an IVF_PQ index,How do I do an exhaustive search?,How do I create a LanceDB table?,Can I use my own embedding function?"); - script.async = true; - document.head.appendChild(script); - }); \ No newline at end of file +// Creates an SVG robot icon (from Lucide) +function robotSVG() { + var svg = document.createElementNS("http://www.w3.org/2000/svg", "svg"); + svg.setAttribute("width", "24"); + svg.setAttribute("height", "24"); + svg.setAttribute("viewBox", "0 0 24 24"); + svg.setAttribute("fill", "none"); + svg.setAttribute("stroke", "currentColor"); + svg.setAttribute("stroke-width", "2"); + svg.setAttribute("stroke-linecap", "round"); + svg.setAttribute("stroke-linejoin", "round"); + svg.setAttribute("class", "lucide lucide-bot-message-square"); + + var path1 = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path1.setAttribute("d", "M12 6V2H8"); + svg.appendChild(path1); + + var path2 = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path2.setAttribute("d", "m8 18-4 4V8a2 2 0 0 1 2-2h12a2 2 0 0 1 2 2v8a2 2 0 0 1-2 2Z"); + svg.appendChild(path2); + + var path3 = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path3.setAttribute("d", "M2 12h2"); + svg.appendChild(path3); + + var path4 = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path4.setAttribute("d", "M9 11v2"); + svg.appendChild(path4); + + var path5 = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path5.setAttribute("d", "M15 11v2"); + svg.appendChild(path5); + + var path6 = document.createElementNS("http://www.w3.org/2000/svg", "path"); + path6.setAttribute("d", "M20 12h2"); + svg.appendChild(path6); + + return svg +} + +// Creates the Fluidic Chatbot buttom +function fluidicButton() { + var btn = document.createElement("a"); + btn.href = "https://asklancedb.com"; + btn.target = "_blank"; + btn.style.position = "fixed"; + btn.style.fontWeight = "bold"; + btn.style.fontSize = ".8rem"; + btn.style.right = "10px"; + btn.style.bottom = "10px"; + btn.style.width = "80px"; + btn.style.height = "80px"; + btn.style.background = "linear-gradient(135deg, #7C5EFF 0%, #625eff 100%)"; + btn.style.color = "white"; + btn.style.borderRadius = "5px"; + btn.style.display = "flex"; + btn.style.flexDirection = "column"; + btn.style.justifyContent = "center"; + btn.style.alignItems = "center"; + btn.style.zIndex = "1000"; + btn.style.opacity = "0"; + btn.style.boxShadow = "0 0 0 rgba(0, 0, 0, 0)"; + btn.style.transition = "opacity 0.2s ease-in, box-shadow 0.2s ease-in"; + + setTimeout(function() { + btn.style.opacity = "1"; + btn.style.boxShadow = "0 0 .2rem #0000001a,0 .2rem .4rem #0003" + }, 0); + + return btn +} + +document.addEventListener("DOMContentLoaded", function() { + var btn = fluidicButton() + btn.appendChild(robotSVG()); + var text = document.createTextNode("Ask AI"); + btn.appendChild(text); + document.body.appendChild(btn); +}); diff --git a/node/src/index.ts b/node/src/index.ts index a11b5d85..72153d48 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -24,6 +24,7 @@ import { RemoteConnection } from './remote' import { Query } from './query' import { isEmbeddingFunction } from './embedding/embedding_function' import { type Literal, toSQL } from './util' +import { type HttpMiddleware } from './middleware' const { databaseNew, @@ -302,6 +303,18 @@ export interface Connection { * @param name The name of the table to drop. */ dropTable(name: string): Promise + + /** + * Instrument the behavior of this Connection with middleware. + * + * The middleware will be called in the order they are added. + * + * Currently this functionality is only supported for remote Connections. + * + * @param {HttpMiddleware} - Middleware which will instrument the Connection. + * @returns - this Connection instrumented by the passed middleware + */ + withMiddleware(middleware: HttpMiddleware): Connection } /** @@ -541,6 +554,18 @@ export interface Table { * names (e.g. "a"). */ dropColumns(columnNames: string[]): Promise + + /** + * Instrument the behavior of this Table with middleware. + * + * The middleware will be called in the order they are added. + * + * Currently this functionality is only supported for remote tables. + * + * @param {HttpMiddleware} - Middleware which will instrument the Table. + * @returns - this Table instrumented by the passed middleware + */ + withMiddleware(middleware: HttpMiddleware): Table } /** @@ -795,6 +820,10 @@ export class LocalConnection implements Connection { async dropTable (name: string): Promise { await databaseDropTable.call(this._db, name) } + + withMiddleware (middleware: HttpMiddleware): Connection { + return this + } } export class LocalTable implements Table { @@ -1105,6 +1134,10 @@ export class LocalTable implements Table { async dropColumns (columnNames: string[]): Promise { return tableDropColumns.call(this._tbl, columnNames) } + + withMiddleware (middleware: HttpMiddleware): Table { + return this + } } export interface CleanupStats { diff --git a/node/src/middleware.ts b/node/src/middleware.ts new file mode 100644 index 00000000..2eb5446d --- /dev/null +++ b/node/src/middleware.ts @@ -0,0 +1,58 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Middleware for Remote LanceDB Connection or Table + */ +export interface HttpMiddleware { + /** + * A callback that can be used to instrument the behavior of http requests to remote + * tables. It can be used to add headers, modify the request, or even short-circuit + * the request and return a response without making the request to the remote endpoint. + * It can also be used to modify the response from the remote endpoint. + * + * @param {RemoteResponse} res - Request to the remote endpoint + * @param {onRemoteRequestNext} next - Callback to advance the middleware chain + */ + onRemoteRequest( + req: RemoteRequest, + next: (req: RemoteRequest) => Promise, + ): Promise +}; + +export enum Method { + GET, + POST +} + +/** + * A LanceDB Remote HTTP Request + */ +export interface RemoteRequest { + uri: string + method: Method + headers: Map + params?: Map + body?: any +} + +/** + * A LanceDB Remote HTTP Response + */ +export interface RemoteResponse { + status: number + statusText: string + headers: Map + body: () => Promise +} diff --git a/node/src/remote/client.ts b/node/src/remote/client.ts index 3d4d59a2..91fa6857 100644 --- a/node/src/remote/client.ts +++ b/node/src/remote/client.ts @@ -12,13 +12,101 @@ // See the License for the specific language governing permissions and // limitations under the License. -import axios, { type AxiosResponse } from 'axios' +import axios, { type AxiosResponse, type ResponseType } from 'axios' import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow' +import { type RemoteResponse, type RemoteRequest, Method } from '../middleware' + +interface HttpLancedbClientMiddleware { + onRemoteRequest( + req: RemoteRequest, + next: (req: RemoteRequest) => Promise, + ): Promise +} + +/** + * Invoke the middleware chain and at the end call the remote endpoint + */ +async function callWithMiddlewares ( + req: RemoteRequest, + middlewares: HttpLancedbClientMiddleware[], + opts?: MiddlewareInvocationOptions +): Promise { + async function call ( + i: number, + req: RemoteRequest + ): Promise { + // if we have reached the end of the middleware chain, make the request + if (i > middlewares.length) { + const headers = Object.fromEntries(req.headers.entries()) + const params = Object.fromEntries(req.params?.entries() ?? []) + const timeout = 10000 + let res + if (req.method === Method.POST) { + res = await axios.post( + req.uri, + req.body, + { + headers, + params, + timeout, + responseType: opts?.responseType + } + ) + } else { + res = await axios.get( + req.uri, + { + headers, + params, + timeout + } + ) + } + + return toLanceRes(res) + } + + // call next middleware in chain + return await middlewares[i - 1].onRemoteRequest( + req, + async (req) => { + return await call(i + 1, req) + } + ) + } + + return await call(1, req) +} + +interface MiddlewareInvocationOptions { + responseType?: ResponseType +} + +/** + * Marshall the library response into a LanceDB response + */ +function toLanceRes (res: AxiosResponse): RemoteResponse { + const headers = new Map() + for (const h in res.headers) { + headers.set(h, res.headers[h]) + } + + return { + status: res.status, + statusText: res.statusText, + headers, + body: async () => { + return res.data + } + } +} + export class HttpLancedbClient { private readonly _url: string private readonly _apiKey: () => string + private readonly _middlewares: HttpLancedbClientMiddleware[] public constructor ( url: string, @@ -27,6 +115,7 @@ export class HttpLancedbClient { ) { this._url = url this._apiKey = () => apiKey + this._middlewares = [] } get uri (): string { @@ -43,74 +132,61 @@ export class HttpLancedbClient { columns?: string[], filter?: string ): Promise> { - const response = await axios.post( - `${this._url}/v1/table/${tableName}/query/`, - { - vector, - k, - nprobes, - refineFactor, - columns, - filter, - prefilter - }, - { - headers: { - 'Content-Type': 'application/json', - 'x-api-key': this._apiKey(), - ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) - }, - responseType: 'arraybuffer', - timeout: 10000 - } - ).catch((err) => { - console.error('error: ', err) - if (err.response === undefined) { - throw new Error(`Network Error: ${err.message as string}`) - } - return err.response - }) - if (response.status !== 200) { - const errorData = new TextDecoder().decode(response.data) - throw new Error( - `Server Error, status: ${response.status as number}, ` + - `message: ${response.statusText as string}: ${errorData}` - ) - } - - const table = tableFromIPC(response.data) + const result = await this.post( + `/v1/table/${tableName}/query/`, + { + vector, + k, + nprobes, + refineFactor, + columns, + filter, + prefilter + }, + undefined, + undefined, + 'arraybuffer' + ) + const table = tableFromIPC(await result.body()) return table } /** * Sent GET request. */ - public async get (path: string, params?: Record): Promise { - const response = await axios.get( - `${this._url}${path}`, - { - headers: { - 'Content-Type': 'application/json', - 'x-api-key': this._apiKey(), - ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) - }, - params, - timeout: 10000 - } - ).catch((err) => { + public async get (path: string, params?: Record): Promise { + const req = { + uri: `${this._url}${path}`, + method: Method.GET, + headers: new Map(Object.entries({ + 'Content-Type': 'application/json', + 'x-api-key': this._apiKey(), + ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) + })), + params: new Map(Object.entries(params ?? {})) + } + + let response + try { + response = await callWithMiddlewares(req, this._middlewares) + return response + } catch (err: any) { console.error('error: ', err) if (err.response === undefined) { throw new Error(`Network Error: ${err.message as string}`) } - return err.response - }) + + response = toLanceRes(err.response) + } + if (response.status !== 200) { - const errorData = new TextDecoder().decode(response.data) + const errorData = new TextDecoder().decode(await response.body()) throw new Error( - `Server Error, status: ${response.status as number}, ` + - `message: ${response.statusText as string}: ${errorData}` + `Server Error, status: ${response.status}, ` + + `message: ${response.statusText}: ${errorData}` ) } + return response } @@ -120,35 +196,65 @@ export class HttpLancedbClient { public async post ( path: string, data?: any, - params?: Record, - content?: string | undefined - ): Promise { - const response = await axios.post( - `${this._url}${path}`, - data, - { - headers: { - 'Content-Type': content ?? 'application/json', - 'x-api-key': this._apiKey(), - ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) - }, - params, - timeout: 30000 - } - ).catch((err) => { + params?: Record, + content?: string | undefined, + responseType?: ResponseType | undefined + ): Promise { + const req = { + uri: `${this._url}${path}`, + method: Method.POST, + headers: new Map(Object.entries({ + 'Content-Type': content ?? 'application/json', + 'x-api-key': this._apiKey(), + ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) + })), + params: new Map(Object.entries(params ?? {})), + body: data + } + + let response + try { + response = await callWithMiddlewares(req, this._middlewares, { responseType }) + + // return response + } catch (err: any) { console.error('error: ', err) if (err.response === undefined) { throw new Error(`Network Error: ${err.message as string}`) } - return err.response - }) + response = toLanceRes(err.response) + } + if (response.status !== 200) { - const errorData = new TextDecoder().decode(response.data) + const errorData = new TextDecoder().decode(await response.body()) throw new Error( - `Server Error, status: ${response.status as number}, ` + - `message: ${response.statusText as string}: ${errorData}` + `Server Error, status: ${response.status}, ` + + `message: ${response.statusText}: ${errorData}` ) } + return response } + + /** + * Instrument this client with middleware + * @param mw - The middleware that instruments the client + * @returns - an instance of this client instrumented with the middleware + */ + public withMiddleware (mw: HttpLancedbClientMiddleware): HttpLancedbClient { + const wrapped = this.clone() + wrapped._middlewares.push(mw) + return wrapped + } + + /** + * Make a clone of this client + */ + private clone (): HttpLancedbClient { + const clone = new HttpLancedbClient(this._url, this._apiKey(), this._dbName) + for (const mw of this._middlewares) { + clone._middlewares.push(mw) + } + return clone + } } diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 6e6e590a..78807b14 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -39,12 +39,13 @@ import { fromTableToStreamBuffer } from '../arrow' import { toSQL } from '../util' +import { type HttpMiddleware } from '../middleware' /** * Remote connection. */ export class RemoteConnection implements Connection { - private readonly _client: HttpLancedbClient + private _client: HttpLancedbClient private readonly _dbName: string constructor (opts: ConnectionOptions) { @@ -84,10 +85,11 @@ export class RemoteConnection implements Connection { limit: number = 10 ): Promise { const response = await this._client.get('/v1/table/', { - limit, + limit: `${limit}`, page_token: pageToken }) - return response.data.tables + const body = await response.body() + return body.tables } async openTable (name: string): Promise @@ -163,7 +165,7 @@ export class RemoteConnection implements Connection { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } @@ -177,6 +179,17 @@ export class RemoteConnection implements Connection { async dropTable (name: string): Promise { await this._client.post(`/v1/table/${name}/drop/`) } + + withMiddleware (middleware: HttpMiddleware): Connection { + const wrapped = this.clone() + wrapped._client = wrapped._client.withMiddleware(middleware) + return wrapped + } + + private clone (): RemoteConnection { + const clone: RemoteConnection = Object.create(RemoteConnection.prototype) + return Object.assign(clone, this) + } } export class RemoteQuery extends Query { @@ -229,7 +242,7 @@ export class RemoteQuery extends Query { // we are using extend until we have next next version release // Table and Connection has both been refactored to interfaces export class RemoteTable implements Table { - private readonly _client: HttpLancedbClient + private _client: HttpLancedbClient private readonly _embeddings?: EmbeddingFunction private readonly _name: string @@ -256,15 +269,15 @@ export class RemoteTable implements Table { get schema (): Promise { return this._client .post(`/v1/table/${this._name}/describe/`) - .then((res) => { + .then(async (res) => { if (res.status !== 200) { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } - return res.data?.schema + return (await res.body())?.schema }) } @@ -320,7 +333,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } } @@ -346,7 +359,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } return tbl.numRows @@ -372,7 +385,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } return tbl.numRows @@ -415,7 +428,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } } @@ -436,14 +449,14 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } } async countRows (): Promise { const result = await this._client.post(`/v1/table/${this._name}/describe/`) - return result.data?.stats?.num_rows + return (await result.body())?.stats?.num_rows } async delete (filter: string): Promise { @@ -476,7 +489,7 @@ export class RemoteTable implements Table { const results = await this._client.post( `/v1/table/${this._name}/index/list/` ) - return results.data.indexes?.map((index: any) => ({ + return (await results.body()).indexes?.map((index: any) => ({ columns: index.columns, name: index.index_name, uuid: index.index_uuid @@ -487,9 +500,10 @@ export class RemoteTable implements Table { const results = await this._client.post( `/v1/table/${this._name}/index/${indexUuid}/stats/` ) + const body = await results.body() return { - numIndexedRows: results.data.num_indexed_rows, - numUnindexedRows: results.data.num_unindexed_rows + numIndexedRows: body?.num_indexed_rows, + numUnindexedRows: body?.num_unindexed_rows } } @@ -504,4 +518,15 @@ export class RemoteTable implements Table { async dropColumns (columnNames: string[]): Promise { throw new Error('Drop columns is not yet supported in LanceDB Cloud.') } + + withMiddleware(middleware: HttpMiddleware): Table { + const wrapped = this.clone() + wrapped._client = wrapped._client.withMiddleware(middleware) + return wrapped + } + + private clone (): RemoteTable { + const clone: RemoteTable = Object.create(RemoteTable.prototype) + return Object.assign(clone, this) + } } diff --git a/nodejs/__test__/table.test.ts b/nodejs/__test__/table.test.ts index 3a673675..e2c10cbe 100644 --- a/nodejs/__test__/table.test.ts +++ b/nodejs/__test__/table.test.ts @@ -240,7 +240,7 @@ describe("When creating an index", () => { ) .column("vec") .toArrow(), - ).rejects.toThrow(/.*does not match the dimension.*/); + ).rejects.toThrow(/.* query dim=64, expected vector dim=32.*/); const query64 = Array(64) .fill(1) diff --git a/python/.bumpversion.cfg b/python/.bumpversion.cfg index 906fddc8..2d68af90 100644 --- a/python/.bumpversion.cfg +++ b/python/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.6.4 +current_version = 0.6.5 commit = True message = [python] Bump version: {current_version} → {new_version} tag = True diff --git a/python/pyproject.toml b/python/pyproject.toml index 292bd5f0..4c0dcfa9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "lancedb" -version = "0.6.4" +version = "0.6.5" dependencies = [ "deprecation", - "pylance==0.10.4", + "pylance==0.10.5", "ratelimiter~=1.0", "retry>=0.9.2", "tqdm>=4.27.0", @@ -94,13 +94,11 @@ lancedb = "lancedb.cli.cli:cli" requires = ["maturin>=1.4"] build-backend = "maturin" - [tool.ruff.lint] select = ["F", "E", "W", "I", "G", "TCH", "PERF"] [tool.pytest.ini_options] addopts = "--strict-markers --ignore-glob=lancedb/embeddings/*.py" - markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "asyncio", diff --git a/python/python/lancedb/db.py b/python/python/lancedb/db.py index 486cd30b..0bd98016 100644 --- a/python/python/lancedb/db.py +++ b/python/python/lancedb/db.py @@ -31,7 +31,13 @@ from lancedb.utils.events import register_event from ._lancedb import connect as lancedb_connect from .pydantic import LanceModel from .table import AsyncTable, LanceTable, Table, _sanitize_data -from .util import fs_from_uri, get_uri_location, get_uri_scheme, join_uri +from .util import ( + fs_from_uri, + get_uri_location, + get_uri_scheme, + join_uri, + validate_table_name, +) if TYPE_CHECKING: from datetime import timedelta @@ -387,6 +393,7 @@ class LanceDBConnection(DBConnection): """ if mode.lower() not in ["create", "overwrite"]: raise ValueError("mode must be either 'create' or 'overwrite'") + validate_table_name(name) tbl = LanceTable.create( self, diff --git a/python/python/lancedb/remote/db.py b/python/python/lancedb/remote/db.py index f2ded712..9dff65c5 100644 --- a/python/python/lancedb/remote/db.py +++ b/python/python/lancedb/remote/db.py @@ -26,6 +26,7 @@ from ..db import DBConnection from ..embeddings import EmbeddingFunctionConfig from ..pydantic import LanceModel from ..table import Table, _sanitize_data +from ..util import validate_table_name from .arrow import to_ipc_binary from .client import ARROW_STREAM_CONTENT_TYPE, RestfulLanceDBClient from .errors import LanceDBClientError @@ -223,6 +224,7 @@ class RemoteDBConnection(DBConnection): LanceTable(table4) """ + validate_table_name(name) if data is None and schema is None: raise ValueError("Either data or schema must be provided.") if embedding_functions is not None: diff --git a/python/python/lancedb/util.py b/python/python/lancedb/util.py index f5987c06..4470754d 100644 --- a/python/python/lancedb/util.py +++ b/python/python/lancedb/util.py @@ -25,6 +25,8 @@ import numpy as np import pyarrow as pa import pyarrow.fs as pa_fs +from ._lancedb import validate_table_name as native_validate_table_name + def safe_import_adlfs(): try: @@ -286,3 +288,8 @@ def deprecated(func): return func(*args, **kwargs) return new_func + + +def validate_table_name(name: str): + """Verify the table name is valid.""" + native_validate_table_name(name) diff --git a/python/python/tests/test_db.py b/python/python/tests/test_db.py index c84c0800..fc4420ba 100644 --- a/python/python/tests/test_db.py +++ b/python/python/tests/test_db.py @@ -521,3 +521,15 @@ def test_prefilter_with_index(tmp_path): .to_arrow() ) assert table.num_rows == 1 + + +def test_create_table_with_invalid_names(tmp_path): + db = lancedb.connect(uri=tmp_path) + data = [{"vector": np.random.rand(128), "item": "foo"} for i in range(10)] + with pytest.raises(ValueError): + db.create_table("foo/bar", data) + with pytest.raises(ValueError): + db.create_table("foo bar", data) + with pytest.raises(ValueError): + db.create_table("foo$$bar", data) + db.create_table("foo.bar", data) diff --git a/python/src/lib.rs b/python/src/lib.rs index 558668cb..9d1f0a80 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -42,6 +42,7 @@ pub fn _lancedb(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_function(wrap_pyfunction!(connect, m)?)?; + m.add_function(wrap_pyfunction!(util::validate_table_name, m)?)?; m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) } diff --git a/python/src/util.rs b/python/src/util.rs index 893e8089..19662fac 100644 --- a/python/src/util.rs +++ b/python/src/util.rs @@ -3,7 +3,7 @@ use std::sync::Mutex; use lancedb::DistanceType; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, - PyResult, + pyfunction, PyResult, }; /// A wrapper around a rust builder @@ -49,3 +49,9 @@ pub fn parse_distance_type(distance_type: impl AsRef) -> PyResult PyResult<()> { + lancedb::utils::validate_table_name(table_name) + .map_err(|e| PyValueError::new_err(e.to_string())) +} diff --git a/rust/lancedb/Cargo.toml b/rust/lancedb/Cargo.toml index 93f38691..f7317a79 100644 --- a/rust/lancedb/Cargo.toml +++ b/rust/lancedb/Cargo.toml @@ -22,6 +22,7 @@ chrono = { workspace = true } object_store = { workspace = true } snafu = { workspace = true } half = { workspace = true } +lazy_static.workspace = true lance = { workspace = true } lance-index = { workspace = true } lance-linalg = { workspace = true } @@ -34,11 +35,10 @@ bytes = "1" futures.workspace = true num-traits.workspace = true url.workspace = true +regex.workspace = true serde = { version = "^1" } serde_json = { version = "1" } - # For remote feature - reqwest = { version = "0.11.24", features = ["gzip", "json"], optional = true } [dev-dependencies] diff --git a/rust/lancedb/src/connection.rs b/rust/lancedb/src/connection.rs index 54ae8d27..06bf9c41 100644 --- a/rust/lancedb/src/connection.rs +++ b/rust/lancedb/src/connection.rs @@ -31,6 +31,7 @@ use crate::arrow::IntoArrow; use crate::error::{CreateDirSnafu, Error, InvalidTableNameSnafu, Result}; use crate::io::object_store::MirroringObjectStoreWrapper; use crate::table::{NativeTable, WriteOptions}; +use crate::utils::validate_table_name; use crate::Table; pub const LANCE_FILE_EXTENSION: &str = "lance"; @@ -675,13 +676,18 @@ impl Database { /// Get the URI of a table in the database. fn table_uri(&self, name: &str) -> Result { + validate_table_name(name)?; + let path = Path::new(&self.uri); let table_uri = path.join(format!("{}.{}", name, LANCE_FILE_EXTENSION)); let mut uri = table_uri .as_path() .to_str() - .context(InvalidTableNameSnafu { name })? + .context(InvalidTableNameSnafu { + name, + reason: "Name is not valid URL", + })? .to_string(); // If there are query string set on the connection, propagate to lance diff --git a/rust/lancedb/src/error.rs b/rust/lancedb/src/error.rs index 8baed35d..a528a177 100644 --- a/rust/lancedb/src/error.rs +++ b/rust/lancedb/src/error.rs @@ -20,8 +20,8 @@ use snafu::Snafu; #[derive(Debug, Snafu)] #[snafu(visibility(pub(crate)))] pub enum Error { - #[snafu(display("Invalid table name: {name}"))] - InvalidTableName { name: String }, + #[snafu(display("Invalid table name (\"{name}\"): {reason}"))] + InvalidTableName { name: String, reason: String }, #[snafu(display("Invalid input, {message}"))] InvalidInput { message: String }, #[snafu(display("Table '{name}' was not found"))] diff --git a/rust/lancedb/src/lib.rs b/rust/lancedb/src/lib.rs index 2485e965..4b1269e9 100644 --- a/rust/lancedb/src/lib.rs +++ b/rust/lancedb/src/lib.rs @@ -230,9 +230,9 @@ pub enum DistanceType { impl From for LanceDistanceType { fn from(value: DistanceType) -> Self { match value { - DistanceType::L2 => LanceDistanceType::L2, - DistanceType::Cosine => LanceDistanceType::Cosine, - DistanceType::Dot => LanceDistanceType::Dot, + DistanceType::L2 => Self::L2, + DistanceType::Cosine => Self::Cosine, + DistanceType::Dot => Self::Dot, } } } @@ -240,9 +240,9 @@ impl From for LanceDistanceType { impl From for DistanceType { fn from(value: LanceDistanceType) -> Self { match value { - LanceDistanceType::L2 => DistanceType::L2, - LanceDistanceType::Cosine => DistanceType::Cosine, - LanceDistanceType::Dot => DistanceType::Dot, + LanceDistanceType::L2 => Self::L2, + LanceDistanceType::Cosine => Self::Cosine, + LanceDistanceType::Dot => Self::Dot, } } } @@ -251,7 +251,7 @@ impl<'a> TryFrom<&'a str> for DistanceType { type Error = >::Error; fn try_from(value: &str) -> std::prelude::v1::Result { - LanceDistanceType::try_from(value).map(DistanceType::from) + LanceDistanceType::try_from(value).map(Self::from) } } diff --git a/rust/lancedb/src/table.rs b/rust/lancedb/src/table.rs index 252267ef..e3e3ab2d 100644 --- a/rust/lancedb/src/table.rs +++ b/rust/lancedb/src/table.rs @@ -854,6 +854,7 @@ impl NativeTable { .to_str() .ok_or(Error::InvalidTableName { name: uri.to_string(), + reason: "Table name is not valid URL".to_string(), })?; Ok(name.to_string()) } @@ -1185,15 +1186,26 @@ impl NativeTable { let field = ds_ref.schema().field(&column).ok_or(Error::Schema { message: format!("Column {} not found in dataset schema", column), })?; - if !matches!(field.data_type(), arrow_schema::DataType::FixedSizeList(f, dim) if f.data_type().is_floating() && dim == query_vector.len() as i32) - { - return Err(Error::Schema { - message: format!( - "Vector column '{}' does not match the dimension of the query vector: dim={}", - column, - query_vector.len(), - ), - }); + if let arrow_schema::DataType::FixedSizeList(f, dim) = field.data_type() { + if !f.data_type().is_floating() { + return Err(Error::InvalidInput { + message: format!( + "The data type of the vector column '{}' is not a floating point type", + column + ), + }); + } + if dim != query_vector.len() as i32 { + return Err(Error::InvalidInput { + message: format!( + "The dimension of the query vector does not match with the dimension of the vector column '{}': + query dim={}, expected vector dim={}", + column, + query_vector.len(), + dim, + ), + }); + } } let query_vector = query_vector.as_primitive::(); scanner.nearest( diff --git a/rust/lancedb/src/utils.rs b/rust/lancedb/src/utils.rs index 05499017..d6578f81 100644 --- a/rust/lancedb/src/utils.rs +++ b/rust/lancedb/src/utils.rs @@ -1,12 +1,30 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + use std::sync::Arc; use arrow_schema::Schema; - use lance::dataset::{ReadParams, WriteParams}; use lance::io::{ObjectStoreParams, WrappingObjectStore}; +use lazy_static::lazy_static; use crate::error::{Error, Result}; +lazy_static! { + static ref TABLE_NAME_REGEX: regex::Regex = regex::Regex::new(r"^[a-zA-Z0-9_\-\.]+$").unwrap(); +} + pub trait PatchStoreParam { fn patch_with_store_wrapper( self, @@ -64,6 +82,25 @@ impl PatchReadParam for ReadParams { } } +/// Validate table name. +pub fn validate_table_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(Error::InvalidTableName { + name: name.to_string(), + reason: "Table names cannot be empty strings".to_string(), + }); + } + if !TABLE_NAME_REGEX.is_match(name) { + return Err(Error::InvalidTableName { + name: name.to_string(), + reason: + "Table names can only contain alphanumeric characters, underscores, hyphens, and periods" + .to_string(), + }); + } + Ok(()) +} + /// Find one default column to create index. pub(crate) fn default_vector_column(schema: &Schema, dim: Option) -> Result { // Try to find one fixed size list array column. @@ -145,4 +182,20 @@ mod tests { .to_string() .contains("More than one")); } + + #[test] + fn test_validate_table_name() { + assert!(validate_table_name("my_table").is_ok()); + assert!(validate_table_name("my_table_1").is_ok()); + assert!(validate_table_name("123mytable").is_ok()); + assert!(validate_table_name("_12345table").is_ok()); + assert!(validate_table_name("table.12345").is_ok()); + assert!(validate_table_name("table.._dot_..12345").is_ok()); + + assert!(validate_table_name("").is_err()); + assert!(validate_table_name("my_table!").is_err()); + assert!(validate_table_name("my/table").is_err()); + assert!(validate_table_name("my@table").is_err()); + assert!(validate_table_name("name with space").is_err()); + } }