# =============================================================================
# SYNTHETIC COUPLED OCEAN–ATMOSPHERE + DRIFTER (FIXED VERSION)
# St. Lawrence Estuary (2.5 km grid)
# Author: Hatem Yazidi - May 2025
# =============================================================================
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# =============================================================================
# 1. GRID (2.5 km)
# =============================================================================
dx_km = 2.5
deg_per_km = 1 / 111.0
ddeg = dx_km * deg_per_km
lon = np.arange(-69.5, -67.0, ddeg) # This is for the St. Lawrence Estuary region
lat = np.arange(47.8, 49.5, ddeg) # This is for the St. Lawrence Estuary region
nt = 48 # Time steps (48 hours at 1-hour intervals)
time = np.arange(nt) # 0, 1, 2, ..., 47 hours
ny, nx = len(lat), len(lon) # Number of grid points in lat and lon
lon2d, lat2d = np.meshgrid(lon, lat) # 2D grids for lat and lon (shape: ny x nx)
# =============================================================================
# 2. FIXED OCEAN FIELDS (IMPORTANT: correct 3D shape)
# =============================================================================
uo_2d = 0.3 * np.sin(2*np.pi*lat2d/10) # This is a simple synthetic pattern for ocean surface currents (shape: ny x nx)
vo_2d = 0.1 * np.cos(2*np.pi*lon2d/10) # We need to expand these 2D fields to 3D (time, lat, lon) by repeating them along the time dimension
uo = np.repeat(uo_2d[None, :, :], nt, axis=0) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of uo_2d
vo = np.repeat(vo_2d[None, :, :], nt, axis=0) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of vo_2d
# FIXED SST (THIS WAS YOUR ERROR SOURCE)
tos_1d = 285 + 2*np.sin(2*np.pi*time/24) # This is a simple synthetic pattern for sea surface temperature (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
tos = np.repeat(tos_1d[:, None, None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of tos_1d repeated across all latitudes
tos = np.repeat(tos, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of tos_1d repeated across all longitudes
sos = np.full((nt, ny, nx), 30.0) # This creates a 3D array of shape (nt, ny, nx) filled with the value 30.0 for sea surface salinity (shape: nt, ny, nx)
h_ml = np.full((nt, ny, nx), 20.0) # This creates a 3D array of shape (nt, ny, nx) filled with the value 20.0 for mixed layer depth (shape: nt, ny, nx)
ocean = xr.Dataset(
{
"uo": (["time","lat","lon"], uo),
"vo": (["time","lat","lon"], vo),
"tos": (["time","lat","lon"], tos),
"sos": (["time","lat","lon"], sos),
"h_ml": (["time","lat","lon"], h_ml),
},
coords={"time": time, "lat": lat, "lon": lon}
)
ocean.to_netcdf("ocean.nc")
# =============================================================================
# 3. ATMOSPHERE (FIXED SHAPES)
# =============================================================================
ua = 5 + 1*np.sin(2*np.pi*time/24) # This is a simple synthetic pattern for atmospheric zonal wind (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
va = 2 + 0*time # This is a simple synthetic pattern for atmospheric meridional wind (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
ta = 280 + 0.5*np.cos(2*np.pi*time/24) # This is a simple synthetic pattern for atmospheric temperature (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
qa = 0.005 + 0*time # This is a simple synthetic pattern for atmospheric specific humidity (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
ua = np.repeat(ua[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of ua repeated across all latitudes
ua = np.repeat(ua, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of ua repeated across all longitudes
va = np.repeat(va[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of va repeated across all latitudes
va = np.repeat(va, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of va repeated across all longitudes
ta = np.repeat(ta[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of ta repeated across all latitudes
ta = np.repeat(ta, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of ta repeated across all longitudes
qa = np.repeat(qa[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of qa repeated across all latitudes
qa = np.repeat(qa, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of qa repeated across all longitudes
sw = 200 + 50*np.sin(2*np.pi*time/24) # This is a simple synthetic pattern for downward shortwave radiation (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
lw = 300 + 0*time # This is a simple synthetic pattern for downward longwave radiation (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
precip = 0*time # This is a simple synthetic pattern for precipitation (shape: nt,). We need to expand this to 3D (time, lat, lon) by repeating it along the lat and lon dimensions
sw = np.repeat(sw[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of sw repeated across all latitudes
sw = np.repeat(sw, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of sw repeated across all longitudes
lw = np.repeat(lw[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of lw repeated across all latitudes
lw = np.repeat(lw, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of lw repeated across all longitudes
precip = np.repeat(precip[:,None,None], ny, axis=1) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of precip repeated across all latitudes
precip = np.repeat(precip, nx, axis=2) # This creates a 3D array of shape (nt, ny, nx) where each time slice is the same 2D pattern of precip repeated across all longitudes
atm = xr.Dataset(
{
"ua": (["time","lat","lon"], ua),
"va": (["time","lat","lon"], va),
"ta": (["time","lat","lon"], ta),
"qa": (["time","lat","lon"], qa),
"sw_down": (["time","lat","lon"], sw),
"lw_down": (["time","lat","lon"], lw),
"precip": (["time","lat","lon"], precip),
},
coords={"time": time, "lat": lat, "lon": lon}
)
atm.to_netcdf("atmosphere.nc")
# =============================================================================
# 4. DRIFTER MODEL
# =============================================================================
def interp(field, lon_p, lat_p):
return field.interp(
lon=[float(lon_p)],
lat=[float(lat_p)]
)
lon_p = -68.383276
lat_p = 48.629327
traj_lon = []
traj_lat = []
dt = 3600
windage = 0.02
for t in range(nt):
O = ocean.isel(time=t)
A = atm.isel(time=t)
uo = interp(O["uo"], lon_p, lat_p).values.item()
vo = interp(O["vo"], lon_p, lat_p).values.item()
ua = interp(A["ua"], lon_p, lat_p).values.item()
va = interp(A["va"], lon_p, lat_p).values.item()
u = uo + windage * (ua - uo)
v = vo + windage * (va - vo)
R = 6371000
dlat = (v * dt) / R * (180/np.pi)
dlon = (u * dt) / (R * np.cos(np.deg2rad(lat_p))) * (180/np.pi)
lat_p += dlat
lon_p += dlon
traj_lon.append(lon_p)
traj_lat.append(lat_p)
traj_lon = np.array(traj_lon)
traj_lat = np.array(traj_lat)
# =============================================================================
# 5. CARTOPY PLOT (ZOOM + COASTLINE)
# =============================================================================
lon_min, lon_max = -69.5, -67.0
lat_min, lat_max = 47.8, 49.5
fig = plt.figure(figsize=(9,7))
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent([lon_min, lon_max, lat_min, lat_max])
ax.add_feature(cfeature.LAND, facecolor="lightgray")
ax.add_feature(cfeature.COASTLINE, linewidth=1.2)
ax.add_feature(cfeature.OCEAN, facecolor="white")
gl = ax.gridlines(draw_labels=True, linestyle="--", alpha=0.5)
gl.top_labels = False
gl.right_labels = False
# trajectory cloud (simple 1 particle here)
ax.plot(traj_lon, traj_lat,
color="black",
transform=ccrs.PlateCarree(),
linewidth=2)
ax.scatter(traj_lon[0], traj_lat[0],
color="green", s=60,
transform=ccrs.PlateCarree(),
label="Start")
ax.scatter(traj_lon[-1], traj_lat[-1],
color="red", s=60,
transform=ccrs.PlateCarree(),
label="End")
ax.set_title("48h Drifter – St. Lawrence Synthetic Coupled Model")
ax.legend()
plt.show()