Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ jobs:
name: "TypeChecking: pixi run typing"
runs-on: ubuntu-latest
needs: [should-skip-ci, cache-pixi-lock]
# TODO v4: Enable typechecking again
# needs.should-skip-ci.outputs.value == 'false'
if: false
if: needs.should-skip-ci.outputs.value == 'false'
steps:
- name: Checkout
uses: actions/checkout@v5
Expand Down
18 changes: 13 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,13 @@ known-first-party = ["parcels"]

[tool.mypy]
files = [
"parcels/_typing.py",
"parcels/tools/*.py",
"parcels/grid.py",
"parcels/field.py",
"parcels/fieldset.py",
"src/parcels/_typing.py",
"src/parcels/_core/xgrid.py",
"src/parcels/_core/uxgrid.py",
"src/parcels/_core/field.py",
"src/parcels/_core/fieldset.py",
]
disable_error_code = "attr-defined,assignment,operator,call-overload,index,valid-type,override,misc,union-attr"

[[tool.mypy.overrides]]
module = [
Expand All @@ -174,9 +175,16 @@ module = [
"scipy.spatial",
"sklearn.cluster",
"zarr",
"zarr.storage",
"uxarray",
"xgcm",
"cftime",
"pykdtree.kdtree",
"netCDF4",
"pooch",
]
ignore_missing_imports = true

[[tool.mypy.overrides]] # TODO: This module should stabilize before release of v4
module = "parcels.interpolators"
ignore_errors = true
4 changes: 2 additions & 2 deletions src/parcels/_core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ def __init__(
self.igrid = U.igrid

if W is None:
_assert_same_time_interval((U, V))
_assert_same_time_interval([U, V])
else:
_assert_same_time_interval((U, V, W))
_assert_same_time_interval([U, V, W])

self.time_interval = U.time_interval

Expand Down
16 changes: 13 additions & 3 deletions src/parcels/_core/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from parcels._logger import logger
from parcels._reprs import fieldset_repr
from parcels._typing import Mesh
from parcels.convert import _ds_rename_using_standard_names
from parcels.interpolators import (
CGrid_Velocity,
Ux_Velocity,
Expand Down Expand Up @@ -182,7 +181,7 @@ def add_constant(self, name, value):

@property
def gridset(self) -> list[BaseGrid]:
grids = []
grids: list[BaseGrid] = []
for field in self.fields.values():
if field.grid not in grids:
grids.append(field.grid)
Expand Down Expand Up @@ -416,7 +415,8 @@ def _datetime_to_msg(example_datetime: TimeLike) -> str:
return msg


def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -> str:
def _format_calendar_error_message(field: Field | VectorField, reference_datetime: TimeLike) -> str:
assert field.time_interval is not None
return f"Expected field {field.name!r} to have calendar compatible with datetime object {_datetime_to_msg(reference_datetime)}. Got field with calendar {_datetime_to_msg(field.time_interval.left)}. Have you considered using xarray to update the time dimension of the dataset to have a compatible calendar?"


Expand Down Expand Up @@ -456,6 +456,16 @@ def _format_calendar_error_message(field: Field, reference_datetime: TimeLike) -
}


def _ds_rename_using_standard_names(ds: xr.Dataset | ux.UxDataset, name_dict: dict[str, str]) -> xr.Dataset:
for standard_name, rename_to in name_dict.items():
name = ds.cf[standard_name].name
ds = ds.rename({name: rename_to})
logger.info(
f"cf_xarray found variable {name!r} with CF standard name {standard_name!r} in dataset, renamed it to {rename_to!r} for Parcels simulation."
)
return ds


Comment on lines +459 to +468
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should calling this function not be part of convert?

Once a user requests a FieldSet, do we really want to still change field names? There would be no way to override that for a user anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm only moving these here since its only use is in the ux fieldset ingestion - and that hasn't yet moved to convert. I'm happy either way (I thought having it here with the other functionality was neater, especially since we're probably going to outright drop this functionality due to us not doing U and V discovery in convert.py)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, clear. Then perhaps add a TODO that we will likely drop this function in the near future?

def _discover_ux_U_and_V(ds: ux.UxDataset) -> ux.UxDataset:
# Common variable names for U and V found in UxDatasets
common_ux_UV = [("unod", "vnod"), ("u", "v")]
Expand Down
11 changes: 6 additions & 5 deletions src/parcels/_core/index_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
Expand All @@ -9,8 +8,8 @@
from parcels._core.utils.time import timedelta_to_float

if TYPE_CHECKING:
from parcels import XGrid
from parcels._core.field import Field
from parcels.xgrid import XGrid


GRID_SEARCH_ERROR = -3
Expand All @@ -21,7 +20,7 @@
def _search_1d_array(
arr: np.array,
x: float,
) -> tuple[int, int]:
) -> tuple[np.array[int], np.array[float]]:
"""
Searches for particle locations in a 1D array and returns barycentric coordinate along dimension.

Expand Down Expand Up @@ -63,14 +62,14 @@ def _search_1d_array(
return np.atleast_1d(index), np.atleast_1d(bcoord)


def _search_time_index(field: Field, time: datetime):
def _search_time_index(field: Field, time: float):
"""Find and return the index and relative coordinate in the time array associated with a given time.

Parameters
----------
field: Field

time: datetime
time: float
This is the amount of time, in seconds (time_delta), in unix epoch
Note that we normalize to either the first or the last index
if the sampled value is outside the time value range.
Expand Down Expand Up @@ -172,6 +171,8 @@ def _search_indices_curvilinear_2d(
"""
if np.any(xi):
# If an initial guess is provided, we first perform a point in cell check for all guessed indices
assert xi is not None
assert yi is not None
Comment on lines +174 to +175
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my own understanding; why are these new lines needed?

is_in_cell, coords = curvilinear_point_in_cell(grid, y, x, yi, xi)
y_check = y[is_in_cell == 0]
x_check = x[is_in_cell == 0]
Expand Down
6 changes: 1 addition & 5 deletions src/parcels/_core/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import types
import warnings
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -24,9 +23,6 @@
AdvectionRK45,
)

if TYPE_CHECKING:
from collections.abc import Callable

__all__ = ["Kernel"]


Expand Down Expand Up @@ -84,7 +80,7 @@ def __init__(
# if (pyfunc is AdvectionRK4_3D) and fieldset.U.gridindexingtype == "croco":
# pyfunc = AdvectionRK4_3D_CROCO

self._pyfuncs: list[Callable] = pyfuncs
self._pyfuncs: list[types.FunctionType] = pyfuncs

@property #! Ported from v3. To be removed in v4? (/find another way to name kernels in output file)
def funcname(self):
Expand Down
4 changes: 2 additions & 2 deletions src/parcels/_core/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Variable:
def __init__(
self,
name,
dtype: np.dtype = np.float32,
dtype: type[np.float32 | np.float64 | np.int32 | np.int64] = np.float32,
initial=0,
to_write: bool | Literal["once"] = True,
attrs: dict | None = None,
Expand Down Expand Up @@ -122,7 +122,7 @@ def _assert_no_duplicate_variable_names(*, existing_vars: list[Variable], new_va
raise ValueError(f"Variable name already exists: {var.name}")


def get_default_particle(spatial_dtype: np.float32 | np.float64) -> ParticleClass:
def get_default_particle(spatial_dtype: type[np.float32 | np.float64]) -> ParticleClass:
if spatial_dtype not in [np.float32, np.float64]:
raise ValueError(f"spatial_dtype must be np.float32 or np.float64. Got {spatial_dtype=!r}")

Expand Down
22 changes: 14 additions & 8 deletions src/parcels/_core/particleset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import datetime
import sys
import types
import warnings
from collections.abc import Iterable
from typing import Literal
from typing import TYPE_CHECKING, Literal

import numpy as np
import xarray as xr
Expand All @@ -21,6 +24,9 @@
from parcels._logger import logger
from parcels._reprs import _format_zarr_output_location, particleset_repr

if TYPE_CHECKING:
from parcels import FieldSet, ParticleClass, ParticleFile

__all__ = ["ParticleSet"]


Expand Down Expand Up @@ -58,10 +64,10 @@ class ParticleSet:

def __init__(
self,
fieldset,
pclass=Particle,
lon=None,
lat=None,
fieldset: FieldSet,
pclass: ParticleClass = Particle,
lon: np.array[float] = None,
lat: np.array[float] = None,
z=None,
time=None,
trajectory_ids=None,
Expand Down Expand Up @@ -376,12 +382,12 @@ def set_variable_write_status(self, var, write_status):

def execute(
self,
pyfunc,
pyfunc: types.FunctionType | Kernel,
dt: datetime.timedelta | np.timedelta64 | float,
endtime: np.timedelta64 | np.datetime64 | None = None,
runtime: datetime.timedelta | np.timedelta64 | float | None = None,
output_file=None,
verbose_progress=True,
output_file: ParticleFile = None,
verbose_progress: bool = True,
):
"""Execute a given kernel function over the particle set for multiple timesteps.

Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_core/utils/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def phi1D_quad(xsi: float) -> list[float]:
return phi


def phi2D_lin(eta: float, xsi: float) -> list[float]:
def phi2D_lin(eta: float, xsi: float) -> np.ndarray:
phi = np.column_stack([(1-xsi) * (1-eta),
xsi * (1-eta),
xsi * eta ,
Expand Down
4 changes: 2 additions & 2 deletions src/parcels/_core/utils/sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def to_attrs(self) -> dict[str, str | int]:
d["vertical_dimensions"] = dump_mappings(self.vertical_dimensions)
return d

def rename(self, names_dict: dict[str, str]) -> Self:
def rename(self, names_dict: dict[str, str]) -> Grid2DMetadata:
return _metadata_rename(self, names_dict)

def get_value_by_id(self, id: str) -> str:
Expand Down Expand Up @@ -285,7 +285,7 @@ def to_attrs(self) -> dict[str, str | int]:
d["node_coordinates"] = dump_mappings(self.node_coordinates)
return d

def rename(self, dims_dict: dict[str, str]) -> Self:
def rename(self, dims_dict: dict[str, str]) -> Grid3DMetadata:
return _metadata_rename(self, dims_dict)

def get_value_by_id(self, id: str) -> str:
Expand Down
13 changes: 7 additions & 6 deletions src/parcels/_core/utils/time.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from collections.abc import Callable
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, TypeVar
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

import cftime
import numpy as np
Expand All @@ -14,7 +15,7 @@
T = TypeVar("T", bound="TimeLike")


class TimeInterval:
class TimeInterval(Generic[T]):
"""A class representing a time interval between two datetime or np.timedelta64 objects.

Parameters
Expand All @@ -29,7 +30,7 @@ class TimeInterval:
For the purposes of this codebase, the interval can be thought of as closed on the left and right.
"""

def __init__(self, left: T, right: T) -> None:
def __init__(self, left: T, right: T):
if not isinstance(left, (np.timedelta64, datetime, cftime.datetime, np.datetime64)):
raise ValueError(
f"Expected right to be a np.timedelta64, datetime, cftime.datetime, or np.datetime64. Got {type(left)}."
Expand Down Expand Up @@ -130,7 +131,7 @@ def get_datetime_type_calendar(
return type(example_datetime), calendar


_TD_PRECISION_GETTER_FOR_UNIT = (
_TD_PRECISION_GETTER_FOR_UNIT: tuple[tuple[Callable[[timedelta], int], Literal["D", "s", "us"]], ...] = (
(lambda dt: dt.days, "D"),
(lambda dt: dt.seconds, "s"),
(lambda dt: dt.microseconds, "us"),
Expand All @@ -142,14 +143,14 @@ def maybe_convert_python_timedelta_to_numpy(dt: timedelta | np.timedelta64) -> n
return dt

try:
dts = []
dts: list[np.timedelta64] = []
for get_value_for_unit, np_unit in _TD_PRECISION_GETTER_FOR_UNIT:
value = get_value_for_unit(dt)
if value != 0:
dts.append(np.timedelta64(value, np_unit))

if dts:
return sum(dts)
return np.sum(dts)
else:
return np.timedelta64(0, "s")
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/parcels/_core/uxgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class UxGrid(BaseGrid):
for interpolation on unstructured grids.
"""

def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh) -> UxGrid:
def __init__(self, grid: ux.grid.Grid, z: ux.UxDataArray, mesh):
"""
Initializes the UxGrid with a uxarray grid and vertical coordinate array.

Expand Down
13 changes: 6 additions & 7 deletions src/parcels/_core/xgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
from parcels._core.basegrid import BaseGrid
from parcels._core.index_search import _search_1d_array, _search_indices_curvilinear_2d
from parcels._reprs import xgrid_repr
from parcels._typing import assert_valid_mesh
from parcels._typing import CfAxis, assert_valid_mesh

_XGRID_AXES = Literal["X", "Y", "Z"]
_XGRID_AXES_ORDERING: Sequence[_XGRID_AXES] = "ZYX"

_XGCM_AXIS_DIRECTION = Literal["X", "Y", "Z", "T"]
_XGCM_AXIS_POSITION = Literal["center", "left", "right", "inner", "outer"]
_XGCM_AXES = Mapping[_XGCM_AXIS_DIRECTION, xgcm.Axis]
_XGCM_AXES = Mapping[CfAxis, xgcm.Axis]

_FIELD_DATA_ORDERING: Sequence[_XGCM_AXIS_DIRECTION] = "TZYX"
_FIELD_DATA_ORDERING: Sequence[CfAxis] = "TZYX"

_DEFAULT_XGCM_KWARGS = {"periodic": False}

Expand Down Expand Up @@ -282,7 +281,7 @@ def _gtype(self):

TODO: Remove
"""
from parcels.grid import GridType
from parcels._core.basegrid import GridType

if len(self.lon.shape) <= 1:
if self.depth is None or len(self.depth.shape) <= 1:
Expand Down Expand Up @@ -384,7 +383,7 @@ def get_axis_dim_mapping(self, dims: list[str]) -> dict[_XGRID_AXES, str]:
return result


def get_axis_from_dim_name(axes: _XGCM_AXES, dim: str) -> _XGCM_AXIS_DIRECTION | None:
def get_axis_from_dim_name(axes: _XGCM_AXES, dim: Hashable) -> CfAxis | None:
"""For a given dimension name in a grid, returns the direction axis it is on."""
for axis_name, axis in axes.items():
if dim in axis.coords.values():
Expand Down Expand Up @@ -421,7 +420,7 @@ def assert_valid_field_array(da: xr.DataArray, axes: _XGCM_AXES):
assert_all_dimensions_correspond_with_axis(da, axes)

dim_to_axis = {dim: get_axis_from_dim_name(axes, dim) for dim in da.dims}
dim_to_axis = cast(dict[Hashable, _XGCM_AXIS_DIRECTION], dim_to_axis)
dim_to_axis = cast(dict[Hashable, CfAxis], dim_to_axis)

# Assert all dimensions are present
if set(dim_to_axis.values()) != {"T", "Z", "Y", "X"}:
Expand Down
2 changes: 2 additions & 0 deletions src/parcels/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import numpy as np
from cftime import datetime as cftime_datetime

CfAxis = Literal["X", "Y", "Z", "T"]

InterpMethodOption = Literal[
"linear",
"nearest",
Expand Down
Loading
Loading