Consider the problem of sampling from a multinomial distribution , where is sampled from a Dirichlet prior distribution .
A conceptually straight-forward solution is to sample from , and then to generated $\latex n$ samples from the discrete distribution defined by . As described by Wikipedia, sampling can be done by drawing samples from K Gamma distributions: , and then get by normalizing : . According to Wikipedia, if is a positive integer, we have , where is a sample drawn from the uniform distribution over . However, if ‘s are not positive integers, sampling Gamma would become a complex procedure.
Even if we can implement the algorithm that draws samples from Gamma and then Dirichlet, this algorithm would not be numerically robust. Consider that when is close to 0, would be Inf. Another dangerous point is that if we get successively K , would lead to either divide-by-zero interrupt or make NaN.
Fortunately, we can make use of the conjugacy between Dirichlet and multinomial. This conjugacy, as explained in the textbook Pattern Recognition and Machine Learning, states that is the prior number of observations of the multinomial output $k$. This leads to the following simple sampling method, which can be generalized further to sample from Dirichlet processes:
- ,
- , ,
- while , goto 2.
Full Go code is as follows:
func sampleDirichletMultinomial(alpah []float64, n int, rng *rand.Rand) []int { dist := make([]float64, len(alpha)) copy(dist, alpha) hist := make([]int, len(alpha)) for i := 0; i < n; i++ { k := sampleDiscrete(dist, rng) dist[k] += 1.0 hist[k]++ } return hist } func sampleDiscrete(dist []float64, rng *rand.Rand) int { if len(dist) <= 0 { panic("sample from empty distribution") } sum := 0.0 for _, v := range dist { if v < 0 { panic(fmt.Sprintf("bad dist: %v", dist)) } sum += v } u := rng.Float64() * sum sum = 0 for i, v := range dist { sum += v if u < sum { return i } } panic("sampleDiscrete gets out of all possiblilities") }