import os
from time import sleep
import numpy as np
from . import global_vars
if (global_vars.cupy_enabled):
import cupy as xp
else:
xp = np
from matplotlib import pyplot as plt
from matplotlib import colors as clr
from matplotlib import cm
from matplotlib.colors import Normalize
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.interpolate import interp1d as sp_interp1d
from . import io
def zeros_like(obj):
if (type(obj) is dict):
return {k:zeros_like(v) for k,v in obj.items()}
if (type(obj) is list):
return [zeros_like(a) for a in obj]
if (type(obj) is str):
return obj
return np.zeros_like(obj)
def plus(a,b):
if (type(a) is dict):
return {k:plus(a[k],b[k]) for k in a.keys()}
if (type(a) is list):
return [plus(a[i],b[i]) for i in range(len(a))]
if (type(a) is str):
return a
return a+b
def times(a,b):
if (type(a) is dict):
return {k:times(a[k],b) for k in a.keys()}
if (type(a) is list):
return [times(a[i],b) for i in range(len(a))]
if (type(a) is str):
return a
return a*b
# Convert all binary files in binary path to athdf files in athdf path
[docs]
def bin_to_athdf(binary_fname,athdf_fname):
xdmf_fname = athdf_fname + ".xdmf"
filedata = io.read_binary(binary_fname)
io.write_athdf(athdf_fname, filedata)
io.write_xdmf_for(xdmf_fname, os.path.basename(athdf_fname), filedata)
return
[docs]
def bins_to_athdfs(binpath,athdfpath,overwrite=False,info=True):
if not os.path.isdir(athdfpath):
os.mkdir(athdfpath)
for file in sorted(os.listdir(binpath)):
if file.endswith(".bin"):
binary_fname = os.path.join(binpath, file)
athdf_fname = os.path.join(athdfpath, file.replace(".bin", ".athdf"))
xdmf_fname = athdf_fname + ".xdmf"
if (overwrite or not os.path.exists(athdf_fname) or not os.path.exists(xdmf_fname)):
if (info): print(f"Converting {file}")
filedata = io.read_binary(binary_fname)
io.write_athdf(athdf_fname, filedata)
io.write_xdmf_for(xdmf_fname, os.path.basename(athdf_fname), filedata)
else:
if (info): print(f"Skipping {file}")
return
@np.vectorize
def CoolFnShure_vec(T):
if(np.isnan(T)):
return 0.0
# original data from Shure et al. paper, covers 4.12 < logt < 8.16
lhd = [
-22.5977, -21.9689, -21.5972, -21.4615, -21.4789, -21.5497, -21.6211, -21.6595,
-21.6426, -21.5688, -21.4771, -21.3755, -21.2693, -21.1644, -21.0658, -20.9778,
-20.8986, -20.8281, -20.7700, -20.7223, -20.6888, -20.6739, -20.6815, -20.7051,
-20.7229, -20.7208, -20.7058, -20.6896, -20.6797, -20.6749, -20.6709, -20.6748,
-20.7089, -20.8031, -20.9647, -21.1482, -21.2932, -21.3767, -21.4129, -21.4291,
-21.4538, -21.5055, -21.5740, -21.6300, -21.6615, -21.6766, -21.6886, -21.7073,
-21.7304, -21.7491, -21.7607, -21.7701, -21.7877, -21.8243, -21.8875, -21.9738,
-22.0671, -22.1537, -22.2265, -22.2821, -22.3213, -22.3462, -22.3587, -22.3622,
-22.3590, -22.3512, -22.3420, -22.3342, -22.3312, -22.3346, -22.3445, -22.3595,
-22.3780, -22.4007, -22.4289, -22.4625, -22.4995, -22.5353, -22.5659, -22.5895,
-22.6059, -22.6161, -22.6208, -22.6213, -22.6184, -22.6126, -22.6045, -22.5945,
-22.5831, -22.5707, -22.5573, -22.5434, -22.5287, -22.5140, -22.4992, -22.4844,
-22.4695, -22.4543, -22.4392, -22.4237, -22.4087, -22.3928]
# for temperatures less than 10^4 K, use Koyama & Inutsuka
logt=np.log10(T)
if (logt <= 4.2):
temp = pow(10.0, logt)
return (2.0e-19*np.exp(-1.184e5/(temp + 1.0e3)) + 2.8e-28*np.sqrt(temp)*np.exp(-92.0/temp))
#for temperatures above 10^8.15 use CGOLS fit
if (logt > 8.15): return pow(10.0, (0.45*logt - 26.065))
# in between values of 4.2 < log(T) < 8.15
# linear interpolation of tabulated SPEX cooling rate
ipps = int(25.0*logt) - 103
ipps = ipps if (ipps < 100) else 100
ipps = ipps if (ipps > 0 ) else 0
x0 = 4.12 + 0.04*ipps
dx = logt - x0
tcool = (lhd[ipps+1]*dx - lhd[ipps]*(dx - 0.04))*25.0
return pow(10.0,tcool)
[docs]
def CoolFnShure_numpy(T):
T=np.asarray(T)
# original data from Shure et al. paper, covers 4.12 < logt < 8.16
log_t_tab = np.linspace(4.12,8.16,102,endpoint=True)
log_c_tab = [
-22.5977, -21.9689, -21.5972, -21.4615, -21.4789, -21.5497, -21.6211, -21.6595,
-21.6426, -21.5688, -21.4771, -21.3755, -21.2693, -21.1644, -21.0658, -20.9778,
-20.8986, -20.8281, -20.7700, -20.7223, -20.6888, -20.6739, -20.6815, -20.7051,
-20.7229, -20.7208, -20.7058, -20.6896, -20.6797, -20.6749, -20.6709, -20.6748,
-20.7089, -20.8031, -20.9647, -21.1482, -21.2932, -21.3767, -21.4129, -21.4291,
-21.4538, -21.5055, -21.5740, -21.6300, -21.6615, -21.6766, -21.6886, -21.7073,
-21.7304, -21.7491, -21.7607, -21.7701, -21.7877, -21.8243, -21.8875, -21.9738,
-22.0671, -22.1537, -22.2265, -22.2821, -22.3213, -22.3462, -22.3587, -22.3622,
-22.3590, -22.3512, -22.3420, -22.3342, -22.3312, -22.3346, -22.3445, -22.3595,
-22.3780, -22.4007, -22.4289, -22.4625, -22.4995, -22.5353, -22.5659, -22.5895,
-22.6059, -22.6161, -22.6208, -22.6213, -22.6184, -22.6126, -22.6045, -22.5945,
-22.5831, -22.5707, -22.5573, -22.5434, -22.5287, -22.5140, -22.4992, -22.4844,
-22.4695, -22.4543, -22.4392, -22.4237, -22.4087, -22.3928]
logt2logc = sp_interp1d(log_t_tab,log_c_tab)
logt = np.log10(T)
cool_rate = np.zeros(T.shape)
loc_c = logt <= 4.2
loc_h = logt > 8.15
loc_w = np.logical_not(np.logical_or(loc_c,loc_h))
# for temperatures less than 10^4 K, use Koyama & Inutsuka
temp = T[loc_c]
cool_rate[loc_c] = 2.0e-19*np.exp(-1.184e5/(temp + 1.0e3)) + 2.8e-28*np.sqrt(temp)*np.exp(-92.0/temp)
#for temperatures above 10^8.15 use CGOLS fit
cool_rate[loc_h] = 10.0**(0.45*logt[loc_h] - 26.065)
# in between values of 4.2 < log(T) < 8.15
# linear interpolation of tabulated SPEX cooling rate
cool_rate[loc_w] = 10.0**logt2logc(logt[loc_w])
return cool_rate
[docs]
def CoolFnShure(T):
if (type(T) is not xp.ndarray):
return CoolFnShure_numpy(T)
T=xp.asarray(T)
# original data from Shure et al. paper, covers 4.12 < logt < 8.16
log_t_tab = xp.linspace(4.12,8.16,102,endpoint=True)
log_c_tab = xp.asarray([
-22.5977, -21.9689, -21.5972, -21.4615, -21.4789, -21.5497, -21.6211, -21.6595,
-21.6426, -21.5688, -21.4771, -21.3755, -21.2693, -21.1644, -21.0658, -20.9778,
-20.8986, -20.8281, -20.7700, -20.7223, -20.6888, -20.6739, -20.6815, -20.7051,
-20.7229, -20.7208, -20.7058, -20.6896, -20.6797, -20.6749, -20.6709, -20.6748,
-20.7089, -20.8031, -20.9647, -21.1482, -21.2932, -21.3767, -21.4129, -21.4291,
-21.4538, -21.5055, -21.5740, -21.6300, -21.6615, -21.6766, -21.6886, -21.7073,
-21.7304, -21.7491, -21.7607, -21.7701, -21.7877, -21.8243, -21.8875, -21.9738,
-22.0671, -22.1537, -22.2265, -22.2821, -22.3213, -22.3462, -22.3587, -22.3622,
-22.3590, -22.3512, -22.3420, -22.3342, -22.3312, -22.3346, -22.3445, -22.3595,
-22.3780, -22.4007, -22.4289, -22.4625, -22.4995, -22.5353, -22.5659, -22.5895,
-22.6059, -22.6161, -22.6208, -22.6213, -22.6184, -22.6126, -22.6045, -22.5945,
-22.5831, -22.5707, -22.5573, -22.5434, -22.5287, -22.5140, -22.4992, -22.4844,
-22.4695, -22.4543, -22.4392, -22.4237, -22.4087, -22.3928])
logt = xp.log10(T)
cool_rate = xp.zeros(T.shape)
loc_c = logt <= 4.2
loc_h = logt > 8.15
loc_w = xp.logical_not(xp.logical_or(loc_c,loc_h))
# for temperatures less than 10^4 K, use Koyama & Inutsuka
temp = T[loc_c]
cool_rate[loc_c] = 2.0e-19*xp.exp(-1.184e5/(temp + 1.0e3)) + 2.8e-28*xp.sqrt(temp)*xp.exp(-92.0/temp)
#for temperatures above 10^8.15 use CGOLS fit
cool_rate[loc_h] = 10.0**(0.45*logt[loc_h] - 26.065)
# in between values of 4.2 < log(T) < 8.15
# linear interpolation of tabulated SPEX cooling rate
cool_rate[loc_w] = 10.0**xp.interp(logt[loc_w],log_t_tab,log_c_tab)
return cool_rate
# profile solver
[docs]
def RK4(func,x,y,h):
k1=func(x,y)
x+=0.5*h
k2=func(x,y+0.5*k1*h)
k3=func(x,y+0.5*k2*h)
x+=0.5*h
k4=func(x,y+k3*h)
y+=1/6*(k1+2*k2+2*k3+k4)*h
return y
def NFWMass(r,ms,rs):
return ms*(np.log(1+r/rs)-r/(rs+r))
def NFWDens(r,ms,rs):
return ms/(4*np.pi*rs**3)/(r/rs*(1+r/rs)**2)
##########################################################################################
## cooling time as a function of temperature and density
##########################################################################################
def rho_T_t_cool(cooling_rho=np.logspace(-4,4,400),cooling_temp=np.logspace(0,8,400)):
cooling_rho,cooling_temp=np.meshgrid(cooling_rho,cooling_temp)
cooling_tcool=k_boltzmann_cgs*cooling_temp/cooling_rho/CoolFnShure(cooling_temp)/(gamma-1)/myr_cgs
return cooling_rho,cooling_temp,cooling_tcool
##########################################################################################
## Some useful functions
##########################################################################################
[docs]
def ave(a,n):
end = -(a.size%n) if(-(a.size%n)) else None
return np.average((a.ravel()[:end]).reshape(-1,n),axis=1)
[docs]
def smooth(x,s=3,mode='nearest',**kwargs):
from scipy.ndimage import gaussian_filter
return gaussian_filter(x,s,mode=mode,**kwargs)
# TODO(@mhguo): this is just a collection of functions, need to be organized
def pro_from_hist2d(hist,x='r',y='dens'):
x,yrange=hist['centers'][x],hist['centers'][y]
ym=[np.average(yrange,weights=hist['dat'][i,:]) for i in range(hist['dat'].shape[0])]
return x,ym
# quantile with weights
[docs]
def quantile(x,q,weights=None,axis=None):
if (weights is None):
return np.quantile(x,q)
else:
if (axis is None):
return np.interp(q,np.cumsum(weights)/np.sum(weights),x)
else:
return np.apply_along_axis(lambda a: np.interp(q,np.cumsum(a)/np.sum(a),x),axis,weights)
##########################################################################################
## Scipy Measurements Label with boundary correction
##########################################################################################
import scipy.ndimage as sn
default_struct = sn.generate_binary_structure(3,3)
def clean_tuples(tuples):
return sorted(set([(min(pair),max(pair)) for pair in tuples]))
def merge_tuples_unionfind(tuples):
# use classic algorithms union find with path compression
# https://enp.wikipedia.org/wiki/Disjoint-set_data_structure
parent_dict = {}
def subfind(x):
# update roots while visiting parents
if parent_dict[x] != x:
parent_dict[x] = subfind(parent_dict[x])
return parent_dict[x]
def find(x):
if x not in parent_dict:
# x forms new set and becomes a root
parent_dict[x] = x
return x
if parent_dict[x] != x:
# follow chain of parents of parents to find root
parent_dict[x] = subfind(parent_dict[x])
return parent_dict[x]
# each tuple represents a connection between two items
# so merge them by setting root to be the lower root.
for p0,p1 in list(tuples):
r0 = find(p0)
r1 = find(p1)
if r0 < r1:
parent_dict[r1] = r0
elif r1 < r0:
parent_dict[r0] = r1
# for unique parents, subfind the root, replace occurrences with root
vs = set(parent_dict.values())
for parent in vs:
sp = subfind(parent)
if sp != parent:
for key in parent_dict:
if parent_dict[key] == parent:
parent_dict[key] = sp
return parent_dict
def make_dict(mask,struct,boundary,bargs):
label,things = sn.label(mask,structure=struct)
cs = clean_tuples(boundary(label,bargs))
slc = sn.labeled_comprehension(mask,label,range(1,things+1),
lambda a,b: b,
list,
None,
pass_positions=True)
outdict = dict(zip(range(1,things+1),slc))
ownerof = merge_tuples_unionfind(cs)
for key in ownerof:
if key != ownerof[key]:
# add key to its owner and remove key
outdict[ownerof[key]] = np.append(outdict[ownerof[key]],outdict[key])
outdict.pop(key)
return outdict,ownerof
def shear_periodic(label,axis,cell_shear,shear_axis):
# just return the tuple of the one axis
# 1. get faces
dim = label.ndim
size = label.shape[axis]
select1 = [slice(None)]*dim
select2 = [slice(None)]*dim
select1[axis] = 0
select2[axis] = size-1
lf1 = label[tuple(select1)]
lf2 = label[tuple(select2)]
# 2. now cell shear
axes = list(range(dim))
axes.remove(axis)
aisa = axes.index(shear_axis)
lf2 = np.roll(lf2,cell_shear,axis=aisa)
return connect_faces(lf1,lf2)
def periodic(label,axis):
dim = label.ndim
size = label.shape[axis]
select1 = [slice(None)]*dim
select2 = [slice(None)]*dim
select1[axis] = 0
select2[axis] = size-1
lf1 = label[tuple(select1)]
lf2 = label[tuple(select2)]
return connect_faces(lf1,lf2)
def tigress_shear(label,cell_shear):
# open in Z
# periodic in Y, so axis = 1
connectset = set()
connectset = connectset.union(periodic(label,1))
# shear periodic in X, so axis = 0, shear_axis = 1
connectset = connectset.union(shear_periodic(label,0,cell_shear,1))
return connectset
def tigress(label,cell_shear):
# open in Z
# periodic in Y, so axis = 1
connectset = set()
connectset = connectset.union(periodic(label,0))
connectset = connectset.union(periodic(label,1))
connectset = connectset.union(periodic(label,2))
return connectset
def tigress_nob(label,cell_shear):
# open in Z
# periodic in Y, so axis = 1
connectset = set()
return connectset
def connect_faces_simple(lf1,lf2):
# lf1 and lf2 are label faces
select = lf1*lf2 > 0
stack = np.zeros(list(lf1.shape)+[2])
stack[:,:,0] = lf1
stack[:,:,1] = lf2
pairs = stack[select]
return set([tuple(pair) for pair in pairs])
def connect_faces_rank(lf1,lf2):
stack = np.zeros([2]+list(lf1.shape))
stack[0] = lf1
stack[1] = lf2
label,things = sn.label(stack > 0,structure=default_struct)
if things == 0:
return set()
slc = sn.labeled_comprehension(stack,label,range(1,things+1),
lambda a: list(set(a)),
list,
None,
pass_positions=False)
tuples = []
for region in slc:
if len(region) == 0:
continue
owner = np.min(region)
for cell in region:
if cell != owner:
tuples += [(owner,cell)]
return set(tuples)
connect_faces = connect_faces_rank
##########################################################################################
## plots
##########################################################################################
def mgcolors(name='default'):
mgcolors=['#3369E8','#009925','#FBBC05','#EA4335',]
if (name == 'default'):
return mgcolors
else:
return mgcolors
[docs]
def colors(n,cmap='nipy_spectral',x1=0.0,x2=0.88,beta=0.99):
colors = ['k','darkviolet', 'b', 'royalblue', 'c', 'g', 'springgreen', 'gold', 'y',
'salmon','pink', 'r','darkred', 'm', 'violet',]
clmap = plt.cm.nipy_spectral # define the colormap
clmap = plt.get_cmap(cmap)
#cmap = plt.cm.terrain # define the colormap
#cmap = plt.cm.viridis # define the colormap
colors=[tuple(np.array(clmap(x1+i/max(1,n-1)*(x2-x1)))*beta) for i in range(n)]
return colors
[docs]
def subplots(nrows=2,ncols=2,figsize=(7.2,5.0),dpi=120,sharex=False,squeeze=False,\
constrained_layout=False,top=0.94, bottom=0.1,left=0.125, right=0.9,
wspace=0.02, hspace=0.0,raw=False,**kwargs):
fig, axes = plt.subplots(nrows,ncols,figsize=figsize,dpi=dpi,sharex=sharex,\
constrained_layout=constrained_layout,squeeze=squeeze,**kwargs)
fig.subplots_adjust(top=top,bottom=bottom, left=left, right=right, wspace=wspace, hspace=hspace)
if (raw): return fig,axes
#fig.subplots_adjust(left=0.125, bottom=0.1, right=0.9, top=0.9, wspace=0.2, hspace=0.2)
if(ncols>1):
for ax in axes[:,-1]:
ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
for ax in axes.flat:
ax.grid(linestyle='--')
ax.tick_params(bottom=True,top=True,left=True,right=True,which='both',direction="in")
return fig,axes
[docs]
def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
if (type(cmap) is str):
cmap = plt.get_cmap(cmap)
new_cmap = clr.LinearSegmentedColormap.from_list(
f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
cmap(np.linspace(minval, maxval, n)))
return new_cmap
# generate a rgb image showing different species
def get_rgb(x,c=[[1,0,0],[0.98,0.6,0.02],[0.0,0.6,0.1],[0.05,0.1,1.0]]):
u,v,w=0.,0.,0.
for i in range(len(x)):
u=u+c[i][0]*x[i]
v=v+c[i][1]*x[i]
w=w+c[i][2]*x[i]
return np.array([u,v,w]).transpose([1,2,0])