# How to Sample a Dirichlet-Multinomial Distribution

Consider the problem of sampling from a multinomial distribution $Mult(\vec{x}|\vec{p}, n)$, where $\vec{p}$ is sampled from a Dirichlet prior distribution $Dir(\vec{p}|\vec{\alpha})$.

A conceptually straight-forward solution is to sample $\vec{p}$ from $Dir(\vec{p}|\vec\alpha)$, and then to generated $\latex n$ samples from the discrete distribution defined by $\vec{p}$. As described by Wikipedia, sampling $\vec{p}=\{p_1,\ldots,p_K\}$ can be done by drawing samples $\{y_1,\ldots,y_K\}$ from K Gamma distributions: $y_k \sim \Gamma(\alpha_k, 1) \text{, } k\in[1,K]$, and then get $\vec{p}$ by normalizing $y_k$: $p_k = y_k/(\sum_k y_k)$. According to Wikipedia, if $\alpha_k$ is a positive integer, we have $\sum_{i=1}^{\alpha_k} - \log U_i \sim \Gamma(\alpha_k, 1)$, where $U_i$ is a sample drawn from the uniform distribution over $(0, 1]$. However, if $\alpha_k$‘s are not positive integers, sampling Gamma would become a complex procedure.

Even if we can implement the algorithm that draws samples from Gamma and then Dirichlet, this algorithm would not be numerically robust. Consider that when $U_i$ is close to 0, $\log U_i$ would be Inf. Another dangerous point is that if we get successively K $y_k=0$, $p_k=y_k/(\sum y_k)$ would lead to either divide-by-zero interrupt or make $p_k$ NaN.

Fortunately, we can make use of the conjugacy between Dirichlet and multinomial. This conjugacy, as explained in the textbook Pattern Recognition and Machine Learning, states that $\alpha_k$ is the prior number of observations of the multinomial output $k$. This leads to the following simple sampling method, which can be generalized further to sample from Dirichlet processes:

1. $\vec{p} = \vec\alpha$, $i = 0$
2. $k \sim Discrete(\vec{p})$
3. $p_k = p_k+1$, $x_k=x_k+1$, $i=i+1$
4. while $i < n$, goto 2.

Full Go code is as follows:

func sampleDirichletMultinomial(alpah []float64, n int, rng *rand.Rand) []int {
dist := make([]float64, len(alpha))
copy(dist, alpha)
hist := make([]int, len(alpha))
for i := 0; i < n; i++ {
k := sampleDiscrete(dist, rng)
dist[k] += 1.0
hist[k]++
}
return hist
}

func sampleDiscrete(dist []float64, rng *rand.Rand) int {
if len(dist) <= 0 {
panic("sample from empty distribution")
}
sum := 0.0
for _, v := range dist {
if v < 0 {
panic(fmt.Sprintf("bad dist: %v", dist))
}
sum += v
}
u := rng.Float64() * sum
sum = 0
for i, v := range dist {
sum += v
if u < sum {
return i
}
}
panic("sampleDiscrete gets out of all possiblilities")
}