Node SDK Client middleware for HTTP Requests (#1130)

Adds client-side middleware to LanceDB Node SDK to instrument HTTP
Requests

Example - adding `x-request-id` request header:
```js
class HttpMiddleware {
    constructor({ requestId }) {
        this.requestId = requestId
    }

    onRemoteRequest(req, next) {
        req.headers['x-request-id'] = this.requestId
        return next(req)
    }
}

const db = await lancedb.connect({
  uri: 'db://remote-123',
  apiKey: 'sk_...',
})

let tables = await db.withMiddleware(new HttpMiddleware({ requestId: '123' })).tableNames();

```

---------

Co-authored-by: Weston Pace <weston.pace@gmail.com>
This commit is contained in:
Bert
2024-03-22 11:58:05 -04:00
committed by Weston Pace
parent db2631c2ad
commit 1e41232f28
4 changed files with 315 additions and 93 deletions

View File

@@ -24,6 +24,7 @@ import { RemoteConnection } from './remote'
import { Query } from './query'
import { isEmbeddingFunction } from './embedding/embedding_function'
import { type Literal, toSQL } from './util'
import { type HttpMiddleware } from './middleware'
const {
databaseNew,
@@ -302,6 +303,18 @@ export interface Connection {
* @param name The name of the table to drop.
*/
dropTable(name: string): Promise<void>
/**
* Instrument the behavior of this Connection with middleware.
*
* The middleware will be called in the order they are added.
*
* Currently this functionality is only supported for remote Connections.
*
* @param {HttpMiddleware} - Middleware which will instrument the Connection.
* @returns - this Connection instrumented by the passed middleware
*/
withMiddleware(middleware: HttpMiddleware): Connection
}
/**
@@ -541,6 +554,18 @@ export interface Table<T = number[]> {
* names (e.g. "a").
*/
dropColumns(columnNames: string[]): Promise<void>
/**
* Instrument the behavior of this Table with middleware.
*
* The middleware will be called in the order they are added.
*
* Currently this functionality is only supported for remote tables.
*
* @param {HttpMiddleware} - Middleware which will instrument the Table.
* @returns - this Table instrumented by the passed middleware
*/
withMiddleware(middleware: HttpMiddleware): Table<T>
}
/**
@@ -795,6 +820,10 @@ export class LocalConnection implements Connection {
async dropTable (name: string): Promise<void> {
await databaseDropTable.call(this._db, name)
}
withMiddleware (middleware: HttpMiddleware): Connection {
return this
}
}
export class LocalTable<T = number[]> implements Table<T> {
@@ -1105,6 +1134,10 @@ export class LocalTable<T = number[]> implements Table<T> {
async dropColumns (columnNames: string[]): Promise<void> {
return tableDropColumns.call(this._tbl, columnNames)
}
withMiddleware (middleware: HttpMiddleware): Table<T> {
return this
}
}
export interface CleanupStats {

58
node/src/middleware.ts Normal file
View File

@@ -0,0 +1,58 @@
// Copyright 2024 LanceDB Developers.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* Middleware for Remote LanceDB Connection or Table
*/
export interface HttpMiddleware {
/**
* A callback that can be used to instrument the behavior of http requests to remote
* tables. It can be used to add headers, modify the request, or even short-circuit
* the request and return a response without making the request to the remote endpoint.
* It can also be used to modify the response from the remote endpoint.
*
* @param {RemoteResponse} res - Request to the remote endpoint
* @param {onRemoteRequestNext} next - Callback to advance the middleware chain
*/
onRemoteRequest(
req: RemoteRequest,
next: (req: RemoteRequest) => Promise<RemoteResponse>,
): Promise<RemoteResponse>
};
export enum Method {
GET,
POST
}
/**
* A LanceDB Remote HTTP Request
*/
export interface RemoteRequest {
uri: string
method: Method
headers: Map<string, string>
params?: Map<string, string>
body?: any
}
/**
* A LanceDB Remote HTTP Response
*/
export interface RemoteResponse {
status: number
statusText: string
headers: Map<string, string>
body: () => Promise<any>
}

View File

@@ -12,13 +12,101 @@
// See the License for the specific language governing permissions and
// limitations under the License.
import axios, { type AxiosResponse } from 'axios'
import axios, { type AxiosResponse, type ResponseType } from 'axios'
import { tableFromIPC, type Table as ArrowTable } from 'apache-arrow'
import { type RemoteResponse, type RemoteRequest, Method } from '../middleware'
interface HttpLancedbClientMiddleware {
onRemoteRequest(
req: RemoteRequest,
next: (req: RemoteRequest) => Promise<RemoteResponse>,
): Promise<RemoteResponse>
}
/**
* Invoke the middleware chain and at the end call the remote endpoint
*/
async function callWithMiddlewares (
req: RemoteRequest,
middlewares: HttpLancedbClientMiddleware[],
opts?: MiddlewareInvocationOptions
): Promise<RemoteResponse> {
async function call (
i: number,
req: RemoteRequest
): Promise<RemoteResponse> {
// if we have reached the end of the middleware chain, make the request
if (i > middlewares.length) {
const headers = Object.fromEntries(req.headers.entries())
const params = Object.fromEntries(req.params?.entries() ?? [])
const timeout = 10000
let res
if (req.method === Method.POST) {
res = await axios.post(
req.uri,
req.body,
{
headers,
params,
timeout,
responseType: opts?.responseType
}
)
} else {
res = await axios.get(
req.uri,
{
headers,
params,
timeout
}
)
}
return toLanceRes(res)
}
// call next middleware in chain
return await middlewares[i - 1].onRemoteRequest(
req,
async (req) => {
return await call(i + 1, req)
}
)
}
return await call(1, req)
}
interface MiddlewareInvocationOptions {
responseType?: ResponseType
}
/**
* Marshall the library response into a LanceDB response
*/
function toLanceRes (res: AxiosResponse): RemoteResponse {
const headers = new Map()
for (const h in res.headers) {
headers.set(h, res.headers[h])
}
return {
status: res.status,
statusText: res.statusText,
headers,
body: async () => {
return res.data
}
}
}
export class HttpLancedbClient {
private readonly _url: string
private readonly _apiKey: () => string
private readonly _middlewares: HttpLancedbClientMiddleware[]
public constructor (
url: string,
@@ -27,6 +115,7 @@ export class HttpLancedbClient {
) {
this._url = url
this._apiKey = () => apiKey
this._middlewares = []
}
get uri (): string {
@@ -43,74 +132,61 @@ export class HttpLancedbClient {
columns?: string[],
filter?: string
): Promise<ArrowTable<any>> {
const response = await axios.post(
`${this._url}/v1/table/${tableName}/query/`,
{
vector,
k,
nprobes,
refineFactor,
columns,
filter,
prefilter
},
{
headers: {
'Content-Type': 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
responseType: 'arraybuffer',
timeout: 10000
}
).catch((err) => {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
if (response.status !== 200) {
const errorData = new TextDecoder().decode(response.data)
throw new Error(
`Server Error, status: ${response.status as number}, ` +
`message: ${response.statusText as string}: ${errorData}`
)
}
const table = tableFromIPC(response.data)
const result = await this.post(
`/v1/table/${tableName}/query/`,
{
vector,
k,
nprobes,
refineFactor,
columns,
filter,
prefilter
},
undefined,
undefined,
'arraybuffer'
)
const table = tableFromIPC(await result.body())
return table
}
/**
* Sent GET request.
*/
public async get (path: string, params?: Record<string, string | number>): Promise<AxiosResponse> {
const response = await axios.get(
`${this._url}${path}`,
{
headers: {
'Content-Type': 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
params,
timeout: 10000
}
).catch((err) => {
public async get (path: string, params?: Record<string, string>): Promise<RemoteResponse> {
const req = {
uri: `${this._url}${path}`,
method: Method.GET,
headers: new Map(Object.entries({
'Content-Type': 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
})),
params: new Map(Object.entries(params ?? {}))
}
let response
try {
response = await callWithMiddlewares(req, this._middlewares)
return response
} catch (err: any) {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
response = toLanceRes(err.response)
}
if (response.status !== 200) {
const errorData = new TextDecoder().decode(response.data)
const errorData = new TextDecoder().decode(await response.body())
throw new Error(
`Server Error, status: ${response.status as number}, ` +
`message: ${response.statusText as string}: ${errorData}`
`Server Error, status: ${response.status}, ` +
`message: ${response.statusText}: ${errorData}`
)
}
return response
}
@@ -120,35 +196,65 @@ export class HttpLancedbClient {
public async post (
path: string,
data?: any,
params?: Record<string, string | number>,
content?: string | undefined
): Promise<AxiosResponse> {
const response = await axios.post(
`${this._url}${path}`,
data,
{
headers: {
'Content-Type': content ?? 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
},
params,
timeout: 30000
}
).catch((err) => {
params?: Record<string, string>,
content?: string | undefined,
responseType?: ResponseType | undefined
): Promise<RemoteResponse> {
const req = {
uri: `${this._url}${path}`,
method: Method.POST,
headers: new Map(Object.entries({
'Content-Type': content ?? 'application/json',
'x-api-key': this._apiKey(),
...(this._dbName !== undefined ? { 'x-lancedb-database': this._dbName } : {})
})),
params: new Map(Object.entries(params ?? {})),
body: data
}
let response
try {
response = await callWithMiddlewares(req, this._middlewares, { responseType })
// return response
} catch (err: any) {
console.error('error: ', err)
if (err.response === undefined) {
throw new Error(`Network Error: ${err.message as string}`)
}
return err.response
})
response = toLanceRes(err.response)
}
if (response.status !== 200) {
const errorData = new TextDecoder().decode(response.data)
const errorData = new TextDecoder().decode(await response.body())
throw new Error(
`Server Error, status: ${response.status as number}, ` +
`message: ${response.statusText as string}: ${errorData}`
`Server Error, status: ${response.status}, ` +
`message: ${response.statusText}: ${errorData}`
)
}
return response
}
/**
* Instrument this client with middleware
* @param mw - The middleware that instruments the client
* @returns - an instance of this client instrumented with the middleware
*/
public withMiddleware (mw: HttpLancedbClientMiddleware): HttpLancedbClient {
const wrapped = this.clone()
wrapped._middlewares.push(mw)
return wrapped
}
/**
* Make a clone of this client
*/
private clone (): HttpLancedbClient {
const clone = new HttpLancedbClient(this._url, this._apiKey(), this._dbName)
for (const mw of this._middlewares) {
clone._middlewares.push(mw)
}
return clone
}
}

View File

@@ -39,12 +39,13 @@ import {
fromTableToStreamBuffer
} from '../arrow'
import { toSQL } from '../util'
import { type HttpMiddleware } from '../middleware'
/**
* Remote connection.
*/
export class RemoteConnection implements Connection {
private readonly _client: HttpLancedbClient
private _client: HttpLancedbClient
private readonly _dbName: string
constructor (opts: ConnectionOptions) {
@@ -84,10 +85,11 @@ export class RemoteConnection implements Connection {
limit: number = 10
): Promise<string[]> {
const response = await this._client.get('/v1/table/', {
limit,
limit: `${limit}`,
page_token: pageToken
})
return response.data.tables
const body = await response.body()
return body.tables
}
async openTable (name: string): Promise<Table>
@@ -163,7 +165,7 @@ export class RemoteConnection implements Connection {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
@@ -177,6 +179,17 @@ export class RemoteConnection implements Connection {
async dropTable (name: string): Promise<void> {
await this._client.post(`/v1/table/${name}/drop/`)
}
withMiddleware (middleware: HttpMiddleware): Connection {
const wrapped = this.clone()
wrapped._client = wrapped._client.withMiddleware(middleware)
return wrapped
}
private clone (): RemoteConnection {
const clone: RemoteConnection = Object.create(RemoteConnection.prototype)
return Object.assign(clone, this)
}
}
export class RemoteQuery<T = number[]> extends Query<T> {
@@ -229,7 +242,7 @@ export class RemoteQuery<T = number[]> extends Query<T> {
// we are using extend until we have next next version release
// Table and Connection has both been refactored to interfaces
export class RemoteTable<T = number[]> implements Table<T> {
private readonly _client: HttpLancedbClient
private _client: HttpLancedbClient
private readonly _embeddings?: EmbeddingFunction<T>
private readonly _name: string
@@ -256,15 +269,15 @@ export class RemoteTable<T = number[]> implements Table<T> {
get schema (): Promise<any> {
return this._client
.post(`/v1/table/${this._name}/describe/`)
.then((res) => {
.then(async (res) => {
if (res.status !== 200) {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
return res.data?.schema
return (await res.body())?.schema
})
}
@@ -320,7 +333,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
}
@@ -346,7 +359,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
return tbl.numRows
@@ -372,7 +385,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
return tbl.numRows
@@ -415,7 +428,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
}
@@ -436,14 +449,14 @@ export class RemoteTable<T = number[]> implements Table<T> {
throw new Error(
`Server Error, status: ${res.status}, ` +
// eslint-disable-next-line @typescript-eslint/restrict-template-expressions
`message: ${res.statusText}: ${res.data}`
`message: ${res.statusText}: ${await res.body()}`
)
}
}
async countRows (): Promise<number> {
const result = await this._client.post(`/v1/table/${this._name}/describe/`)
return result.data?.stats?.num_rows
return (await result.body())?.stats?.num_rows
}
async delete (filter: string): Promise<void> {
@@ -476,7 +489,7 @@ export class RemoteTable<T = number[]> implements Table<T> {
const results = await this._client.post(
`/v1/table/${this._name}/index/list/`
)
return results.data.indexes?.map((index: any) => ({
return (await results.body()).indexes?.map((index: any) => ({
columns: index.columns,
name: index.index_name,
uuid: index.index_uuid
@@ -487,9 +500,10 @@ export class RemoteTable<T = number[]> implements Table<T> {
const results = await this._client.post(
`/v1/table/${this._name}/index/${indexUuid}/stats/`
)
const body = await results.body()
return {
numIndexedRows: results.data.num_indexed_rows,
numUnindexedRows: results.data.num_unindexed_rows
numIndexedRows: body?.num_indexed_rows,
numUnindexedRows: body?.num_unindexed_rows
}
}
@@ -504,4 +518,15 @@ export class RemoteTable<T = number[]> implements Table<T> {
async dropColumns (columnNames: string[]): Promise<void> {
throw new Error('Drop columns is not yet supported in LanceDB Cloud.')
}
withMiddleware(middleware: HttpMiddleware): Table<T> {
const wrapped = this.clone()
wrapped._client = wrapped._client.withMiddleware(middleware)
return wrapped
}
private clone (): RemoteTable<T> {
const clone: RemoteTable<T> = Object.create(RemoteTable.prototype)
return Object.assign(clone, this)
}
}