In [1]:
import glob 
import os 
import sqlite3
from collections import defaultdict

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np 
import matplotlib
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib.ticker import FuncFormatter
import matplotlib.colors as colors
import colgen.display
In [2]:
from utils import DATASETS, format_tree_labels, rack_layout, TREATMENTS

export_path = "03_survival/"
if not os.path.exists(export_path):
    os.makedirs(export_path)
    
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
In [3]:
# Read the data from colgen's output and add it to the database 
f_all = []
cv_all = []
for file in glob.glob(export_path+"/*model*csv"):
    f = pd.read_csv(file).loc[:, ["name","survival_change", "survival_estimate"]]
    fit = os.path.basename(file).replace(".csv","")
    f['fit'] = fit
    cv = pd.read_csv(file.replace(".csv",".convergence"),index_col=0).tail(1)
    cv = {'likelihood':cv['likelihood'].iloc[-1], 'sigma':cv['sigma'].iloc[-1], 'fit':fit}
    cv.update({k:v for k,v in zip(['size','experiment','model','grid','phase','fitname'],fit.replace('LCS-','LCSm').split('-'))})
    cv['experiment']=cv['size']+'-'+cv['experiment'].replace('m','-')
    cv['model']=cv['model'].replace('model','')
    cv['grid']=int(cv['grid'].replace('grid',''))
    cv['phase']=int(cv['phase'].replace('p',''))

    del cv['size']
    f_all.append(f)
    cv_all.append(cv)
    
with sqlite3.connect("lce_data.sqlite") as database:
    pd.concat(f_all).to_sql("survival_fit", database, index=False,  if_exists="replace")
    pd.DataFrame(cv_all).to_sql("survival_convergence", database, index=False,  if_exists="replace")
    database.execute('CREATE VIEW IF NOT EXISTS survival AS SELECT * FROM survival_fit'
                     ' JOIN survival_convergence on survival_fit.fit = survival_convergence.fit')
In [4]:
def make_gradient(data, saturation=1):
    mi = data.survival_change.min()
    mx = data.survival_change.max()
    if mx == 0:
        mx = 0.001
    if mi == 0:
        mi = -0.001 
    cmap = matplotlib.colormaps['coolwarm_r']
    norm =  matplotlib.colors.TwoSlopeNorm(0, mi, mx)
    colorize = lambda x: plt.cm.coolwarm_r(norm(x))
    return colorize, norm, cmap

def load_data(ex, model, grid, sigma):
    files = list(glob.glob(os.path.join(export_path,f'{ex}-model{model}-grid{grid}-*-{sigma}.csv')))
    if len(files)<2:
        print(f"not enough files for {ex}-model{model}-grid{grid}-*-{sigma}.csv")
        return
    data = pd.concat([pd.read_csv(x) for x in files])
    data = data[np.logical_not(pd.isna(data.parent))]
    data.fillna(dict(survival_change=0),inplace=True)
    data.set_index('name',inplace=True)
    return data
In [5]:
def get_top(df, top=3,bot=1):
    sorted_diff = df.reset_index().set_index('name').sort_values("survival_change", ascending=False)
    top_diff = sorted_diff[sorted_diff.survival_change>0]['survival_change'].head(top)
    bot_diff = sorted_diff[sorted_diff.survival_change<0]['survival_change'].tail(bot)
    return pd.concat([top_diff, bot_diff])

def plot_improvment_on_tree(df, tree, colorize, ax=None, labels=None):
    if ax is None:
        fig,ax = plt.subplots(1,2, figsize=(10,5))
    coal = colgen.display.coalescent(list(df.query('time==20 & extinct==0').index),
                                     tree['branches'])
    ax, scales = colgen.display.draw_tree(tree['branches'],
                                          tree['xinfo'],
                                          rack_layout(df.reset_index()), 
                                          oinfo={name:1 if name in coal else 0.3 for name,d in df.iterrows()},
                                          color={name:colorize(d.survival_change) for name,d in df.iterrows()},
                                          child_color_branch=True,
                                          ax = ax)
    format_tree_labels(tree['xinfo'].keys(), ax, scales)
    if labels is not None:
        for k in labels:
            tube = k.split('_')[-1] 
            ax.text(scales[0](k), scales[1](k)+5,
                    f"{tube}", color=colorize(df.loc[k].survival_change),
                    horizontalalignment='center',
                    verticalalignment='bottom',
                    font={'weight':'bold'})  
    return ax   

def plot_one_tree(ax, ex, model, sigma, grid, colorize=None, norm=None, cmap=None, cbar=False, ):
    data = load_data(ex,model,grid,sigma)
    colorize, norm, cmap = make_gradient(data)
    if data is None:
        return
    top = get_top(data,7,1)

    tree, dd = colgen.display.load_df(f"01_genealogies/{ex}.csv")

    plot_improvment_on_tree(data, tree, colorize,
                            ax=ax, labels=top.index)
    ax.set_title(ex, font={'weight':'bold'})
    
    if cbar:
        divider = make_axes_locatable(ax)
        cbaxes = divider.append_axes('right', size='2%', pad=0.05)
        fmt = lambda x, pos: '{:.1%}'.format(x)
        cbar = matplotlib.colorbar.ColorbarBase(cbaxes, cmap=cmap, norm=norm, orientation='vertical',
                                                label="Survival Change")
        cbar.ax.set_yscale('linear')
        cbar.ax.yaxis.set_major_formatter(FuncFormatter(fmt))
        cbar.minorticks_on()


def plot_all(model, sigma, grid):
    ax = plt.figure(layout="constrained",                     
                    figsize=(2*12*1.5,2*9*1.5)).subplot_mosaic("""AB
                             CD""", gridspec_kw=dict(wspace=0.05))
    fig = ax['A'].get_figure()
    fig.suptitle(f"Colgen Surival Procedure Output (Model:{model}, Grid {grid})")
    for ex,a in [("L-LCS-",'D'),("S-LCS-",'B'),("L-LCS+",'C'),("S-LCS+",'A')]:
        plot_one_tree(ax[a], ex=ex, model=model, sigma=sigma, grid=grid, cbar=True)
    return fig,ax
In [6]:
MODEL='beta'
SIGMA="1500"
GRID=100

fig, ax = plot_all(model=MODEL, sigma=SIGMA, grid=GRID)
fig.savefig(os.path.join(export_path,'all_improvement.pdf'), bbox_inches='tight', dpi=200)
No description has been provided for this image
In [7]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "L-LCS-", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'llcsm_improvement.pdf'), bbox_inches='tight')
No description has been provided for this image
In [8]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "S-LCS-", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'slcsm_improvement.pdf'), bbox_inches='tight')
No description has been provided for this image
In [9]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "L-LCS+", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'llcsp_improvement.pdf'), bbox_inches='tight')
No description has been provided for this image
In [10]:
fig,ax = plt.subplots(1,1, figsize=(12*1.5,9*1.5))
plot_one_tree(ax, "S-LCS+", MODEL,SIGMA,GRID, cbar=True)
ax.get_figure().savefig(os.path.join(export_path,'slcsp_improvement.pdf'), bbox_inches='tight')
No description has been provided for this image
In [11]:
fig, ax = plot_all(model='jump', sigma='01', grid=100)
fig.savefig(os.path.join(export_path,'all_improvement_jump.pdf'), bbox_inches='tight', dpi=200)
No description has been provided for this image
In [ ]:
 
In [ ]: