Thursday, May 16, 2024

Fitting a multi-state threshold model using the discrete approximation in phytools

Last week I posted some very basic code illustrating how we could use the discrete approximation of the likelihood function originally described by Boucher & Démery (2016) to fit a simple version of the threshold model. This will be part of my talk for the upcoming Evolution Meeting in Montreal, so I’ve been working on it a little more in the meantime.

As a quick reminder, the threshold model is a discrete character evolution model in which each character level is underlain by the value of an unobserved continuous character referred to as “liability.” Every time the evolving trait (liability) crosses some pre-determined (but unknown) threshold, the discrete character changes state! Here’s a quick reminder about what discrete character evolution under the threshold model looks like on a phylogenetic tree.

library(phytools)
layout(matrix(c(1,2),1,2),widths=c(0.55,0.45))
par(mar=c(5.1,4.1,1.1,1.1),las=1,cex.axis=0.8)
obj<-bmPlot(pbtree(b=0.03,n=80,t=100,type="discrete",quiet=TRUE),
  type="threshold",thresholds=c(0,1.5),anc=0.5,sig2=1/100,
  ngen=100,colors=hcl.colors(3),return.tree=TRUE,
  bty="n")
plot(obj$tree,colors=setNames(hcl.colors(3),c("a","b","c")),
  direction="upwards",ftype="off",mar=c(5.1,1.1,1.1,1.1))

plot of chunk unnamed-chunk-10

Well, now I have it pretty much working perfectly, and for a multi-state, ordered threshold trait with unknown threshold between adjacent character levels. (Under the threshold model, the specific scale of the liability is arbitrary, as is it’s central value, so we invariably must fix the rate of evolution of liability and one of the thresholds between states. In my implementation, I arbitrarily fixed \(\sigma^2 = 1.0\) and set the “lowest” threshold to \(1/(k-1)\) the range of the liability, which I fixed to be over 3 \(\times\) the anticipated disparity of x under Brownian evolution. Some of these decisions are a bit arbitrary.)

Some of the advantages of this compared to my previous MCMC implementation (described in Revell 2013) is that it is much faster, and, even more importantly for many users, our likelihood is the probability of the original discrete character data under the model (thus being directly comparable to an Mk model), as opposed to the posterior probability density of the liabilities on a continuous scale.

When working on this over the past couple of days, the most significant improvement I realized I could make was to appreciate that because we use a constant value of Q for the discrete approximation, I could just exponentiate it once for each edge of the tree, and then use these exponentiated matrices for every calculation of the log-likelihood as the position of the thresholds between character levels were being estimated.

Here is my full code, which includes print, logLik, and ancr generic methods. (This will be on GitHub soon.)

## load expm
library(expm)

## compute trace (used internally)
tr<-function(X) sum(diag(X))

## convert discrete character to liability matrix
thresh2bin<-function(liability,threshold,X){
  Y<-matrix(0,length(liability),nrow(X),
    dimnames=list(round(liability,3),rownames(X)))
  interval<-mean(liability[2:length(liability)]-
      liability[1:(length(liability)-1)])
  down<-liability-interval/2
  ## down[1]<--Inf
  up<-liability+interval/2
  ## up[length(up)]<-Inf
  xx<-rep(0,length(liability))
  for(i in 1:(length(threshold)-1)){
    xx[]<-0
    ind<-intersect(which(up>threshold[i]),
      which(down<threshold[i+1]))
    xx[ind]<-1
    ind<-setdiff(ind,
      intersect(which(down>threshold[i]),
        which(up<threshold[i+1])))
    if(length(ind)>0){
      for(j in 1:length(ind)){
        xx[ind[j]]<-if(down[ind[j]]>threshold[i]&&
            down[ind[j]]<threshold[i+1]) 
          (threshold[i+1]-down[ind[j]])/interval else
            abs(up[ind[j]]-threshold[i])/interval
      }
    }
    Y[,which(X[,i]==1)]<-xx
  }
  t(Y)
}

