Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
from narwhals._typing_compat import assert_never
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
is_list_of,
no_default,
not_implemented,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError, ShapeError

if TYPE_CHECKING:
Expand Down Expand Up @@ -61,7 +63,7 @@
)
from narwhals._compliant.series import HistData
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
ClosedInterval,
Expand Down Expand Up @@ -569,7 +571,9 @@ def is_nan(self) -> Self:
return self._with_native(pc.is_nan(self.native), preserve_broadcast=True)

def cast(self, dtype: IntoDType) -> Self:
data_type = narwhals_to_native_dtype(dtype, self._version)
if (version := self._version) != Version.V1:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
data_type = narwhals_to_native_dtype(dtype, version)
return self._with_native(pc.cast(self.native, data_type), preserve_broadcast=True)

def null_count(self, *, _return_py_scalar: bool = True) -> int:
Expand Down
22 changes: 17 additions & 5 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
Version,
generate_temporary_column_name,
no_default,
not_implemented,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -40,7 +42,7 @@
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
Expand Down Expand Up @@ -613,11 +615,21 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
)

def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
def func(df: DaskLazyFrame) -> list[dx.Series]:
if (version := self._version) != Version.V1 and dtype.is_numeric():
schema = df.schema
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

return self._with_callable(func)
native_dtype = narwhals_to_native_dtype(dtype, version)
return [expr.astype(native_dtype) for expr in self._call(df)]

return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

def is_finite(self) -> Self:
import dask.array as da
Expand Down
24 changes: 20 additions & 4 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, extend_bool, no_default
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -42,6 +43,12 @@
from narwhals._utils import _LimitedContext
from narwhals.typing import FillNullStrategy, IntoDType, RollingInterpolationMethod

try:
import duckdb.sqltypes as duckdb_dtypes
except ModuleNotFoundError:
# DuckDB pre 1.3
import duckdb.typing as duckdb_dtypes

DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression]
DuckDBWindowInputs = WindowInputs[Expression]

Expand Down Expand Up @@ -266,14 +273,23 @@ def _fill_constant(expr: Expression, value: Any) -> Expression:
return self._with_elementwise(_fill_constant, value=value)

def cast(self, dtype: IntoDType) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
def _validated_dtype(
dtype: IntoDType, df: DuckDBLazyFrame
) -> duckdb_dtypes.DuckDBPyType:
if (version := self._version) != Version.V1 and dtype.is_numeric():
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return narwhals_to_native_dtype(dtype, version, tz)

def func(df: DuckDBLazyFrame) -> list[Expression]:
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self(df)]

def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]

return self.__class__(
Expand Down
20 changes: 15 additions & 5 deletions narwhals/_ibis/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
not_implemented,
zip_strict,
)
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
Expand Down Expand Up @@ -269,12 +270,21 @@ def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value:
return self._with_callable(_fill_null, value=value)

def cast(self, dtype: IntoDType) -> Self:
def _func(expr: ir.Column) -> ir.Value:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
# ibis `cast` overloads do not include DataType, only literals
return expr.cast(native_dtype) # type: ignore[unused-ignore]
def func(df: IbisLazyFrame) -> list[ir.Value]:
if (version := self._version) != Version.V1 and dtype.is_numeric():
schema = df.collect_schema()
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(source=schema[name], target=dtype)

native_dtype = narwhals_to_native_dtype(dtype, version)
return [expr.cast(native_dtype) for expr in self(df)] # pyright: ignore[reportArgumentType, reportCallIssue]

return self._with_callable(_func)
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

def is_unique(self) -> Self:
return self._with_callable(
Expand Down
9 changes: 6 additions & 3 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
set_index,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import Implementation, is_list_of, no_default, parse_version
from narwhals._utils import Implementation, Version, is_list_of, no_default, parse_version
from narwhals.dependencies import is_numpy_array_1d, is_pandas_like_series
from narwhals.dtypes import _validate_cast_temporal_to_numeric
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
Expand All @@ -44,7 +45,7 @@
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.typing import NativeSeriesT
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
ClosedInterval,
Expand Down Expand Up @@ -308,6 +309,8 @@ def _scatter_in_place(self, indices: Self, values: Self) -> None:
self.native.iloc[indices.native] = values_native

def cast(self, dtype: IntoDType) -> Self:
if (version := self._version) != Version.V1:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
if self.dtype == dtype and self.native.dtype != "object":
# Avoid dealing with pandas' type-system if we can. Note that it's only
# safe to do this if we're not starting with object dtype, see tests/expr_and_series/cast_test.py::test_cast_object_pandas
Expand All @@ -317,7 +320,7 @@ def cast(self, dtype: IntoDType) -> Self:
dtype,
dtype_backend=get_dtype_backend(self.native.dtype, self._implementation),
implementation=self._implementation,
version=self._version,
version=version,
)
return self._with_native(self.native.astype(pd_dtype), preserve_broadcast=True)

Expand Down
9 changes: 6 additions & 3 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
narwhals_to_native_dtype,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation, no_default, requires
from narwhals._utils import Implementation, Version, no_default, requires
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
from narwhals.dtypes import _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
Expand All @@ -35,7 +36,7 @@
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._typing import NoDefault
from narwhals._utils import Version, _LimitedContext
from narwhals._utils import _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Expand Down Expand Up @@ -288,7 +289,9 @@ def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
return self._from_native_object(self.native.__getitem__(item))

def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
if (version := self._version) != Version.V1:
_validate_cast_temporal_to_numeric(source=self.dtype, target=dtype)
dtype_pl = narwhals_to_native_dtype(dtype, version)
return self._with_native(self.native.cast(dtype_pl))

def clip(self, lower_bound: PolarsSeries, upper_bound: PolarsSeries) -> Self:
Expand Down
31 changes: 23 additions & 8 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import operator
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast

from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
Expand All @@ -23,6 +24,7 @@
not_implemented,
zip_strict,
)
from narwhals.dtypes import DType, _validate_cast_temporal_to_numeric

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
Expand All @@ -40,6 +42,7 @@
)
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.utils import _NativeDType
from narwhals._typing import NoDefault
from narwhals._utils import _LimitedContext
from narwhals.typing import FillNullStrategy, IntoDType, RankMethod
Expand Down Expand Up @@ -246,19 +249,31 @@ def __invert__(self) -> Self:
return self._with_elementwise(invert)

def cast(self, dtype: IntoDType) -> Self:
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
def _validated_dtype(dtype: IntoDType, df: SparkLikeLazyFrame) -> _NativeDType:
if (version := self._version) != Version.V1 and dtype.is_numeric():
schema: dict[str, DType] = {}
with suppress(Exception):
schema = df.collect_schema()

if schema:
for name in self._evaluate_output_names(df):
_validate_cast_temporal_to_numeric(
source=schema[name], target=dtype
)
Comment on lines +254 to +262
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not particularly proud of this piece of code 🤔


return narwhals_to_native_dtype(
dtype, version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self(df)]

def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self(df)]

def window_f(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
native_dtype = _validated_dtype(dtype, df)
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]

return self.__class__(
func,
Expand Down
23 changes: 23 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ def _validate_into_dtype(dtype: Any) -> None:
raise TypeError(msg)


def _validate_cast_temporal_to_numeric(
source: DType | type[DType], target: IntoDType
) -> None:
"""Validate that we're not casting from temporal to numeric types.

Arguments:
source: The source data type.
target: The target data type to cast to.

Raises:
InvalidOperationError: If attempting to cast from temporal to integer.
"""
if source.is_temporal() and target.is_numeric():
msg = (
"Casting from temporal type to numeric is not supported.\n\n"
"Hint: Use `.dt` accessor methods instead, such as:\n"
" - `.dt.timestamp()` for Unix timestamp.\n"
" - `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.\n"
" - `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time."
)
raise InvalidOperationError(msg)


class DTypeClass(type):
"""Metaclass for DType classes.

Expand Down
9 changes: 9 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ def cast(self, dtype: IntoDType) -> Self:
Arguments:
dtype: Data type that the object will be cast into.

Note:
Unlike polars, we don't allow to cast from a temporal to a numeric data type.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL: polars allows also casting to Float, not only to Integer


Use `.dt` accessor methods instead, such as:

* `.dt.timestamp()` for Unix timestamp.
* `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.
* `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time.

Examples:
>>> import pandas as pd
>>> import narwhals as nw
Expand Down
9 changes: 9 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,15 @@ def cast(self, dtype: IntoDType) -> Self:
Arguments:
dtype: Data type that the object will be cast into.

Note:
Unlike polars, we don't allow to cast from a temporal to a numeric data type.

Use `.dt` accessor methods instead, such as:

* `.dt.timestamp()` for Unix timestamp.
* `.dt.year()`, `.dt.month()`, `.dt.day()`, ..., for date components.
* `.dt.total_seconds()`, `.dt.total_milliseconds(), ..., for duration total time.

Examples:
>>> import pyarrow as pa
>>> import narwhals as nw
Expand Down
Loading
Loading