Sunday, May 22, 2022

Parallel version of fitMk -- with a catch

Recently, a phytools user sent me the following message:

Hope you are doing well. Question for you - I'm fitting an Mk model using the fitMk function and then running make.simmap in parallel following your helpful code. Is there a way to do something similar to speedup the fitMk function? I do not see one but I wanted to make sure. Fitting an ARD model (e.g. fit.ARD<-fitMk(phy,x,model="ARD")) with a tree of around 130 species and a number of discrete states still takes a while, even on my 10 core laptop.

First of all – 10 core laptop?! I want it.

Second, in response I did write a parallel version of fitMk using the CRAN pacakge optimParallel. optimParallel runs a parallelized version of stats::optim, for the box constraints ("L-BFGS-B") method.

Unfortunately, my function it does not consistently run faster than fitMk, when the fitMk default optimization method ("nlminb") is used. To be honest, I don't know why, because I've verified that it is using all the cores on my machine.

Here's the function:

fitMk.parallel<-function(tree,x,model="SYM",ncores=1,...){
    ## compute states
    ss<-sort(unique(x))
    m<-length(ss)
    ## set pi
    if(hasArg(pi)) pi<-list(...)$pi
    else pi<-"equal"
    if(is.numeric(pi)) root.prior<-"given"
    if(pi[1]=="equal"){ 
        pi<-setNames(rep(1/m,m),ss)
        root.prior<-"flat"
    } else if(pi[1]=="fitzjohn") root.prior<-"nuisance"
    if(is.numeric(pi)){ 
        pi<-pi/sum(pi)
        if(is.null(names(pi))) pi<-setNames(pi,ss)
        pi<-pi[ss]
    } 
    ## create object of class "fitMk"
    unfitted<-fitMk(tree,x,model=model,pi=pi,opt.method="none")
    ## get initial values for optimization
    if(hasArg(q.init)) q.init<-list(...)$q.init
    else q.init<-rep(m/sum(tree$edge.length),
        max(unfitted$index.matrix,na.rm=TRUE))
    ## create likelihood function
    loglik<-function(par,lik,index.matrix){
        Q<-phytools:::makeQ(nrow(index.matrix),exp(par),index.matrix)
        -lik(Q)
    }
    ## create cluster
    cl<-makeCluster(ncores)
    setDefaultCluster(cl=cl)
    ## optimize model
    fit.parallel<-optimParallel(
        log(q.init),
        loglik,lik=unfitted$lik,
        index.matrix=unfitted$index.matrix,
        lower=log(1e-12),
        upper=log(max(nodeHeights(tree))*100)
    )
    ## stop cluster
    stopCluster(cl=cl)
    ## get Q matrix
    estQ<-phytools:::makeQ(nrow(unfitted$index.matrix),
        exp(fit.parallel$par),
        unfitted$index.matrix)
    colnames(estQ)<-rownames(estQ)<-ss
    ## get object
    object<-fitMk(tree,x,fixedQ=estQ,pi=pi)
    object$method<-"optimParallel"
    object
}

Let's load packages & pull down a dataset to try with our function. As in a prior post I'll use data from Ramm et al. (2019) slated to feature in my upcoming book with Luke Harmon.

## load packages
library(optimParallel)
library(phytools)
library(geiger)
## read tree and data from file
lizard.tree<-read.nexus(file=
    "http://www.phytools.org/Rbook/7/lizard_tree.nex")
lizard.tree
## 
## Phylogenetic tree with 4162 tips and 4161 internal nodes.
## 
## Tip labels:
##   Sphenodon_punctatus, Dibamus_bourreti, Dibamus_greeri, Dibamus_montanus, Anelytropsis_papillosus, Dibamus_tiomanensis, ...
## 
## Rooted; includes branch lengths.
## read data from file
lizard.data<-read.csv(file=
    "http://www.phytools.org/Rbook/7/lizard_spines.csv",
    stringsAsFactors=TRUE,row.names=1)
head(lizard.data)
##                                   habitat tail.spines
## Ablepharus_budaki          non-saxicolous   non-spiny
## Abronia_anzuetoi           non-saxicolous   non-spiny
## Abronia_frosti             non-saxicolous   non-spiny
## Acanthodactylus_aureus     non-saxicolous   non-spiny
## Acanthodactylus_opheodurus non-saxicolous   non-spiny
## Acanthodactylus_pardalis   non-saxicolous   non-spiny
## check names
chk<-name.check(lizard.tree,lizard.data)
summary(chk)
## 3504 taxa are present in the tree but not the data:
##     Ablepharus_chernovi,
##     Ablepharus_kitaibelii,
##     Ablepharus_pannonicus,
##     Abronia_aurita,
##     Abronia_campbelli,
##     Abronia_chiszari,
##     ....
## 
## To see complete list of mis-matched taxa, print object.
## prune tree to include only taxa in data
lizard.tree<-drop.tip(lizard.tree,chk$tree_not_data)

