Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from narwhals._typing_compat import assert_never
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
is_list_of,
no_default,
not_implemented,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError, ShapeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +63,7 @@
)
from narwhals._compliant.series import HistData
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
ClosedInterval,
Expand Down Expand Up @@ -569,7 +571,9 @@ def is_nan(self) -> Self:
return self._with_native(pc.is_nan(self.native), preserve_broadcast=True)

def cast(self, dtype: IntoDType) -> Self:
data_type = narwhals_to_native_dtype(dtype, self._version)
if (version := self._version) != Version.V1:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
data_type = narwhals_to_native_dtype(dtype, version)
return self._with_native(pc.cast(self.native, data_type), preserve_broadcast=True)

def null_count(self, *, _return_py_scalar: bool = True) -> int:
Expand Down
22 changes: 17 additions & 5 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
no_default,
not_implemented,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -40,7 +42,7 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
Expand Down Expand Up @@ -613,11 +615,21 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
)

def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
def func(df: DaskLazyFrame) -> list[dx.Series]:
if (version := self._version) != Version.V1:
schema = df.schema
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

return self._with_callable(func)
native_dtype = narwhals_to_native_dtype(dtype, version)
return [expr.astype(native_dtype) for expr in self._call(df)]

return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

def is_finite(self) -> Self:
import dask.array as da
Expand Down
15 changes: 13 additions & 2 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, extend_bool, no_default
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -267,13 +268,23 @@ def _fill_constant(expr: Expression, value: Any) -> Expression:

def cast(self, dtype: IntoDType) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
if (version := self._version) != Version.V1:
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
native_dtype = narwhals_to_native_dtype(dtype, version, tz)
return [expr.cast(native_dtype) for expr in self(df)]

def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
if (version := self._version) != Version.V1:
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
native_dtype = narwhals_to_native_dtype(dtype, version, tz)
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]

return self.__class__(
Expand Down
20 changes: 15 additions & 5 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
not_implemented,
zip_strict,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -269,12 +270,21 @@ def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value:
return self._with_callable(_fill_null, value=value)

def cast(self, dtype: IntoDType) -> Self:
def _func(expr: ir.Column) -> ir.Value:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
# ibis `cast` overloads do not include DataType, only literals
return expr.cast(native_dtype) # type: ignore[unused-ignore]
def func(df: IbisLazyFrame) -> list[ir.Value]:
if (version := self._version) != Version.V1:
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

native_dtype = narwhals_to_native_dtype(dtype, version)
return [expr.cast(native_dtype) for expr in self(df)] # type: ignore[misc]

return self._with_callable(_func)
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

def is_unique(self) -> Self:
return self._with_callable(
Expand Down
9 changes: 6 additions & 3 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
set_index,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import Implementation, is_list_of, no_default, parse_version
from narwhals._utils import Implementation, Version, is_list_of, no_default, parse_version
from narwhals.dependencies import is_numpy_array_1d, is_pandas_like_series
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -44,7 +45,7 @@
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.typing import NativeSeriesT
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
ClosedInterval,
Expand Down Expand Up @@ -308,6 +309,8 @@ def _scatter_in_place(self, indices: Self, values: Self) -> None:
self.native.iloc[indices.native] = values_native

def cast(self, dtype: IntoDType) -> Self:
if (version := self._version) != Version.V1:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
if self.dtype == dtype and self.native.dtype != "object":
# Avoid dealing with pandas' type-system if we can. Note that it's only
# safe to do this if we're not starting with object dtype, see tests/expr_and_series/cast_test.py::test_cast_object_pandas
Expand All @@ -317,7 +320,7 @@ def cast(self, dtype: IntoDType) -> Self:
dtype,
dtype_backend=get_dtype_backend(self.native.dtype, self._implementation),
implementation=self._implementation,
version=self._version,
version=version,
)
return self._with_native(self.native.astype(pd_dtype), preserve_broadcast=True)

Expand Down
9 changes: 6 additions & 3 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
narwhals_to_native_dtype,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation, no_default, requires
from narwhals._utils import Implementation, Version, no_default, requires
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -35,7 +36,7 @@
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Expand Down Expand Up @@ -288,7 +289,9 @@ def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
return self._from_native_object(self.native.__getitem__(item))

def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
if (version := self._version) != Version.V1:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
dtype_pl = narwhals_to_native_dtype(dtype, version)
return self._with_native(self.native.cast(dtype_pl))

def clip(self, lower_bound: PolarsSeries, upper_bound: PolarsSeries) -> Self:
Expand Down
15 changes: 13 additions & 2 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
not_implemented,
zip_strict,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
Expand Down Expand Up @@ -247,16 +248,26 @@ def __invert__(self) -> Self:

def cast(self, dtype: IntoDType) -> Self:
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
if (version := self._version) != Version.V1:
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
dtype, version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self(df)]

