This commit is contained in:
Chang She
2024-03-03 21:48:36 -08:00
parent 408988abce
commit 2084fbcff4
17 changed files with 48 additions and 36 deletions

View File

@@ -37,30 +37,26 @@ import numpy as np
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import semver import semver
from pydantic.fields import FieldInfo
from lance.arrow import ( from lance.arrow import (
EncodedImageArray,
EncodedImageScalar,
EncodedImageType, EncodedImageType,
ImageURIArray,
ImageURIScalar,
ImageURIType,
) )
from pydantic.fields import FieldInfo
from pydantic_core import core_schema
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic_core import CoreSchema, core_schema
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
if TYPE_CHECKING: if TYPE_CHECKING:
from pydantic.fields import FieldInfo from pydantic.fields import FieldInfo
from .embeddings import EmbeddingFunctionConfig from .embeddings import EmbeddingFunctionConfig
PYDANTIC_VERSION = semver.Version.parse(pydantic.__version__)
try:
from pydantic import GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import CoreSchema
except ImportError:
if PYDANTIC_VERSION >= (2,):
raise
class FixedSizeListMixin(ABC): class FixedSizeListMixin(ABC):
@staticmethod @staticmethod
@@ -135,7 +131,7 @@ def Vector(
@classmethod @classmethod
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema: ) -> "CoreSchema":
return core_schema.no_info_after_validator_function( return core_schema.no_info_after_validator_function(
cls, cls,
core_schema.list_schema( core_schema.list_schema(
@@ -238,7 +234,7 @@ def EncodedImage():
@classmethod @classmethod
def __get_pydantic_core_schema__( def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler cls, _source_type: Any, _handler: pydantic.GetCoreSchemaHandler
) -> CoreSchema: ) -> "CoreSchema":
from_bytes_schema = core_schema.bytes_schema() from_bytes_schema = core_schema.bytes_schema()
return core_schema.json_or_python_schema( return core_schema.json_or_python_schema(
@@ -256,8 +252,8 @@ def EncodedImage():
@classmethod @classmethod
def __get_pydantic_json_schema__( def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler cls, _core_schema: "CoreSchema", handler: "GetJsonSchemaHandler"
) -> JsonSchemaValue: ) -> "JsonSchemaValue":
return handler(core_schema.bytes_schema()) return handler(core_schema.bytes_schema())
@classmethod @classmethod

View File

@@ -1,4 +1,5 @@
from click.testing import CliRunner from click.testing import CliRunner
from lancedb.cli.cli import cli from lancedb.cli.cli import cli
from lancedb.utils import CONFIG from lancedb.utils import CONFIG

View File

@@ -13,6 +13,7 @@
import pandas as pd import pandas as pd
import pytest import pytest
from lancedb.context import contextualize from lancedb.context import contextualize

View File

@@ -19,6 +19,8 @@ import numpy as np
import pandas as pd import pandas as pd
import pyarrow as pa import pyarrow as pa
import pytest import pytest
import lancedb
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector

View File

@@ -13,6 +13,7 @@
import numpy as np import numpy as np
import pytest import pytest
from lancedb import LanceDBConnection from lancedb import LanceDBConnection
# TODO: setup integ test mark and script # TODO: setup integ test mark and script

View File

@@ -13,10 +13,11 @@
import sys import sys
import lance import lance
import lancedb
import numpy as np import numpy as np
import pyarrow as pa import pyarrow as pa
import pytest import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.embeddings import ( from lancedb.embeddings import (
EmbeddingFunctionConfig, EmbeddingFunctionConfig,

View File

@@ -14,11 +14,12 @@ import importlib
import io import io
import os import os
import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
import requests import requests
import lancedb
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
@@ -184,9 +185,10 @@ def test_imagebind(tmp_path):
import shutil import shutil
import tempfile import tempfile
import lancedb.embeddings.imagebind
import pandas as pd import pandas as pd
import requests import requests
import lancedb.embeddings.imagebind
from lancedb.embeddings import get_registry from lancedb.embeddings import get_registry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector

View File

@@ -14,11 +14,12 @@ import os
import random import random
from unittest import mock from unittest import mock
import lancedb as ldb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import pytest import pytest
import lancedb as ldb
pytest.importorskip("lancedb.fts") pytest.importorskip("lancedb.fts")
tantivy = pytest.importorskip("tantivy") tantivy = pytest.importorskip("tantivy")

View File

@@ -13,9 +13,10 @@
import os import os
import lancedb
import pytest import pytest
import lancedb
# You need to setup AWS credentials an a base path to run this test. Example # You need to setup AWS credentials an a base path to run this test. Example
# AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py # AWS_PROFILE=default TEST_S3_BASE_URL=s3://my_bucket/dataset pytest tests/test_io.py

View File

@@ -14,18 +14,17 @@
import io import io
import json import json
import sys
import os import os
import sys
from datetime import date, datetime from datetime import date, datetime
from pathlib import Path
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from pathlib import Path
import pyarrow as pa import pyarrow as pa
import pydantic import pydantic
import pytest import pytest
import pytz
from pydantic import Field
from lance.arrow import EncodedImageType from lance.arrow import EncodedImageType
from pydantic import Field
from lancedb.pydantic import ( from lancedb.pydantic import (
PYDANTIC_VERSION, PYDANTIC_VERSION,

View File

@@ -18,6 +18,7 @@ import numpy as np
import pandas.testing as tm import pandas.testing as tm
import pyarrow as pa import pyarrow as pa
import pytest import pytest
from lancedb.db import LanceDBConnection from lancedb.db import LanceDBConnection
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector
from lancedb.query import LanceVectorQueryBuilder, Query from lancedb.query import LanceVectorQueryBuilder, Query

View File

@@ -17,6 +17,7 @@ import pandas as pd
import pyarrow as pa import pyarrow as pa
import pytest import pytest
from aiohttp import web from aiohttp import web
from lancedb.remote.client import RestfulLanceDBClient, VectorQuery from lancedb.remote.client import RestfulLanceDBClient, VectorQuery

View File

@@ -11,8 +11,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import lancedb
import pyarrow as pa import pyarrow as pa
import lancedb
from lancedb.remote.client import VectorQuery, VectorQueryResult from lancedb.remote.client import VectorQuery, VectorQueryResult

View File

@@ -1,8 +1,9 @@
import os import os
import lancedb
import numpy as np import numpy as np
import pytest import pytest
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction # noqa from lancedb.conftest import MockTextEmbeddingFunction # noqa
from lancedb.embeddings import EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionRegistry
from lancedb.pydantic import LanceModel, Vector from lancedb.pydantic import LanceModel, Vector

View File

@@ -12,30 +12,31 @@
# limitations under the License. # limitations under the License.
import functools import functools
from copy import copy
from datetime import date, datetime, timedelta
import io import io
import os import os
from copy import copy
from datetime import date, datetime, timedelta
from pathlib import Path from pathlib import Path
from time import sleep from time import sleep
from typing import List from typing import List
from unittest.mock import PropertyMock, patch from unittest.mock import PropertyMock, patch
import lance import lance
import lancedb
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import polars as pl import polars as pl
import pyarrow as pa import pyarrow as pa
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from lance.arrow import EncodedImageType
from pydantic import BaseModel
import lancedb
from lancedb.conftest import MockTextEmbeddingFunction from lancedb.conftest import MockTextEmbeddingFunction
from lancedb.db import AsyncConnection, LanceDBConnection from lancedb.db import AsyncConnection, LanceDBConnection
from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry from lancedb.embeddings import EmbeddingFunctionConfig, EmbeddingFunctionRegistry
from lancedb.pydantic import EncodedImage, LanceModel, Vector from lancedb.pydantic import EncodedImage, LanceModel, Vector
from lancedb.table import LanceTable from lancedb.table import LanceTable
from pydantic import BaseModel
from lance.arrow import EncodedImageArray, EncodedImageType, ImageURIType
class MockDB: class MockDB:

View File

@@ -1,7 +1,8 @@
import json import json
import lancedb
import pytest import pytest
import lancedb
from lancedb.utils.events import _Events from lancedb.utils.events import _Events

View File

@@ -15,6 +15,7 @@ import os
import pathlib import pathlib
import pytest import pytest
from lancedb.util import get_uri_scheme, join_uri from lancedb.util import get_uri_scheme, join_uri