Utilizing conjugacy for the Multinomial distribution
conjugacy
bayesian
multinomial
dirichlet-multinomial
Author
Michael Green
Published
December 26, 2024
Modified
December 29, 2024
Introduction
As outlined in Sensoy, Kaplan, and Kandemir (2018) there’s some work done in Evidential Deep Learning to quantify uncertainty in classification applications. The evidential framework depends on the ability to specify conjugate priors for the parameters of the likelihood of interest. In classification tasks we naturally lean on the Categorical distribution in the general case and the Bernoulli distribution in the binary classification case. The Bernoulli and the Categorical distribution are both useful for describing an individual sample of a class outcome, i.e., true or false for a given class. But what if your data are not individual classification samples but rather aggregations of the number of times a given class came up in several samples? Let’s take the example of a cancer classification problem. In the individual classification example we would get a true (for cancer) or false (for not cancer) per individual we test. If we have have a dataset showing how many patients had the diagnosis cancer vs not cancer per region then we would have a Binomial distribution that would best describe this data. The same logic would hold true for the Categorical distribution which would then map to the multinomial distribution.
A quick primer on Bayesian probability
There are of course terminology in Bayesian modeling that’s necessary to know and be able to relate to. I will follow Tu (2014) and outline the most important ones to know.
The likelihood is a measure of how well a chosen distribution and it’s given parameters explain the observed data. It is strictly speaking not a probability since it usually operates of several samples from a given distribution. As such, it is not normalized.
The prior is the assumed distribution you put on your parameters before observing the data. It can be interpreted as a kind of regularization of your model to make sure it doesn’t do anything crazy.
The posterior is the distribution of your parameters after observing the data. Thus, it’s a distribution over your parameters where your assumptions and data has been weighted together to form a coherent opinion about the problem.
The evidence or marginal is the probability of observing the data after all parameters have been integrated out. This is important as it is arguable what we would like to maximize for a given problem. Note that this is distinct from the likelihood which doesn’t factor the prior assumptions in.
Mathematically we can put these items into one glorious equation
where I have chosen a parameterization that matches the one we will use for the Multinomial distribution. Usually you would have written this with \(𝛉\) instead of \(𝐩\).
Now, there’s one more I would like to mention which is also an integration exercise where we instead are interested in the marginal predictive distribution of a new datapoint \(x\). It’s called marginal predictive. In this setting we will use \(𝐗\) as representing the entire dataset of \([𝐱_i]\).
In order to really seal the deal here I want to be explicit about the marginal distribution which is over an entire dataset. We will use the same formalism as in the above equation. In this case we can state the marginal data distribution as
where you can see the inclusion of each datapoint \(𝐱_i\) in the dataset \(𝐗\).
The Multinomial distribution
In order to describe the multinomial distribution for a random variable \(𝐱\) we need the total number of cases \(n\) as well as the probability of each class to occur \(𝐩\). In this case \(𝐱\) and \(𝐩\) are vectors of \(K\) elements representing the \(K\) classes.
This is the probability mass function for the Multinomial distribution. The gamma functions you see in the equations are just better ways of expressing the numbers to the \(K\) classes. Just to get a feel for what is happening it’s worth remembering that \(Γ(n+1)=n!\).
The Multinomial distribution is a generalized form (\(K≥2\)) of the Binomial (\(K=2\)) distribution. Conversely, it’s also a generalization (\(n≥1\)) of the Categorical (\(n=1\)) distribution. Table 1 shows a summary of the different distributions mentioned here.
Table 1: Description of when the different distributions are applicable as well as their relation.
K
n
Distribution
2
1
Bernoulli
2
>1
Binomial
>2
1
Categorical
>2
>1
Multinomial
Conjugate priors
There’s a lot of research put into the conjugate distributions for the exponential family which in my opinion is really good. In fact I would argue we still need a lot more.. Anyway, the point here is that if you looked at the equations above you might wonder where the heck the prior distribution comes from? You’d be forgiven for wondering. This can in principle be any distribution you like. As long as you can write it down. Sounds great! The problem is that it’s very often expensive to compute as the integrals we showed in the posterior and the marginals are quite hard to solve and most of the choices do not have a closed form solution. This is where conjugate priors comes into the pictures. Conjugate priors are priors matched with a likelihood which has a closed form expression for the resulting posterior. This is a big thing since we won’t have to rely on MCMC techniques to do the integrations.
Ok, but so what, what is the darn prior distribution that we want to use? Well, in the case of the Multinomial likelihood the conjugate prior is the Dirichlet distribution. In our notation above we will parameterize this distribution with \(𝛂\). Thus the prior distribution for our parameters of the Multinomial distribution will be \(P(𝐩)=P(𝐩|𝛂)=\mathrm{Dir}(𝐩;𝛂)\). It’s written out explicitely as
The \(\mathrm{B}(𝛂)\) is the multivariate beta function (Wikipedia contributors 2024a) which you can read about in Wikipedia contributors (2024b). A rather convenient yet annoying fact is that the Dirichlet distribution is the conjugate prior for both the Categorical distribution as well as the Multinomial.
Note
The Dirichlet distribution is the conjugate prior for both the Categorical and the Multinomial distribution.
The closed form solutions we’ve been waiting for
With the Bayesian primer and our Multinomial distribution (Equation 4) out of the way we’re now ready to have a look at the model that this whole post is motivated by, namely the “Dirichlet-Multinomial”. It’s a model for several samples from a Categorical distribution with a prior on the class probabilities.
Since we know from the section on conjugate priors that this model indeed is conjugate we can get explicit analytical expressions for the posterior (Equation 1) and marginal (Equation 3). The posterior distribution is also Dirichlet.
\[P(𝐩|𝐱)\sim\mathrm{Dir}(𝛂^′)\]
where \(𝛂^′=[α_1^′,\ldots,α_K^′]\), such that \(α_k^′=α_k+∑_{i=1}^Nx_i^k\) and \(x_i^k\) is the number of observed instances of class \(k\) in sample \(i\). Thus, the posterior distribution is formed from the prior pseudocount \(𝛂\) and the observed counts from dataset \(𝐗\).
With this it’s time to move on to the marginal. It’s where we integrate the parameters \(𝐩\) out. Below you can see the procedure for a single datapoint \(𝐱\).
Again, \(α_0=∑α_k\) like before. It’s important to note here that there’s no conditioning on the data going on here. This is because the likelihood only enters through the one \(𝐱\) here. This is the marginal data distribution. The posterior predictive distribution is achieved by utilizing the conditioned \(𝛂^′=𝛂+∑_{i=1}^{N}𝐱_i\) instead of \(𝛂\).
Note
The posterior predictive distribution of a new datapoint \(𝐱_i\) is given by
This is crazy useful as this compound distribution allows us to model overdispersed data that the original Multinomial distribution would have a hard time capturing.
How to apply with Neural Networks
Now, none of this really matters unless we bring some conditioning in to deliver the \(𝛂\). This is what we want a model to do. To look at input data and predict the \(𝛂\) that can be fed into the Dirichlet-Multinomial to explain the data. In this section I will be using Julia (Bezanson et al. 2017) which in my opinion is a nice language for playing around with these concepts.
First I’ll define some dummy data inspired from Guinney et al. (2015) where they define four main consensus molecular subtypes (CMSs) of Colorectal cancer.
Table 2: A fake cancer dataset for illustration purposes.
Gender
Income
CMS1
CMS2
CMS3
CMS4
0
1
4
10
3
15
1
2
5
15
7
11
With this small dataset from Table 2 we can start to define a model for this. We have 4 classes so we need 4 outputs. I’ll show a small piece of code below which implements a stupid overparameterized neural network for this problem. Don’t judge me. It’s for illustration purposes. But I’m also getting ahead of myself, we need to implement the marginal data distribution (which we want to maximize). Since this is indeed just the compound DirichletMultinomial from above we can implement that probability mass function and take the negative logarithm to get a kind of type II likelihood that we want to maximize (or well who’s negative likelihood to minimize).
usingSpecialFunctionsfunctiondmlogpmf(x, α) α₀ =sum(α, dims=1) n =sum(x, dims=1)# First term ft =loggamma.(α₀) +loggamma.(n .+1) -loggamma.(n + α₀)# Second term st =sum(loggamma.(α .+ x) - (loggamma.(α) +loggamma.(x .+1)), dims =1) ft + stenddmnegloglik(x, α) =sum(-dmlogpmf(x, α))
dmnegloglik (generic function with 1 method)
Next we’ll define the neural network and the training of it.
Let’s unpack this a bit. First we made sure that we put the training data in a format that Flux prefers. Then we define a loss function based on the compound distribution we got due to conjugacy. That’s the DirichletMultinomial distribution. We want a neural network that produces the parameters of this distribution. We then evaluate the negative log likelihood for the observed counts we have and update the network until it matches the data better.
As you can see the loss decreased so the network indeed got trained. I know, big surprise right? Appears that 704 parameters can help when you have 2 datapoints. 😄
Now let’s see a more understandable performance metric than likelihood. For this type of output I would just use the standard MAPE per class. We’re also interested in the epistemic uncertainty for each datapoint. This represents the models uncertainty about it’s own prediction. In order to do this we define a predict and an uncert function to use for predictions and estimating epistemic uncertainty.
With these functions we can now see how well our model fits the data. Since we have such a simple and nice dataset we can just print it out and inspect what the predicted numbers look like in Table 3. Comparing this to Table 2 it’s clear to see that the predicted numbers are in the vicinity of the observed ones. It’s also not obvious to see which data point that’s better or worse predicted. However, due to our ability to estimate epistemic uncertainty we can now have a computational view on uncertainty. By inspecting the uncertainty column we can see that the model is a lot more confident regarding the second datapoint.
Table 3: The predicted counts for all types of the cancer we modeled with included epistemic uncertainty.
CMS1
CMS2
CMS3
CMS4
Uncertainty
5
11
4
13
0.15
5
14
7
12
0.08
Now we might also ask with the classwise performance is, i.e., how well are the counts in each class predicted? Here, we’ll turn to MAPE and use the function from above. The numbers are in Table 4 below.
Table 4: The performance and prevalence of each class.
CMS1
CMS2
CMS3
CMS4
Type
12%
8%
17%
11%
MAPE
13%
36%
14%
37%
Prevalence
Conclusion
So we’ve seen that using a conjugate prior distribution allows us to get an explicit solution for our posterior probability as well as the marginal data distribution. We’ve also seen how the resulting compound distribution can be used as a loss function through the negative log likelihood. This lends itself nicely to using arbitrarily complex models like neural networks to predict the parameters of the posterior distribution.
As a bonus we can readily estimate the epistemic uncertainty of our model, which in the case of cancer prediction can be pretty darn important.
References
Bezanson, Jeff, Alan Edelman, Stefan Karpinski, and Viral B. Shah. 2017. “Julia: A fresh approach to numerical computing.”SIAM Review 59 (1): 65–98. https://doi.org/10.1137/141000671.
Guinney, Justin, Rodrigo Dienstmann, Xin Wang, Aurélien De Reynies, Andreas Schlicker, Charlotte Soneson, Laetitia Marisa, et al. 2015. “The Consensus Molecular Subtypes of Colorectal Cancer.”Nature Medicine 21 (11): 1350–56.
Sensoy, Murat, Lance Kaplan, and Melih Kandemir. 2018. “Evidential Deep Learning to Quantify Classification Uncertainty.”https://arxiv.org/abs/1806.01768.