This dataset consists of trait data for two different binary characters.

For fun, let's use this (admittedly hacky) solution to plot the two characters at the tips of a fan-style tree.

plotTree(lizard.tree,type="fan",lwd=1,ftype="off")
obj<-get("last_plot.phylo",envir=.PlotPhyloEnv)
n<-Ntip(lizard.tree)
col.hab<-setNames(c("darkgreen","grey"),
    levels(lizard.data$habitat))
col.morph<-setNames(palette()[c(4,2)],
    levels(lizard.data$tail.spines))
par(lend=3)
for(i in 1:Ntip(lizard.tree)){
    cc<-if(obj$xx[i]>0) 5 else -5
    th<-atan(obj$yy[i]/obj$xx[i])
    segments(obj$xx[i],obj$yy[i],
        obj$xx[i]+cc*cos(th),
        obj$yy[i]+cc*sin(th),
        lwd=4,
        col=col.hab[lizard.data[lizard.tree$tip.label[i],
        "habitat"]])
    segments(obj$xx[i]+cc*cos(th),
        obj$yy[i]+cc*sin(th),
        obj$xx[i]+2*cc*cos(th),
        obj$yy[i]+2*cc*sin(th),lwd=4,
        col=col.morph[lizard.data[lizard.tree$tip.label[i],
        "tail.spines"]])
}
legend("topleft",c(levels(lizard.data$habitat),
    levels(lizard.data$tail.spines)),
    pch=15,col=c(col.hab,col.morph),
    pt.cex=1.5,cex=0.8,bty="n")

plot of chunk unnamed-chunk-3

We'll just analyze one trait here. Let's pull out habitat as a new vector.

## extract discrete character
habitat<-setNames(lizard.data$habitat,
    rownames(lizard.data))
head(habitat)
##          Ablepharus_budaki           Abronia_anzuetoi 
##             non-saxicolous             non-saxicolous 
##             Abronia_frosti     Acanthodactylus_aureus 
##             non-saxicolous             non-saxicolous 
## Acanthodactylus_opheodurus   Acanthodactylus_pardalis 
##             non-saxicolous             non-saxicolous 
## Levels: non-saxicolous saxicolous

Ready.

Now, let's run our function! We can use base::system.time to keep track of how long it takes.

system.time(
    fit.optimParallel<-fitMk.parallel(lizard.tree,habitat,
        model="ARD",pi="fitzjohn",ncores=8)
)
##    user  system elapsed 
##    0.90    0.06    8.99
fit.optimParallel
## Object of class "fitMk".
## 
## Fitted (or set) value of Q:
##                non-saxicolous saxicolous
## non-saxicolous      -0.001552   0.001552
## saxicolous           0.022388  -0.022388
## 
## Fitted (or set) value of pi:
## non-saxicolous     saxicolous 
##        0.03763        0.96237 
## due to treating the root prior as (a) nuisance.
## 
## Log-likelihood: -216.724702 
## 
## Optimization method used was "optimParallel"

Awesome. It works!

Only there's a catch….

system.time(
    fit.nlminb<-fitMk(lizard.tree,habitat,model="ARD",
        pi="fitzjohn")
)
##    user  system elapsed 
##   10.67    0.11   10.81
fit.nlminb
## Object of class "fitMk".
## 
## Fitted (or set) value of Q:
##                non-saxicolous saxicolous
## non-saxicolous      -0.001552   0.001552
## saxicolous           0.022388  -0.022388
## 
## Fitted (or set) value of pi:
## non-saxicolous     saxicolous 
##        0.03763        0.96237 
## due to treating the root prior as (a) nuisance.
## 
## Log-likelihood: -216.724702 
## 
## Optimization method used was "nlminb"

Serial fitMk takes about the same amount of time – and I could tell than fitMk.parallel was using all my CPU (because I checked system monitoring & I could tell that even a text editor became laggy).

Obviously, this is not a satisfactory solution. Maybe someone who reads this blog can tell me how to make it better…?

Before you respond that this can be sped up by using diversitree, let me say that I already thought of that & will make it the subject of a subsequent post.

No comments:

Post a Comment

Note: due to the very large amount of spam, all comments are now automatically submitted for moderation.