## fit threshold model using discrete approximation
fitThresh<-function(tree,x,sequence=NULL,...){
  if(hasArg(trace)) trace<-list(...)$trace
  else trace<-0
  if(hasArg(levs)) levs<-list(...)$levs
  else levs<-200
  if(hasArg(root)) root<-list(...)$root
  else root<-"fitzjohn"
  C<-vcv(tree)
  N<-Ntip(tree)
  Ed<-tr(C)/N-sum(C)/(N^2)
  lims<-qnorm(c(0.005,0.995),sd=sqrt(Ed))
  liability<-seq(lims[1],lims[2],length.out=levs)
  if(!is.factor(x)) x<-setNames(as.factor(x),names(x))
  if(is.null(sequence)) sequence<-levels(x)
  X<-to.matrix(x,sequence)
  q<-1/(2*(diff(lims)/levs)^2)
  model<-matrix(0,levs,levs,
    dimnames=list(round(liability,3),round(liability,3)))
  ind<-cbind(1:(nrow(model)-1),2:nrow(model))
  model[ind]<-1
  model[ind[,2:1]]<-1
  Q<-model*q
  diag(Q)<--rowSums(Q)
  pw<-if(!is.null(attr(tree,"order"))&&
      attr(tree,"order")=="postorder") tree else 
        reorder(tree,"postorder")
  P<-lapply(pw$edge.length,function(e,Q) expm(e*Q),
    Q=Q)
  if(length(sequence)==2){
    Y<-thresh2bin(liability,c(-Inf,0,Inf),X)
    lnL<-lik_thresh(c(),pw,fixed_threshold=0,
      liability,X,P,trace=trace,pi=root)
    attr(lnL,"df")<-1
    threshold<-0
  } else if(length(sequence)>2){
    fixed_threshold<-min(liability)+
      diff(range(liability))/length(sequence)
    if(length(sequence)==3){
      opt<-optimize(lik_thresh,
        c(fixed_threshold,max(liability)),
        pw=pw,fixed_threshold=fixed_threshold,
        liability=liability,x=X,P.all=P,trace=trace,
        pi=root,maximum=TRUE)
      lnL<-opt$objective
      attr(lnL,"df")<-2
      threshold<-c(fixed_threshold,opt$maximum)
    } else if(length(sequence)>3){
      init<-seq(fixed_threshold,max(liability),
        length.out=length(sequence))[
          -c(1,length(sequence))]
      opt<-optim(init,function(p) -lik_thresh(p,pw=pw,
        fixed_threshold=fixed_threshold,liability=liability,
        x=X,pi=root,P.all=P,trace=trace),method="L-BFGS",
        lower=rep(fixed_threshold,length(init)),
        upper=rep(max(liability),length(init)))
      lnL<--opt$value
      attr(lnL,"df")<-length(sequence)-1
      threshold<-c(fixed_threshold,sort(opt$par))
    }
  }
  Y<-thresh2bin(liability,c(-Inf,threshold,Inf),X)
  fit<-list(
    logLik=lnL,
    rates=q,
    index.matrix=model,
    states=colnames(Y),
    pi=lik_thresh(threshold,pw,c(),liability,X,P,
      trace=0,pi=root,return="pi"),
    opt_results=opt,
    data=Y,
    tree=tree)
  class(fit)<-"fitMk"
  object<-list(sigsq=1.0,bounds=lims,ncat=levs,
    liability=liability,threshold=threshold,
    logLik=lnL,tree=tree,data=X,mk_fit=fit)
  class(object)<-c("fitThresh","fitMk")
  object
}

## print for "fitThresh" object
print.fitThresh<-function(x, digits=6, ...){
  spacer<-if(length(x$threshold)>2) "\n        " else ""
  cat("Object of class \"fitThresh\".\n\n")
  cat("    Set value of sigsq (of the liability) = 1.0\n\n")
  cat(paste("    Set or estimated threshold(s) =",spacer,"[",
    paste(round(x$threshold,digits),collapse=", "),"]*\n\n"))
  cat(paste("    Log-likelihood:", round(x$logLik, digits),
    "\n\n"))
  cat("(*lowermost threshold is fixed)\n\n")
  
}

## logLik for "fitThresh" object
logLik.fitThresh<-function(object,...) object$logLik

