In [1]:
import glob
import os
import sqlite3
from collections import defaultdict
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import matplotlib
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FuncFormatter
import matplotlib.colors as colors
import colgen.display
In [2]:
from utils import DATASETS, format_tree_labels, rack_layout, TREATMENTS
export_path = "03_survival/"
if not os.path.exists(export_path):
os.makedirs(export_path)
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16
plt.rc('font', size=SMALL_SIZE) # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE) # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title
In [3]:
# Read the data from colgen's output and add it to the database
f_all = []
cv_all = []
for file in glob.glob(export_path+"/*model*csv"):
f = pd.read_csv(file).loc[:, ["name","survival_change", "survival_estimate"]]
fit = os.path.basename(file).replace(".csv","")
f['fit'] = fit
cv = pd.read_csv(file.replace(".csv",".convergence"),index_col=0).tail(1)
cv = {'likelihood':cv['likelihood'].iloc[-1], 'sigma':cv['sigma'].iloc[-1], 'fit':fit}
cv.update({k:v for k,v in zip(['size','experiment','model','grid','phase','fitname'],fit.replace('LCS-','LCSm').split('-'))})
cv['experiment']=cv['size']+'-'+cv['experiment'].replace('m','-')
cv['model']=cv['model'].replace('model','')
cv['grid']=int(cv['grid'].replace('grid',''))
cv['phase']=int(cv['phase'].replace('p',''))
del cv['size']
f_all.append(f)
cv_all.append(cv)
with sqlite3.connect("lce_data.sqlite") as database:
pd.concat(f_all).to_sql("survival_fit", database, index=False, if_exists="replace")
pd.DataFrame(cv_all).to_sql("survival_convergence", database, index=False, if_exists="replace")
database.execute('CREATE VIEW IF NOT EXISTS survival AS SELECT * FROM survival_fit'
' JOIN survival_convergence on survival_fit.fit = survival_convergence.fit')
In [4]:
def make_gradient(data, saturation=1):
mi = data.survival_change.min()
mx = data.survival_change.max()
if mx == 0:
mx = 0.001
if mi == 0:
mi = -0.001
cmap = matplotlib.colormaps['coolwarm_r']
norm = matplotlib.colors.TwoSlopeNorm(0, mi, mx)
colorize = lambda x: plt.cm.coolwarm_r(norm(x))
return colorize, norm, cmap
def load_data(ex, model, grid, sigma):
files = list(glob.glob(os.path.join(export_path,f'{ex}-model{model}-grid{grid}-*-{sigma}.csv')))
if len(files)<2:
print(f"not enough files for {ex}-model{model}-grid{grid}-*-{sigma}.csv")
return
data = pd.concat([pd.read_csv(x) for x in files])
data = data[np.logical_not(pd.isna(data.parent))]
data.fillna(dict(survival_change=0),inplace=True)
data.set_index('name',inplace=True)
return data
In [5]:
def get_top(df, top=3,bot=1):
sorted_diff = df.reset_index().set_index('name').sort_values("survival_change", ascending=False)
top_diff = sorted_diff[sorted_diff.survival_change>0]['survival_change'].head(top)
bot_diff = sorted_diff[sorted_diff.survival_change<0]['survival_change'].tail(bot)
return pd.concat([top_diff, bot_diff])
def plot_improvment_on_tree(df, tree, colorize, ax=None, labels=None):
if ax is None:
fig,ax = plt.subplots(1,2, figsize=(10,5))
coal = colgen.display.coalescent(list(df.query('time==20 & extinct==0').index),
tree['branches'])
ax, scales = colgen.display.draw_tree(tree['branches'],
tree['xinfo'],
rack_layout(df.reset_index()),
oinfo={name:1 if name in coal else 0.3 for name,d in df.iterrows()},
color={name:colorize(d.survival_change) for name,d in df.iterrows()},
child_color_branch=True,
ax = ax)
format_tree_labels(tree['xinfo'].keys(), ax, scales)
if labels is not None:
for k in labels:
tube = k.split('_')[-1]
ax.text(scales[0](k), scales[1](k)+5,
f"{tube}", color=colorize(df.loc[k].survival_change),
horizontalalignment='center',
verticalalignment='bottom',
font={'weight':'bold'})
return ax
def plot_one_tree(ax, ex, model, sigma, grid, colorize=None, norm=None, cmap=None, cbar=False, ):
data = load_data(ex,model,grid,sigma)
colorize, norm, cmap = make_gradient(data)
if data is None:
return
top = get_top(data,7,1)
tree, dd = colgen.display.load_df(f"01_genealogies/{ex}.csv")
plot_improvment_on_tree(data, tree, colorize,
ax=ax, labels=top.index)
ax.set_title(ex, font={'weight':'bold'})
if cbar:
divider = make_axes_locatable(ax)
cbaxes = divider.append_axes('right', size='2%', pad=0.05)
fmt = lambda x, pos: '{:.1%}'.format(x)
cbar = matplotlib.colorbar.ColorbarBase(cbaxes, cmap=cmap, norm=norm, orientation='vertical',
label="Survival Change")
cbar.ax.set_yscale('linear')
cbar.ax.yaxis.set_major_formatter(FuncFormatter(fmt))
cbar.minorticks_on()
def plot_all(model, sigma, grid):
ax = plt.figure(layout="constrained",
figsize=(2*12*1.5,2*9*1.5)).subplot_mosaic("""AB
CD""", gridspec_kw=dict(wspace=0.05))
fig = ax['A'].get_figure()
fig.suptitle(f"Colgen Surival Procedure Output (Model:{model}, Grid {grid})")
for ex,a in [("L-LCS-",'D'),("S-LCS-",'B'),("L-LCS+",'C'),("S-LCS+",'A')]:
plot_one_tree(ax[a], ex=ex, model=model, sigma=sigma, grid=grid, cbar=True)
return fig,ax
In [6]:
MODEL='beta'
SIGMA="1500"
GRID=100
fig, ax = plot_all(model=MODEL, sigma=SIGMA, grid=GRID)
fig.savefig(os.path.join(export_path,'all_improvement.pdf'), bbox_inches='tight', dpi=200)
In [7]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "L-LCS-", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'llcsm_improvement.pdf'), bbox_inches='tight')
In [8]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "S-LCS-", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'slcsm_improvement.pdf'), bbox_inches='tight')
In [9]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "L-LCS+", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'llcsp_improvement.pdf'), bbox_inches='tight')
In [10]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "S-LCS+", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'slcsp_improvement.pdf'), bbox_inches='tight')
In [11]:
fig, ax = plot_all(model='jump', sigma='01', grid=100)
fig.savefig(os.path.join(export_path,'all_improvement_jump.pdf'), bbox_inches='tight', dpi=200)
In [ ]:
In [ ]: