21  Mixture Distributions and Mixture Modeling

Abstract
This chapter introduces mixture distributions and mixture modeling. This tools use simple probability distributions to model and understand complex empirical distributions.

21.1 Review 1

We’ll start with a review of multivariate normal distributions. In particular, this exercise demonstrates the impact of the variance-covariance matrix on the shape of multivariate normal distributions.

Exercise 1
  1. Load library(mvtnorm).
  2. Copy and paste the following code. This code will not obviously not run as is. We will add tibbles to independent, sigma1 and sigma2 in the steps below.
bind_rows(
  independent = ,
  var_covar1 = ,
  var_covar2 = ,
  .id = "source"
)
  1. Create a tibble with V1 and V2. For both variables, use rnorm() to sample 1,000 observations from a standard normal distribution. Add the results to independent.
  2. Using the following variance-covariance matrix, sample 1,000 observations from a multivariate-normal distribution. Add the results for sigma1 and use as_tibble().
sigma1 <- matrix(
  c(1, 0,
    0, 1), 
  nrow = 2, ncol = 2, byrow = TRUE
)
  1. Using the following variance-covariance matrix, sample 1,000 observations from a multivariate-normal distribution. Add the results for sigma2 and use as_tibble().
sigma2 <- matrix(
  c(1, 0.8,
    0.8, 1), 
  nrow = 2, ncol = 2, byrow = TRUE
)
  1. Create a scatter plot with V1 on the x-axis and V2 on the y-axis. Facet based on source.
library(mvtnorm)

sigma1 <- matrix(
  c(1, 0,
    0, 1), 
  nrow = 2, ncol = 2, byrow = TRUE
)

sigma2 <- matrix(
  c(1, 0.8,
    0.8, 1), 
  nrow = 2, ncol = 2, byrow = TRUE
)

bind_rows(
  independent = tibble(
    V1 = rnorm(n = 1000),
    V2 = rnorm(n = 1000)
  ),
  sigma1 = rmvnorm(
    n = 1000, 
    sigma = sigma1
  ) |>
    as_tibble(),
  sigma2 = rmvnorm(
    n = 1000, 
    sigma = sigma2
  ) |>
    as_tibble(),
  .id = "source"
) |>
  ggplot(aes(V1, V2)) +
  geom_point() +
  facet_wrap(~ source)

21.2 A New Type of Random Variable

We learned about common univariate and multivariate distributions. For each of the distributions, there are well-defined and straightforward ways to sample values from the distribution. We can also manipulate these distributions to calculate probabilities.

The real world is complicated, and we will quickly come across data where we struggle to find a common probability distributions.

Figure Figure 21.1 shows a relative frequency histogram for the duration of eruptions at Old Faithful in Yellowstone National Park.

# faithful is a data set built into R
faithful |>
  ggplot(aes(eruptions, y = after_stat(density))) +
  geom_histogram()
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
Figure 21.1: Distribution of waiting times between eruptions at the Old Faithful geyser in Yellowstone National Park.

This distribution looks very complicated. But what if we break this distribution into pieces? In this case, what if we think of the distribution as a combination of two normal distributions?

Code
# show geyser as two normal distribution
library(mclust)

gmm_geyser <- Mclust(
  data = dplyr::select(faithful, eruptions), 
  G = 2
)


bind_cols(
  faithful,
  cluster = gmm_geyser$classification
) |>
  ggplot(aes(eruptions, y = after_stat(density), 
             fill = factor(cluster))) +
  geom_histogram() +
  guides(fill = "none")

Latent Variable

A latent variable is a variable that isn’t directly observed but can be inferred through other variables and modeling. Sometimes the latent variable is meaningful but unobserved. Sometimes it isn’t meaningful.

Latent variables are sometimes called hidden variables.

Breaking complex problems into smaller pieces is good. These latent variables will allow us to do some cools things:

  1. Simply express complicated probability distributions
  2. Make inferences about complex populations
  3. Cluster data

In this set of notes, we’ll use latent variables to

  1. Construct mixture distributions
  2. Cluster data

Let’s consider a “data generation story” different than anything we considered in Chapter 5. Instead of sampling directly from one known probability distribution, we will sample in two stages (Hastie, Tibshirani, and Friedman 2009).

  1. Sample from a discrete probability distribution with \(k\) unique values (i.e. Bernoulli distribution when \(k = 2\) and categorical distribution when \(k > 2\)).
  2. Sample from one of \(k\) different distributions conditional on the outcome of step 1.

