How to Sample a Dirichlet-Multinomial Distribution

Consider the problem of sampling from a multinomial distribution Mult(\vec{x}|\vec{p}, n), where \vec{p} is sampled from a Dirichlet prior distribution Dir(\vec{p}|\vec{\alpha}).

A conceptually straight-forward solution is to sample \vec{p} from Dir(\vec{p}|\vec\alpha), and then to generated $\latex n$ samples from the discrete distribution defined by \vec{p}. As described by Wikipedia, sampling \vec{p}=\{p_1,\ldots,p_K\} can be done by drawing samples \{y_1,\ldots,y_K\} from K Gamma distributions: y_k \sim \Gamma(\alpha_k, 1) \text{,  } k\in[1,K], and then get \vec{p} by normalizing y_k: p_k = y_k/(\sum_k y_k). According to Wikipedia, if \alpha_k is a positive integer, we have \sum_{i=1}^{\alpha_k} - \log U_i \sim \Gamma(\alpha_k, 1), where U_i is a sample drawn from the uniform distribution over (0, 1]. However, if \alpha_k‘s are not positive integers, sampling Gamma would become a complex procedure.

Even if we can implement the algorithm that draws samples from Gamma and then Dirichlet, this algorithm would not be numerically robust. Consider that when U_i is close to 0, \log U_i would be Inf. Another dangerous point is that if we get successively K y_k=0, p_k=y_k/(\sum y_k) would lead to either divide-by-zero interrupt or make p_k NaN.

Fortunately, we can make use of the conjugacy between Dirichlet and multinomial. This conjugacy, as explained in the textbook Pattern Recognition and Machine Learning, states that \alpha_k is the prior number of observations of the multinomial output $k$. This leads to the following simple sampling method, which can be generalized further to sample from Dirichlet processes:

  1. \vec{p} = \vec\alpha, i = 0
  2. k \sim Discrete(\vec{p})
  3. p_k = p_k+1, x_k=x_k+1, i=i+1
  4. while i < n, goto 2.

Full Go code is as follows:

func sampleDirichletMultinomial(alpah []float64, n int, rng *rand.Rand) []int {
	dist := make([]float64, len(alpha))
	copy(dist, alpha)
	hist := make([]int, len(alpha))
	for i := 0; i < n; i++ {
		k := sampleDiscrete(dist, rng)
		dist[k] += 1.0
	return hist

func sampleDiscrete(dist []float64, rng *rand.Rand) int {
	if len(dist) <= 0 {
		panic("sample from empty distribution")
	sum := 0.0
	for _, v := range dist {
		if v < 0 {
			panic(fmt.Sprintf("bad dist: %v", dist))
		sum += v
	u := rng.Float64() * sum
	sum = 0
	for i, v := range dist {
		sum += v
		if u < sum {
			return i
	panic("sampleDiscrete gets out of all possiblilities")