16.1 The Dirichlet Updater Base Class

(Linux version)

< 16.0 | 16.1 | 16.2 >

Create a file named dirichlet_updater.hpp and fill it with the following class declaration. Note that the DirichletUpdater class is derived from the class Updater and overrides 2 of pure virtual functions specified in Updater (namely proposeNewState and revert).

#pragma once    

#include "updater.hpp"

namespace strom {
    
    class Chain;
    
    class DirichletUpdater : public Updater {

        friend class Chain;
        
        public:
            typedef std::vector&lt;double&gt;                 point_t;
            typedef std::shared_ptr&lt; DirichletUpdater &gt; SharedPtr;

                                                DirichletUpdater();
            virtual                             ~DirichletUpdater();
        
            void                                clear();
            virtual double                      calcLogPrior();
        
        protected:
        
            virtual void                        pullFromModel() = 0;    
            virtual void                        pushToModel() = 0;    

            void                                proposeNewState();
            void                                revert();
        
            point_t                             _curr_point;
            point_t                             _prev_point;
    };
    
    // member function bodies go here
    
}   

Constructor and destructor

As usual, the constructor just calls the clear function to do its work and the destructor is a placeholder but does nothing currently.

    inline DirichletUpdater::DirichletUpdater() {   
        // std::cout &lt;&lt; "Creating DirichletUpdater object" &lt;&lt; std::endl;
        clear();
    }

    inline DirichletUpdater::~DirichletUpdater() {
        // std::cout &lt;&lt; "Destroying DirichletUpdater object" &lt;&lt; std::endl;
    }   

The clear member function

This function returns the object to its just-constructed state (and is what actually does the work of the constructor). Note that this class is derived from Updater and first calls the Updater::clear function before doing additional work specific to this class.

    inline void DirichletUpdater::clear() { 
        Updater::clear();
        _prev_point.clear();
    }   

The calcLogPrior member function

This function assumes that the _prior_parameters vector in the Updater base class has been filled with the appropriate number of Dirichlet prior parameters. The assumption is that if DirichletUpdater is being used to update a multivariate model parameter, that parameter must have a Dirichlet prior and thus the length of a vector representing the current value of the parameter (_curr_point) should have the same length as the vector _prior_parameters. Because all parameters governed by DirichletUpdater have a Dirichlet prior (or at least a transformed Dirichlet prior), this class can handle calculation of most or all of the log prior, thus relieving derived classes of the need to do this job.

    inline double DirichletUpdater::calcLogPrior() {  
        pullFromModel();
        assert(_curr_point.size() &gt; 0);
        assert(_curr_point.size() == _prior_parameters.size());
        bool flat_prior = true;
        bool bad_point = false;
        double log_prior = 0.0;
        double prior_param_sum = 0.0;
        for (unsigned i = 0; i &lt; _curr_point.size(); ++i) {
            if (_prior_parameters[i] != 1.0)
                flat_prior = false;
            if (_curr_point[i] == 0.0)
                bad_point = true;
            log_prior += (_prior_parameters[i] - 1.0)*std::log(_curr_point[i]);
            log_prior -= std::lgamma(_prior_parameters[i]);
            prior_param_sum += _prior_parameters[i];
        }
        if (flat_prior)
            return std::lgamma(prior_param_sum);
        else if (bad_point)
            return Updater::_log_zero;
        else
            log_prior += std::lgamma(prior_param_sum);
        return log_prior;
    }  

The pullFromModel and pushToModel pure virtual member functions

From the class declaration:

            virtual void                        pullFromModel() = 0;    
            virtual void                        pushToModel() = 0;    

These are placeholders for functions that must be defined in derived classes. Each derived class must provide a way to fill the _curr_point vector with values that have a Dirichlet prior and are updated using the focussed Dirichlet proposal implemented in this abstract base class. For example, the StateFreqUpdater will simply copy the state frequencies stored in the model to _curr_point in its pullFromModel function, and copy the values in _curr_point to the model in its pushToModel function.

The proposeNewState member function

This updater works by centering a sharp (low variance) Dirichlet distribution over the current value of the parameter (_curr_point), then choosing the proposed value from that Dirichlet distribution. Note that the tuning parameter _lambda controls the sharpness of the proposal distribution: larger values of _lambda (e.g. 1) mean bolder proposals that generate proposed states further away from the current state on average, while smaller _lambda values (e.g. 1/1000) result in a sharper proposal distribution that chooses proposed values close to the current state.

