Monday, October 17, 2022

Parallelized optimization now an option in fitMk, plus the advantage of optimizing on logarithmic scale

I recently developed a parallelized version of the phytools function fitMk for fitting the extended Mk model for the evolution of a discrete character on the phylogeny.

I originally build this as a separate function because I wasn't quite sure how it would work; however, yesterday I decided to add opt.method="optimParallel" as a choice of optimization method within phytools::fitMk, while also leaving fitMk.parallel in the phytools namespace.

In addition to parallelization, one thing that fitMk.parallel also did by default was optimize on a log rather than a linear scale. I found that this improved convergence quite a bit; however, the option to optimize on a log-scale was already an option in fitMk (via the argument logscale=TRUE).

Just to see how these different implementations work, here is a quick demo of fitMk with logscale=FALSE, fitMk with logscale=TRUE, fitMk.parallel (by running fitMk and setting opt.method="optimParallel"), and geiger::fitDiscrete. For the example, I will use the subsampled tree and dataset for digit number evolution in squamate reptiles that Luke Harmon & I feature in our Princeton University Press book, but that originally derives from Brandley et al. (2008).

library(phytools)
packageVersion("phytools")
## [1] '1.2.6'
library(geiger)
## read tree and data from file
sqTree<-read.nexus(file="http://www.phytools.org/Rbook/6/squamate.tre")
sqData<-read.csv(file="http://www.phytools.org/Rbook/6/squamate-data.csv",
    row.names=1)
## pull out trait vector
toes<-as.factor(setNames(sqData[[1]],rownames(sqData)))
## subsample tree and data to match
chk<-name.check(sqTree,sqData)
summary(chk)
## 139 taxa are present in the tree but not the data:
##     Abronia_graminea,
##     Acontias_litoralis,
##     Acontophiops_lineatus,
##     Acrochordus_granulatus,
##     Agamodon_anguliceps,
##     Agkistrodon_contortrix,
##     ....
## 1 taxon is present in the data but not the tree:
##     Trachyboa_boulengeri
## 
## To see complete list of mis-matched taxa, print object.
sqTree<-drop.tip(sqTree,chk$tree_not_data)
toes<-toes[sqTree$tip.label]

Now let's fit our model four times, using each of the aforementioned approaches. I'm going to use the base R function system.time so we can see the timings for each of our optimizations.