## marginal ancestral states of "fitThresh" object
ancr.fitThresh<-function(object,...){
  anc_mk<-ancr(object$mk_fit,type="marginal")
  anc<-matrix(0,object$tree$Nnode,ncol(object$data),
    dimnames=list(1:object$tree$Nnode+Ntip(object$tree),
      colnames(object$data)))
  tmp<-diag(rep(1,ncol(object$data)))
  colnames(tmp)<-colnames(object$data)
  xx<-thresh2bin(object$liability,
    c(-Inf,object$threshold,Inf),tmp)
  for(i in 1:nrow(anc)){
    for(j in 1:nrow(xx)){
      anc[i,j]<-sum(anc_mk$ace[i,]*xx[j,])
    }
  }
  result<-list(ace=anc,logLik=logLik(object$mk_fit))
  attr(result,"type")<-"marginal"
  attr(result,"tree")<-object$tree
  attr(result,"data")<-object$data
  class(result)<-"ancr"
  result
}

## likelihood function for thresholds
lik_thresh<-function(threshold,pw,fixed_threshold,
  liability,x,P.all,trace=1,...){
  if(hasArg(return)) return<-list(...)$return
  else return<-"likelihood"
  th<-c(-Inf,fixed_threshold,sort(threshold),Inf)
  Y<-thresh2bin(liability,th,x)
  k<-ncol(Y)
  if(hasArg(pi)) pi<-list(...)$pi
  else pi<-rep(1/k,k)
  L<-rbind(Y[pw$tip.label,],
    matrix(0,pw$Nnode,k,
      dimnames=list(1:pw$Nnode+Ntip(pw))))
  nn<-unique(pw$edge[,1])
  pp<-vector(mode="numeric",length=length(nn))
  root<-min(nn)
  for(i in 1:length(nn)){
    ee<-which(pw$edge[,1]==nn[i])
    PP<-matrix(NA,length(ee),k)
    for(j in 1:length(ee)){
      P<-P.all[[ee[j]]]
      PP[j,]<-P%*%L[pw$edge[ee[j],2],]
    }
    L[nn[i],]<-apply(PP,2,prod)
    if(nn[i]==root){
      if(pi[1]=="fitzjohn") pi<-L[nn[i],]/sum(L[nn[i],])
      else if(pi[1]=="mle") 
        pi<-as.numeric(L[nn[i],]==max(L[nn[i],]))
      L[nn[i],]<-pi*L[nn[i],]
    }
    pp[i]<-sum(L[nn[i],])
    L[nn[i],]<-L[nn[i],]/pp[i]
  }
  prob<-sum(log(pp))
  if(trace>0) print(c(fixed_threshold,sort(threshold),
    prob))
  if(return=="likelihood") 
    if(is.na(prob)||is.nan(prob)) 
      return(-Inf) else return(prob)
  else if(return=="conditional") L
  else if(return=="pi") pi
}

Now, to test it out I’m going to proceed to simulate data under the threshold model. That involves first simulating liabilities of the threshold trait, and then “thresholding” these liabilities (i.e., translating the simulated continuous trait values to each species’ level discrete character state).

Here’s what that looks like for a discrete character with a total of four levels (and thus three thresholds between character states). In this case, the threshold positions are \([0, 0.8, 2]\).

tree<-pbtree(n=250,scale=1)
liabilities<-fastBM(tree,a=0.5,internal=TRUE)
thresh<-function(x) if(x<0) "a" else 
  if(x>0&&x<0.8) "b" else 
    if(x>0.8&&x<2) "c" else "d"
y_all<-as.factor(sapply(liabilities,thresh))
y_obs<-y_all[tree$tip.label]
head(y_obs,30)
## t197 t198 t128 t118 t199 t200  t24 t124 t125 t129 t130  t76  t14 t116 t117 t106 t107  t40  t41 t173 t174 
##    b    b    b    b    b    b    a    b    a    b    a    b    a    a    a    b    c    a    a    a    a 
##  t98 t175 t176 t132 t133  t30  t31  t18  t47 
##    a    a    b    a    a    a    a    b    b 
## Levels: a b c d

Now let’s fit our model. The function allows us to set a sequence, but will otherwise order our input trait alphanumerically.

thresh_fit<-fitThresh(tree,y_obs)
thresh_fit
## Object of class "fitThresh".
## 
##     Set value of sigsq (of the liability) = 1.0
## 
##     Set or estimated threshold(s) = 
##          [ -1.127676, -0.339556, 0.793517 ]*
## 
##     Log-likelihood: -190.637214 
## 
## (*lowermost threshold is fixed)