def window_f(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
if (version := self._version) != Version.V1:
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
dtype, version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]

Expand Down
23 changes: 23 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ def _validate_into_dtype(dtype: Any) -> None:
raise TypeError(msg)


def _validate_cast_temporal_to_numeric(
source: DType | type[DType], target: IntoDType
) -> None:
"""Validate that we're not casting from temporal to numeric types.

Arguments:
source: The source data type.
target: The target data type to cast to.

Raises:
InvalidOperationError: If attempting to cast from temporal to integer.
"""
if source.is_temporal() and target.is_numeric():
msg = (
"Casting from temporal type to numeric is not supported.\n\n"
"Hint: Use `.dt` accessor methods instead, such as:\n"
" - `.dt.timestamp()` for Unix timestamp.\n"
" - `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.\n"
" - `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time."
)
raise InvalidOperationError(msg)


class DTypeClass(type):
"""Metaclass for DType classes.

Expand Down
9 changes: 9 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def cast(self, dtype: IntoDType) -> Self:
Arguments:
dtype: Data type that the object will be cast into.

Note:
Unlike polars, we don't allow to cast from a temporal to a numeric data type.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: polars allows also casting to Float, not only to Integer


Use `.dt` accessor methods instead, such as:

* `.dt.timestamp()` for Unix timestamp.
* `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.
* `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time.

Examples:
>>> import pandas as pd
>>> import narwhals as nw
Expand Down
9 changes: 9 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,15 @@ def cast(self, dtype: IntoDType) -> Self:
Arguments:
dtype: Data type that the object will be cast into.

Note:
Unlike polars, we don't allow to cast from a temporal to a numeric data type.

Use `.dt` accessor methods instead, such as:

* `.dt.timestamp()` for Unix timestamp.
* `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.
* `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time.

Examples:
>>> import pyarrow as pa
>>> import narwhals as nw
Expand Down
57 changes: 57 additions & 0 deletions tests/expr_and_series/cast_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import narwhals as nw
from narwhals.exceptions import InvalidOperationError
from tests.utils import (
PANDAS_VERSION,
PYARROW_VERSION,
Expand Down Expand Up @@ -442,3 +443,59 @@ def test_cast_object_pandas() -> None:
s = nw.from_native(pd.DataFrame({"a": [2, 3, None]}, dtype=object))["a"]
assert s[0] == 2
assert s.cast(nw.String)[0] == "2"


NUMERIC_DTYPES = [
nw.Int8,
nw.Int16,
nw.Int32,
nw.Int64,
nw.Float32,
nw.Float64,
nw.UInt32,
nw.UInt64,
]


@pytest.mark.parametrize(
"values", [[datetime(2000, 1, 1, 12, 0), None], [timedelta(365, 59), None]]
)
@pytest.mark.parametrize(("target_dtype"), NUMERIC_DTYPES)
def test_cast_temporal_to_numeric_raises_expr(
constructor: Constructor,
request: pytest.FixtureRequest,
values: list[datetime] | list[timedelta],
target_dtype: nw.dtypes.DType,
) -> None:
if "polars" in str(constructor):
reason = "Polars expressions wrap native expressions"
request.applymarker(pytest.mark.xfail(reason=reason))

if isinstance(values[0], timedelta) and "spark" in str(constructor):
reason = "interval not implemented"
request.applymarker(pytest.mark.xfail(reason=reason))

df = nw.from_native(constructor({"a": values})).lazy()
msg = "Casting from temporal type to numeric"
with pytest.raises(InvalidOperationError, match=msg):
df.select(nw.col("a").cast(target_dtype)).collect()


@pytest.mark.parametrize(
"values",
[
[datetime(2000, 1, 1, 12, 0), datetime(2000, 1, 2, 12, 0), None],
[timedelta(2, 59), timedelta(1, 59), None],
],
)
@pytest.mark.parametrize(("target_dtype"), NUMERIC_DTYPES)
def test_cast_temporal_to_numeric_raises_series(
constructor_eager: ConstructorEager,
values: list[datetime] | list[timedelta],
target_dtype: nw.dtypes.DType,
) -> None:
df = nw.from_native(constructor_eager({"a": values}), eager_only=True)
series = df["a"]
msg = "Casting from temporal type to numeric"
with pytest.raises(InvalidOperationError, match=msg):
series.cast(target_dtype)
Loading