Yes, it is a little confusing that we are using distinct Dirichlet distributions for the prior and the proposal distribution, but the advantage of this is that we modify all 4 nucleotide frequencies (or 6 exchangeabilities, or 61 codon frequencies, or all subset relative rates) at once but in a way that keeps the proposed values close to the current values. This proposal approach also automatically ensures that the proposed values add to 1, thus maintaining the constraint that is required.

One additional complication is that this proposal is not symmetric, so we must calculate the Hastings ratio in this function as well as proposing a new state. The Hastings ratio is the conditional probability density of the current state given the proposed state divided by the conditional probability density of the proposed state given the current state. Said more simply, but less precisely, it is the ratio of the probability of the reverse proposal to the probability of the forward proposal.

What is the reverse proposal? It involves proposing the current state assuming that the Markov chain is currently sitting at the proposed state. Calculation of the Hastings ratio thus involves centering a sharp Dirichlet distribution over the proposed state and asking about the probability density of the current state were the current state to be drawn from that distribution.

The comments attempt to explain which part of the process is being done by each chunk of code.

    inline void DirichletUpdater::proposeNewState() {   
        // Save length of _curr_point.
        pullFromModel();
        unsigned dim = (unsigned)_curr_point.size();
        
        // Save copy of _curr_point in case revert is necessary.
        _prev_point.assign(_curr_point.begin(), _curr_point.end());
        
        // Determine parameters of Dirichlet forward proposal distribution and, at the same time,
        // draw gamma deviates that will be used to form the proposed point.
        std::vector&lt;double&gt; forward_params(dim, 0.0);
        for (unsigned i = 0; i &lt; dim; ++i) {
            // Calculate ith forward parameter
            double alpha_i = 1.0 + _prev_point[i]/_lambda;
            if (alpha_i &lt; 1.e-12)
                alpha_i = 1.e-12;
            forward_params[i] = alpha_i;
            
            // Draw ith gamma deviate
            _curr_point[i] = 0.0;
            if (alpha_i &gt; 0.0)
                _curr_point[i] = _lot-&gt;gamma(alpha_i, 1.0);
        }
        
        double sum_gamma_deviates     = std::accumulate(_curr_point.begin(), _curr_point.end(), 0.0);
        double sum_forward_parameters = std::accumulate(forward_params.begin(), forward_params.end(), 0.0);

        // Choose new state by sampling from forward proposal distribution.
        // We've already stored gamma deviates in _curr_point, now just need to normalize them.
        for (unsigned i = 0; i &lt; dim; ++i) {
            _curr_point[i] /= sum_gamma_deviates;
        }
        
        // Determine probability density of the forward proposal
        double log_forward_density = 0.0;
        for (unsigned i = 0; i &lt; dim; ++i) {
            log_forward_density += (forward_params[i] - 1.0)*std::log(_prev_point[i]);
            log_forward_density -= std::lgamma(forward_params[i]);
        }
        log_forward_density += std::lgamma(sum_forward_parameters);
        
        // Determine parameters of Dirichlet reverse proposal distribution
        std::vector&lt;double&gt; reverse_params(dim, 0.0);
        for (unsigned i = 0; i &lt; dim; ++i) {
            reverse_params[i] = 1.0 + _curr_point[i]/_lambda;
        }
        
        double sum_reverse_parameters = std::accumulate(reverse_params.begin(), reverse_params.end(), 0.0);

        // determine probability density of the reverse proposal
        double log_reverse_density = 0.0;
        for (unsigned i = 0; i &lt; dim; ++i) {
            log_reverse_density += (reverse_params[i] - 1.0)*std::log(_curr_point[i]);
            log_reverse_density -= std::lgamma(reverse_params[i]);
        }
        log_reverse_density += std::lgamma(sum_reverse_parameters);
        
        // calculate the logarithm of the Hastings ratio
        _log_hastings_ratio = log_reverse_density - log_forward_density;
        
        pushToModel();

        // This proposal invalidates all transition matrices and partials
        _tree_manipulator-&gt;selectAllPartials();
        _tree_manipulator-&gt;selectAllTMatrices();
    }   

The revert member function

This function is called if the proposal is not accepted. It simply copies the _prev_point vector into the _curr_point vector. (The _prev_point vector was filled with the values from _curr_point at the beginning of the proposeNewState function before _curr_point was modified.)

    inline void DirichletUpdater::revert() {    
        std::copy(_prev_point.begin(), _prev_point.end(), _curr_point.begin());
        pushToModel();
    }   

< 16.0 | 16.1 | 16.2 >