Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ include = [
"python/lance/schema.py",
"python/lance/file.py",
"python/lance/util.py",
"python/lance/arrow.py",
"python/tests/test_arrow.py",
]
# Dependencies like pyarrow make this difficult to enforce strictly.
reportMissingTypeStubs = "warning"
Expand Down
56 changes: 36 additions & 20 deletions python/python/lance/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
"""Extensions to PyArrows."""

import json
import typing
from pathlib import Path
from typing import Callable, Iterable, Optional, Union

import pyarrow as pa
import pyarrow.compute as pc

from ._arrow.bf16 import ( # noqa: F401
BFloat16,
Expand All @@ -19,8 +21,10 @@
from .lance import bfloat16_array

__all__ = [
"BFloat16",
"BFloat16Array",
"BFloat16Type",
"PandasBFloat16Array",
"bfloat16_array",
"cast",
"EncodedImageType",
Expand Down Expand Up @@ -220,14 +224,15 @@ def from_uris(
<lance.arrow.ImageURIArray object at 0x...>
['file::///tmp/1.png']
"""
storage: pa.Array
if isinstance(uris, (pa.StringArray, pa.LargeStringArray)):
pass
storage = uris
elif isinstance(uris, Iterable):
uris = pa.array((str(uri) for uri in uris), type=pa.string())
storage = pa.array((str(uri) for uri in uris), type=pa.string())
else:
raise TypeError("Cannot build a ImageURIArray from {}".format(type(uris)))

return cls.from_storage(ImageURIType(uris.type), uris)
return cls.from_storage(ImageURIType(storage.type), storage)

def read_uris(self, storage_type=pa.binary()) -> "EncodedImageArray":
"""
Expand Down Expand Up @@ -268,7 +273,8 @@ def download(url):
print("Failed to reach the server: ", e.reason)
elif hasattr(e, "code"):
print(
"The server could not fulfill the request. Error code: ", e.code
"The server could not fulfill the request. Error code: ",
getattr(e, "code"),
)

images = []
Expand All @@ -277,7 +283,9 @@ def download(url):
if parsed_uri.scheme in ("http", "https"):
images.append(download(uri))
else:
filesystem, path = fs.FileSystem.from_uri(uri.as_py())
filesystem, path = fs.FileSystem.from_uri( # pyright: ignore[reportPrivateImportUsage]
uri.as_py()
)
with filesystem.open_input_stream(path) as f:
images.append(f.read())

Expand All @@ -297,7 +305,7 @@ def __repr__(self):
def pillow_metadata_decoder(images):
import io

from PIL import Image
from PIL import Image # pyright: ignore[reportMissingImports]

img = Image.open(io.BytesIO(images[0].as_py()))
return img
Expand Down Expand Up @@ -365,22 +373,27 @@ def to_tensor(

if not decoder:

def pillow_decoder(images):
def pillow_decoder(images) -> "np.ndarray":
import io

from PIL import Image
from PIL import Image # pyright: ignore[reportMissingImports]

return np.stack(
[Image.open(io.BytesIO(img)) for img in images.to_pylist()]
[
np.asarray(Image.open(io.BytesIO(img)))
for img in images.to_pylist()
]
)

def tensorflow_decoder(images):
def tensorflow_decoder(images) -> "np.ndarray":
import tensorflow as tf

decoded_to_tensor = tuple(
tf.io.decode_image(img) for img in images.to_pylist()
)
return tf.stack(decoded_to_tensor, axis=0).numpy()
return tf.stack( # pyright: ignore[reportOptionalCall]
decoded_to_tensor, axis=0
).numpy()

decoders = [
("tensorflow", tensorflow_decoder),
Expand All @@ -401,9 +414,10 @@ def tensorflow_decoder(images):

image_array = decoder(self.storage)
if isinstance(image_array, pa.FixedShapeTensorType):
shape = image_array.shape
arrow_type = image_array.storage_type
tensor_array = image_array
tensor = typing.cast("pa.Array", image_array)
shape = tensor.shape
arrow_type = tensor.storage_type
tensor_array = tensor
else:
shape = image_array.shape[1:]
arrow_type = pa.from_numpy_dtype(image_array.dtype)
Expand Down Expand Up @@ -476,7 +490,7 @@ def to_encoded(self, encoder=None, storage_type=pa.binary()) -> "EncodedImageArr
def pillow_encoder(x):
import io

from PIL import Image
from PIL import Image # pyright: ignore[reportMissingImports]

encoded_images = []
for y in x:
Expand Down Expand Up @@ -571,7 +585,8 @@ def cast(
+ f"got: {target_type}"
)
np_arr = arr.to_numpy()
float_arr = np_arr.astype(target_type.to_pandas_dtype())
float_type = typing.cast("pa.DataType", target_type)
float_arr = np_arr.astype(float_type.to_pandas_dtype())
return pa.array(float_arr)
elif isinstance(target_type, BFloat16Type) or target_type in ["bfloat16", "bf16"]:
if not pa.types.is_floating(arr.type):
Expand All @@ -586,15 +601,16 @@ def cast(
target_type
):
# Casting fixed size list to fixed size list
if arr.type.list_size != target_type.list_size:
list_type = typing.cast("pa.DataType", target_type)
if arr.type.list_size != list_type.list_size:
raise ValueError(
"Only support casting fixed size list to fixed size list "
f"with the same size, got: {arr.type} to {target_type}"
)
values = cast(arr.values, target_type.value_type)
values = cast(arr.values, list_type.value_type)
return pa.FixedSizeListArray.from_arrays(
values=values, list_size=target_type.list_size
values=values, list_size=list_type.list_size
)

# Fallback to normal cast.
return pa.compute.cast(arr, target_type, *args, **kwargs)
return pc.cast(arr, target_type, *args, **kwargs)
2 changes: 1 addition & 1 deletion python/python/lance/lance/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ class BFloat16:
def __gt__(self, other: BFloat16) -> bool: ...
def __ge__(self, other: BFloat16) -> bool: ...

def bfloat16_array(values: List[str | None]) -> BFloat16Array: ...
def bfloat16_array(values: Sequence[float | None]) -> BFloat16Array: ...

class PyFullTextQuery:
@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion python/python/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from pathlib import Path

import lance
import lance.arrow
import numpy as np
import pandas as pd
import pyarrow as pa
import pytest
import pytest # pyright: ignore[reportMissingImports]
from lance.arrow import (
BFloat16,
BFloat16Array,
Expand Down
Loading