This new sampling procedure aligns closely with the idea of hierarchical sampling and hierarchical models. It is also sometimes called ancestral sampling (Bishop 2006, 430).

This two-step approach dramatically increases the types of distributions at our disposal because we are no longer limited to individual common univariate distributions like a single normal distribution or a single uniform distribution. The two-step approach is also the foundation of two related tools:

  1. Mixture distributions: Distributions expressed as the linear combination of other distributions. Mixture distributions can be very complicated distributions expressed in terms of simple distributions with known properties.
  2. Mixture modeling: Statistical inference about sub-populations made only with pooled data without labels for the sub populations.

With mixture distributions, we care about the overall distribution and don’t care about the latent variables.

With mixture modeling, we use the overall distribution to learn about the latent variables/sub populations/clusters in the data.

21.3 Mixture Distributions

Mixture Distribution

A mixture distribution is a probabilistic model that is a linear combination of common probability distributions.

A discrete mixture distribution can be expressed as

\[ p_{mixture}(x) = \sum_{k = 1}^K \pi_kp(x) \]

where \(K\) is the number of mixtures and \(\pi_k\) is the weight of each PMF included in the mixture distribution.

A continuous mixture distribution can be expressed as

\[ p_{mixture}(x) = \sum_{k = 1}^K \pi_kf(x) \]

where \(K\) is the number of mixtures and \(\pi_k\) is the weight of each PDF included in the mixture distribution.

21.3.1 Example 1

Let’s consider a concrete example with a Bernoulli distribution and two normal distributions.

  1. Sample \(X \sim Bern(p = 0.25)\)
  2. Sample from \(Y \sim N(\mu = 0, \sigma = 2)\) if \(X = 0\) and \(Y \sim (\mu = 4, \sigma = 2)\) if \(X = 1\).

Now, let’s sample from a Bernoulli distribution and then sample from one of two normal distributions using R code.

generate_data <- function(n) {
  
  step1 <- sample(x = c(0, 1), size = n, replace = TRUE, prob = c(0.75, 0.25))
  
  step1 <- sort(step1)
  
  step2 <- c(
    rnorm(n = sum(step1 == 0), mean = 0, sd = 2),
    rnorm(n = sum(step1 == 1), mean = 5, sd = 1)
  )
  
  tibble::tibble(
    x = step1,
    y = step2
  )

}

set.seed(1)

generate_data(n = 1000) |>
  ggplot(aes(x = y, y = after_stat(density))) +
  geom_histogram()
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

This marginal distribution looks complex but the process of creating the marginal distribution is simple.

In fact, consider this quote from Bishop (2006) (Page 111):

By using a sufficient number of Gaussians, and by adjusting their means and covariances as well as the coefficients in the linear combination, almost any continuous density can be approximated to arbitrary accuracy.

Component

A component is each common probability distribution that is combined to create a mixture distribution. For example, a mixture of two Gaussian distributions has two components.

Mixing Coefficient

A mixing coefficient is the probability associated with a component with a component in a mixture distribution. Mixing coefficients must sum to 1.

We’ll use \(\pi_k\) for population mixing coefficients and \(p_k\) for sample mixing coefficients. Mixing coefficients are also called mixing weights and mixing probabilities.

Mixture distributions are often overparameterized, which means they have an excessive number of parameters. For a univariate mixture of normals with \(k\) components, we have \(k\) means, \(k\) standard deviations, and \(k\) mixing coefficients.

Exercise 2
  1. Sample 1,000 observations from a mixture of three normal distributions with the following parameters:
  • \(p_1 = p_2 = p_3\)
  • \(\mu_1 = -3\), \(\mu_2 = 0\), \(\mu_3 = 3\)
  • \(\sigma_1 = \sigma_2 = \sigma_3 = 1\)
  1. Create a relative frequency histogram of the values.

21.3.2 Example 2

Suppose we used statistical inference to infer some parameters for the geysers example above. We will describe how to estimate these paramaters later.

  • \(p_1 =\) 0.3485696 and \(p_2 =\) 0.6514304
  • \(\bar{x_1} =\) 2.0189927 and \(\bar{x_2} =\) 4.2737083
  • \(s_1 =\) 0.2362355and \(s_2 =\) 0.4365146

The mixture density is

