add tests for rust

This commit is contained in:
albertlockett
2024-11-21 06:58:51 -05:00
parent 134258308c
commit 60ad82b6ad

View File

@@ -54,11 +54,15 @@ impl<S: HttpSend> RemoteTable<S> {
}
async fn describe(&self) -> Result<TableDescription> {
let version = self.current_version().await;
self.describe_version(version).await
}
async fn describe_version(&self, version: Option<u64>) -> Result<TableDescription> {
let mut request = self
.client
.post(&format!("/v1/table/{}/describe/", self.name));
let version = self.current_version().await;
let body = serde_json::json!({ "version": version });
request = request.json(&body);
@@ -326,9 +330,16 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
self.describe().await.map(|desc| desc.version)
}
async fn checkout(&self, version: u64) -> Result<()> {
// TODO check that the version exists
// we can do this when the list_versions changes land
// https://github.com/lancedb/lancedb/pull/1850
// check that the version exists
self.describe_version(Some(version)).await.map_err(|e| match e {
// try to map the error to a more user-friendly error telling them
// specifically that the version does not exist
Error::TableNotFound { name } => Error::TableNotFound {
name: format!("{} (version: {})", name, version),
},
e => e,
})?;
let mut write_guard = self.version.write().await;
*write_guard = Some(version);
Ok(())
@@ -702,7 +713,6 @@ impl<S: HttpSend> TableInternal for RemoteTable<S> {
Ok(Some(stats))
}
async fn table_definition(&self) -> Result<TableDefinition> {
self.check_mutable().await?;
Err(Error::NotSupported {
message: "table_definition is not supported on LanceDB cloud.".into(),
})
@@ -861,7 +871,7 @@ mod tests {
request.headers().get("Content-Type").unwrap(),
JSON_CONTENT_TYPE
);
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{}"#);
assert_eq!(request.body().unwrap().as_bytes().unwrap(), br#"{"version":null}"#);
http::Response::builder().status(200).body("42").unwrap()
});
@@ -878,7 +888,7 @@ mod tests {
);
assert_eq!(
request.body().unwrap().as_bytes().unwrap(),
br#"{"predicate":"a > 10"}"#
br#"{"predicate":"a > 10","version":null}"#
);
http::Response::builder().status(200).body("42").unwrap()
@@ -1179,6 +1189,7 @@ mod tests {
"nprobes": 20,
"ef": Option::<usize>::None,
"refine_factor": null,
"version": null,
});
// Pass vector separately to make sure it matches f32 precision.
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
@@ -1225,6 +1236,7 @@ mod tests {
"nprobes": 12,
"ef": Option::<usize>::None,
"refine_factor": 2,
"version": null,
});
// Pass vector separately to make sure it matches f32 precision.
expected_body["vector"] = vec![0.1f32, 0.2, 0.3].into();
@@ -1280,6 +1292,7 @@ mod tests {
"k": 10,
"vector": [],
"with_row_id": true,
"version": null
});
assert_eq!(body, expected_body);
@@ -1509,4 +1522,176 @@ mod tests {
let indices = table.index_stats("my_index").await.unwrap();
assert!(indices.is_none());
}
#[tokio::test]
async fn test_passes_version() {
let table = Table::new_with_handler("my_table", |request| {
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let version = body.as_object().unwrap().get("version").unwrap().as_u64().unwrap();
assert_eq!(version, 42);
let response_body = match request.url().path() {
"/v1/table/my_table/describe/" => {
serde_json::json!({
"version": 42,
"schema": { "fields": [] }
})
},
"/v1/table/my_table/index/list/" => {
serde_json::json!({
"indexes": []
})
},
"/v1/table/my_table/index/my_idx/stats/" => {
serde_json::json!({
"num_indexed_rows": 100000,
"num_unindexed_rows": 0,
"index_type": "IVF_PQ",
"distance_type": "l2"
})
},
"/v1/table/my_table/count_rows/" => {
serde_json::json!(1000)
}
"/v1/table/my_table/query/" => {
let expected_data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let expected_data_ref = expected_data.clone();
let response_body = write_ipc_file(&expected_data_ref);
return http::Response::builder()
.status(200)
.header(CONTENT_TYPE, ARROW_FILE_CONTENT_TYPE)
.body(response_body)
.unwrap();
},
path => panic!("Unexpected path: {}", path),
};
http::Response::builder()
.status(200)
.body(serde_json::to_string(&response_body).unwrap().as_bytes().to_vec())
.unwrap()
});
table.checkout(42).await.unwrap();
// ensure that version is passed to the /describe endpoint
let version = table.version().await.unwrap();
assert_eq!(version, 42);
// ensure it's passed to other read API calls
table.list_indices().await.unwrap();
table.index_stats("my_idx").await.unwrap();
table.count_rows(None).await.unwrap();
table
.query()
.nearest_to(vec![0.1, 0.2, 0.3])
.unwrap()
.execute()
.await
.unwrap();
}
#[tokio::test]
async fn test_fails_if_checkout_version_doesnt_exist() {
let table = Table::new_with_handler("my_table", |request| {
let body = request.body().unwrap().as_bytes().unwrap();
let body: serde_json::Value = serde_json::from_slice(body).unwrap();
let version = body.as_object().unwrap().get("version").unwrap().as_u64().unwrap();
if version != 42 {
return http::Response::builder()
.status(404)
.body(format!("Table my_table (version: {}) not found", version))
.unwrap();
}
let response_body = match request.url().path() {
"/v1/table/my_table/describe/" => {
serde_json::json!({
"version": 42,
"schema": { "fields": [] }
})
},
_ => panic!("Unexpected path"),
};
http::Response::builder()
.status(200)
.body(serde_json::to_string(&response_body).unwrap())
.unwrap()
});
let res = table.checkout(43).await;
println!("{:?}", res);
assert!(matches!(res, Err(Error::TableNotFound { name }) if name == "my_table (version: 43)"));
}
#[tokio::test]
async fn test_timetravel_immutable() {
let table = Table::new_with_handler::<String>("my_table", |request| {
let response_body = match request.url().path() {
"/v1/table/my_table/describe/" => {
serde_json::json!({
"version": 42,
"schema": { "fields": [] }
})
},
_ => panic!("Should not have made a request: {:?}", request),
};
http::Response::builder()
.status(200)
.body(serde_json::to_string(&response_body).unwrap())
.unwrap()
});
table.checkout(42).await.unwrap();
// Ensure that all mutable operations fail.
let res = table
.update()
.column("a", "a + 1")
.column("b", "b - 1")
.only_if("b > 10")
.execute()
.await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let batch = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let data = Box::new(RecordBatchIterator::new(
[Ok(batch.clone())],
batch.schema(),
));
let res = table
.merge_insert(&["some_col"])
.execute(data)
.await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let res = table.delete("id in (1, 2, 3)").await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let data = RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)
.unwrap();
let res = table
.add(RecordBatchIterator::new([Ok(data.clone())], data.schema()))
.execute()
.await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
let res = table.create_index(&["a"], Index::IvfPq(Default::default())).execute().await;
assert!(matches!(res, Err(Error::NotSupported { .. })));
}
}