-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: where for cupy.array
#11026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: where for cupy.array
#11026
Conversation
|
|
||
| if xp == np: | ||
| # numpy currently doesn't have a astype: | ||
| if xp is np or not hasattr(xp, "astype"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if xp is np or not hasattr(xp, "astype"): | |
| if not hasattr(xp, "astype"): |
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No objections, what I originally did until I saw this below:
xarray/xarray/core/duck_array_ops.py
Line 275 in 3c6b050
| if xp is np or not hasattr(xp, "astype"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in abd935a
max-sixty
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I don't know this code at all, looks very reasonable, if anyone knows it better...)
This reverts commit abd935a.
|
Reverted abd935a as it made a lot of tests fail. |
whats-new.rstapi.rstNoticed the following will fail; I copied the logic from the
asarrayfunction just below, and added a simple test.--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[1], line 11 8 mask = xp.isnan(arr) 10 da = xr.DataArray(arr, dims=("x", "y"), name="example") ---> 11 da.where(mask, 0) File [~/projects/xarray/xarray/core/common.py:1251](http://localhost:8888/home/shh/projects/xarray/xarray/core/common.py#line=1250), in DataWithCoords.where(self, cond, other, drop) 1247 cond = cond.isel(**indexers) 1249 from xarray.computation import ops -> 1251 return ops.where_method(self, cond, other) File [~/projects/xarray/xarray/computation/ops.py:184](http://localhost:8888/home/shh/projects/xarray/xarray/computation/ops.py#line=183), in where_method(self, cond, other) 182 # alignment for three arguments is complicated, so don't support it yet 183 join: Literal["inner", "exact"] = "inner" if other is dtypes.NA else "exact" --> 184 return apply_ufunc( 185 duck_array_ops.where_method, 186 self, 187 cond, 188 other, 189 join=join, 190 dataset_join=join, 191 dask="allowed", 192 keep_attrs=True, 193 ) File [~/projects/xarray/xarray/computation/apply_ufunc.py:1269](http://localhost:8888/home/shh/projects/xarray/xarray/computation/apply_ufunc.py#line=1268), in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, on_missing_core_dim, *args) 1267 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc 1268 elif any(isinstance(a, DataArray) for a in args): -> 1269 return apply_dataarray_vfunc( 1270 variables_vfunc, 1271 *args, 1272 signature=signature, 1273 join=join, 1274 exclude_dims=exclude_dims, 1275 keep_attrs=keep_attrs, 1276 ) 1277 # feed Variables directly through apply_variable_ufunc 1278 elif any(isinstance(a, Variable) for a in args): File [~/projects/xarray/xarray/computation/apply_ufunc.py:312](http://localhost:8888/home/shh/projects/xarray/xarray/computation/apply_ufunc.py#line=311), in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args) 307 result_coords, result_indexes = build_output_coords_and_indexes( 308 args, signature, exclude_dims, combine_attrs=keep_attrs 309 ) 311 data_vars = [getattr(a, "variable", a) for a in args] --> 312 result_var = func(*data_vars) 314 out: tuple[DataArray, ...] | DataArray 315 if signature.num_outputs > 1: File [~/projects/xarray/xarray/computation/apply_ufunc.py:820](http://localhost:8888/home/shh/projects/xarray/xarray/computation/apply_ufunc.py#line=819), in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args) 815 elif vectorize: 816 func = _vectorize( 817 func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims 818 ) --> 820 result_data = func(*input_data) 822 if signature.num_outputs == 1: 823 result_data = (result_data,) File [~/projects/xarray/xarray/core/duck_array_ops.py:414](http://localhost:8888/home/shh/projects/xarray/xarray/core/duck_array_ops.py#line=413), in where_method(data, cond, other) 412 if other is dtypes.NA: 413 other = dtypes.get_fill_value(data.dtype) --> 414 return where(cond, data, other) File [~/projects/xarray/xarray/core/duck_array_ops.py:404](http://localhost:8888/home/shh/projects/xarray/xarray/core/duck_array_ops.py#line=403), in where(condition, x, y) 402 condition = asarray(condition, dtype=dtype, xp=xp) 403 else: --> 404 condition = astype(condition, dtype=dtype, xp=xp) 406 promoted_x, promoted_y = as_shared_dtype([x, y], xp=xp) 408 return xp.where(condition, promoted_x, promoted_y) File [~/projects/xarray/xarray/core/duck_array_ops.py:258](http://localhost:8888/home/shh/projects/xarray/xarray/core/duck_array_ops.py#line=257), in astype(data, dtype, xp, **kwargs) 256 if xp is np: 257 return data.astype(dtype, **kwargs) --> 258 return xp.astype(data, dtype, **kwargs) File [~/.local/conda/envs/holoviz/lib/python3.13/site-packages/cupy/__init__.py:1087](http://localhost:8888/home/shh/.local/conda/envs/holoviz/lib/python3.13/site-packages/cupy/__init__.py#line=1086), in __getattr__(name) 1084 if name in _deprecated_apis: 1085 return getattr(_numpy, name) -> 1087 raise AttributeError(f"module 'cupy' has no attribute {name!r}") AttributeError: module 'cupy' has no attribute 'astype'