diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index e726139f29..27f290f435 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -28,7 +28,7 @@ evaluate_output_names_and_aliases, ) from narwhals._sql.namespace import SQLNamespace -from narwhals._utils import Implementation, zip_strict +from narwhals._utils import Implementation, requires, zip_strict if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -187,6 +187,7 @@ def func(df: DuckDBLazyFrame) -> list[Expression]: version=self._version, ) + @requires.backend_version((1, 3)) def struct(self, *exprs: DuckDBExpr) -> DuckDBExpr: version = self._version diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index e0e0c8c857..d7ac768928 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -2,7 +2,7 @@ import operator from collections.abc import Callable -from functools import lru_cache +from functools import lru_cache, partial from importlib import import_module from operator import attrgetter from types import ModuleType @@ -182,15 +182,13 @@ def narwhals_to_native_dtype( # noqa: C901 return native.ArrayType( elementType=narwhals_to_native_dtype(dtype.inner, version, native, session) ) - if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover + if isinstance_or_issubclass(dtype, dtypes.Struct): + to_native = partial( + narwhals_to_native_dtype, version=version, spark_types=native, session=session + ) return native.StructType( fields=[ - native.StructField( - name=field.name, - dataType=narwhals_to_native_dtype( - field.dtype, version, native, session - ), - ) + native.StructField(name=field.name, dataType=to_native(field.dtype)) for field in dtype.fields ] ) diff --git a/tests/expr_and_series/cast_test.py b/tests/expr_and_series/cast_test.py index cfadaff347..90282ea596 100644 --- a/tests/expr_and_series/cast_test.py +++ b/tests/expr_and_series/cast_test.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime, time, timedelta, timezone -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import pytest @@ -19,7 +19,6 @@ if TYPE_CHECKING: from collections.abc import Mapping - from narwhals._native import NativeSQLFrame from narwhals.typing import NonNestedDType DATA = { @@ -270,7 +269,7 @@ def test_cast_datetime_utc( def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any(backend in str(constructor) for backend in ("dask", "cudf", "sqlframe")): + if any(backend in str(constructor) for backend in ("dask", "cudf")): request.applymarker(pytest.mark.xfail) if "pandas" in str(constructor): @@ -278,40 +277,27 @@ def test_cast_struct(request: pytest.FixtureRequest, constructor: Constructor) - pytest.skip() pytest.importorskip("pyarrow") - data = { - "a": [{"movie ": "Cars", "rating": 4.5}, {"movie ": "Toy Story", "rating": 4.9}] - } + from_dtype = nw.Struct( + [nw.Field("movie", nw.String()), nw.Field("rating", nw.Float64())] + ) - native_df = constructor(data) - - # NOTE: This branch needs to be rewritten to **not depend** on private `SparkLikeLazyFrame` properties - if "spark" in str(constructor): # pragma: no cover - # Special handling for pyspark as it natively maps the input to - # a column of type MAP - native_ldf = cast("NativeSQLFrame", native_df) - _tmp_nw_compliant_frame = nw.from_native(native_ldf)._compliant_frame - F = _tmp_nw_compliant_frame._F # type: ignore[attr-defined] - T = _tmp_nw_compliant_frame._native_dtypes # type: ignore[attr-defined] # noqa: N806 - - native_ldf = native_ldf.withColumn( - "a", - F.struct( - F.col("a.movie ").cast(T.StringType()).alias("movie "), - F.col("a.rating").cast(T.DoubleType()).alias("rating"), - ), - ) - assert nw.from_native(native_ldf).collect_schema() == nw.Schema( - { - "a": nw.Struct( - [nw.Field("movie ", nw.String()), nw.Field("rating", nw.Float64())] - ) - } + if "spark" in str(constructor): + data = {"movie": ["Cars", "Toy Story"], "rating": [4.5, 4.9]} + dframe = nw.from_native(constructor(data)).select( + a=nw.struct("movie", "rating").cast(from_dtype) ) - native_df = native_ldf - dtype = nw.Struct([nw.Field("movie ", nw.String()), nw.Field("rating", nw.Float32())]) - result = nw.from_native(native_df).select(nw.col("a").cast(dtype)).lazy().collect() - assert result.schema == {"a": dtype} + else: + data = { + "a": [{"movie": "Cars", "rating": 4.5}, {"movie": "Toy Story", "rating": 4.9}] + } + dframe = nw.from_native(constructor(data)).select(nw.col("a").cast(from_dtype)) + + to_dtype = nw.Struct( + [nw.Field("movie", nw.String()), nw.Field("rating", nw.Float32())] + ) + result = dframe.select(nw.col("a").cast(to_dtype)) + assert result.collect_schema() == {"a": to_dtype} def test_raise_if_polars_dtype(constructor: Constructor) -> None: