Add cohere embedding function (#550)

This commit is contained in:
Ayush Chaurasia
2023-10-13 16:27:34 +05:30
committed by GitHub
parent db7bdefe77
commit 683824f1e9
5 changed files with 118 additions and 1 deletions

View File

@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import numpy as np
import pandas as pd
@@ -123,3 +124,26 @@ def test_openclip(tmp_path):
arrow_table["vector"].combine_chunks().values.to_numpy(),
arrow_table["vec_from_bytes"].combine_chunks().values.to_numpy(),
)
@pytest.mark.slow
@pytest.mark.skipif(
os.environ.get("COHERE_API_KEY") is None, reason="COHERE_API_KEY not set"
) # also skip if cohere not installed
def test_cohere_embedding_function():
cohere = (
EmbeddingFunctionRegistry.get_instance()
.get("cohere")
.create(name="embed-multilingual-v2.0")
)
class TextModel(LanceModel):
text: str = cohere.SourceField()
vector: Vector(cohere.ndims()) = cohere.VectorField()
df = pd.DataFrame({"text": ["hello world", "goodbye world"]})
db = lancedb.connect("~/lancedb")
tbl = db.create_table("test", schema=TextModel, mode="overwrite")
tbl.add(df)
assert len(tbl.to_pandas()["vector"][0]) == cohere.ndims()