\[ f_{mixture}(x) = p_1f(x|\mu = \bar{x_1}, \sigma = s_1) + p_2f(x|\mu = \bar{x_2},\sigma=s_2) \tag{21.1}\]

geyser_density <- function(x, model) {
  
  probs <- model$parameters$pro
  
  d1 <- dnorm(
    x, 
    mean =  model$parameters$mean[1], 
    sd = sqrt(model$parameters$variance$sigmasq[1])
  )
  
  d2 <- dnorm(
    x, 
    mean =  model$parameters$mean[2], 
    sd = sqrt(model$parameters$variance$sigmasq[2])
  )
  
  probs[1] * d1 + probs[2] * d2
  
}

mm <- tibble(
  x = seq(0, 5, 0.01),
  f_x = map_dbl(x, geyser_density, model = gmm_geyser)
) 

ggplot() +
  geom_histogram(data = faithful, mapping = aes(x = eruptions, y = after_stat(density))) +
  geom_line(data = mm, mapping = aes(x, f_x), color = "red") + 
  labs(
    title = "",
    subtitles = "Observed data in black, inferred distribution in red"
  )
`stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

21.4 Review #2

21.4.1 Multivariate Normal Distribution

The multivariate normal distribution is a higher-dimensional version of the univariate normal distribution. The MVN distribution has a vector of means of length \(k\) and a \(k\)-by-\(k\) variance-covariance matrix.

We show that a random vector is multivariate normally distributed with

\[ \vec{X} \sim \mathcal{N}(\vec\mu, \boldsymbol\Sigma) \tag{21.2}\]

The PDF of a multivariate normally distributed random variable is

\[ f(x) = (2\pi)^{-k/2}det(\boldsymbol\Sigma)^{-1/2}\exp\left(-\frac{1}{2}(\vec{x} - \vec\mu)^T\boldsymbol\Sigma^{-1}(\vec{x} - \vec\mu)\right) \tag{21.3}\]

21.4.2 K-Means Clustering

K-Means Clustering is a heuristic-based approach to finding latent groups in data. The algorithm assigns each observation to one and only one group through a two step iteration that minimizes the Euclidean distance between observations and centroids for each group.

Consider the following data set.

Code
data <- tibble(x = c(1, 2, 1, 4, 7, 10, 8),
               y = c(5, 4, 4, 3, 7, 8, 5))

ggplot() +
  geom_point(data = data, aes(x, y), size = 2) +
  scale_x_continuous(limits = c(0, 10)) +
  scale_y_continuous(limits = c(0, 10)) +
  coord_equal() +
  theme_minimal()

Step 1: Randomly place K centroids in your n-dimensional vector space

Code
centroids <- tibble(x = c(2, 5),
                  y = c(5, 5),
                  cluster = c("a", "b"))

ggplot() +
  geom_point(data = data, aes(x, y), size = 2) +
  geom_point(data = centroids, aes(x, y, color = cluster), size = 4) +
  scale_x_continuous(limits = c(0, 10)) +
  scale_y_continuous(limits = c(0, 10)) +
  coord_equal() +
  theme_minimal()

Step 2: Calculate the nearest centroid for each point using a distance measure

Code
centroids <- tibble(x = c(2, 5),
                  y = c(5, 5),
                  cluster = c("a", "b"))

ggplot() +
  geom_point(data = data, aes(x, y), size = 2) +
  geom_point(data = centroids, aes(x, y, color = cluster), size = 4) +
  geom_line(aes(x = c(4, 2), y = c(3, 5)), linetype = "dashed") +  
  geom_line(aes(x = c(4, 5), y = c(3, 5)), linetype = "dashed") +
  scale_x_continuous(limits = c(0, 10)) +
  scale_y_continuous(limits = c(0, 10)) +
  coord_equal() +
  theme_minimal()

Step 3: Assign each point to the nearest centroid

Code
data$cluster <- c("a", "a", "a", "b", "b", "b", "b")

ggplot() +
  geom_point(data = data, aes(x, y, color = cluster), size = 2) +
  geom_point(data = centroids, aes(x, y, color = cluster), size = 4) +
  scale_x_continuous(limits = c(0, 10)) +
  scale_y_continuous(limits = c(0, 10)) +
  coord_equal() +
  theme_minimal()

Step 4: Recalculate the position of the centroids based on the means of the assigned points

Code
centroids2 <- data %>%
  group_by(cluster) %>%
  summarize(x = mean(x), y = mean(y))

ggplot() +
  geom_point(data = data, aes(x, y, color = cluster), size = 2) +
  geom_point(data = centroids, aes(x, y), size = 4, alpha = 0.25) +
  geom_point(data = centroids2, aes(x, y, color = cluster), size = 4) +  
  scale_x_continuous(limits = c(0, 10)) +
  scale_y_continuous(limits = c(0, 10)) +
  coord_equal() +
  theme_minimal()

Step 5: Repeat steps 2-4 until no points change cluster assignments

Code
data$cluster <- c("a", "a", "a", "a", "b", "b", "b")

ggplot() +
  geom_point(data = data, aes(x, y, color = cluster), size = 2) +
  geom_point(data = centroids2, aes(x, y, color = cluster), size = 4) +  
  scale_x_continuous(limits = c(0, 10)) +
  scale_y_continuous(limits = c(0, 10)) +
  coord_equal() +
  theme_minimal()

Exercise 3
  1. Use library(tidyclust) to cluster the faithful data into three clusters.

21.5 Mixture Modeling/Model-Based Clustering

Until now, we’ve assumed that we’ve known all parameters when working with mixture distributions. What if we want to learn these parameters/make inferences about these parameters?

The process of making inferences about latent groups is related to K-Means Clustering. While K-Means Clustering is heuristic based, mixture modeling formalize the process of making inferences about latent groups using probability models. Gaussian mixture models (GMM) are a popular mixture model.

Mixture Modeling

Mixture modeling is the process of making inferences about sub populations using data that contain sub population but no labels for the sub populations.

21.6 Gaussian Mixture Modeling (GMM)

Gaussian Mixture Modeling (GMM)

Gaussian mixture modeling (GMM) is mixture modeling that uses normal and multivariate normal distributions.

Hard Assignment

Hard assignment assigns an observation in a clustering model to one and only one group.

Soft Assignment

Soft assignment assigns an observation in a clustering model to all groups with varying weights or probabilities.

Responsibilities

Soft assignments are quantified with responsibilities. Responsibilities are the probability that a given observation belongs to a given group. The soft assignments for an observation sum to 1.

We quantified responsibilities with \(\pi_k\) for mixture distributions. Responsibilities are parameters we will infer during mixture modeling.

There are two main differences between K-Means Clustering and GMM.

  1. Instead of calculating Euclidean distance from each observation to each group centroid, we use multivariate normal distributions to calculate the probability that an observation belongs to each group.
    • Observations close to the means of a mixture will have a high relative probability of belonging to that mixture.
    • Observations far from the means of a mixture will have a low relative probability of belonging to that mixture.
  2. Instead of simply updating \(k\) group centroids, we must update \(k\) multivariate normal distributions. This requires calculating a vector of means and a variance-covariance matrix for each of the \(k\) groups.

21.6.1 Example 3

The parameters in example 2 were estimated using GMM. Let’s repeat a similar exercise with the faithful using eruptions and waiting instead of just eruptions. We’ll assume there are three groups.

# fit GMM
gmm2_geyser <- Mclust(faithful, G = 3)

Let’s plot the multivariate normal distributions. Figure 21.2 shows the centroids (stars) and shapes (ellipses) of the distributions in black. The colors represent hard assignments to groups and the size of the points represent the uncertainty of the assignments with larger points having more uncertainty.

# plot fitted model
plot(gmm2_geyser, what = "uncertainty")
Figure 21.2: Uncertainty plot from a GMM

We can also summarize the model with library(broom).

library(broom)

augment(gmm2_geyser)
# A tibble: 272 × 4
   eruptions waiting .class .uncertainty
       <dbl>   <dbl> <fct>         <dbl>
 1      3.6       79 1          2.82e- 2
 2      1.8       54 2          8.60e-13
 3      3.33      74 1          3.26e- 3
 4      2.28      62 2          3.14e- 7
 5      4.53      85 3          1.17e- 2
 6      2.88      55 2          3.09e- 3
 7      4.7       88 3          2.99e- 3
 8      3.6       85 1          2.39e- 2
 9      1.95      51 2          5.23e-12
10      4.35      85 3          5.52e- 2
# ℹ 262 more rows
tidy(gmm2_geyser)
# A tibble: 3 × 5
  component  size proportion mean.eruptions mean.waiting
      <int> <int>      <dbl>          <dbl>        <dbl>
1         1    40      0.166           3.79         77.5
2         2    97      0.356           2.04         54.5
3         3   135      0.478           4.46         80.8
glance(gmm2_geyser)
# A tibble: 1 × 7
  model     G    BIC logLik    df hypvol  nobs
  <chr> <int>  <dbl>  <dbl> <dbl>  <dbl> <int>
1 EEE       3 -2314. -1126.    11     NA   272

21.6.2 mclust

The previous example uses library(mclust)1 and library(broom).

Mclust() is the main function for fitting Gaussian Mixture Models. The function contains several different types of models for the variances of the multivariate normal distributions. The defaults are sensible. G is the number of groups. If G isn’t specified, then Mclust() will try 1:9 and pick the G with the lowest BIC (defined below).

plot() with what = "uncertainty" creates a very useful data visualization for seeing the multivariate normal distributions and classifications for low-dimensional GMM.

glance(), tidy(), and augment() from library(broom) return important information about the assignments, groups, and model diagnostics.

21.6.3 Estimation

Suppose we have \(n\) observations, \(k\) groups, and \(p\) variables. A single GMM will have

  • an \(n\) by \(k\) matrix of responsibilities
  • \(k\) vectors of means of length \(p\)
  • \(k\) \(p\) by \(p\) variance-covariance matrices

We want the maximum likelihood estimates for all of the parameters in the model. For technical reasons, it is very difficult to get these estimates using popular methods like stochastic gradient descent.

Instead, we will use expectations maximization (EM) to find the parameters. We also used EM for K-Means clustering.

  1. Randomly initialize all of the parameters. Calculate the log-likelihood.
  2. E-Step: Update the responsibilities assuming the means and variance-covariance matrices are known.
  3. M-Step: Estimate new means and variance-covariance matrices assuming the responsibilities are known. The means and variance-covariance matrices are calculated using weighted MLE where the responsibilities are the weights.
  4. Calculate the log-likelihood. Go back to step 2 if the log-likelihood improves by at least as much as the stopping threshold.

This algorithm is computationally efficient, but it is possible for it to find a local maximum log-likelihood without finding the global maximum log-likelihood.

For a more mathematical description of this process, see Elements of Statistical Learning Section 6.8 (Hastie, Tibshirani, and Friedman 2009). A highly descriptive comparison to kmeans (with Python code) can be seen here.

21.6.4 Example 4

Let’s consider a policy-relevant example using data from the Small Area Health Insurance Estimates (SAHIE) Program.

First, we pull the 2016 county-level estimates of the uninsured rate. We label a state as an expansion state if it expanded data before 2015-01-01. We use this date with 2016 data because of policy lags.

library(censusapi)

sahie <- getCensus(
  name = "timeseries/healthins/sahie",
  key = Sys.getenv("CENSUS_KEY"),
  vars = c("GEOID", "PCTUI_PT"),
  region = "county:*",
  time = 2016
) |>
  as_tibble()

Next, we pull data from the Kaiser Family Foundation about the expansion dates of Medicaid under the Patient Protection and Affordable Care Act.

states <- tribble(
  ~state, ~state_fips, ~implementation_date,
  "Alabama", "01", NA,
  "Alaska", "02", "2015-09-15",
  "Arizona", "04", "2014-01-01",
  "Arkansas", "05", "2014-01-01",
  "California", "06", "2014-01-01",
  "Colorado", "08", "2014-01-01",
  "Connecticut", "09", "2014-01-01",
  "Delaware", "10", "2014-01-01",
  "District of Columbia", "11", "2014-01-01",
  "Florida", "12", NA,
  "Georgia", "13", NA,
  "Hawaii", "15", "2014-01-01",
  "Idaho", "16", "2020-01-01",
  "Illinois", "17", "2014-01-01",
  "Indiana", "18", "2015-02-01",
  "Iowa", "19", "2014-01-01",
  "Kansas", "20", NA,
  "Kentucky", "21", "2014-01-01", 
  "Louisiana", "22", "2016-07-01",
  "Maine", "23", "2018-07-02",
  "Maryland", "24", "2014-01-01",
  "Massachusetts", "25", "2014-01-01",
  "Michigan", "26", "2014-04-01",
  "Minnesota", "27", "2014-01-01",
  "Mississippi", "28", NA,
  "Missouri", "29", "2021-07-01",
  "Montana", "30", "2016-01-01",
  "Nebraska", "31", "2020-10-01",
  "Nevada", "32", "2014-01-01", 
  "New Hampshire", "33", "2014-08-15",
  "New Jersey", "34", "2014-01-01",
  "New Mexico", "35", "2014-01-01",
  "New York", "36", "2014-01-01", 
  "North Carolina", "37", NA,
  "North Dakota", "38", "2014-01-01", 
  "Ohio", "39", "2014-01-01",
  "Oklahoma", "40", "2021-07-01", 
  "Oregon", "41", "2014-01-01", 
  "Pennsylvania", "42", "2015-01-01", 
  "Rhode Island", "44", "2014-01-01", 
  "South Carolina", "45", NA,
  "South Dakota", "46", "2023-07-01", 
  "Tennessee", "47", NA,
  "Texas", "48", NA,
  "Utah", "49", "2020-01-01",
  "Vermont", "50", "2014-01-01",
  "Virginia", "51", "2019-01-01", 
  "Washington", "53", "2014-01-01",
  "West Virginia", "54", "2014-01-01",
  "Wisconsin", "55", NA,
  "Wyoming", "56", NA
) %>%
  mutate(implementation_date = ymd(implementation_date))

sahie <- left_join(
  sahie, 
  states,
  by = c("state" = "state_fips")
) |>
  filter(!is.na(PCTUI_PT)) |> 
  mutate(expanded = implementation_date < "2015-01-01") %>%
  mutate(expanded = replace_na(expanded, FALSE))

We use GMM to cluster the data.

uni <- select(sahie, PCTUI_PT)

set.seed(1)
uni_mc <- Mclust(uni, G = 2)

glance(uni_mc)
# A tibble: 1 × 7
  model     G     BIC logLik    df hypvol  nobs
  <chr> <int>   <dbl>  <dbl> <dbl>  <dbl> <int>
1 V         2 -18542. -9251.     5     NA  3141

When we compare .class to expansion, we see that the model does an good job of labeling counties’ expansion status without observing counties’ expansion status.

bind_cols(
  sahie,
  augment(uni_mc)
) |>
  count(expanded, .class)
New names:
• `PCTUI_PT` -> `PCTUI_PT...5`
• `PCTUI_PT` -> `PCTUI_PT...9`
# A tibble: 4 × 3
  expanded .class     n
  <lgl>    <fct>  <int>
1 FALSE    1        465
2 FALSE    2       1486
3 TRUE     1       1042
4 TRUE     2        148

21.6.5 BIC

Likelihood quantifies how likely observed data are given a set of parameters. If \(\theta\) is a vector of parameters, then \(L(\theta |x) = f(x |\theta)\) is the likelihood function.

We often don’t know the exact number of latent groups in the data. We need a way to compare models with varying numbers of groups. Simply picking the model with the maximum likelihood will lead to models with too many groups.

The Bayesian information criterion (BIC) is an alternative to likelihoods that penalizes models for having many parameters. Let \(L\) be the likelihood, \(m\) the number of free parameters, and \(n\) the number of observations.

\[ BIC = -2log(L) + mlog(n) \tag{21.4}\]

We will choose models that minimize BIC. Ideally, we will use v-fold cross validation for this process.

21.6.6 Example 5

The Mclust() function will try G = 1:9 when G isn’t specified. Mclust() will also try 14 different variance models for the mixture models.

Important

We want to minimize BIC but library(mclust) is missing a negative sign. So we want to maximize the BIC plotted by library(mclust). You can read more here.

We can plot the BICs with plot() and view the optimal model with glance().

faithful_gmm <- Mclust(faithful)

plot(faithful_gmm, what = "BIC")

glance(faithful_gmm)
# A tibble: 1 × 7
  model     G    BIC logLik    df hypvol  nobs
  <chr> <int>  <dbl>  <dbl> <dbl>  <dbl> <int>
1 EEE       3 -2314. -1126.    11     NA   272

21.7 Bernoulli Mixture Modeling (BMM)

Let’s consider a data generation story based on the Bernoulli distribution. Now, each variable, \(X_1, X_2, ..., X_D\), is draw from a mixture of \(K\) Bernoulli distributions.

\[ X_d = \begin{cases} Bern(p_1) \text{ with probability }\pi_1 \\ Bern(p_2) \text{ with probability }\pi_2 \\ \vdots \\ Bern(p_K) \text{ with probability }\pi_K \end{cases} \tag{21.5}\]

Let \(i\) be an index for each mixture that contributes to the random variable. The probability mass function of the random variable is written as

\[ P(X_d) = \Pi_{i = 1}^Kp_i^{x_i} (1 - p_i)^{1 - x_i} \tag{21.6}\]

Let’s consider a classic example from Bishop (2006) and Murphy (2022). The example uses the MNIST database, which contains 70,000 handwritten digits. The digits are stored in 784 variables, from a 28 by 28 grid, with values ranging from 0 to 255, which indicate the darkness of the pixel.

To prepare the data, we divide each pixel by 255 and then turn the pixels into indicators with values under 0.5 as 0 and values over 0.5 as 1. Figure Figure 21.3 visualizes the first four digits after reading in the data and applying pre-processing.

source(here::here("R", "visualize_digit.R"))

mnist <- read_csv(here::here("data", "mnist_binary.csv"))

glimpse(dplyr::select(mnist, 1:10))
Rows: 60,000
Columns: 10
$ label    <dbl> 5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4…
$ pix_28_1 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_2 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_3 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_4 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_5 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_6 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_7 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_8 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pix_28_9 <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
visualize_digit(mnist, 1)
visualize_digit(mnist, 2)
visualize_digit(mnist, 3)
visualize_digit(mnist, 4)
(a) 5
(b) 0
(c) 4
(d) 1
Figure 21.3: First Four Digits

The digits are labelled in the MNIST data set but we will ignore the labels and use Bernoulli Mixture Modeling to learn the latent labels or groups. We will treat each pixel as its own Bernoulli distribution and cluster observations using mixtures of 784 Bernoulli distributions. This means each cluster will contain \(784\) parameters.

21.7.1 Two Digit Example

Let’s start with a simple example using just the digits “1” and “8”. We’ll use library(flexmix) by Leisch (2004). library(flexmix) is powerful but uses different syntax than we are used to.

  1. The function flexmix() expects a matrix.
  2. The formula expects the entire matrix on the left side of the ~.
  3. We specify the distribution used during the maximization (M) step with model = FLXMCmvbinary().
library(flexmix)
Loading required package: lattice
mnist_18 <- mnist |>
  filter(label %in% c("1", "8")) |>
  dplyr::select(-label) |>
  as.matrix()

The starting assignments are random, so we set a seed.

set.seed(20230612)
mnist_18_clust <- flexmix(
  formula = mnist_18 ~ 1, 
  k = 2, 
  model = FLXMCmvbinary(), 
  control = list(iter.max = 100)
)

The MNIST data are already labelled, so we can compare our assignments to the labels if we convert the “soft assignments” to “hard assignments”. Note that most applications won’t have labels.

mnist |>
  filter(label %in% c("1", "8")) |>  
  bind_cols(cluster = mnist_18_clust@cluster) |>
  count(label, cluster)
# A tibble: 4 × 3
  label cluster     n
  <dbl>   <int> <int>
1     1       1   482
2     1       2  6260
3     8       1  5610
4     8       2   241

Figure 21.4 shows the estimated \(p_i\) for each pixel for each cluster. The figure shows 784 \(p_i\) for \(k = 1\) and 784 \(p_i\) for \(k = 2\). We see that the estimated parameters closely resemble the digits.

Of course, each digit can differ from these images because everyone writes differently. In some ways, these are average digits across many version of the digits.

means_18 <- rbind(
  t(parameters(mnist_18_clust, component = 1)),
  t(parameters(mnist_18_clust, component = 2))
) |>
  as_tibble() |>
  mutate(label = NA)
visualize_digit(means_18, 1)
visualize_digit(means_18, 2)
(a) 8
(b) 1
Figure 21.4: Estimated Parameters for Each Cluster

The BMM does a good job of labeling the digits and recovering the average shape of the digits.

21.7.2 Ten Digit Example

Let’s now consider an example that uses all 10 digits.

In most applications, we won’t know the number of latent variables. First, we sample 1,0002 digits and run the model with \(k = 2, 3, ..., 12\). We’ll calculate the BIC for each hyperparameter and pick the \(k\) with lowest BIC.

set.seed(20230613)
mnist_sample <- mnist |>
  slice_sample(n = 1000) |>
  dplyr::select(-label) |>
  as.matrix()

steps <- stepFlexmix(
  formula = mnist_sample ~ 1, 
  model = FLXMCmvbinary(), 
  control = list(iter.max = 100, minprior = 0),
  k = 2:12, 
  nrep = 1
)

\(k = 7\) provides the lowest BIC. This is probably because digits like 3 and 8 are very similar.

steps

Call:
stepFlexmix(formula = mnist_sample ~ 1, model = FLXMCmvbinary(), 
    control = list(iter.max = 100, minprior = 0), k = 2:12, nrep = 1)

   iter converged  k k0    logLik      AIC      BIC      ICL
2    43      TRUE  2  2 -196191.7 395521.4 403221.6 403227.9
3    30      TRUE  3  3 -188722.8 382153.6 393706.4 393713.9
4    32      TRUE  4  4 -182949.0 372176.0 387581.4 387585.1
5    27      TRUE  5  5 -178955.2 365758.4 385016.4 385019.7
6    35      TRUE  6  6 -175448.7 360315.5 383426.1 383428.5
7    37      TRUE  7  7 -171697.0 354381.9 381345.1 381347.8
8    37      TRUE  8  8 -171282.5 355123.1 385938.8 385941.1
9    38      TRUE  9  9 -169213.3 352554.6 387223.0 387224.9
10   25      TRUE 10 10 -165521.6 346741.2 385262.2 385263.7
11   34      TRUE 11 11 -162919.3 343106.5 385480.1 385481.8
12   26      TRUE 12 12 -162253.5 343345.0 389571.1 389572.7

Next, we run the BMM on the full data with \(k = 7\).

mnist_full <- mnist |>
  dplyr::select(-label) |>
  as.matrix()

mnist_clust <- flexmix(
  formula = mnist_full ~ 1, 
  k = 7, 
  model = FLXMCmvbinary(), 
  control = list(iter.max = 200, minprior = 0)
)

The MNIST data are already labelled, so we can compare our assignments to the labels if we convert the “soft assignments” to “hard assignments”. Note that most applications won’t have labels. The rows of the table are the digits. The columns of the table are the clusters. We can see, for example, that most of the 0’s are clustered in cluster 5.

labels <- mnist |>
  bind_cols(cluster = mnist_clust@cluster)

table(labels$label, labels$cluster)
   
       1    2    3    4    5    6    7
  0    5  357  282  289 4875    1  114
  1   36  288   40   35    0 6319   24
  2  114  166  652   73   50  163 4740
  3  263  473 4786   80   41  260  228
  4 3384 1779    4  139    7   53  476
  5  325 2315 2367  173  109   59   73
  6    9   86   41 4365   42  128 1247
  7 3560 2395   21    0   25  234   30
  8  257 2582 2369   57   32  445  109
  9 3739 1797  109    5   26  136  137

Figure 21.5 shows the estimated \(p_i\) for each pixel for each cluster. The following visualize the \(784K\) parameters that we estimated. It shows 784 \(p_i\) for \(k = 1, 2, ..., 7\) clusters. We see that the estimated parameters closely resemble the digits.

means <- rbind(
  t(parameters(mnist_clust, component = 1)),
  t(parameters(mnist_clust, component = 2)),
  t(parameters(mnist_clust, component = 3)),
  t(parameters(mnist_clust, component = 4)),
  t(parameters(mnist_clust, component = 5)),
  t(parameters(mnist_clust, component = 6)),
  t(parameters(mnist_clust, component = 7))
) |>
  as_tibble() |>
  mutate(label = NA)

visualize_digit(means, 1)
visualize_digit(means, 2)
visualize_digit(means, 3)
visualize_digit(means, 4)
visualize_digit(means, 5)
visualize_digit(means, 6)
visualize_digit(means, 7)
(a) 3, 7, and 9
(b) 5, 7, and 8
(c) 3
(d) 6
(e) 0
(f) 1
(g) 2
Figure 21.5: Estimated Parameters for Each Cluster

The example with all digits doesn’t result in 10 distinct mixtures but it does a fairly good job of structuring finding structure in the data. Without labels and considering the variety of messy handwriting, this is a useful model.

21.8 Considerations

Mixture modeling is difficult for a couple of reasons:

  1. We need to assume a model. It can be difficult to assume a multivariate distribution that fits the data in all dimensions of interest.
  2. The models are overparameterized and can take a very long time to fit.

  1. library(tidyclust) currently doesn’t support mixture modeling. I hope this will change in the future.↩︎

  2. This is solely to save computation time.↩︎