mirror of
https://github.com/lancedb/lancedb.git
synced 2026-01-04 10:52:56 +00:00
Compare commits
1 Commits
ayush/gemi
...
codex/upda
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
914c975615 |
64
Cargo.lock
generated
64
Cargo.lock
generated
@@ -3044,8 +3044,8 @@ checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "fsst"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"rand 0.9.2",
|
||||
@@ -4229,8 +4229,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4293,8 +4293,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-arrow"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4312,8 +4312,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-bitpacking"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrayref",
|
||||
"paste",
|
||||
@@ -4322,8 +4322,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-core"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4359,8 +4359,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datafusion"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4389,8 +4389,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-datagen"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4407,8 +4407,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-encoding"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4445,8 +4445,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-file"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-arith",
|
||||
"arrow-array",
|
||||
@@ -4479,8 +4479,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-index"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4542,8 +4542,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-io"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-arith",
|
||||
@@ -4583,8 +4583,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-linalg"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-buffer",
|
||||
@@ -4607,8 +4607,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"async-trait",
|
||||
@@ -4620,8 +4620,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-namespace-impls"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-ipc",
|
||||
@@ -4653,8 +4653,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-table"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow",
|
||||
"arrow-array",
|
||||
@@ -4692,8 +4692,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lance-testing"
|
||||
version = "0.38.3-beta.9"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.9#5f603515786cdc3b5aadb3313131bf686d5e932b"
|
||||
version = "0.38.3-beta.10"
|
||||
source = "git+https://github.com/lancedb/lance.git?tag=v0.38.3-beta.10#c6c9249d2891577338e38045bcb2db6373652525"
|
||||
dependencies = [
|
||||
"arrow-array",
|
||||
"arrow-schema",
|
||||
|
||||
28
Cargo.toml
28
Cargo.toml
@@ -15,20 +15,20 @@ categories = ["database-implementations"]
|
||||
rust-version = "1.78.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
lance = { "version" = "=0.38.3-beta.9", default-features = false, "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-core = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datagen = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-file = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.3-beta.9", default-features = false, "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-arrow = { "version" = "=0.38.3-beta.9", "tag" = "v0.38.3-beta.9", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance = { "version" = "=0.38.3-beta.10", default-features = false, "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-core = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datagen = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-file = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-io = { "version" = "=0.38.3-beta.10", default-features = false, "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-index = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-linalg = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-namespace-impls = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-table = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-testing = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-datafusion = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-encoding = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
lance-arrow = { "version" = "=0.38.3-beta.10", "tag" = "v0.38.3-beta.10", "git" = "https://github.com/lancedb/lance.git" }
|
||||
ahash = "0.8"
|
||||
# Note that this one does not include pyarrow
|
||||
arrow = { version = "56.2", optional = false }
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
|
||||
import os
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -46,11 +46,10 @@ class GeminiText(TextEmbeddingFunction):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name: str, default "models/text-embedding-004"
|
||||
name: str, default "models/embedding-001"
|
||||
The name of the model to use. See the Gemini documentation for a list of
|
||||
available models.
|
||||
dims: int, optional
|
||||
The dimension of the embedding, otherwise it will be inferred.
|
||||
|
||||
query_task_type: str, default "retrieval_query"
|
||||
Sets the task type for the queries.
|
||||
source_task_type: str, default "retrieval_document"
|
||||
@@ -78,10 +77,9 @@ class GeminiText(TextEmbeddingFunction):
|
||||
|
||||
"""
|
||||
|
||||
name: str = "models/text-embedding-004"
|
||||
name: str = "models/embedding-001"
|
||||
query_task_type: str = "retrieval_query"
|
||||
source_task_type: str = "retrieval_document"
|
||||
dims: Optional[int] = None
|
||||
|
||||
if PYDANTIC_VERSION.major < 2: # Pydantic 1.x compat
|
||||
|
||||
@@ -91,18 +89,9 @@ class GeminiText(TextEmbeddingFunction):
|
||||
model_config = dict()
|
||||
model_config["ignored_types"] = (cached_property,)
|
||||
|
||||
@cached_property
|
||||
def _model(self):
|
||||
return self.client.get_model(self.name)
|
||||
|
||||
def ndims(self) -> int:
|
||||
if self.dims:
|
||||
return self.dims
|
||||
if hasattr(self._model, "output_dimensionality"):
|
||||
return self._model.output_dimensionality
|
||||
# Fallback for older versions of the library
|
||||
# or models that don't have the attribute
|
||||
return len(self.generate_embeddings(["lancedb"])[0])
|
||||
def ndims(self):
|
||||
# TODO: fix hardcoding
|
||||
return 768
|
||||
|
||||
def compute_query_embeddings(self, query: str, *args, **kwargs) -> List[np.array]:
|
||||
return self.compute_source_embeddings(query, task_type=self.query_task_type)
|
||||
@@ -130,8 +119,6 @@ class GeminiText(TextEmbeddingFunction):
|
||||
): # Provide a title to use existing API design
|
||||
title = "Embedding of a document"
|
||||
kwargs["title"] = title
|
||||
if self.dims:
|
||||
kwargs["output_dimensionality"] = self.dims
|
||||
|
||||
return [
|
||||
self.client.embed_content(model=self.name, content=text, **kwargs)[
|
||||
@@ -144,8 +131,6 @@ class GeminiText(TextEmbeddingFunction):
|
||||
def client(self):
|
||||
genai = attempt_import_or_raise("google.generativeai", "google.generativeai")
|
||||
|
||||
api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
if not os.environ.get("GOOGLE_API_KEY"):
|
||||
api_key_not_found_help("google")
|
||||
genai.configure(api_key=api_key)
|
||||
return genai
|
||||
|
||||
@@ -308,7 +308,7 @@ def test_instructor_embedding(tmp_path):
|
||||
os.environ.get("GOOGLE_API_KEY") is None, reason="GOOGLE_API_KEY not set"
|
||||
)
|
||||
def test_gemini_embedding(tmp_path):
|
||||
model = get_registry().get("gemini-text").create(max_retries=0, dims=512)
|
||||
model = get_registry().get("gemini-text").create(max_retries=0)
|
||||
|
||||
class TextModel(LanceModel):
|
||||
text: str = model.SourceField()
|
||||
@@ -319,7 +319,7 @@ def test_gemini_embedding(tmp_path):
|
||||
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
|
||||
|
||||
tbl.add(df)
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims() == 512
|
||||
assert len(tbl.to_pandas()["vector"][0]) == model.ndims()
|
||||
assert tbl.search("hello").limit(1).to_pandas()["text"][0] == "hello world"
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user