diff --git a/node/src/index.ts b/node/src/index.ts index a11b5d85..72153d48 100644 --- a/node/src/index.ts +++ b/node/src/index.ts @@ -24,6 +24,7 @@ import { RemoteConnection } from './remote' import { Query } from './query' import { isEmbeddingFunction } from './embedding/embedding_function' import { type Literal, toSQL } from './util' +import { type HttpMiddleware } from './middleware' const { databaseNew, @@ -302,6 +303,18 @@ export interface Connection { * @param name The name of the table to drop. */ dropTable(name: string): Promise + + /** + * Instrument the behavior of this Connection with middleware. + * + * The middleware will be called in the order they are added. + * + * Currently this functionality is only supported for remote Connections. + * + * @param {HttpMiddleware} - Middleware which will instrument the Connection. + * @returns - this Connection instrumented by the passed middleware + */ + withMiddleware(middleware: HttpMiddleware): Connection } /** @@ -541,6 +554,18 @@ export interface Table { * names (e.g. "a"). */ dropColumns(columnNames: string[]): Promise + + /** + * Instrument the behavior of this Table with middleware. + * + * The middleware will be called in the order they are added. + * + * Currently this functionality is only supported for remote tables. + * + * @param {HttpMiddleware} - Middleware which will instrument the Table. + * @returns - this Table instrumented by the passed middleware + */ + withMiddleware(middleware: HttpMiddleware): Table } /** @@ -795,6 +820,10 @@ export class LocalConnection implements Connection { async dropTable (name: string): Promise { await databaseDropTable.call(this._db, name) } + + withMiddleware (middleware: HttpMiddleware): Connection { + return this + } } export class LocalTable implements Table { @@ -1105,6 +1134,10 @@ export class LocalTable implements Table { async dropColumns (columnNames: string[]): Promise { return tableDropColumns.call(this._tbl, columnNames) } + + withMiddleware (middleware: HttpMiddleware): Table { + return this + } } export interface CleanupStats { diff --git a/node/src/middleware.ts b/node/src/middleware.ts new file mode 100644 index 00000000..2eb5446d --- /dev/null +++ b/node/src/middleware.ts @@ -0,0 +1,58 @@ +// Copyright 2024 LanceDB Developers. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/** + * Middleware for Remote LanceDB Connection or Table + */ +export interface HttpMiddleware { + /** + * A callback that can be used to instrument the behavior of http requests to remote + * tables. It can be used to add headers, modify the request, or even short-circuit + * the request and return a response without making the request to the remote endpoint. + * It can also be used to modify the response from the remote endpoint. + * + * @param {RemoteResponse} res - Request to the remote endpoint + * @param {onRemoteRequestNext} next - Callback to advance the middleware chain + */ + onRemoteRequest( + req: RemoteRequest, + next: (req: RemoteRequest) => Promise, + ): Promise +}; + +export enum Method { + GET, + POST +} + +/** + * A LanceDB Remote HTTP Request + */ +export interface RemoteRequest { + uri: string + method: Method + headers: Map + params?: Map + body?: any +} + +/** + * A LanceDB Remote HTTP Response + */ +export interface RemoteResponse { + status: number + statusText: string + headers: Map + body: () => Promise +} diff --git a/node/src/remote/client.ts b/node/src/remote/client.ts index 3d4d59a2..91fa6857 100644 --- a/node/src/remote/client.ts +++ b/node/src/remote/client.ts @@ -12,13 +12,101 @@ // See the License for the specific language governing permissions and // limitations under the License. -import axios, { type AxiosResponse } from 'axios' +import axios, { type AxiosResponse, type ResponseType } from 'axios' import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow' +import { type RemoteResponse, type RemoteRequest, Method } from '../middleware' + +interface HttpLancedbClientMiddleware { + onRemoteRequest( + req: RemoteRequest, + next: (req: RemoteRequest) => Promise, + ): Promise +} + +/** + * Invoke the middleware chain and at the end call the remote endpoint + */ +async function callWithMiddlewares ( + req: RemoteRequest, + middlewares: HttpLancedbClientMiddleware[], + opts?: MiddlewareInvocationOptions +): Promise { + async function call ( + i: number, + req: RemoteRequest + ): Promise { + // if we have reached the end of the middleware chain, make the request + if (i > middlewares.length) { + const headers = Object.fromEntries(req.headers.entries()) + const params = Object.fromEntries(req.params?.entries() ?? []) + const timeout = 10000 + let res + if (req.method === Method.POST) { + res = await axios.post( + req.uri, + req.body, + { + headers, + params, + timeout, + responseType: opts?.responseType + } + ) + } else { + res = await axios.get( + req.uri, + { + headers, + params, + timeout + } + ) + } + + return toLanceRes(res) + } + + // call next middleware in chain + return await middlewares[i - 1].onRemoteRequest( + req, + async (req) => { + return await call(i + 1, req) + } + ) + } + + return await call(1, req) +} + +interface MiddlewareInvocationOptions { + responseType?: ResponseType +} + +/** + * Marshall the library response into a LanceDB response + */ +function toLanceRes (res: AxiosResponse): RemoteResponse { + const headers = new Map() + for (const h in res.headers) { + headers.set(h, res.headers[h]) + } + + return { + status: res.status, + statusText: res.statusText, + headers, + body: async () => { + return res.data + } + } +} + export class HttpLancedbClient { private readonly _url: string private readonly _apiKey: () => string + private readonly _middlewares: HttpLancedbClientMiddleware[] public constructor ( url: string, @@ -27,6 +115,7 @@ export class HttpLancedbClient { ) { this._url = url this._apiKey = () => apiKey + this._middlewares = [] } get uri (): string { @@ -43,74 +132,61 @@ export class HttpLancedbClient { columns?: string[], filter?: string ): Promise> { - const response = await axios.post( - `${this._url}/v1/table/${tableName}/query/`, - { - vector, - k, - nprobes, - refineFactor, - columns, - filter, - prefilter - }, - { - headers: { - 'Content-Type': 'application/json', - 'x-api-key': this._apiKey(), - ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) - }, - responseType: 'arraybuffer', - timeout: 10000 - } - ).catch((err) => { - console.error('error: ', err) - if (err.response === undefined) { - throw new Error(`Network Error: ${err.message as string}`) - } - return err.response - }) - if (response.status !== 200) { - const errorData = new TextDecoder().decode(response.data) - throw new Error( - `Server Error, status: ${response.status as number}, ` + - `message: ${response.statusText as string}: ${errorData}` - ) - } - - const table = tableFromIPC(response.data) + const result = await this.post( + `/v1/table/${tableName}/query/`, + { + vector, + k, + nprobes, + refineFactor, + columns, + filter, + prefilter + }, + undefined, + undefined, + 'arraybuffer' + ) + const table = tableFromIPC(await result.body()) return table } /** * Sent GET request. */ - public async get (path: string, params?: Record): Promise { - const response = await axios.get( - `${this._url}${path}`, - { - headers: { - 'Content-Type': 'application/json', - 'x-api-key': this._apiKey(), - ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) - }, - params, - timeout: 10000 - } - ).catch((err) => { + public async get (path: string, params?: Record): Promise { + const req = { + uri: `${this._url}${path}`, + method: Method.GET, + headers: new Map(Object.entries({ + 'Content-Type': 'application/json', + 'x-api-key': this._apiKey(), + ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) + })), + params: new Map(Object.entries(params ?? {})) + } + + let response + try { + response = await callWithMiddlewares(req, this._middlewares) + return response + } catch (err: any) { console.error('error: ', err) if (err.response === undefined) { throw new Error(`Network Error: ${err.message as string}`) } - return err.response - }) + + response = toLanceRes(err.response) + } + if (response.status !== 200) { - const errorData = new TextDecoder().decode(response.data) + const errorData = new TextDecoder().decode(await response.body()) throw new Error( - `Server Error, status: ${response.status as number}, ` + - `message: ${response.statusText as string}: ${errorData}` + `Server Error, status: ${response.status}, ` + + `message: ${response.statusText}: ${errorData}` ) } + return response } @@ -120,35 +196,65 @@ export class HttpLancedbClient { public async post ( path: string, data?: any, - params?: Record, - content?: string | undefined - ): Promise { - const response = await axios.post( - `${this._url}${path}`, - data, - { - headers: { - 'Content-Type': content ?? 'application/json', - 'x-api-key': this._apiKey(), - ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) - }, - params, - timeout: 30000 - } - ).catch((err) => { + params?: Record, + content?: string | undefined, + responseType?: ResponseType | undefined + ): Promise { + const req = { + uri: `${this._url}${path}`, + method: Method.POST, + headers: new Map(Object.entries({ + 'Content-Type': content ?? 'application/json', + 'x-api-key': this._apiKey(), + ...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {}) + })), + params: new Map(Object.entries(params ?? {})), + body: data + } + + let response + try { + response = await callWithMiddlewares(req, this._middlewares, { responseType }) + + // return response + } catch (err: any) { console.error('error: ', err) if (err.response === undefined) { throw new Error(`Network Error: ${err.message as string}`) } - return err.response - }) + response = toLanceRes(err.response) + } + if (response.status !== 200) { - const errorData = new TextDecoder().decode(response.data) + const errorData = new TextDecoder().decode(await response.body()) throw new Error( - `Server Error, status: ${response.status as number}, ` + - `message: ${response.statusText as string}: ${errorData}` + `Server Error, status: ${response.status}, ` + + `message: ${response.statusText}: ${errorData}` ) } + return response } + + /** + * Instrument this client with middleware + * @param mw - The middleware that instruments the client + * @returns - an instance of this client instrumented with the middleware + */ + public withMiddleware (mw: HttpLancedbClientMiddleware): HttpLancedbClient { + const wrapped = this.clone() + wrapped._middlewares.push(mw) + return wrapped + } + + /** + * Make a clone of this client + */ + private clone (): HttpLancedbClient { + const clone = new HttpLancedbClient(this._url, this._apiKey(), this._dbName) + for (const mw of this._middlewares) { + clone._middlewares.push(mw) + } + return clone + } } diff --git a/node/src/remote/index.ts b/node/src/remote/index.ts index 6e6e590a..78807b14 100644 --- a/node/src/remote/index.ts +++ b/node/src/remote/index.ts @@ -39,12 +39,13 @@ import { fromTableToStreamBuffer } from '../arrow' import { toSQL } from '../util' +import { type HttpMiddleware } from '../middleware' /** * Remote connection. */ export class RemoteConnection implements Connection { - private readonly _client: HttpLancedbClient + private _client: HttpLancedbClient private readonly _dbName: string constructor (opts: ConnectionOptions) { @@ -84,10 +85,11 @@ export class RemoteConnection implements Connection { limit: number = 10 ): Promise { const response = await this._client.get('/v1/table/', { - limit, + limit: `${limit}`, page_token: pageToken }) - return response.data.tables + const body = await response.body() + return body.tables } async openTable (name: string): Promise @@ -163,7 +165,7 @@ export class RemoteConnection implements Connection { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } @@ -177,6 +179,17 @@ export class RemoteConnection implements Connection { async dropTable (name: string): Promise { await this._client.post(`/v1/table/${name}/drop/`) } + + withMiddleware (middleware: HttpMiddleware): Connection { + const wrapped = this.clone() + wrapped._client = wrapped._client.withMiddleware(middleware) + return wrapped + } + + private clone (): RemoteConnection { + const clone: RemoteConnection = Object.create(RemoteConnection.prototype) + return Object.assign(clone, this) + } } export class RemoteQuery extends Query { @@ -229,7 +242,7 @@ export class RemoteQuery extends Query { // we are using extend until we have next next version release // Table and Connection has both been refactored to interfaces export class RemoteTable implements Table { - private readonly _client: HttpLancedbClient + private _client: HttpLancedbClient private readonly _embeddings?: EmbeddingFunction private readonly _name: string @@ -256,15 +269,15 @@ export class RemoteTable implements Table { get schema (): Promise { return this._client .post(`/v1/table/${this._name}/describe/`) - .then((res) => { + .then(async (res) => { if (res.status !== 200) { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } - return res.data?.schema + return (await res.body())?.schema }) } @@ -320,7 +333,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } } @@ -346,7 +359,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } return tbl.numRows @@ -372,7 +385,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } return tbl.numRows @@ -415,7 +428,7 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } } @@ -436,14 +449,14 @@ export class RemoteTable implements Table { throw new Error( `Server Error, status: ${res.status}, ` + // eslint-disable-next-line @typescript-eslint/restrict-template-expressions - `message: ${res.statusText}: ${res.data}` + `message: ${res.statusText}: ${await res.body()}` ) } } async countRows (): Promise { const result = await this._client.post(`/v1/table/${this._name}/describe/`) - return result.data?.stats?.num_rows + return (await result.body())?.stats?.num_rows } async delete (filter: string): Promise { @@ -476,7 +489,7 @@ export class RemoteTable implements Table { const results = await this._client.post( `/v1/table/${this._name}/index/list/` ) - return results.data.indexes?.map((index: any) => ({ + return (await results.body()).indexes?.map((index: any) => ({ columns: index.columns, name: index.index_name, uuid: index.index_uuid @@ -487,9 +500,10 @@ export class RemoteTable implements Table { const results = await this._client.post( `/v1/table/${this._name}/index/${indexUuid}/stats/` ) + const body = await results.body() return { - numIndexedRows: results.data.num_indexed_rows, - numUnindexedRows: results.data.num_unindexed_rows + numIndexedRows: body?.num_indexed_rows, + numUnindexedRows: body?.num_unindexed_rows } } @@ -504,4 +518,15 @@ export class RemoteTable implements Table { async dropColumns (columnNames: string[]): Promise { throw new Error('Drop columns is not yet supported in LanceDB Cloud.') } + + withMiddleware(middleware: HttpMiddleware): Table { + const wrapped = this.clone() + wrapped._client = wrapped._client.withMiddleware(middleware) + return wrapped + } + + private clone (): RemoteTable { + const clone: RemoteTable = Object.create(RemoteTable.prototype) + return Object.assign(clone, this) + } }