The Maths Behind Generative AI

By Nicolas Muntwyler • Posted March 6th, 2023

Generative AI applications like ChatGPT, DALL-E 2 and Stable Diffusion have taken the world by storm, and hearing the buzzword ”Generative AI” is now unavoidable. But have you ever wondered which mathematical foundations lie beneath the technology? What does it really mean for AI to be generative, and what other types of AI are there?

Email Classification Example Flow

The term generative model originated in the field of statistical classification, which is a subfield of mathematics about deciding which category a certain object belongs to. Take a data point (for example, a piece of text or an image of an animal), let's call it XX. Statistical classification concerns itself with assigning this data point to one specific category YY in a given list of categories; for example, ”formal” and ”casual” are two possible categories of text, or ”dogs” and ”cats” are categories of animals.

There are three main approaches on how to solve such a task:

  • Discriminant Models: Directly predict YY
  • Probabilistic Discriminant Models: Infer P(YX)P(Y\vert X)
  • Probabilistic Generative Models: Infer P(X,Y)P(X, Y)

We will first look at the most obvious type of model, discriminant models, and why they are not sufficient for most needs.

Discriminant Models

In this case we completely ignore that there exists an underlying distribution. We directly want to predict YY. In order to do a purely discriminant approach a common technique is statistical learning theory. We choose a loss function L(y,x)L(y,x) which allows us the formulate the expected risk that we want to minimize:

EX,Y[L(Y,C(X))]E_{X,Y}[L(Y,C(X))]

We can't calculate this though, since we don't know the underlying probability distribution p(x,y)p(x,y) and we also don't want to make any assumptions on it. Therefore we approximate it empirically:

1ni=1nL(yi,c(xi))\frac{1}{n} \sum_{i=1}^n L(y_i,c(x_i))

Here ii stands for the ii-th data sample. Now we just need an algorithm that minimizes this empirical risk. Here are some famous examples:

  • Support Vector Machine (SVM)
  • Perceptron
  • Fishers linear discriminant

Probabilistic Discriminant Model

In the approach above (purely discriminant model) we did not want to make any assumption on the underlying description p(xy)p(x\vert y) when optimizing for the expected risk. But here we do. To be more precise we want to make assumptions on how we want to model p(yx)p(y\vert x). As an example we could model it like this: p(yx)=σ(wtx+w0)p(y\vert x) = \sigma(w^tx + w_0). (Note that in this example we induce a bias on how XX and YY relate to each other. This is a common tradeoff. So for this example we forced a linear relation between xx and yy.) Now what is left is to find out the optimal parameters ww. We can do this through Maximum Likelihood estimation of our training set:

argmaxwlogi=1np(xi,yiw)=argmaxwlogi=1np(yixi,w)p(xiw)p(xi)argmax_w \log \prod_{i=1}^n p(x_i,y_i \vert w) = argmax_w \log \prod_{i=1}^n p(y_i \vert x_i,w)\underbrace{p(x_i\vert w)}_{p(x_i)}

What this means is that we want to choose our parameters ww in a way that our samples from the training data are very likely to exist in the models world. Also note that the optimization function above is analytically intractable. Therefore the common way is to use gradient descent to optimize it since it is differentiable.



Generative Model

Now to one we are all so hyped about. The generative model approach. Here we try to model the whole underlying data distribution. Namely: p(X,Y)p(X,Y). The nice thing about this is that if we get it right we have a full understanding of the whole distribution. This means we can do outlier detection, have a degree of belief and most importantly can generate new samples. Meaning that we can create more images about cats and dogs. The usual approach is again to guess a family of parametric probabilistic models and then infer its parameters. Note that the following holds for probability distributions:

p(x,y)=p(xy)p(y)p(x,y) = p(x\vert y)p(y)

Therefore we make an assumption about p(xy)p(x\vert y) (for example the same assumption that we took for the probabilistic discriminant model) and additionally about p(y)p(y). For example we could model p(y)Bernoulli(β)p(y) \sim Bernoulli(\beta). Now again we try to find optimal parameters for ww and β\beta. Here are some famous examples:

  • Mixture of Gaussian's
  • Latent Dirichlet Allocation (LDA)

Introduction of Deep Learning

Up until now we always explicitly made assumptions about the underlying distribution p(x,y)p(x,y). For example that it is a Mixture of Gaussian's. However such assumptions are almost always wrong in practise. With the introduction of deep learning however we can now model a distribution as a deep neural network. Therefore we can have: p(x,y)p(x,y) = DeepNeuralNetworkθ_{\theta} which is parametrized by its parameters θ\theta that we will optimize. You can already imagine that this is a big game changer. We don't have to make any model assumptions anymore and can directly model the underlying distribution. The only question remains is how well the deep neural network can approximate it. Luckily in the past years a lot of progress has been made and for most domains like text, images and even audio, it is now possible. There has been a lot of effort in finding deep neural network architectures that can do this task well and surprisingly different architectures are better suited for different domains. Here are some famous examples:

  • Transformers (GPT-series like GPT3 and ChatGPT)
  • Generative Adversarial Networks (StyleGAN)
  • Diffusion Models (DALL-E 2)

Summary

Generative model became their name because they could generate new samples. They can do that because they model the complete underlying distribution. This is a very hard problem and was made possible through deep neural networks.

Get started with Waveline today.

Get started todayRead the docs
Extract
HomeDocumentation