Create a file named dirichlet_updater.hpp and fill it with the following class declaration. Note that the DirichletUpdater
class is derived from the class Updater
and overrides 2 of pure virtual functions specified in Updater
(namely proposeNewState
and revert
).
#pragma once
#include "updater.hpp"
namespace strom {
class Chain;
class DirichletUpdater : public Updater {
friend class Chain;
public:
typedef std::vector<double> point_t;
typedef std::shared_ptr< DirichletUpdater > SharedPtr;
DirichletUpdater();
virtual ~DirichletUpdater();
void clear();
virtual double calcLogPrior();
protected:
virtual void pullFromModel() = 0;
virtual void pushToModel() = 0;
void proposeNewState();
void revert();
point_t _curr_point;
point_t _prev_point;
};
// member function bodies go here
}
As usual, the constructor just calls the clear
function to do its work and the destructor is a placeholder but does nothing currently.
inline DirichletUpdater::DirichletUpdater() {
// std::cout << "Creating DirichletUpdater object" << std::endl;
clear();
}
inline DirichletUpdater::~DirichletUpdater() {
// std::cout << "Destroying DirichletUpdater object" << std::endl;
}
This function returns the object to its just-constructed state (and is what actually does the work of the constructor). Note that this class is derived from Updater
and first calls the Updater::clear
function before doing additional work specific to this class.
inline void DirichletUpdater::clear() {
Updater::clear();
_prev_point.clear();
}
This function assumes that the _prior_parameters
vector in the Updater
base class has been filled with the appropriate number of Dirichlet prior parameters. The assumption is that if DirichletUpdater
is being used to update a multivariate model parameter, that parameter must have a Dirichlet prior and thus the length of a vector representing the current value of the parameter (_curr_point
) should have the same length as the vector _prior_parameters
. Because all parameters governed by DirichletUpdater
have a Dirichlet prior (or at least a transformed Dirichlet prior), this class can handle calculation of most or all of the log prior, thus relieving derived classes of the need to do this job.
inline double DirichletUpdater::calcLogPrior() {
pullFromModel();
assert(_curr_point.size() > 0);
assert(_curr_point.size() == _prior_parameters.size());
bool flat_prior = true;
bool bad_point = false;
double log_prior = 0.0;
double prior_param_sum = 0.0;
for (unsigned i = 0; i < _curr_point.size(); ++i) {
if (_prior_parameters[i] != 1.0)
flat_prior = false;
if (_curr_point[i] == 0.0)
bad_point = true;
log_prior += (_prior_parameters[i] - 1.0)*std::log(_curr_point[i]);
log_prior -= std::lgamma(_prior_parameters[i]);
prior_param_sum += _prior_parameters[i];
}
if (flat_prior)
return std::lgamma(prior_param_sum);
else if (bad_point)
return Updater::_log_zero;
else
log_prior += std::lgamma(prior_param_sum);
return log_prior;
}
From the class declaration:
virtual void pullFromModel() = 0;
virtual void pushToModel() = 0;
These are placeholders for functions that must be defined in derived classes. Each derived class must provide a way to fill the _curr_point
vector with values that have a Dirichlet prior and are updated using the focussed Dirichlet proposal implemented in this abstract base class. For example, the StateFreqUpdater
will simply copy the state frequencies stored in the model to _curr_point
in its pullFromModel
function, and copy the values in _curr_point
to the model in its pushToModel
function.
This updater works by centering a sharp (low variance) Dirichlet distribution over the current value of the parameter (_curr_point
), then choosing the proposed value from that Dirichlet distribution. Note that the tuning parameter _lambda
controls the sharpness of the proposal distribution: larger values of _lambda
(e.g. 1) mean bolder proposals that generate proposed states further away from the current state on average, while smaller _lambda
values (e.g. 1/1000) result in a sharper proposal distribution that chooses proposed values close to the current state.
Yes, it is a little confusing that we are using distinct Dirichlet distributions for the prior and the proposal distribution, but the advantage of this is that we modify all 4 nucleotide frequencies (or 6 exchangeabilities, or 61 codon frequencies, or all subset relative rates) at once but in a way that keeps the proposed values close to the current values. This proposal approach also automatically ensures that the proposed values add to 1, thus maintaining the constraint that is required.
One additional complication is that this proposal is not symmetric, so we must calculate the Hastings ratio in this function as well as proposing a new state. The Hastings ratio is the conditional probability density of the current state given the proposed state divided by the conditional probability density of the proposed state given the current state. Said more simply, but less precisely, it is the ratio of the probability of the reverse proposal to the probability of the forward proposal.
What is the reverse proposal? It involves proposing the current state assuming that the Markov chain is currently sitting at the proposed state. Calculation of the Hastings ratio thus involves centering a sharp Dirichlet distribution over the proposed state and asking about the probability density of the current state were the current state to be drawn from that distribution.
The comments attempt to explain which part of the process is being done by each chunk of code.
inline void DirichletUpdater::proposeNewState() {
// Save length of _curr_point.
pullFromModel();
unsigned dim = (unsigned)_curr_point.size();
// Save copy of _curr_point in case revert is necessary.
_prev_point.assign(_curr_point.begin(), _curr_point.end());
// Determine parameters of Dirichlet forward proposal distribution and, at the same time,
// draw gamma deviates that will be used to form the proposed point.
std::vector<double> forward_params(dim, 0.0);
for (unsigned i = 0; i < dim; ++i) {
// Calculate ith forward parameter
double alpha_i = 1.0 + _prev_point[i]/_lambda;
if (alpha_i < 1.e-12)
alpha_i = 1.e-12;
forward_params[i] = alpha_i;
// Draw ith gamma deviate
_curr_point[i] = 0.0;
if (alpha_i > 0.0)
_curr_point[i] = _lot->gamma(alpha_i, 1.0);
}
double sum_gamma_deviates = std::accumulate(_curr_point.begin(), _curr_point.end(), 0.0);
double sum_forward_parameters = std::accumulate(forward_params.begin(), forward_params.end(), 0.0);
// Choose new state by sampling from forward proposal distribution.
// We've already stored gamma deviates in _curr_point, now just need to normalize them.
for (unsigned i = 0; i < dim; ++i) {
_curr_point[i] /= sum_gamma_deviates;
}
// Determine probability density of the forward proposal
double log_forward_density = 0.0;
for (unsigned i = 0; i < dim; ++i) {
log_forward_density += (forward_params[i] - 1.0)*std::log(_prev_point[i]);
log_forward_density -= std::lgamma(forward_params[i]);
}
log_forward_density += std::lgamma(sum_forward_parameters);
// Determine parameters of Dirichlet reverse proposal distribution
std::vector<double> reverse_params(dim, 0.0);
for (unsigned i = 0; i < dim; ++i) {
reverse_params[i] = 1.0 + _curr_point[i]/_lambda;
}
double sum_reverse_parameters = std::accumulate(reverse_params.begin(), reverse_params.end(), 0.0);
// determine probability density of the reverse proposal
double log_reverse_density = 0.0;
for (unsigned i = 0; i < dim; ++i) {
log_reverse_density += (reverse_params[i] - 1.0)*std::log(_curr_point[i]);
log_reverse_density -= std::lgamma(reverse_params[i]);
}
log_reverse_density += std::lgamma(sum_reverse_parameters);
// calculate the logarithm of the Hastings ratio
_log_hastings_ratio = log_reverse_density - log_forward_density;
pushToModel();
// This proposal invalidates all transition matrices and partials
_tree_manipulator->selectAllPartials();
_tree_manipulator->selectAllTMatrices();
}
This function is called if the proposal is not accepted. It simply copies the _prev_point
vector into the _curr_point
vector. (The _prev_point
vector was filled with the values from _curr_point
at the beginning of the proposeNewState
function before _curr_point
was modified.)
inline void DirichletUpdater::revert() {
std::copy(_prev_point.begin(), _prev_point.end(), _curr_point.begin());
pushToModel();
}