#!/usr/bin/env python # Wenchang Yang (wenchang@princeton.edu) # Wed Dec 15 22:09:06 EST 2021 if __name__ == '__main__': import sys from misc.timer import Timer s = ' ' tt = Timer(f'start {s.join(sys.argv)}') import sys, os.path, os, glob, datetime, re import xarray as xr, numpy as np, pandas as pd, matplotlib.pyplot as plt #more imports from misc.landmask import flagland # if __name__ == '__main__': tt.check('end import') # #start from here expname = sys.argv[1] season_size = 3 dsname = 'PI' daname = 'vmax' #expname = 'hist-nat' #from sys.argv[1] units = 'm/s' time_span = slice('1981', '2014') years = slice(1982, 2014) idir = f'/tigress/wenchang/data/cmip6/variables/{dsname}/{expname}/wy_regrid_all_members' func_years = lambda ncfile: ncfile.split('.')[-2].split('-') # e.g. ts.historical.UKESM1-0-LL.r9i1p1f2.850hPa.1850-2014.nc -> ['1850', '2014'] ncfiles = [f for f in os.listdir(idir) if f.endswith('.nc') and func_years(f)[0]<=time_span.start and func_years(f)[1]>=time_span.stop ] #sort by model/member func_key = lambda ncfile: ncfile.split('.')[2:3] + [int(s) for s in re.split('[ipf]', ncfile.split('.')[3][1:])] #used in model/member sorting ncfiles.sort(key=func_key) #for ncfile in ifiles: # print(ncfile) """ #hem = 'NH' #from sys.argv[2] if hem == 'NH': latmin, latmax = 10, 30 lonmin, lonmax = 40, 360-20 elif hem == 'SH': latmin, latmax = -30, -10 lonmin, lonmax = 30, 360-150 """ mask_land = lambda da: da.where(flagland(da)<0.1) #season = 'ASO' #from sys.argv[3] if season_size == 3: isSeasonSH = lambda month: (month>=2)&(month<=4) #FMA isSeasonNH = lambda month: (month>=8)&(month<=10) #ASO season = 'FMA-ASO' elif season_size == 4: isSeasonSH = lambda month: (month==12)|((month>=1)&(month<=4)) #DJFMA isSeasonNH = lambda month: (month>=6)&(month<=10) #JJASO season = 'DJFMA-JJASO' print(idir) print(season) n_files = len(ncfiles) ofile = f'{dsname}_{expname}_{season}_ens{n_files}_{years.start}-{years.stop}.nc' if os.path.exists(ofile): print('[exists]:', ofile) sys.exit() long_name = f'{expname} {season} PI vmax' das = [] models = [] members = [] model_members = [] note = '' for ii,ncfile in enumerate(ncfiles, start=1): model,member = ncfile.split('.')[2:4] year_end = ncfile.split('.')[4].split('-')[-1] print(f'{ii:3d} of {n_files:3d}: {daname} {expname} {model}, {member}, {ncfile}') ifile = os.path.join(idir, ncfile) da = xr.open_dataset(ifile)[daname] da_seasonSH = da.sel(time=isSeasonSH(da.time.dt.month)).sel(time=time_span) da_seasonNH = da.sel(time=isSeasonNH(da.time.dt.month)).sel(time=time_span) if ii==1: #check data selection print('**check data selection**:') print(da_seasonSH) print(da_seasonNH) da_seasonSH.load() da_seasonNH.load() #print(da) da_seasonSH = da_seasonSH.pipe(mask_land) da_seasonNH = da_seasonNH.pipe(mask_land) if ii==1: #check land mask da_seasonSH.isel(time=0).plot() plt.title(f'land mask check: {long_name}') plt.savefig(ofile.replace('.nc', '.landmask.png')) #da.isel(time=0).plot() ##spatial and seasonal mean #import geoxarray #_da = da.geo.fldmean() \ # .groupby('time.year').mean('time').sel(year=years) #_da.plot() #da = da.weighted(np.cos(np.deg2rad(da.lat))).mean(['lat', 'lon']) \ #seasonal mean if season_size == 4: #season SH is 'DJFMA' da_seasonSH['time'] = da_seasonSH.indexes['time'].shift(1, 'MS').values #shift time to move year0 Dec to year1 Jan if ii==1: print(f'**shifted time coords (1MS)**: {da_seasonSH.time.values}') da_seasonSH = da_seasonSH.groupby('time.year').mean('time').sel(year=years) da_seasonNH = da_seasonNH.groupby('time.year').mean('time').sel(year=years) da = da_seasonSH.where(da_seasonSH.lat<0, other=da_seasonNH) #merge different seasons from SH and NH #da.plot(ls='--') das.append(da) models.append(model) members.append(member) model_members.append(f'{model}_{member}') print('concat...') da = xr.concat(das, dim=pd.Index(model_members, name='model_member')) \ .assign_coords(model=('model_member', models)) \ .assign_coords(members=('model_member', members)) \ .assign_attrs(note=note) \ .assign_attrs(long_name=long_name) \ .assign_attrs(units=units) da.to_dataset(name=daname).to_netcdf(ofile) print('[saved]:', ofile) if __name__ == '__main__': #from wyconfig import * #my plot settings #savefig if len(sys.argv)>1 and 'savefig' in sys.argv[1:]: figname = __file__.replace('.py', f'.png') wysavefig(figname) tt.check(f'**Done**') #plt.show()