Our estimated thresholds don’t very much resemble the generating values – but remember, only their relative positions are relevant and we simulated with a lower threshold of 0.

thresh_fit$threshold-min(thresh_fit$threshold)
## [1] 0.0000000 0.7881201 1.9211932

Compare this to our generating values of \([0, 0.8, 2]\). Not bad.

We can compare this approximation of the log-likelihood to the log-likelihood of a standard Mk model. Let’s try it. To give the Mk model it’s best shot, I’ll make it ordered with an order that matches our known trait order.

ordered<-matrix(c(
  0,1,0,0,
  2,0,3,0,
  0,4,0,5,
  0,0,6,0),4,4)
mk_fit<-fitMk(tree,y_obs,model=ordered,
  pi="fitzjohn")
mk_fit
## Object of class "fitMk".
## 
## Fitted (or set) value of Q:
##           a         b         c         d
## a -2.149170  2.149170  0.000000  0.000000
## b  3.066453 -3.999178  0.932726  0.000000
## c  0.000000  1.857538 -2.253100  0.395561
## d  0.000000  0.000000  4.696889 -4.696889
## 
## Fitted (or set) value of pi:
##        a        b        c        d 
## 0.020451 0.071092 0.422495 0.485963 
## due to treating the root prior as (a) nuisance.
## 
## Log-likelihood: -196.346033 
## 
## Optimization method used was "nlminb"
## 
## R thinks it has found the ML solution.
anova(mk_fit,thresh_fit)

##               log(L) d.f.      AIC       weight
## mk_fit     -196.3460    6 404.6921 0.0001650959
## thresh_fit -190.6372    3 387.2744 0.9998349041

This tells us that the threshold model is much better at explaining our data than was our ordered Mk model.

Lastly, last time I illustrated how this can actual be consequential for ancestral state estimation. Let’s see if that’s also true in this instance. To measure the accuracy of our ancestral states, I’ll first convert the true states into a binary matrix, then I’ll simply compute the mean squared difference between our estimated states and these known values.

Let’s try it. First under the ordered Mk model.

anc_mk<-ancr(mk_fit)
anc_mk
## Marginal ancestral state estimates:
##            a        b        c        d
## 251 0.000995 0.012030 0.424869 0.562106
## 252 0.004212 0.065235 0.671346 0.259207
## 253 0.321149 0.615953 0.061806 0.001092
## 254 0.463081 0.520253 0.016543 0.000122
## 255 0.529549 0.456087 0.014249 0.000115
## 256 0.494236 0.495255 0.010469 0.000041
## ...
## 
## Log-likelihood = -196.346033
true_states<-to.matrix(y_all[1:tree$Nnode+Ntip(tree)],
  levels(y_all))
head(true_states)
##     a b c d
## 251 0 1 0 0
## 252 1 0 0 0
## 253 1 0 0 0
## 254 1 0 0 0
## 255 1 0 0 0
## 256 0 1 0 0
mean((true_states-anc_mk$ace)^2)
## [1] 0.08465672

Now under the threshold model. (This takes a sec, but could probably be sped up substantially if I put a little effort into it.

anc_thresh<-ancr(thresh_fit)
anc_thresh
## Marginal ancestral state estimates:
##            a        b        c d
## 251 0.070839 0.825943 0.103219 0
## 252 0.096181 0.829316 0.074503 0
## 253 0.513837 0.485117 0.001046 0
## 254 0.691465 0.308313 0.000222 0
## 255 0.715267 0.283944 0.000790 0
## 256 0.573314 0.425312 0.001373 0
## ...
## 
## Log-likelihood = -190.637214
mean((true_states-anc_thresh$ace)^2)
## [1] 0.05507425

(In some cases the improvement can be larger – in other cases, it may be marginal.)

Here’s a visualization of our ancestral states:

plot(anc_thresh,direction="upwards",ftype="off",
  type="arc",arc_height=0.2,
  args.nodelabels=list(piecol=hcl.colors(4),
    cex=0.4),
  args.tiplabels=list(cex=0.2))

plot of chunk unnamed-chunk-23

That’s it, but I promise to write more about this later!

No comments:

Post a Comment

Note: due to the very large amount of spam, all comments are now automatically submitted for moderation.