## fitMk, default settings
system.time(
    fitARD.1<-fitMk(sqTree,toes,model="ARD",pi="fitzjohn")
)
##    user  system elapsed 
##  196.18    2.59  199.50
fitARD.1
## Object of class "fitMk".
## 
## Fitted (or set) value of Q:
##          0         1         2         3         4         5
## 0 0.000000  0.000000  0.000000  0.000000  0.000000  0.000000
## 1 0.013927 -0.013927  0.000000  0.000000  0.000000  0.000000
## 2 0.009505  0.000000 -0.011555  0.002050  0.000000  0.000000
## 3 0.008962  0.000000  0.008875 -0.017837  0.000000  0.000000
## 4 0.000234  0.002032  0.000000  0.011390 -0.033613  0.019957
## 5 0.000000  0.001998  0.000462  0.000000  0.001811 -0.004271
## 
## Fitted (or set) value of pi:
##        0        1        2        3        4        5 
## 0.000000 0.000000 0.000000 0.000000 0.260594 0.739406 
## due to treating the root prior as (a) nuisance.
## 
## Log-likelihood: -110.823838 
## 
## Optimization method used was "nlminb"
## fitMk, logscale=TRUE
system.time(
    fitARD.2<-fitMk(sqTree,toes,model="ARD",pi="fitzjohn",
        logscale=TRUE)
)
##    user  system elapsed 
##  134.90    2.19  137.47
fitARD.2
## Object of class "fitMk".
## 
## Fitted (or set) value of Q:
##          0         1         2         3         4        5
## 0 0.000000  0.000000  0.000000  0.000000  0.000000  0.00000
## 1 0.013157 -0.013157  0.000000  0.000000  0.000000  0.00000
## 2 0.012577  0.000000 -0.023516  0.010939  0.000000  0.00000
## 3 0.017438  0.000000  0.023221 -0.040659  0.000000  0.00000
## 4 0.000000  0.003780  0.000000  0.012048 -0.049627  0.03380
## 5 0.000000  0.001648  0.000237  0.000000  0.002795 -0.00468
## 
## Fitted (or set) value of pi:
##        0        1        2        3        4        5 
## 0.000000 0.000000 0.000000 0.000000 0.328326 0.671674 
## due to treating the root prior as (a) nuisance.
## 
## Log-likelihood: -109.417016 
## 
## Optimization method used was "nlminb"
## fitMk.parallel, 8 cores, default settings
system.time(
    fitARD.3<-fitMk(sqTree,toes,model="ARD",pi="fitzjohn",
        opt.method="optimParallel")
)
##    user  system elapsed 
##    0.81    0.42  128.48
fitARD.3
## Object of class "fitMk".
## 
## Fitted (or set) value of Q:
##           0         1         2         3         4         5
## 0 -0.000003  0.000001  0.000001  0.000000  0.000000  0.000000
## 1  0.013129 -0.013131  0.000000  0.000001  0.000000  0.000000
## 2  0.012523  0.000003 -0.022217  0.009686  0.000002  0.000002
## 3  0.017446  0.000002  0.021999 -0.039452  0.000003  0.000003
## 4  0.000002  0.003536  0.000005  0.012066 -0.049368  0.033760
## 5  0.000000  0.001663  0.000237  0.000000  0.002770 -0.004670
## 
## Fitted (or set) value of pi:
##        0        1        2        3        4        5 
## 0.000000 0.000000 0.000000 0.000000 0.330107 0.669893 
## due to treating the root prior as (a) nuisance.
## 
## Log-likelihood: -109.422074 
## 
## Optimization method used was "optimParallel"
## geiger::fitDiscrete, default settings
system.time(
    fitARD.4<-fitDiscrete(sqTree,toes,model="ARD")
)
##    user  system elapsed 
##  249.51    3.91  259.35
fitARD.4
## GEIGER-fitted comparative model of discrete data
##  fitted Q matrix:
##                   0             1             2             3             4
##     0 -5.068207e-08  6.154024e-10  3.486143e-09  2.000740e-08  2.010647e-08
##     1  1.315725e-02 -1.315726e-02  4.296115e-09  2.595663e-09  2.223224e-09
##     2  1.262284e-02  2.611535e-11 -2.363226e-02  1.100942e-02  5.206479e-09
##     3  1.740235e-02  8.849326e-09  2.331708e-02 -4.071955e-02  9.433599e-08
##     4  1.981897e-11  3.779308e-03  4.994664e-07  1.204940e-02 -4.963837e-02
##     5  3.987151e-11  1.648293e-03  2.364779e-04  2.778118e-15  2.796235e-03
##                   5
##     0  6.466654e-09
##     1  1.770830e-10
##     2  4.298561e-10
##     3  8.131348e-09
##     4  3.380917e-02
##     5 -4.681005e-03
## 
##  model summary:
##  log-likelihood = -109.417070
##  AIC = 278.834139
##  AICc = 299.970503
##  free parameters = 30
## 
## Convergence diagnostics:
##  optimization iterations = 100
##  failed iterations = 0
##  number of iterations with same best fit = 1
##  frequency of best fit = 0.01
## 
##  object summary:
##  'lik' -- likelihood function
##  'bnd' -- bounds for likelihood search
##  'res' -- optimization iteration summary
##  'opt' -- maximum likelihood parameter estimates

From this we should see that we've obtained quite similar solutions for the latter three optimizations, but a solution that is quite poor using fitMk and its default settings.

We also might have found that parallelization has significantly sped-up the optimization (although this will depend on our specific computer hardware).

Lastly, we probably noticed that only one of 100 optimization converged to the same best solution using fitDiscrete. This is concerning, without doubt, but our fears should be slightly allayed by the fact that fitMk with logscale=TRUE and fitMk.parallel also found more or less the same best value of Q.

Now let's visualization our four different results.

par(mfrow=c(2,2),mar=c(1.1,1.1,3.1,1.1))
plot(fitARD.1,width=TRUE,color=TRUE,tol=1e-3,show.zeros=FALSE,
    offset=0.03)
title(main="(a) fitMk, default settings",cex.main=1.2,line=0)
title(main=paste("log(L) = ",round(logLik(fitARD.1),2)),
    cex.main=0.8,line=-1)
plot(fitARD.2,width=TRUE,color=TRUE,tol=1e-3,show.zeros=FALSE,
    offset=0.03)
title(main="(b) fitMk, logscale=TRUE",cex.main=1.2,line=0)
title(main=paste("log(L) = ",round(logLik(fitARD.2),2)),
    cex.main=0.8,line=-1)
plot(fitARD.3,width=TRUE,color=TRUE,tol=1e-3,show.zeros=FALSE,
    offset=0.03)
title(main="(c) fitMk, opt.method=\"optimParallel\"",cex.main=1.2,
    line=0)
title(main=paste("log(L) = ",round(logLik(fitARD.3),2)),
    cex.main=0.8,line=-1)
plot(fitARD.4,width=TRUE,color=TRUE,tol=1e-3,show.zeros=FALSE,
    offset=0.03)
title(main="(d) fitDiscrete, default settings",cex.main=1.2,line=0)
title(main=paste("log(L) = ",round(logLik(fitARD.4),2)),
    cex.main=0.8,line=-1)

plot of chunk unnamed-chunk-3

That's it!

No comments:

Post a Comment

Note: due to the very large amount of spam, all comments are now automatically submitted for moderation.