{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# From Classic Archetypal Analysis to Multimodal Deep Archetypal Analysis " ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "### Archetypal Analysis" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Archetypal Analysis is a technique in statistics and machine learning designed to uncover the extreme points within a dataset, termed archetypes. These archetypes represent the most distinct or extreme manifestations within the data space, suggesting that every data point can be approximated as a mixture of these archetypal forms. \n", "Given a dataset represented by the matrix $X$, where each row corresponds to a d-dimensional data point, the objective of archetypal analysis is to identify a matrix $A$, which encapsulates the archetypes, and a matrix $B$, which contains coefficients that express each data point in $X$ as convex combinations of the archetypes in $A$.\n", "\n", "The core optimization challenge in archetypal analysis is to minimize the reconstruction error between the dataset $X$ and its approximation $ABX$, formally expressed as:\n", "\n", "```{math}\n", "\\min_{A, B} \\|X - ABX\\|^2_F\n", "```\n", "\n", "subject to constraints for both $A$ and $B$ that ensure the convexity of the combinations, with $B_{ij} \\geq 0$ for all elements to guarantee non-negativity, and $\\sum_{j} B_{ij} = 1$ for all $i$, ensuring that the coefficients for each data point sum to one. \n", "Similarly, constraints are applied to $A$ to ensure its columns can be interpreted as mixtures of data points, hence $A_{ij} \\geq 0$ and $\\sum_{j} A_{ij} = 1$ for all $i$. \n", "These constraints ensure that each data point in $X$ is represented as a convex combination of archetypes, making the solution interpretable and reflective of the underlying structure of the dataset.\n", "\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "The original algorithm proposed by [Cutler and Breiman](https://www.tandfonline.com/doi/abs/10.1080/00401706.1994.10485840) is based on the idea of solving alternating interative least square problems. Namely, the algorithm operates iteratively, alternating between two main steps: updating coefficients B for a fixed set of archetypes $A$, and then updating $A$ given $B$. Initially, $A$ is populated with randomly selected data points. In each iteration, $B$ is updated to represent each data point as a convex combination of the current archetypes, and then $A$ is updated to better fit the data points based on the new coefficients. The process iterates until the change in $A$ between iterations, measured using the Frobenius norm $\\|A_{new} - A_{old}\\|_F$, falls below a predefined convergence threshold, or a maximum number of iterations is reached." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Let us first define some convenience function:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Some code was taken/adapted from the amazing implementation in\n", "# https://github.com/aleixalcacer/archetypes\n", "\n", "import numpy as np\n", "from scipy.optimize import nnls\n", "\n", "\n", "def frobenius_norm_difference(M1, M2):\n", " \"\"\"Compute the Frobenius norm of the difference between two matrices.\"\"\"\n", " return np.linalg.norm(M1 - M2, 'fro')\n", "\n", "def initialize_archetypes(X, k):\n", " \"\"\"Randomly initialize archetypes.\"\"\"\n", " \n", " # For A and B we just sample from a dirichlet distribution\n", " B = np.random.dirichlet(np.ones(X.shape[0]), k)\n", " A = np.random.dirichlet(np.ones(k), X.shape[0])\n", " \n", " return A, B\n", "\n", "def optimize_nnls(M1,M2):\n", " \n", " # Add some constants to enforce the convexity of the final matrix\n", " M1 = np.pad(M1, ((0, 0), (0, 1)), \"constant\", constant_values=20)\n", " M2 = np.pad(M2, ((0, 0), (0, 1)), \"constant\", constant_values=20)\n", " res = np.empty((M1.shape[0], M2.shape[0]))\n", " \n", " # Solve the actual non-negative least square problem\n", " for j in range(res.T.shape[1]):\n", " res.T[:, j], _ = nnls(M2.T, M1.T[:, j])\n", " \n", " # Check convexity + remove nans\n", " res /= res.sum(1)[:, None]\n", " res[np.isnan(res)] = 1 / res.shape[1]\n", " return res\n", "\n", "def update_archetypes(X, H):\n", " \"\"\"Update archetypes Z for fixed coefficients.\"\"\"\n", " return optimize_nnls(H,X)\n", "\n", "def update_weights(X, H):\n", " \"\"\"Update coefficients A for fixed archetypes\"\"\"\n", " return optimize_nnls(X,H)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We then run the training loop and have a look at the results. For this example we choose a toy dataset of body measuremnts for 3 different species of penguins." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# code a function that performs archetypal analysis\n", "def archetypal_analysis(X, k, max_iter=30, tol=1e-3):\n", " \"\"\"Perform archetypal analysis.\"\"\"\n", " A, B = initialize_archetypes(X, k)\n", " H = B @ X\n", " loss_old = None\n", " for _ in range(max_iter):\n", " A = update_weights(X, H)\n", " H = np.linalg.pinv(A) @ X\n", " B = update_archetypes(X, H)\n", " H = B @ X\n", " loss = frobenius_norm_difference(X, A @ H)\n", " if loss_old is not None and loss_old - loss < tol:\n", " break\n", " return A, B" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | species | \n", "island | \n", "bill_length_mm | \n", "bill_depth_mm | \n", "flipper_length_mm | \n", "body_mass_g | \n", "sex | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "Adelie | \n", "Torgersen | \n", "39.1 | \n", "18.7 | \n", "181.0 | \n", "3750.0 | \n", "Male | \n", "
1 | \n", "Adelie | \n", "Torgersen | \n", "39.5 | \n", "17.4 | \n", "186.0 | \n", "3800.0 | \n", "Female | \n", "
2 | \n", "Adelie | \n", "Torgersen | \n", "40.3 | \n", "18.0 | \n", "195.0 | \n", "3250.0 | \n", "Female | \n", "
3 | \n", "Adelie | \n", "Torgersen | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
4 | \n", "Adelie | \n", "Torgersen | \n", "36.7 | \n", "19.3 | \n", "193.0 | \n", "3450.0 | \n", "Female | \n", "