diff --git a/.gitignore b/.gitignore index 5a050af5..df02ef97 100644 --- a/.gitignore +++ b/.gitignore @@ -117,3 +117,5 @@ notebooks/*.html .vscode .DS_Store + +.idea/ diff --git a/notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.ipynb b/notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.ipynb new file mode 100644 index 00000000..245bc99d --- /dev/null +++ b/notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.ipynb @@ -0,0 +1,938 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "321343a2", + "metadata": {}, + "source": [ + "# Input Driven Observations and Transitions (GLM-HMM)" + ] + }, + { + "cell_type": "markdown", + "id": "3244b244", + "metadata": {}, + "source": [ + "This notebook is written by Zeinab Mohammadi. Here, a class called \"HMM_TO\" is defined which contains GLM-HMM with Input-Driven Observations and Transitions. We used an HMM enriched with two sets of per-state GLM. These GLMs consist of a Bernoulli GLM, which models observations (mice choices), and a multinomial GLM, which handles transitions between different states. This sophisticated framework effectively captures the dynamic interplay between covariates, mouse choices, and state transitions. As a result, it offers a more refined description of behavioral activity when compared to classical models. " + ] + }, + { + "cell_type": "markdown", + "id": "b6d71ee3", + "metadata": {}, + "source": [ + "Therefore, within the context of our model, we define two sets of covariates: $\\mathbf{U}^{ob}={u}^{ob}_{1}, ..., {u}^{ob}_{T}$ represents the observation covariates, and $\\mathbf{U}^{tr}={u}^{tr}_{1}, ..., {u}^{tr}_{T}$ represents the transition covariates where $T$ is the number of considered trials. Here, the input at each trial has a size of ${M}_{obs}$ for the observation model and ${M}_{tran}$ for the transition model. Additionally, we have a set of latent states denoted as $\\mathbf{Z}={z}_{1}, ..., {z}_{T} $, and corresponding observations for these states denoted as $\\mathbf{Y}={y}_{1}, ..., {y}_{T}$. For more detailed information about the model, please refer to our paper: [Identifying the factors governing internal state switches during nonstationary sensory decision-making](https://www.biorxiv.org/content/10.1101/2024.02.02.578482v2) " + ] + }, + { + "cell_type": "markdown", + "id": "95cffc78", + "metadata": {}, + "source": [ + "If you have any questions, please do not hesitate to contact me at zm6112 at princeton dot edu. Your engagement and feedback are highly appreciated for the improvement of this work." + ] + }, + { + "cell_type": "markdown", + "id": "7db06eb3", + "metadata": {}, + "source": [ + "## Bernoulli GLM for observation" + ] + }, + { + "cell_type": "markdown", + "id": "3de2de69", + "metadata": {}, + "source": [ + "We employed a Bernoulli GLM to map the binary values of the animal's decision to a set of covariates. These weights serve to depict how the inputs of the model (e.g., stimulus features) influence the output (i.e., the animal's choice on each trial). The logit link function stands as the most widely adopted link function for a 2AFC choice GLM, and it can be expressed as $log(p/(1-p)) = F * \\beta$, where $F$ corresponds to a design matrix, and $\\beta$ represents a vector of coefficients. \n", + "\n", + "Consequently, we can describe an observational GLM using the following equation, where the animal choice at trial $t$, denoted by $y_{t}$, can take a value of 1 or 0, indicating the mouse turning the wheel to the right-side or left-side, respectively:" + ] + }, + { + "cell_type": "markdown", + "id": "f95d31fd", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{align}\n", + "\\Pr(y_t \\mid z_{t} = k, u_{t}^{ob}) = \n", + "\\frac{\\exp\\{(w_{kt}^{ob})^\\mathsf{T} u_{t}^{ob} \\}}\n", + "{1+\\exp\\{(w_{kt}^{ob})^\\mathsf{T} u_{t}^{ob} \\}}\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "59aa0211", + "metadata": {}, + "source": [ + "In this equation, the presented GLM is associated with the observation covariates, $u_{t}^{ob} \\in \\mathbb{R}^{{M}_{obs}}$, and observation weights $w_{kt}^{ob}$ at trial $t$ and state $z_{t} = k$." + ] + }, + { + "cell_type": "markdown", + "id": "cddddbbb", + "metadata": {}, + "source": [ + "## Multinomial GLM for transition" + ] + }, + { + "cell_type": "markdown", + "id": "845ea356", + "metadata": {}, + "source": [ + "The multinomial GLM is an extension of the Generalized Linear Model, specifically designed to handle data with multiple categories. It's also known as softmax regression or the maximum entropy classifier. Unlike logistic regression, which deals with binary outcomes, multinomial GLMs can simultaneously analyze data from multiple categories. They establish relationships between independent variables and categorical dependent variables, enabling the determination of the likelihood associated with each category.\n", + "\n", + "We explore the GLM-HMM with multinomial GLM outputs, a method for estimating the likelihood of the next state. In this framework, each state is equipped with a multinomial GLM, allowing it to model the complex relationship between transition covariates $u_{t}^{tr} \\in \\mathbb{R}^{{M}_{tran}}$, such as previous choice and reward, and the corresponding transition probabilities. This can be written as:" + ] + }, + { + "cell_type": "markdown", + "id": "3a5e6c2a", + "metadata": {}, + "source": [ + "$$\n", + "\\begin{align}\n", + "\\Pr(z_t=k \\mid u_{t}^{tr}) = \n", + "\\frac{\\exp\\{(w_{kt}^{tr})^\\mathsf{T} u_{t}^{tr} \\}}\n", + "{\\sum_{j=1}^{K} \\exp\\{(w_{jt}^{tr})^\\mathsf{T} u_{t}^{tr} \\}}\n", + "\\end{align}\n", + "$$" + ] + }, + { + "cell_type": "markdown", + "id": "c0cf811c", + "metadata": {}, + "source": [ + "where $w_{jt}^{tr}$ corresponds to the transition weights associated with j-th state at trial $t$ and $K$ represents the total number of states. " + ] + }, + { + "cell_type": "markdown", + "id": "59eaccf2", + "metadata": {}, + "source": [ + "## 1. Setup\n", + "The line `import ssm` imports the package for use. Here, we have also imported a few other packages for plotting. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "d285dbf4", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import numpy.random as npr\n", + "import matplotlib.pyplot as plt\n", + "import ssm\n", + "import random\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from ssm.util import find_permutation\n", + "\n", + "%matplotlib inline\n", + "npr.seed(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7556a945", + "metadata": {}, + "outputs": [], + "source": [ + "# Set the parameters of the GLM-HMM framework\n", + "time_bins = 5000 # number of data points\n", + "num_states = 3 # number of discrete states\n", + "obs_dim = 1 # number of observed dimensions\n", + "num_categories = 2 # number of categories for output\n", + "input_dim_T = 2 # Transition input dimensions\n", + "input_dim_O = 2 # Observation input dimensions" + ] + }, + { + "cell_type": "markdown", + "id": "7f7077f9", + "metadata": {}, + "source": [ + "# 2. Defining matrices of regressors for the model\n", + "\n", + "Here, we generate two design matrices, one for observation and one for transition regressors. Within each matrix, a column represents a covariate. These covariates may include elements such as past choices or past stimuli, which are deemed important in influencing the animal's decision-making process." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a91b420d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inpt_pc= [-1 -1 1 ... -1 -1 1]\n", + "inpt_wsls= [-1 -1 1 ... 1 1 1]\n" + ] + } + ], + "source": [ + "# Specifying the regressors (past choice, past stimuli, etc.)\n", + "inpt_pc = np.array(random.choices([-1, 1], k=time_bins)) \n", + "inpt_wsls = np.array(random.choices([-1, 1], k=time_bins)) \n", + "print('inpt_pc=', inpt_pc)\n", + "print('inpt_wsls=', inpt_wsls)" + ] + }, + { + "cell_type": "markdown", + "id": "a1b8a3f7", + "metadata": {}, + "source": [ + "# 2a. Designing input matrix for Transition \n", + "In this section, we define the inputs for the transition matrix. While we're illustrating a simple case here, these specific regressors play a significant role in the transitions between different states within the GLM-HMM. To exemplify the inclusion of time and history in the regressors, we have applied an exponential filter to the transition regressors." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2ab033bc", + "metadata": {}, + "outputs": [], + "source": [ + "design_mat_T = np.zeros((time_bins, input_dim_T)) # transition design matrix" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e465e5e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "design_mat_T= [[-1. -1. ]\n", + " [-1. -1. ]\n", + " [-0.05263158 -0.05263158]\n", + " ...\n", + " [-0.23778504 -0.00444354]\n", + " [-0.4918567 0.33037098]\n", + " [ 0.00542887 0.55358065]]\n" + ] + } + ], + "source": [ + "# Creating an exponential filter for the transition regressors\n", + "def ewma_time_series(values, period): \n", + " df_ewma = pd.DataFrame(data=np.array(values))\n", + " ewma_data = df_ewma.ewm(span=period)\n", + " ewma_data_mean = ewma_data.mean()\n", + " return ewma_data_mean\n", + "\n", + "Taus = [5] # time constant for the exponential filter\n", + "add = np.array(Taus).shape[0]\n", + "\n", + "for i, tau in enumerate(Taus):\n", + " design_mat_T[:, i] = ewma_time_series(inpt_pc, tau)[0] \n", + " \n", + "for i, tau in enumerate(Taus):\n", + " design_mat_T[:, i+add] = ewma_time_series(inpt_wsls, tau)[0] \n", + "\n", + "transition_input = design_mat_T\n", + "print('design_mat_T=', design_mat_T)" + ] + }, + { + "cell_type": "markdown", + "id": "bce5725f", + "metadata": {}, + "source": [ + "# 2b. Designing input matrix for Observation " + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "3a887a05", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "design_mat_Obs= [[-1. -1.]\n", + " [-1. -1.]\n", + " [ 1. 1.]\n", + " ...\n", + " [-1. 1.]\n", + " [-1. 1.]\n", + " [ 1. 1.]]\n" + ] + } + ], + "source": [ + "design_mat_Obs = np.zeros((time_bins, input_dim_O)) # observation design matrix\n", + "design_mat_Obs[:, 0] = inpt_pc \n", + "design_mat_Obs[:, 1] = inpt_wsls \n", + "observation_input = design_mat_Obs\n", + "print('design_mat_Obs=', design_mat_Obs)" + ] + }, + { + "cell_type": "markdown", + "id": "ce9ffee4", + "metadata": {}, + "source": [ + "# 3. Creating a GLM-HMM\n", + "In this section, we make a GLM-HMM with the following transition and observation components:\n", + "\n", + "\n", + "```python\n", + "true_glmhmm = ssm.HMM_TO(num_states, obs_dim, M_trans, M_obs, observations=\"input_driven_obs_diff_inputs\", observation_kwargs=dict(C=num_categories), transitions=\"inputdrivenalt\")\n", + "```\n", + "\n", + "This function has two sections:\n", + "\n", + "**a) Observation model:**\n", + "The observation model is a categorical class indicated by `observations=\"input_driven_obs_diff_inputs\"`. Within this model, the animal's choices are influenced by a range of inputs into the system. This observation class is particularly suited for scenarios where transitions are driven by external inputs. Also, `M_obs` represents the number of covariates utilized in the observation model.\n", + "\n", + "We can determine the number of response categories with `observation_kwargs=dict(C=num_categories)`. Here, `C = 2` since there are only two possible choices for the animal, making the observations binary. \n", + "\n", + "\n", + "**b) Transition model:**\n", + "The transition model, specified as `transitions=\"inputdrivenalt\"` is a multiclass logistic regression in which `M_trans` is the number of covariates influencing the transitions between states. In this model, multiple regressors play a key role in determining the transitions between states.\n", + "\n", + "The model's number of states is determined by `num_states`, which represents the number of hypothesized strategies in this task." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a055856e", + "metadata": {}, + "outputs": [], + "source": [ + "# Make a GLM-HMM\n", + "true_glmhmm = ssm.HMM_TO(num_states, obs_dim, M_trans=input_dim_T, M_obs=input_dim_O,\n", + " observations=\"input_driven_obs_diff_inputs\", observation_kwargs=dict(C=num_categories),\n", + " transitions=\"inputdrivenalt\")" + ] + }, + { + "cell_type": "markdown", + "id": "8777a51c", + "metadata": {}, + "source": [ + "# 3a. Initializing the model\n", + "To ensure that the model accurately reflects the mouse behavior, we must bring the GLM-HMM into an appropriate parameter regime (see Mohammadi et al. (2024)), and provide a good initialization for the observation and transition weights.\n", + "\n", + "By carefully adjusting these parameters and initializing the model accurately, we aim to create a model that closely mirrors the behavioral patterns observed in mice. This fine-tuning process will contribute to a better representation of real-world scenarios, enhancing the model's reliability in studying mice behavior." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bca3525a", + "metadata": {}, + "outputs": [], + "source": [ + "gen_weights = np.array([[[5, 2]], [[2, -4]], [[-3, 4]]])\n", + "gen_log_trans_mat = np.log(np.array([[[0.94, 0.03, 0.03], [0.05, 0.90, 0.05], [0.04, 0.04, 0.92]]]))\n", + "Ws_transition = np.array([[[3, 1], [2, .6]]])\n", + "\n", + "true_glmhmm.observations.params = gen_weights\n", + "true_glmhmm.transitions.params[0][:] = gen_log_trans_mat\n", + "true_glmhmm.transitions.params[1][None] = Ws_transition" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "53fff4d5", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot generative parameters:\n", + "fig = plt.figure(figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')\n", + "cols = ['darkviolet', 'gold', 'chocolate']\n", + "\n", + "# 1) Observation\n", + "gen_weights_obs = gen_weights\n", + "\n", + "plt.subplot(1, 2, 1)\n", + "for k in range(num_states):\n", + " if k ==0:\n", + " plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D',\n", + " color=cols[k], linestyle='-', lw=1.5)\n", + " else:\n", + " plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D',\n", + " color=cols[k], linestyle='-', lw=1.5)\n", + "\n", + "plt.yticks(fontsize=10)\n", + "plt.ylabel(\"Weight value\", fontsize=15)\n", + "plt.xticks([0, 1], ['Obs_in1', 'Obs_in2'], fontsize=12, rotation=45)\n", + "plt.axhline(y=0, color=\"k\", alpha=0.5, ls=\"--\")\n", + "plt.title(\"Observation GLM\")\n", + "\n", + "# 2) transition \n", + "gen_log_trans_mat = true_glmhmm.transitions.params[0]\n", + "gen_weights_Trans = true_glmhmm.transitions.params[1]\n", + "generative_weights_Trans = true_glmhmm.trans_weights_K(true_glmhmm.params, num_states)\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "for k in range(num_states):\n", + " if k ==0:\n", + " plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D',\n", + " color=cols[k], linestyle='-', lw=1.5)\n", + " else:\n", + " plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D',\n", + " color=cols[k], linestyle='-', lw=1.5)\n", + "\n", + "plt.yticks(fontsize=10)\n", + "plt.axhline(y=0, color=\"k\", alpha=0.5, ls=\"--\")\n", + "plt.xticks([0, 1], ['Tran_in1', 'Tran_in2'], fontsize=12, rotation=45)\n", + "plt.title(\"Transition GLM\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6d6d1729", + "metadata": {}, + "source": [ + "# 3b. Sample data from the GLM-HMM" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d3048b9f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/5r/j_22d45j5wn_6j9p3016x8p00000gw/T/ipykernel_56319/3193879353.py:20: MatplotlibDeprecationWarning: The 'b' parameter of grid() has been renamed 'visible' since Matplotlib 3.5; support for the old name will be dropped two minor releases later.\n", + " plt.grid(b=None)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAACfCAYAAABk4NpGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAxOAAAMTgF/d4wjAAALkElEQVR4nO3df0zV9R7H8dc5nB8gB1EEEWQKscJm+KMfuiPdBe3ems3sF9dbOS1tXq8t2ZzuunJGLlOv1+o/dfmDaN47nJvFMu9abWUtydnKudB7USemE7ygWYgCHvjcP0wSAYGEiPt+PjY2+P78nPPlfM9z56fHOecEAADM8vb3AAAAQP8iBgAAMI4YAADAOGIAAADjiAEAAIwjBgAAMI4YAADAOGIAGCD27dunxx57TKNGjVIwGFRycrLC4bAWL17cusz69ev19ttv39R+Vq1apffee+/mBgtgQPHwoUPAb98HH3yg6dOnKzc3V/PmzVNKSoqqqqr01VdfqaSkRKdOnZIk3XHHHUpMTNSnn376i/cVCoWUn59/01EBYODw9fcAAHRt7dq1ysjI0Icffiif7+eb7ZNPPqm1a9f248gA/D/gaQJgADh79qwSExPbhMBVXu+Vm3F6errKy8u1Z88eeTweeTwepaenS5IaGhq0ePFiTZgwQfHx8UpISFA4HFZpaWmbbXk8HtXX16u4uLh1G7m5ua3zq6urNX/+fKWlpSkQCCgjI0MrVqxQJBJps50NGzZo/PjxCoVCiouL05gxY/TSSy/17pUCoNfwyAAwAITDYW3evFkFBQWaOXOm7rzzTvn9/jbLvPvuu8rPz1d8fLzWr18vSQoGg5KkxsZGnTt3TkuWLNHIkSPV1NSkjz/+WI8//riKioo0e/ZsSVJZWZnuv/9+5eXlafny5ZKkwYMHS7oSApMmTZLX69XLL7+szMxMlZWVaeXKlaqsrFRRUZEkqaSkRM8//7wWLlyodevWyev16ujRozp06NCvcl0B+AUcgN+82tpad++99zpJTpLz+/1uypQpbvXq1a6urq51ubFjx7r77ruvy+1FIhF3+fJl99xzz7mJEye2mRcbG+ueeeaZduvMnz/fhUIhd+LEiTbT161b5yS58vJy55xzL7zwghsyZEjPLySAfsPTBMAAMGzYMH3++efav3+/1qxZo0ceeUQVFRV68cUXlZ2drdra2i63sWPHDuXk5CgUCsnn88nv92vLli06fPhwt8awa9cu5eXlKTU1VZFIpPVn6tSpkqQ9e/ZIkiZNmqTz58/rqaeeUmlpabfGBqB/EQPAAHL33Xdr6dKl2rFjh06fPq1FixapsrKyyxcR7ty5UzNmzNDIkSO1bds2lZWVaf/+/Zo7d64aGhq6te8zZ87o/fffl9/vb/MzduxYSWq90581a5a2bt2qEydO6IknntDw4cM1efJkffTRRzd34QH0GV4zAAxQfr9fhYWFevPNN/Xtt9/ecNlt27YpIyND27dvl8fjaZ3e2NjY7f0lJiZq3Lhxeu211zqcn5qa2vr7nDlzNGfOHNXX1+uzzz5TYWGhpk2bpoqKCo0ePbrb+wTw6yAGgAGgqqpKKSkp7aZffYj/6h1xMBjUpUuX2i3n8XgUCATahEB1dXW7dxPcaBvTpk3T7t27lZmZqaFDh3Zr3LGxsZo6daqampr06KOPqry8nBgAfoP40CFgABg3bpzS0tL08MMPa8yYMWppadGBAwf0+uuvq66uTnv37lV2draeffZZlZSUqLi4WLfccouio6OVnZ2toqIizZ07VwsWLFB+fr5OnjypV199VV6vV0eOHNG1p4Hc3FwdPnxYmzdvVkpKiuLi4pSVlaWqqiqFw2HFxMSooKBAWVlZamhoUGVlpXbv3q2NGzcqLS1N8+bNU0xMjHJycpSSkqLq6mqtXr1ax48f15EjR5SUlNSP1ySADvXzCxgBdMP27dvd008/7W699VYXCoWc3+93o0aNcrNmzXKHDh1qXa6ystI98MADLi4uzklyo0ePbp23Zs0al56e7oLBoLv99tvdpk2bXGFhobv+NHDgwAGXk5PjBg0a5CS1eXdCTU2NKygocBkZGc7v97uEhAR31113uWXLlrkLFy4455wrLi52eXl5Ljk52QUCAZeamupmzJjhDh482KfXEYBfjkcGAAAwjncTAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgXI8/dCjKH6WUJE/XC17jvz+G5K1v6umu2mmOC+jab1Pw1UWklhZJUmRwQOrZsLptxKAL7aZVXwx1uKzvh84up0eReP+1f2pETPvtXutiS5Qunonq7jC7FvArEtP+SvIFIkr0df2RtNX1Ifl+vPnj2Jnrj+9V1x9nX11Eci2t8+NHNMrv6aODL6m2KtDpvEh85/M65HEaEVN/ZV3nVHspTh4nRXXjenVBv5IS6nu2v59UN8bKd+5y1/uICai5g4sUE2xSfNSVMZ6pCynqQtfj9SZ5ldCN/6tfy42OY5+KilLi8Eutx7s3eH0tGh642On864938+CAkmOvnG++bw6qsdHf2aq9wlffIl33tdY34+r5fXhMnbx9dFvvtf+Pn453X6tr8am+IVqS5L0seS92fpuMxAc0YtAF1ZyN6vRTR3scA9FDo/Xd16ldL3iNu15ZoMS3ynq6q3ZOLpqixoSf7wSyVlao+ew5SVLl0rAisX3zLsljf9rYblrm9r+0m+ZpljKXfNnhNrzR0apYMaH175aYFh2f/tYN9/vOj4n6x5i0ng32BtykCTqWH91u+u0TT2jXbf/qcv2sLQuUvvzmj2NnTi2coobhLe2mZ606puaaGklS5V/Dynzj32r+/vvW+W/sK9PYQEyfjevB1Akdz/BG6eiKe3q0rZZBzTr+8CZJ0vHLF/T7nUsU+MGrUa/s7XLdpvvv0SdbN/Vof1fd9tlsZTx5sMvlLj40Wad/1/5km5/7pf6WfECSNG7d80p5o+vxDisdqn9mfNLjsfaVB0dOlPrhndS+0aP1wRelrce7Nwy+5by+uaek0/nXH+/vFk/R4T9f+Vrr+afC+viL8b0yjs5kbaxV83+O9tr2ji8Lqzno9PUf31S8t/dv682uRQ+NvLNXtnX1ePe1v5/L1MaP/iBJSvzGoyHvdH5uPrY8rKMzN2jU3Z1HCk8TAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMYRAwAAGEcMAABgHDEAAIBxxAAAAMZ5nHOuJysEg0ElJSX11XgAAEAfqKmpUWNjY4fzehwDAADg/wtPEwAAYBwxAACAccQAAADGEQMAABhHDAAAYBwxAACAccQAAADGEQMAABhHDAAAYBwxAACAcf8DVhT6wN4GiWMAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Sample some data from the GLM-HMM\n", + "true_states, obs = true_glmhmm.sample(time_bins, transition_input=transition_input, observation_input=observation_input)\n", + "\n", + "# Plot the data\n", + "T= 500\n", + "fig = plt.figure(figsize=(8, 2), dpi=80, facecolor='w', edgecolor='k')\n", + "plt.imshow(true_states[None, :], aspect=\"auto\")\n", + "plt.xticks([])\n", + "plt.xlim(0, T)\n", + "plt.yticks([])\n", + "plt.title(\"States\", fontsize=15)\n", + "\n", + "# For visualizing categorical observations, we create a Cmap. \n", + "obs_flat = np.array([x[0] for x in obs]) # SSM initially provides categorical observations as a list of lists, and we transform them into a 1D array to facilitate plotting.\n", + "fig = plt.figure(figsize=(8, 2), dpi=80, facecolor='w', edgecolor='k')\n", + "plt.imshow(obs_flat[None,:], aspect=\"auto\")\n", + "plt.xlim(0, T)\n", + "plt.xlabel(\"trial #\", fontsize=12)\n", + "plt.yticks([])\n", + "plt.grid(b=None)\n", + "plt.title(\"Observations\", fontsize=15)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b3e6b0df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "true_lp = -2030.8011892459606\n" + ] + } + ], + "source": [ + "# Calculate the actual log likelihood by summing over discrete states\n", + "true_lp = true_glmhmm.log_probability(obs, transition_input=transition_input, observation_input=observation_input)\n", + "print(\"true_lp = \" + str(true_lp))" + ] + }, + { + "cell_type": "markdown", + "id": "0f99279f", + "metadata": {}, + "source": [ + "# 4. Make a new HMM for fitting purpose\n", + "\n", + "In this section, we instantiate a new GLM-HMM and assess its ability to recover generative parameters using Maximum Likelihood Estimation (MLE) on simulated data. MLE enables us to optimize model parameters to match the underlying data generation process, providing insights into the model's performance in emulating various scenarios." + ] + }, + { + "cell_type": "markdown", + "id": "944eafb0", + "metadata": {}, + "source": [ + "## EM for fitting" + ] + }, + { + "cell_type": "markdown", + "id": "ff4c3e26", + "metadata": {}, + "source": [ + "We employ the Expectation-Maximization (EM) method to fit the data. The EM algorithm consists of two main steps: the E-step and the M-step. The algorithm starts with an initial guess for the model parameters and iterates until the log marginal likelihood converges. During the E-step, the algorithm computes the expected value of the complete-data log-likelihood based on the model parameters estimate and the observed data (animal choices). Next, the M-Step tries to find the parameters that maximize the expected log-likelihood." + ] + }, + { + "cell_type": "markdown", + "id": "2b7e7566", + "metadata": {}, + "source": [ + "To elaborate further, during each trial and based on the specified GLM-HMM parameters, we compute the joint probability distribution encompassing both the states and the animals' decisions (left or right). Subsequently, the log-likelihood of the model is evaluated using this joint probability distribution. This relationship can be expressed in the following manner:\n", + "\n", + "$$\n", + "\\begin{align}\n", + "\\log \\left[ p(\\mathbf{Y}|\\theta, \\mathbf{X}^{ob}, \\mathbf{X}^{tr})\\right]= \\log \\left[ \\sum_{z} \n", + "p(\\mathbf{Y}, \\mathbf{Z}|\\theta, \\mathbf{X}^{ob}, \\mathbf{X}^{tr})\\right]\n", + "\\end{align}\n", + "$$\n", + "\n", + "In this model, as mentioned, we have defined two sets of covariates, $\\mathbf{X}^{ob}={x}^{ob}_{1}, ..., {x}^{ob}_{T}$ and $\\mathbf{X}^{tr}={x}^{tr}_{1}, ..., {x}^{tr}_{T}$ which correspond to the observation and transition covariates respectively and a set of latent states as $\\mathbf{Z}={z}_{1}, ..., {z}_{T} $.\n", + "The model parameters, represented as $\\theta=\\{ \\mathbf{w}^{tr}, \\mathbf{w}^{ob}, \\pi\\}$, encompass the initial state distribution, transition weights, and observation weights for all states." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "7589af90", + "metadata": {}, + "outputs": [], + "source": [ + "glmhmm = ssm.HMM_TO(num_states, obs_dim, input_dim_T, input_dim_O, observations=\"input_driven_obs_diff_inputs\", \n", + " observation_kwargs=dict(C=num_categories), transitions=\"inputdrivenalt\")" + ] + }, + { + "cell_type": "markdown", + "id": "259f87b9", + "metadata": {}, + "source": [ + "# 4a. Fit the new HMM\n", + "Here, we'll fit the model to simulated data. Through this fitting process, our model tries to capture the underlying dependencies that characterize the behavior of interest. This allows us to gain valuable insights into how well our model aligns with the data and its potential to generalize its findings to other cases." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "70103025", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2d9b0344720444a0b5fe8c4fd5c50196", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/200 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize the Learned Parameters\n", + "# Plot the log probabilities of the true and fit models\n", + "fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k')\n", + "plt.plot(hmm_lps, label=\"EM\")\n", + "plt.plot([0, N_iters], true_lp * np.ones(2), ':k', label=\"True\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.xlabel(\"EM Iteration\")\n", + "plt.xlim(0, len(hmm_lps))\n", + "plt.ylabel(\"Log Probability\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c990ef26", + "metadata": {}, + "outputs": [], + "source": [ + "glmhmm.permute(find_permutation(true_states, glmhmm.most_likely_states(obs, transition_input=transition_input, observation_input=observation_input)))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a3794b24", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot generative parameters \n", + "# 1) Observation \n", + "gen_weights_obs = gen_weights\n", + "recovered_weights = glmhmm.observations.params\n", + "\n", + "fig = plt.figure(figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k')\n", + "plt.subplot(1, 2, 1)\n", + "for k in range(num_states):\n", + " if k == 0:\n", + " plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D', color=cols[k], linestyle='-',\n", + " lw=1.5, label=\"generative\")\n", + " else:\n", + " plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D', color=cols[k], linestyle='-',\n", + " lw=1.5, label=\"\")\n", + "\n", + "plt.yticks(fontsize=10)\n", + "plt.ylabel(\"Weight value\", fontsize=12)\n", + "plt.xticks([0, 1], ['Obs_inp1', 'Obs_inp2'], fontsize=12, rotation=30)\n", + "plt.axhline(y=0, color=\"k\", alpha=0.5, ls=\"--\")\n", + "plt.title(\"Observation GLM\", fontsize=12)\n", + "\n", + "for k in range(num_states):\n", + " if k == 0:\n", + " plt.plot(range(input_dim_O), recovered_weights[k][0], color=cols[k],\n", + " lw=1.5, label = \"recovered\", linestyle = '--')\n", + " else:\n", + " plt.plot(range(input_dim_O), recovered_weights[k][0], color=cols[k],\n", + " lw=1.5, label = '', linestyle = '--')\n", + "plt.yticks(fontsize=10)\n", + "plt.axhline(y=0, color=\"k\", alpha=0.5, ls=\"--\")\n", + "plt.legend()\n", + "\n", + "# 2) transition \n", + "gen_log_trans_mat = true_glmhmm.transitions.params[0]\n", + "gen_weights_Trans = true_glmhmm.transitions.params[1]\n", + "recovered_trans_mat = np.exp(glmhmm.transitions.params[0])\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "recovered_weights_Trans = glmhmm.trans_weights_K(glmhmm.params, num_states)\n", + "generative_weights_Trans = true_glmhmm.trans_weights_K(true_glmhmm.params, num_states)\n", + "\n", + "for k in range(num_states): \n", + " if k == 0:\n", + " plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D', color=cols[k], linestyle='-',\n", + " lw=1.5, label=\"generative\")\n", + " plt.plot(range(input_dim_T), recovered_weights_Trans[k], color=cols[k],\n", + " lw=1.5, label = \"recovered\", linestyle = '--')\n", + " else:\n", + " plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D',\n", + " color=cols[k], linestyle='-', lw=1.5, label=\"\")\n", + " plt.plot(range(input_dim_T), recovered_weights_Trans[k], color=cols[k],\n", + " lw=1.5, label = '', linestyle = '--')\n", + " \n", + "plt.yticks(fontsize=10)\n", + "plt.title(\"Transition GLM\", fontsize=12)\n", + "plt.axhline(y=0, color=\"k\", alpha=0.5, ls=\"--\")\n", + "plt.xticks([0, 1], ['Tran_inp1', 'Tran_inp2'], fontsize=12, rotation=30)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "cfb67801", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# transition matrix\n", + "recovered_matrix = glmhmm.Ps_matrix(data=obs, transition_input=transition_input, observation_input=observation_input) # , train_mask=train_mask)[0]\n", + "gen_trans_mat = np.exp(gen_log_trans_mat)\n", + "\n", + "fig = plt.figure(figsize=(6, 3), dpi=80, facecolor='w', edgecolor='k')\n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone')\n", + "for i in range(gen_trans_mat.shape[0]):\n", + " for j in range(gen_trans_mat.shape[1]):\n", + " text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha=\"center\", va=\"center\",\n", + " color=\"k\", fontsize=12)\n", + "plt.xlim(-0.5, num_states - 0.5)\n", + "plt.ylim(num_states - 0.5, -0.5)\n", + "plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)\n", + "plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)\n", + "plt.ylabel(\"state t\", fontsize=15)\n", + "plt.xlabel(\"state t+1\", fontsize=15)\n", + "plt.title(\"generative\", fontsize=15)\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(np.mean(recovered_matrix, axis=0), vmin=-0.8, vmax=1, cmap='bone')\n", + "for i in range(np.mean(recovered_matrix, axis=0).shape[0]):\n", + " for j in range(np.mean(recovered_matrix, axis=0).shape[1]):\n", + " text = plt.text(j, i, str(np.around(np.mean(recovered_matrix, axis=0)[i, j], decimals=2)), ha=\"center\", va=\"center\",\n", + " color=\"k\", fontsize=12)\n", + "plt.xlim(-0.5, num_states - 0.5)\n", + "plt.ylim(num_states - 0.5, -0.5)\n", + "plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)\n", + "plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)\n", + "plt.title(\"recovered\", fontsize=15)\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "0741dcef", + "metadata": {}, + "source": [ + "# 4c. Analysis of the acquired states\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d0350e34", + "metadata": {}, + "outputs": [], + "source": [ + "# Get expected states\n", + "posterior_probs = [glmhmm.expected_states(data = obs, transition_input = transition_input, observation_input=observation_input)[0]]\n", + "\n", + "# Determine the state with the highest posterior probability\n", + "posterior_max = np.argmax(posterior_probs[0], axis = 1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e65b35c8", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAANcAAADVCAYAAADEtWW+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAxOAAAMTgF/d4wjAAAZmElEQVR4nO3deVAUZ/4G8Gc4BJR7OORQUJEhHkgBAmqE4AWKQeNmY7zisR6IGnQSVzG7RokK64FoAFdR1LUixHWjREArioiCcsUVCMZCkEM8YLhEDBKE/v2RX2aDIPTANEOP30/VVDnzzvQ8o/XYPT3dbwsYhmFACJE7FUUHIERZUbkI4QiVixCOULkI4QiVixCOULkI4QiVixCOULkI4YhCyrVt2zYIBII2t9mzZ0vHCwoK4OnpCS0tLVhbWyM6OloRMQnpETVFvbGLiwvi4uKk9zU1NQEAzc3N8PHxgYODA7KyspCRkYFVq1bBysoKkydPVlRcQmSmsHKpq6tj4MCB7R6/ePEiHj58iNu3b0NHRwejRo1CSkoKvv76ayoX4RWFfefKycnBwIEDYWtrizVr1qC2thYAkJmZibFjx0JHR0f63MmTJyMjI0NRUQnpFoWUy83NDf/6179w+fJl7Nu3DykpKZg1axYYhkFlZSVMTEzaPN/Y2BgSiaTDZYWGhsLS0lJ6e/21hCiKQjYLvb29pX8ePXo0RowYARsbG/z444+Q9SB9sVgMsVgsvW9paSm3nIT0RJ/YFT9s2DDo6+ujuLgYpqamqKysbDMukUhgbGysoHSEdE+fKFdZWRnq6upgbW0NFxcXZGdno6GhQTp+9epVuLq6KjAhIbJTyGbhX//6V/j6+sLS0hLFxcXYuHEjxo0bBycnJ7x69QoWFhZYtmwZvvzyS2RkZCAmJgYXL15URFRCuk0h5SotLcWf//xnVFdXw9zcHF5eXtixYwdUVFTQr18/JCQkYNWqVXBycoKpqSkOHTpEu+EJ7wiU7TR/S0tLlJeXKzoGIX3jOxchyojKRQhHqFyEcITKRQhHqFyEcITKRQhHqFyEcITKRQhHZDpCIz8/H2lpaXj06BEaGxthZGSEESNGwN3dHbq6ulxlJISXuixXbW0tDh8+jCNHjqC0tLTDU0LU1NQwY8YMfPrpp5g0aRInQQnhm043Cw8ePAgbGxvs3bsX06dPR2xsLO7fv49nz56hqakJT548QVpaGkJCQlBbW4upU6fC29sbhYWFvZWfkD6r02MLbWxssHXrVsybNw/q6updLqyoqAg7d+6EjY0NtmzZItegbNGxhaSv6LRcr169gpqa7AfOt7S0QFVVtUfBuovKRfqKTjcLu1MsAAorFiF9SbfP50pJSUFiYiIYhsH06dPh6ekpz1yE8F63fuc6fPgwZs6cicLCQuTk5GDatGkIDw+XdzZCeK1bJ0sOHz4cMTExcHZ2BgAcOHAA+/fvR0lJibzzyayz71wRgp96OU3X1jCjFB2BcKTTNZenpyfu37/f7vFnz57Bzs5Oet/W1hb19fXyT0cIj3VarhEjRsDBwQFfffUVmpubpY9PnToVc+fORUJCAs6cOYONGzdi6tSpnIclhE86LVdERASSkpJw9uxZjBkzBmlpaQCA8PBwGBoaYsmSJfD398eYMWMQERHRK4EJ4Ysu9xa6ubnhxx9/xJ49e+Dl5YUFCxZgz549OHXqVG/kI4S3WO0tVFNTQ2BgIHJycvDgwQOIRCKcOXOG62yE8FqX5WIYBgUFBcjNzYWlpSUuX76M3bt3Y926dfDx8UFZWVlv5CSEdzot188//4yRI0fCzs4ODg4OGDRoEBISErBo0SLcvXsXxsbGGDlyJEJDQ9Ha2tpbmQnhhU7L5efnBxsbGzx69Ah1dXVYtmwZFi1ahObmZgiFQpw4cQLff/89Dh8+jLFjx/ZWZkJ4odNy/fe//0VAQADMzMygq6uLwMBA1NXVobi4WPocT09P5ObmwsfHh/OwhPBJp+WysbHB2bNnpZt8MTEx6NevHwYNGtTmeRoaGggKCuIuJSE81Omu+LCwMMyZMwfffPMNNDQ08Pz5cxw4cABaWlq9lY8Q3uq0XO7u7igqKsKtW7fw8uVLODs705UbCWGpyx+R9fT02lxmlRDCTqffuTIyMmReYGNjI/Lz87sdiBBl0Wm5PDw8MGvWLFy5cqXLBVVUVGDPnj0YOnQo4uPj5RaQEL7qdLMwPz8fn3/+OaZNmwYzMzO4u7vD0dERJiYm0NTURE1NDYqKipCeno7MzEwYGBhg+/btWLlyZW/lJ6TPYnWy5M8//4xDhw4hMTERDx48aDOmpaWFcePGYcGCBZg/fz40NDQ4C8sGnSxJ+gqZz0SWSCR4/PixdMZdKysrVtOu9RYqF+krZJ6gxtjYGMbGxlxkIUSpdHv2J0K4kOdvpegIbYyOLO32a+kqJ4RwhMpFCEeoXIRwhMpFCEdYl2vfvn2ora3lMgshSoV1uTZt2gRLS0usWLECOTk5XGYiRCmwLldRURHWrFmD8+fPw9HREe+++y5iY2Px6tUrLvMRwlusy2VlZYXdu3ejvLwcUVFRePnyJebPn4/Bgwdj27ZtePLkCZc5CeEdmXdoaGhoYNmyZcjOzkZaWhpsbW3x1VdfwdraGvPmzaNNRkL+X7f3Fl65cgX/+Mc/kJqaCkNDQyxYsADXr1+Hs7Mzjh49Ks+MhPCSTOV6/vw5Dh48CDs7O0ybNg0PHjzAP//5Tzx8+BDR0dEoKSnBJ598gm3btnEUlxD+YF2u1atXw8LCAmKxGHZ2dkhKSkJubi6WL18OTU1NAIC6ujqWLl2Kx48fcxaYEL5gfeBubGwsVqxYgbVr12LIkCFvfJ6dnR2OHz8ul3CE8BnrcpWXl2PAgAFdPs/IyAiLFy/uUShClAHrzcJffvkFBQUFHY4VFBSgqqpKbqEIUQas11z+/v7Q19dHVFRUu7F9+/ahvr4eMTExcg1HCJ+xXnPdvHkTXl5eHY55eXkhNTVVbqEIUQasy1VVVQWhUNjhmIGBASQSidxCEaIMWJfL1NQUeXl5HY7l5eW9sXiEvK1Yl8vb2xs7d+5st1Pj/v37CA4OxowZM+QejhA+Y71DY9u2bYiPj4e9vT08PT2lU5glJyfDyMgI27dv5zInIbzDes1lbm6O7OxsLFiwALm5uTh58iRyc3OxcOFCZGZmwtzcnMuchPCOTFOrmZub49ixY1xlIUSp0BwahHBEpjVXamoqTp8+jdLSUjQ2NrYZEwgESEpKkms4QviMdbmOHz+Ov/zlLzA0NIStrW27Cy7IOOU8IUqPdbl2796Njz76CCdPnlT4lUwI4QPW37lKS0uxfPlyKhYhLLEu1zvvvIOKiooev+GuXbvg6OgIbW1tmJmZYenSpe0OnRIIBO1ud+7c6fF7E9KbWJdr165dCAkJwaNHj3r0hqmpqRCLxcjOzkZcXBzu3r2LuXPntnvemTNn8OTJE+lt1Ci6jhXhF9bfuSIiIvDs2TPY2trCwcGh3bGEAoEAcXFxXS4nMTGxzf2wsDCMHz8ez549g56envRxAwMDDBw4kG08Qvoc1uXKzc2FqqoqTExM8Pjx43bzZAgEgm4FqKqqgqamZruznJcsWYJff/0VIpEImzdvho+PT4evDw0NRWhoqPR+Q0NDt3IQIm+sy1VSUiL3N29qakJQUBAWL14MNbX/Rdm5cycmT54MNTU1nDt3Du+//z5++OEHTJkypd0yxGIxxGKx9L6lpaXccxLSHQq7smRLSwsWLlwIANi7d2+bsS1btkj/7OTkhLKyMoSFhXVYLkL6qm6VSyKRtDtCAwAGDx7M6vWtra1YsmQJ7t27h5SUFGhra3f6fCcnJxw5cqQ7UQlRGJnKtWPHDhw8eBDV1dUdjre0tHS5DIZhsHz5cqSnp+PGjRswNDTs8jU5OTmwtraWJSohCse6XNHR0QgJCcHmzZuxdetWfPHFF2AYBqdOnYKWlhY2bdrEajl+fn64cOECEhISAABPnz4FABgbG0NVVRXx8fGQSCRwdXWFmpoavvvuO5w8eRLx8fHd+HiEKA7r37kiIiKwZcsWBAYGAgA++OAD7NixA/fu3YOOjg7rqdWOHDmCqqoquLq6wszMTHp7+PAhAEBNTQ1hYWFwcXGBs7Mzzp07h//85z+YPn16Nz4eIYrDes1VWFgINzc3qKj81sdff/0VAKClpYXPPvsMf//737Fx48Yul9PVAb7e3t7w9vZmG4uQPov1muv3XeUCgQC6urooLy+XjhkZGfX4yA1ClA3rNdfw4cOlm25jx45FVFQUZs2aBRUVFRw5coR2OCjCve79cM8pOzr16HesyzVjxgxcv34dixcvRmBgILy8vKCvrw81NTU0NDQgOjqay5yE8A7rcm3dulX650mTJiEtLQ2xsbFQUVGBj48PPD09OQlICF91+wgNFxcXuLi4yDMLIUpF5nKVl5fj+vXrqK6uhlAohLu7Ox3PR0gHWJertbUV69evx6FDh9ociaGqqgo/Pz8cOHBAupueECLjjLvh4eFYsWIF5s+fj4EDB+Lp06f45ptvEBERAQMDAwQFBXGZlRBekenwp4CAAOzfv1/6mEgkgoeHB/r374/o6GgqFyF/wHo7rqam5o0nLPr4+KCmpkZuoQhRBqzLNWbMmE4v20pzXBDSFuvNwj179mDevHmwsrJqswa7cOECQkJCcPr0aU4CEsJXrMu1evVqvHz5Er6+vtDR0YGpqSkqKirw/PlzCIVCrFmzRvpcgUCAnJwcTgITwhesyyUUCmFkZNTmMbpsECFvxrpc165d4zAGIcqHfvUlhCOs11zXr1/v8jnu7u49CkOIMmFdrvfee6/LiT/ZTFBDyNuCdbmSk5PbPVZVVYW4uDikpaUhIiJCrsEI4TvW5fLw8Ojw8T/96U/w8/PDpUuXaO4LQv5ALjs0PvjgA8TGxspjUYQoDbmUq7a2Fk1NTfJYFCFKg/VmYVlZWbvHmpqakJubi8DAQLi5uck1GCF8x7pc1tbWHe4tZBgGIpEI4eHhcg1GCN/JdD7X6+XS1NSEtbU1xo4dS2chE/Ia1uVasmQJhzEIUT6sVzcSiaTT87nYzhVPyNuC9ZrL398f+vr6iIqKaje2b98+1NfXIyYmRq7hCOEz1muumzdvwsvLq8MxLy8vpKamyi0UIcqAdbmqqqogFAo7HDMwMIBEIpFbKEKUAetymZqaIi8vr8OxvLy8NxaPkLcV63J5e3tj586d7XZq3L9/H8HBwZgxY4bcwxHCZzJNChofHw97e3t4enrC0tIS5eXlSE5OhpGREbZv385lTkJ4h/Way9zcHNnZ2ViwYAFyc3Nx8uRJ5ObmYuHChcjMzKT5NAh5jUwXYjA3N8exY8e4ykKIUmG95mpubsaLFy86HHvx4gWam5vlFooQZcC6XCtWrMDy5cs7HFu5ciVWr14tt1CEKAPW5UpOToavr2+HY++//z6SkpLkFooQZcC6XBUVFTAzM+tw7PfLCRFC/od1ufT19VFYWNjhWGFhIXR0dOQWihBlwLpcnp6eCA4ObnepoJqaGoSEhGDSpElyD0cIn8n0I/LYsWMxfPhwzJ07FxYWFigvL8e///1vNDc304/IhLyGdblEIhFu3LgBsViMqKgotLS0QFVVFR4eHggNDYVIJOIyJyG8I9OPyGPGjEFSUhIaGxtRW1sLQ0NDaGpqcpWNEF6TqVwMw6C6uhoCgQBmZmZdTm9NyNuM1Q6NW7duYdasWdDV1YWpqSlMTEygq6uL2bNnIyMjg+uMhPBSl2uuyMhIBAQEAACcnJwwZMgQMAyDkpISJCQkICEhAQcOHIC/vz/nYQnhk07LlZ6ejk8//RQzZsxAZGQkLC0t24yXl5dj9erVCAgIgLOzM1xcXDgNSwifdLpZuG/fPri6uuL8+fPtigUAlpaWiIuLg4uLC/bs2cNZSEL4qNNypaamYs2aNZ1O+KmiogJ/f3+aoIaQ13RarpqaGgwePLjLhVhZWbU7coOQt12n5RIKhSgtLe1yIWVlZTRBDSGv6bRc7777LiIjI9Ha2vrG57S2tiI8PBwTJ06UezhC+KzTconFYmRkZGDOnDl48uRJu/HHjx9jzpw5yMrKwmeffcZZSEL4qNNd8W5ubti/fz82bNiAxMREODs7Y8iQIQCA4uJiZGdno7W1FWFhYbQbnpDXdPkj8rp16+Do6Ijg4GBcu3YN6enpAID+/fvDy8sLgYGBGD9+POdBCeEbVscWTpgwAfHx8WhtbZVezcTIyIiuyUVIJ2Q6cFdFRQUmJiZcZSFEqdCqhxCOULkI4QiVixCOULkI4QiVixCOULkI4UifLldISAjMzc3Rv39/+Pr60qy+hFf6bLmOHz+OHTt2IDw8HDdv3kR9fT3mzp2r6FiEsCbTj8i96euvv0ZAQADmzJkDAIiOjsawYcNw584dODg4KDYcISz0yTVXU1MTcnJy2kyRPXToUFhbW9NsU4Q3+uSaq7q6Gq2tre0OtTI2NkZlZWWbx0JDQxEaGiq9//Tp0w7n+wAAWMgvY0NDA7S1tXu8nOA3RGVHfh9IXp8H6NEHkhu5fZ7v3/x5tLW1ce/evTeO98lyMQzD+rlisRhisZjDNB37/YLryoI+j/z1yc3C34+4f30tJZFI6MBhwht9slwaGhoYM2YMkpOTpY8VFxejpKQErq6uCkxGCHt9crMQANauXYuAgAA4OTlh6NCh2LBhAyZOnNhn9hQqYlOUS/R55E/AyPIFp5cFBwfj4MGDqKurw5QpUxAVFYWBAwcqOhYhrPTpchHCZ33yOxchyoDKRQhHqFwy+O677zB58mTo6elBIBDg1atXio7Ubbt27YKjoyO0tbVhZmaGpUuXQiKRKDpWj4SEhMDOzg79+/eHUCiEr68vCgoKFJaHyiWDX375BZMmTcLmzZsVHaXHUlNTIRaLkZ2djbi4ONy9e5f3B0YPGzYM4eHhyM/Px9WrV6GqqgofHx+F5aEdGt1w7do1eHp6orm5GWpqffbXDJncunUL48ePR11dHfT09BQdRy7y8vJgb2+Pp0+fwtTUtNffn9ZcBABQVVUFTU1NDBgwQNFR5KKxsREnTpyASCSCsbGxQjJQuQiampoQFBSExYsX835NHB8fD21tbQwYMAAJCQm4ePGiwiavpXK95VpaWrBw4UIAwN69exWcpuc8PT1x584dXL9+He+88w7mzZuH5uZmhWTh939TpEdaW1uxZMkS3Lt3DykpKXI65USxBgwYABsbG9jY2MDFxQUGBga4ePEifH19ez0LlestxTAMli9fjvT0dNy4cQOGhoaKjsQJhmEUtqlL5ZJBTU0NysrKUFhYCADIycmBqqoqbGxsePe/vp+fHy5cuICEhAQAkE7+Y2xsDFVVVUVG67ZNmzZh9uzZMDc3R0VFBUJCQmBkZIQJEyYoJhBDWDt+/DgDoN0tOTlZ0dFk1tHnAMAUFxcrOlq3ffzxx4yFhQXTr18/xsLCgvn444+ZgoICheWh37kI4QjtLSSEI1QuQjhC5SKEI1QuQjhC5SKEI1QuQjhC5SKEI1QuJXX69GmEhYX1aBmRkZE4ceKEXPK8jehHZCU1c+ZM/PTTTygpKen2MkaNGgUjIyNcu3ZNbrneJrTmIoQjVC6ekkgkWLlyJQYNGgQNDQ0YGxtjwoQJuHLlCt577z0kJCSgtLQUAoFAevvd9u3b4erqCkNDQ+jq6sLR0RHHjh1rcwEMa2tr5OfnIyUlRfp6a2tr6Xh9fT0+//xzDBkyBP369YOFhQXWr1+PFy9e9OZfQ59GR8Xz1KJFi3D79m3s3LkTtra2qKurw+3bt1FdXY3IyEisXLkSRUVFOHfuXLvXlpSUYNWqVRg8eDAAID09HevWrcOjR4+wdetWAMC5c+fw4YcfQk9PD5GRkQB+m8Mf+G2iHg8PD5SXl2PLli2wt7dHfn4+tm7diry8PFy5cqVNmd9aCjtkmPSItrY2s379+jeO+/j4MFZWVl0up6WlhWlubmaCgoIYoVDItLa2SsdGjhzJeHh4tHtNcHAwo6KiwmRlZbV5/OzZswwAJjExkfXnUGa05uIpFxcXnDhxAkKhEFOmTIGTkxPU1dVZvfbq1avYtWsXsrKyUF9f32assrKyy5mS4uPjMWrUKDg4OLSZu9HLywsCgQDXrl3D9OnTZf9QSoa+c/HUt99+i8WLF+Po0aMYN24cDA0N8cknn0hPenyTzMxMTJs2DQAQFRWFtLQ0ZGVl4YsvvgDw26xJXamoqEBubi7U1dXb3HR0dMAwDKqqqnr+AZUArbl4ysjICGFhYQgLC0NZWRm+//57bN68GZWVlbh06dIbXxcbGwt1dXXEx8dDU1NT+vj58+dlem8tLS1ER0e/cZxQuZTC4MGDsXbtWiQlJSEtLQ3AbzsfOloLCQQCqKmptTmVv7GxEadOnWr33DctY+bMmdi1axeEQiGGDBkix0+iXGizkIeePXsGR0dH7N27F/Hx8UhJScHevXtx6dIlTJ06FQAwevRoVFZW4tChQ8jMzER2djYAwMfHBw0NDZg/fz4uX76M2NhYTJw4Ubon8I9Gjx6NnJwcfPvtt8jKykJeXh4AYP369RCJRHB3d0doaCiuXLmCH374AUePHsVHH32EjIyM3vvL6MsUvUeFyO7ly5eMn58fY29vz+jq6jJaWlqMSCRivvzyS+bFixcMwzBMTU0N8+GHHzL6+vqMQCBg/vhPHR0dzYhEIkZDQ4MZOnQoExwczBw7dqzdHBolJSXMtGnTGB0dHQZAm72PDQ0NzN/+9jdGJBIx/fr1Y/T09JjRo0czGzZsYJ4+fdpbfxV9Gh3+RAhHaLOQEI5QuQjhCJWLEI5QuQjhCJWLEI5QuQjhCJWLEI5QuQjhCJWLEI78H1rgMY3BRH8DAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# calculate state fractional occupancies\n", + "_, occur_for_state = np.unique(posterior_max, return_counts=True)\n", + "sum_all = np.sum(occur_for_state)\n", + "occur_for_state = occur_for_state/sum_all\n", + "\n", + "fig = plt.figure(figsize=(2.5, 2.5), dpi=80, facecolor='w', edgecolor='k')\n", + "\n", + "for state, occur in enumerate(occur_for_state):\n", + " occur_perc = occur * 100\n", + " plt.bar(state, occur_perc, width = 0.7, color = cols[state])\n", + " \n", + "plt.ylim((0, .6))\n", + "plt.xticks([0, 1, 2], ['1', '2', '3'], fontsize=12)\n", + "plt.yticks([0, 25, 50], ['0', '25', '50'], fontsize=12)\n", + "plt.xlabel('state', fontsize=15)\n", + "plt.ylabel('Occupancy (%)', fontsize=15)\n", + "plt.gca().spines['right'].set_visible(False)\n", + "plt.gca().spines['top'].set_visible(False)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "66269f94", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure(figsize=(12, 2.5), dpi=80, facecolor='w', edgecolor='k')\n", + "for k in range(num_states):\n", + " plt.plot(posterior_probs[0][0:200, k], label=\"State \" + str(k + 1), lw=1, marker='*',\n", + " color=cols[k]) \n", + "\n", + "plt.ylim((-0.01, 1.01))\n", + "plt.yticks([0, 0.5, 1], fontsize=10)\n", + "plt.xticks(fontsize=12)\n", + "plt.xlabel(\"trial number\", fontsize=15)\n", + "plt.ylabel(\"Posterior prob.\", fontsize=15)\n", + "plt.gca().spines['right'].set_visible(False)\n", + "plt.gca().spines['top'].set_visible(False)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "610c21a7", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# plot choices and latents:\n", + "plt.figure(figsize=(8, 3.5))\n", + "time_bin= 500\n", + "\n", + "plt.subplot(211)\n", + "plt.imshow(true_states[None, :], aspect=\"auto\")\n", + "plt.xticks([])\n", + "plt.xlim(0, time_bin)\n", + "plt.ylabel(\"true\\nstate\", fontsize=14)\n", + "plt.yticks([])\n", + "\n", + "plt.subplot(212)\n", + "inferred_states = glmhmm.most_likely_states(obs, transition_input=transition_input, observation_input=observation_input)\n", + "plt.imshow(inferred_states[None, :], aspect=\"auto\")\n", + "plt.xlim(0, time_bin)\n", + "plt.ylabel(\"inferred\\nstate\", fontsize=14)\n", + "plt.yticks([])\n", + "plt.xlabel(\"trial #\", fontsize=12)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14b48101", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py b/notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py new file mode 100644 index 00000000..27a97969 --- /dev/null +++ b/notebooks/2c-Input-Driven-Transitions-and-Observations-GLM-HMM.py @@ -0,0 +1,454 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: light +# format_version: '1.5' +# jupytext_version: 1.16.0 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +""" +Input Driven Observations and Transitions (GLM-HMM) +=================================================== +""" + +# # Input Driven Observations and Transitions (GLM-HMM) + +# This notebook is written by Zeinab Mohammadi. Here, a class called "HMM_TO" is defined which contains GLM-HMM with Input-Driven Observations and Transitions. We used an HMM enriched with two sets of per-state GLM. These GLMs consist of a Bernoulli GLM, which models observations (mice choices), and a multinomial GLM, which handles transitions between different states. This sophisticated framework effectively captures the dynamic interplay between covariates, mouse choices, and state transitions. As a result, it offers a more refined description of behavioral activity when compared to classical models. + +# Therefore, within the context of our model, we define two sets of covariates: $\mathbf{U}^{ob}={u}^{ob}_{1}, ..., {u}^{ob}_{T}$ represents the observation covariates, and $\mathbf{U}^{tr}={u}^{tr}_{1}, ..., {u}^{tr}_{T}$ represents the transition covariates where $T$ is the number of considered trials. Here, the input at each trial has a size of ${M}_{obs}$ for the observation model and ${M}_{tran}$ for the transition model. Additionally, we have a set of latent states denoted as $\mathbf{Z}={z}_{1}, ..., {z}_{T} $, and corresponding observations for these states denoted as $\mathbf{Y}={y}_{1}, ..., {y}_{T}$. For more detailed information about the model, please refer to our paper: [Identifying the factors governing internal state switches during nonstationary sensory decision-making](https://www.biorxiv.org/content/10.1101/2024.02.02.578482v2) + +# If you have any questions, please do not hesitate to contact me at zm6112 at princeton dot edu. Your engagement and feedback are highly appreciated for the improvement of this work. + +# ## Bernoulli GLM for observation + +# We employed a Bernoulli GLM to map the binary values of the animal's decision to a set of covariates. These weights serve to depict how the inputs of the model (e.g., stimulus features) influence the output (i.e., the animal's choice on each trial). The logit link function stands as the most widely adopted link function for a 2AFC choice GLM, and it can be expressed as $log(p/(1-p)) = F * \beta$, where $F$ corresponds to a design matrix, and $\beta$ represents a vector of coefficients. +# +# Consequently, we can describe an observational GLM using the following equation, where the animal choice at trial $t$, denoted by $y_{t}$, can take a value of 1 or 0, indicating the mouse turning the wheel to the right-side or left-side, respectively: + +# $$ +# \begin{align} +# \Pr(y_t \mid z_{t} = k, u_{t}^{ob}) = +# \frac{\exp\{(w_{kt}^{ob})^\mathsf{T} u_{t}^{ob} \}} +# {1+\exp\{(w_{kt}^{ob})^\mathsf{T} u_{t}^{ob} \}} +# \end{align} +# $$ + +# In this equation, the presented GLM is associated with the observation covariates, $u_{t}^{ob} \in \mathbb{R}^{{M}_{obs}}$, and observation weights $w_{kt}^{ob}$ at trial $t$ and state $z_{t} = k$. + +# ## Multinomial GLM for transition + +# The multinomial GLM is an extension of the Generalized Linear Model, specifically designed to handle data with multiple categories. It's also known as softmax regression or the maximum entropy classifier. Unlike logistic regression, which deals with binary outcomes, multinomial GLMs can simultaneously analyze data from multiple categories. They establish relationships between independent variables and categorical dependent variables, enabling the determination of the likelihood associated with each category. +# +# We explore the GLM-HMM with multinomial GLM outputs, a method for estimating the likelihood of the next state. In this framework, each state is equipped with a multinomial GLM, allowing it to model the complex relationship between transition covariates $u_{t}^{tr} \in \mathbb{R}^{{M}_{tran}}$, such as previous choice and reward, and the corresponding transition probabilities. This can be written as: + +# $$ +# \begin{align} +# \Pr(z_t=k \mid u_{t}^{tr}) = +# \frac{\exp\{(w_{kt}^{tr})^\mathsf{T} u_{t}^{tr} \}} +# {\sum_{j=1}^{K} \exp\{(w_{jt}^{tr})^\mathsf{T} u_{t}^{tr} \}} +# \end{align} +# $$ + +# where $w_{jt}^{tr}$ corresponds to the transition weights associated with j-th state at trial $t$ and $K$ represents the total number of states. + +# ## 1. Setup +# The line `import ssm` imports the package for use. Here, we have also imported a few other packages for plotting. + +# + +import numpy as np +import numpy.random as npr +import matplotlib.pyplot as plt +import ssm +import random +import pandas as pd +import seaborn as sns +from ssm.util import find_permutation + +# %matplotlib inline +npr.seed(0) +# - + +# Set the parameters of the GLM-HMM framework +time_bins = 5000 # number of data points +num_states = 3 # number of discrete states +obs_dim = 1 # number of observed dimensions +num_categories = 2 # number of categories for output +input_dim_T = 2 # Transition input dimensions +input_dim_O = 2 # Observation input dimensions + +# # 2. Defining matrices of regressors for the model +# +# Here, we generate two design matrices, one for observation and one for transition regressors. Within each matrix, a column represents a covariate. These covariates may include elements such as past choices or past stimuli, which are deemed pivotal in influencing the animal's decision-making process. + +# Specifying the regressors (past choice, past stimuli, etc.) +inpt_pc = np.array(random.choices([-1, 1], k=time_bins)) +inpt_wsls = np.array(random.choices([-1, 1], k=time_bins)) +print('inpt_pc=', inpt_pc) +print('inpt_wsls=', inpt_wsls) + +# # 2a. Designing input matrix for Transition +# In this section, we define the inputs for the transition matrix. While we're illustrating a simple case here, these specific regressors play a pivotal role in the transitions between different states within the GLM-HMM. To exemplify the inclusion of time and history in the regressors, we have applied an exponential filter to the transition regressors. + +design_mat_T = np.zeros((time_bins, input_dim_T)) # transition design matrix + + +# + +# Creating an exponential filter for the transition regressors +def ewma_time_series(values, period): + df_ewma = pd.DataFrame(data=np.array(values)) + ewma_data = df_ewma.ewm(span=period) + ewma_data_mean = ewma_data.mean() + return ewma_data_mean + +Taus = [5] # time constant for the exponential filter +add = np.array(Taus).shape[0] + +for i, tau in enumerate(Taus): + design_mat_T[:, i] = ewma_time_series(inpt_pc, tau)[0] + +for i, tau in enumerate(Taus): + design_mat_T[:, i+add] = ewma_time_series(inpt_wsls, tau)[0] + +transition_input = design_mat_T +print('design_mat_T=', design_mat_T) +# - + +# # 2b. Designing input matrix for Observation + +design_mat_Obs = np.zeros((time_bins, input_dim_O)) # observation design matrix +design_mat_Obs[:, 0] = inpt_pc +design_mat_Obs[:, 1] = inpt_wsls +observation_input = design_mat_Obs +print('design_mat_Obs=', design_mat_Obs) + +# # 3. Creating a GLM-HMM +# In this section, we make a GLM-HMM with the following transition and observation components: +# +# +# ```python +# true_glmhmm = ssm.HMM_TO(num_states, obs_dim, M_trans, M_obs, observations="input_driven_obs_diff_inputs", observation_kwargs=dict(C=num_categories), transitions="inputdrivenalt") +# ``` +# +# This function has two sections: +# +# **a) Observation model:** +# The observation model is a categorical class indicated by `observations="input_driven_obs_diff_inputs"`. Within this model, the animal's choices are influenced by a range of inputs into the system. This observation class is particularly suited for scenarios where transitions are driven by external inputs. Also, `M_obs` represents the number of covariates utilized in the observation model. +# +# We can determine the number of response categories with `observation_kwargs=dict(C=num_categories)`. Here, `C = 2` since there are only two possible choices for the animal, making the observations binary. +# +# +# **b) Transition model:** +# The transition model, specified as `transitions="inputdrivenalt"` is a multiclass logistic regression in which `M_trans` is the number of covariates influencing the transitions between states. In this model, multiple regressors play a key role in determining the transitions between states. +# +# The model's number of states is determined by `num_states`, which represents the number of hypothesized strategies in this task. + +# Make a GLM-HMM +true_glmhmm = ssm.HMM_TO(num_states, obs_dim, M_trans=input_dim_T, M_obs=input_dim_O, + observations="input_driven_obs_diff_inputs", observation_kwargs=dict(C=num_categories), + transitions="inputdrivenalt") + +# # 3a. Initializing the model +# To ensure that the model accurately reflects the mouse behavior, we must bring the GLM-HMM into an appropriate parameter regime (see Mohammadi et al. (2024)), and provide a good initialization for the observation and transition weights. +# +# By carefully adjusting these parameters and initializing the model accurately, we aim to create a model that closely mirrors the behavioral patterns observed in mice. This fine-tuning process will contribute to a better representation of real-world scenarios, enhancing the model's reliability in studying mice behavior. + +# + +gen_weights = np.array([[[5, 2]], [[2, -4]], [[-3, 4]]]) +gen_log_trans_mat = np.log(np.array([[[0.94, 0.03, 0.03], [0.05, 0.90, 0.05], [0.04, 0.04, 0.92]]])) +Ws_transition = np.array([[[3, 1], [2, .6]]]) + +true_glmhmm.observations.params = gen_weights +true_glmhmm.transitions.params[0][:] = gen_log_trans_mat +true_glmhmm.transitions.params[1][None] = Ws_transition + +# + +# Plot generative parameters: +fig = plt.figure(figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k') +cols = ['darkviolet', 'gold', 'chocolate'] + +# 1) Observation +gen_weights_obs = gen_weights + +plt.subplot(1, 2, 1) +for k in range(num_states): + if k ==0: + plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D', + color=cols[k], linestyle='-', lw=1.5) + else: + plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D', + color=cols[k], linestyle='-', lw=1.5) + +plt.yticks(fontsize=10) +plt.ylabel("Weight value", fontsize=15) +plt.xticks([0, 1], ['Obs_in1', 'Obs_in2'], fontsize=12, rotation=45) +plt.axhline(y=0, color="k", alpha=0.5, ls="--") +plt.title("Observation GLM") + +# 2) transition +gen_log_trans_mat = true_glmhmm.transitions.params[0] +gen_weights_Trans = true_glmhmm.transitions.params[1] +generative_weights_Trans = true_glmhmm.trans_weights_K(true_glmhmm.params, num_states) + +plt.subplot(1, 2, 2) +for k in range(num_states): + if k ==0: + plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D', + color=cols[k], linestyle='-', lw=1.5) + else: + plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D', + color=cols[k], linestyle='-', lw=1.5) + +plt.yticks(fontsize=10) +plt.axhline(y=0, color="k", alpha=0.5, ls="--") +plt.xticks([0, 1], ['Tran_in1', 'Tran_in2'], fontsize=12, rotation=45) +plt.title("Transition GLM") +plt.show() +# - + +# # 3b. Sample data from the GLM-HMM + +# + +# Sample some data from the GLM-HMM +true_states, obs = true_glmhmm.sample(time_bins, transition_input=transition_input, observation_input=observation_input) + +# Plot the data +T= 500 +fig = plt.figure(figsize=(8, 2), dpi=80, facecolor='w', edgecolor='k') +plt.imshow(true_states[None, :], aspect="auto") +plt.xticks([]) +plt.xlim(0, T) +plt.yticks([]) +plt.title("States", fontsize=15) + +# For visualizing categorical observations, we create a Cmap. +obs_flat = np.array([x[0] for x in obs]) #SSM initially provides categorical observations as a list of lists, and we transform them into a 1D array to facilitate plotting. +fig = plt.figure(figsize=(8, 2), dpi=80, facecolor='w', edgecolor='k') +plt.imshow(obs_flat[None,:], aspect="auto") +plt.xlim(0, T) +plt.xlabel("trial #", fontsize=12) +plt.yticks([]) +plt.grid(b=None) +plt.title("Observations", fontsize=15) +plt.show() +# - + +# Calculate the actual log likelihood by summing over discrete states +true_lp = true_glmhmm.log_probability(obs, transition_input=transition_input, observation_input=observation_input) +print("true_lp = " + str(true_lp)) + +# # 4. Make a new HMM for fitting purpose +# +# In this section, we instantiate a new GLM-HMM and assess its ability to recover generative parameters using Maximum Likelihood Estimation (MLE) on simulated data. MLE enables us to optimize model parameters to match the underlying data generation process, providing insights into the model's performance in emulating various scenarios. + +# ## EM for fitting + +# We employ the Expectation-Maximization (EM) method to fit the data. The EM algorithm consists of two main steps: the E-step and the M-step. The algorithm starts with an initial guess for the model parameters and iterates until the log marginal likelihood converges. During the E-step, the algorithm computes the expected value of the complete-data log-likelihood based on the model parameters estimate and the observed data (animal choices). Next, the M-Step tries to find the parameters that maximize the expected log-likelihood. + +# To elaborate further, during each trial and based on the specified GLM-HMM parameters, we compute the joint probability distribution encompassing both the states and the animals' decisions (left or right). Subsequently, the log-likelihood of the model is evaluated using this joint probability distribution. This relationship can be expressed in the following manner: +# +# $$ +# \begin{align} +# \log \left[ p(\mathbf{Y}|\theta, \mathbf{X}^{ob}, \mathbf{X}^{tr})\right]= \log \left[ \sum_{z} +# p(\mathbf{Y}, \mathbf{Z}|\theta, \mathbf{X}^{ob}, \mathbf{X}^{tr})\right] +# \end{align} +# $$ +# +# In this model, as mentioned, we have defined two sets of covariates, $\mathbf{X}^{ob}={x}^{ob}_{1}, ..., {x}^{ob}_{T}$ and $\mathbf{X}^{tr}={x}^{tr}_{1}, ..., {x}^{tr}_{T}$ which correspond to the observation and transition covariates respectively and a set of latent states as $\mathbf{Z}={z}_{1}, ..., {z}_{T} $. +# The model parameters, represented as $\theta=\{ \mathbf{w}^{tr}, \mathbf{w}^{ob}, \pi\}$, encompass the initial state distribution, transition weights, and observation weights for all states. + +glmhmm = ssm.HMM_TO(num_states, obs_dim, input_dim_T, input_dim_O, observations="input_driven_obs_diff_inputs", + observation_kwargs=dict(C=num_categories), transitions="inputdrivenalt") + +# # 4a. Fit the new HMM +# Here, we'll fit the model to simulated data. Through this fitting process, our model tries to capture the underlying dependencies that characterize the behavior of interest. This allows us to gain valuable insights into how well our model aligns with the data and its potential to generalize its findings to other cases. + +# Fitting the model +N_iters = 200 +hmm_lps = glmhmm.fit(obs, transition_input=transition_input, observation_input=observation_input, method="em", num_iters=N_iters, tolerance=10**-4) + +# # 4b. Graphically represent the acquired parameters +# + +# Visualize the Learned Parameters +# Plot the log probabilities of the true and fit models +fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k') +plt.plot(hmm_lps, label="EM") +plt.plot([0, N_iters], true_lp * np.ones(2), ':k', label="True") +plt.legend(loc="lower right") +plt.xlabel("EM Iteration") +plt.xlim(0, len(hmm_lps)) +plt.ylabel("Log Probability") +plt.show() + +glmhmm.permute(find_permutation(true_states, glmhmm.most_likely_states(obs, transition_input=transition_input, observation_input=observation_input))) + +# + +# Plot generative parameters +# 1) Observation +gen_weights_obs = gen_weights +recovered_weights = glmhmm.observations.params + +fig = plt.figure(figsize=(10, 3), dpi=80, facecolor='w', edgecolor='k') +plt.subplot(1, 2, 1) +for k in range(num_states): + if k == 0: + plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D', color=cols[k], linestyle='-', + lw=1.5, label="generative") + else: + plt.plot(range(input_dim_O), gen_weights_obs[k][0], marker='D', color=cols[k], linestyle='-', + lw=1.5, label="") + +plt.yticks(fontsize=10) +plt.ylabel("Weight value", fontsize=12) +plt.xticks([0, 1], ['Obs_inp1', 'Obs_inp2'], fontsize=12, rotation=30) +plt.axhline(y=0, color="k", alpha=0.5, ls="--") +plt.title("Observation GLM", fontsize=12) + +for k in range(num_states): + if k == 0: + plt.plot(range(input_dim_O), recovered_weights[k][0], color=cols[k], + lw=1.5, label = "recovered", linestyle = '--') + else: + plt.plot(range(input_dim_O), recovered_weights[k][0], color=cols[k], + lw=1.5, label = '', linestyle = '--') +plt.yticks(fontsize=10) +plt.axhline(y=0, color="k", alpha=0.5, ls="--") +plt.legend() + +# 2) transition +gen_log_trans_mat = true_glmhmm.transitions.params[0] +gen_weights_Trans = true_glmhmm.transitions.params[1] +recovered_trans_mat = np.exp(glmhmm.transitions.params[0]) + +plt.subplot(1, 2, 2) +recovered_weights_Trans = glmhmm.trans_weights_K(glmhmm.params, num_states) +generative_weights_Trans = true_glmhmm.trans_weights_K(true_glmhmm.params, num_states) + +for k in range(num_states): + if k == 0: + plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D', color=cols[k], linestyle='-', + lw=1.5, label="generative") + plt.plot(range(input_dim_T), recovered_weights_Trans[k], color=cols[k], + lw=1.5, label = "recovered", linestyle = '--') + else: + plt.plot(range(input_dim_T), generative_weights_Trans[k], marker='D', + color=cols[k], linestyle='-', lw=1.5, label="") + plt.plot(range(input_dim_T), recovered_weights_Trans[k], color=cols[k], + lw=1.5, label = '', linestyle = '--') + +plt.yticks(fontsize=10) +plt.title("Transition GLM", fontsize=12) +plt.axhline(y=0, color="k", alpha=0.5, ls="--") +plt.xticks([0, 1], ['Tran_inp1', 'Tran_inp2'], fontsize=12, rotation=30) +plt.show() + +# + +# transition matrix +recovered_matrix = glmhmm.Ps_matrix(data=obs, transition_input=transition_input, observation_input=observation_input) # , train_mask=train_mask)[0] +gen_trans_mat = np.exp(gen_log_trans_mat) + +fig = plt.figure(figsize=(6, 3), dpi=80, facecolor='w', edgecolor='k') +plt.subplot(1, 2, 1) +plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone') +for i in range(gen_trans_mat.shape[0]): + for j in range(gen_trans_mat.shape[1]): + text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha="center", va="center", + color="k", fontsize=12) +plt.xlim(-0.5, num_states - 0.5) +plt.ylim(num_states - 0.5, -0.5) +plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10) +plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10) +plt.ylabel("state t", fontsize=15) +plt.xlabel("state t+1", fontsize=15) +plt.title("generative", fontsize=15) + +plt.subplot(1, 2, 2) +plt.imshow(np.mean(recovered_matrix, axis=0), vmin=-0.8, vmax=1, cmap='bone') +for i in range(np.mean(recovered_matrix, axis=0).shape[0]): + for j in range(np.mean(recovered_matrix, axis=0).shape[1]): + text = plt.text(j, i, str(np.around(np.mean(recovered_matrix, axis=0)[i, j], decimals=2)), ha="center", va="center", + color="k", fontsize=12) +plt.xlim(-0.5, num_states - 0.5) +plt.ylim(num_states - 0.5, -0.5) +plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10) +plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10) +plt.title("recovered", fontsize=15) +plt.show() +# - + +# # 4c. Analysis of the acquired states +# + +# + +# Get expected states +posterior_probs = [glmhmm.expected_states(data = obs, transition_input = transition_input, observation_input=observation_input)[0]] + +# Determine the state with the highest posterior probability +posterior_max = np.argmax(posterior_probs[0], axis = 1) + + +# + +# calculate state fractional occupancies +_, occur_for_state = np.unique(posterior_max, return_counts=True) +sum_all = np.sum(occur_for_state) +occur_for_state = occur_for_state/sum_all + +fig = plt.figure(figsize=(2.5, 2.5), dpi=80, facecolor='w', edgecolor='k') + +for state, occur in enumerate(occur_for_state): + occur_perc = occur * 100 + plt.bar(state, occur_perc, width = 0.7, color = cols[state]) + +plt.ylim((0, .6)) +plt.xticks([0, 1, 2], ['1', '2', '3'], fontsize=12) +plt.yticks([0, 25, 50], ['0', '25', '50'], fontsize=12) +plt.xlabel('state', fontsize=15) +plt.ylabel('Occupancy (%)', fontsize=15) +plt.gca().spines['right'].set_visible(False) +plt.gca().spines['top'].set_visible(False) +plt.show() + +# + +fig = plt.figure(figsize=(12, 2.5), dpi=80, facecolor='w', edgecolor='k') +for k in range(num_states): + plt.plot(posterior_probs[0][0:200, k], label="State " + str(k + 1), lw=1, marker='*', + color=cols[k]) + +plt.ylim((-0.01, 1.01)) +plt.yticks([0, 0.5, 1], fontsize=10) +plt.xticks(fontsize=12) +plt.xlabel("trial number", fontsize=15) +plt.ylabel("Posterior prob.", fontsize=15) +plt.gca().spines['right'].set_visible(False) +plt.gca().spines['top'].set_visible(False) +plt.show() + +# + +# plot choices and latents: +plt.figure(figsize=(8, 3.5)) +time_bin= 500 + +plt.subplot(211) +plt.imshow(true_states[None, :], aspect="auto") +plt.xticks([]) +plt.xlim(0, time_bin) +plt.ylabel("true\nstate", fontsize=14) +plt.yticks([]) + +plt.subplot(212) +inferred_states = glmhmm.most_likely_states(obs, transition_input=transition_input, observation_input=observation_input) +plt.imshow(inferred_states[None, :], aspect="auto") +plt.xlim(0, time_bin) +plt.ylabel("inferred\nstate", fontsize=14) +plt.yticks([]) +plt.xlabel("trial #", fontsize=12) +plt.show() +# - + + diff --git a/setup.cfg b/setup.cfg index c6be352f..98d4a402 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,7 +27,7 @@ doc = memory_profiler # measuring memory during docs building mkl jupytext - myst-nb + myst-nb myst-parser numpydoc sphinx diff --git a/ssm/__init__.py b/ssm/__init__.py index eb52d12d..e694ed9b 100644 --- a/ssm/__init__.py +++ b/ssm/__init__.py @@ -1,4 +1,4 @@ # Default imports for SSM - +from .hmm_TO import * from .hmm import * from .lds import * \ No newline at end of file diff --git a/ssm/hmm_TO.py b/ssm/hmm_TO.py new file mode 100644 index 00000000..20d04bb0 --- /dev/null +++ b/ssm/hmm_TO.py @@ -0,0 +1,551 @@ +from functools import partial +from tqdm.auto import trange + +import autograd.numpy as np +import autograd.numpy.random as npr +from autograd import value_and_grad + +from ssm.optimizers import adam_step, rmsprop_step, sgd_step, convex_combination +from ssm.primitives import hmm_normalizer +from ssm.messages import hmm_expected_states, hmm_filter, hmm_sample, viterbi +from ssm.util import ensure_args_are_lists, ensure_args_not_none_modified, ensure_args_not_none, \ + ensure_slds_args_not_none, ensure_variational_args_are_lists, \ + replicate, collapse, ssm_pbar, ensure_args_are_lists_modified + +import ssm.observations as obs +import ssm.transitions as trans +import ssm.init_state_distns as isd +import ssm.hierarchical as hier +import ssm.emissions as emssn + +__all__ = ['HMM_TO'] + + +class HMM_TO(object): + """ + Base class for hidden Markov models with observation and transition inputs. + + Notation: + K: number of discrete latent states + D: dimensionality of observations + M_obs: dimensionality of observation inputs + M_trans: dimensionality of transition inputs + + In the code we will sometimes refer to the discrete + latent state sequence as z and the data as x. + """ + + def __init__(self, K, D, M_trans=0, M_obs=0, init_state_distn=None, + transitions='standard', + transition_kwargs=None, + hierarchical_transition_tags=None, + observations="gaussian", observation_kwargs=None, + hierarchical_observation_tags=None, **kwargs): + + # Make the initial state distribution + if init_state_distn is None: + init_state_distn = isd.InitialStateDistribution(K, D, M=M_trans) + if not isinstance(init_state_distn, isd.InitialStateDistribution): + raise TypeError("'init_state_distn' must be a subclass of" + " ssm.init_state_distns.InitialStateDistribution.") + + # Make the transition model + transition_classes = dict( + standard=trans.StationaryTransitions, + stationary=trans.StationaryTransitions, + constrained=trans.ConstrainedStationaryTransitions, + sticky=trans.StickyTransitions, + inputdriven=trans.InputDrivenTransitions, + inputdrivenalt=trans.InputDrivenTransitionsAlternativeFormulation, + recurrent=trans.RecurrentTransitions, + recurrent_only=trans.RecurrentOnlyTransitions, + rbf_recurrent=trans.RBFRecurrentTransitions, + nn_recurrent=trans.NeuralNetworkRecurrentTransitions + ) + + if isinstance(transitions, str): + if transitions not in transition_classes: + raise Exception("Invalid transition model: {}. Must be one of {}". + format(transitions, list(transition_classes.keys()))) + + transition_kwargs = transition_kwargs or {} + transitions = \ + hier.HierarchicalTransitions(transition_classes[transitions], K, D, M=M_trans, + tags=hierarchical_transition_tags, + **transition_kwargs) \ + if hierarchical_transition_tags is not None \ + else transition_classes[transitions](K, D, M=M_trans, **transition_kwargs) + if not isinstance(transitions, trans.Transitions): + raise TypeError("'transitions' must be a subclass of" + " ssm.transitions.Transitions") + + # This is the master list of observation classes. + # When you create a new observation class, add it here. + observation_classes = dict( + gaussian=obs.GaussianObservations, + diagonal_gaussian=obs.DiagonalGaussianObservations, + studentst=obs.MultivariateStudentsTObservations, + t=obs.MultivariateStudentsTObservations, + diagonal_t=obs.StudentsTObservations, + diagonal_studentst=obs.StudentsTObservations, + exponential=obs.ExponentialObservations, + bernoulli=obs.BernoulliObservations, + categorical=obs.CategoricalObservations, + input_driven_obs=obs.InputDrivenObservations, + input_driven_obs_diff_inputs=obs.InputDrivenObservationsDiffInputs, + poisson=obs.PoissonObservations, + vonmises=obs.VonMisesObservations, + ar=obs.AutoRegressiveObservations, + autoregressive=obs.AutoRegressiveObservations, + no_input_ar=obs.AutoRegressiveObservationsNoInput, + diagonal_ar=obs.AutoRegressiveDiagonalNoiseObservations, + diagonal_autoregressive=obs.AutoRegressiveDiagonalNoiseObservations, + independent_ar=obs.IndependentAutoRegressiveObservations, + robust_ar=obs.RobustAutoRegressiveObservations, + no_input_robust_ar=obs.RobustAutoRegressiveObservationsNoInput, + robust_autoregressive=obs.RobustAutoRegressiveObservations, + diagonal_robust_ar=obs.RobustAutoRegressiveDiagonalNoiseObservations, + diagonal_robust_autoregressive=obs.RobustAutoRegressiveDiagonalNoiseObservations, + ) + + if isinstance(observations, str): + observations = observations.lower() + if observations not in observation_classes: + raise Exception("Invalid observation model: {}. Must be one of {}". + format(observations, list(observation_classes.keys()))) + + observation_kwargs = observation_kwargs or {} + observations = \ + hier.HierarchicalObservations(observation_classes[observations], K, D, M_obs=M_obs, + tags=hierarchical_observation_tags, + **observation_kwargs) \ + if hierarchical_observation_tags is not None \ + else observation_classes[observations](K, D, M_obs=M_obs, **observation_kwargs) + if not isinstance(observations, obs.Observations): + raise TypeError("'observations' must be a subclass of" + " ssm.observations.Observations") + + self.K, self.D, self.M_trans, self.M_obs = K, D, M_trans, M_obs + self.init_state_distn = init_state_distn + self.transitions = transitions + self.observations = observations + + @property + def params(self): + return self.init_state_distn.params, \ + self.transitions.params, \ + self.observations.params + + @params.setter + def params(self, value): + self.init_state_distn.params = value[0] + self.transitions.params = value[1] + self.observations.params = value[2] + + @ensure_args_are_lists_modified + def initialize(self, datas, transition_input=None, observation_input=None, masks=None, tags=None, + init_method="random"): + """ + Initialize parameters given data. + """ + self.init_state_distn.initialize(datas, inputs=observation_input, masks=masks, tags=tags) + self.transitions.initialize(datas, inputs=transition_input, masks=masks, tags=tags) + self.observations.initialize(datas, inputs=observation_input, masks=masks, tags=tags, init_method=init_method) + + def permute(self, perm): + """ + Permute the discrete latent states. + """ + assert np.all(np.sort(perm) == np.arange(self.K)) + self.init_state_distn.permute(perm) + self.transitions.permute(perm) + self.observations.permute(perm) + + def sample(self, T, prefix=None, transition_input=None, observation_input=None, tag=None, with_noise=True): + """ + Sample synthetic data from the model. Optionally, condition on a given + prefix (preceding discrete states and data). + + Parameters + ---------- + T : int + number of time steps to sample + + prefix : (zpre, xpre) + Optional prefix of discrete states (zpre) and continuous states (xpre) + zpre must be an array of integers taking values 0...num_states-1. + xpre must be an array of the same length that has preceding observations. + + transition_input : (T, transition_input_dim) array_like + Optional transition inputs to specify for sampling + + observation_input : (T, observation_input_dim) array_like + Optional observation inputs to specify for sampling + + tag : object + Optional tag indicating which "type" of sampled data + + with_noise : bool + Whether or not to sample data with noise. + + Returns + ------- + z_sample : array_like of type int + Sequence of sampled discrete states + + x_sample : (T x observation_dim) array_like + Array of sampled data + """ + K = self.K + D = (self.D,) if isinstance(self.D, int) else self.D + M_trans = (self.M_trans,) if isinstance(self.M_trans, int) else self.M_trans + M_obs = (self.M_obs,) if isinstance(self.M_obs, int) else self.M_obs + + assert isinstance(D, tuple) + assert isinstance(M_trans, tuple) + assert isinstance(M_obs, tuple) + assert T > 0 + + # Check the transition_input + if transition_input is not None: + assert transition_input.shape == (T,) + M_trans + + # Check the observation_input + if observation_input is not None: + assert observation_input.shape == (T,) + M_obs + + # Get the type of the observations + if isinstance(self.observations, obs.InputDrivenObservationsDiffInputs): + dtype = int + else: + dummy_data = self.observations.sample_x(0, np.empty(0, ) + D) + dtype = dummy_data.dtype + + # Fit the data array + if prefix is None: + # No prefix is given. Sample the initial state as the prefix. + pad = 1 + z = np.zeros(T, dtype=int) + data = np.zeros((T,) + D, dtype=dtype) + transition_input = np.zeros((T,) + M_trans) if transition_input is None else transition_input + observation_input = np.zeros((T,) + M_obs) if observation_input is None else observation_input + + mask = np.ones((T,) + D, dtype=bool) + + # Sample the first state from the initial distribution + pi0 = self.init_state_distn.initial_state_distn + z[0] = npr.choice(self.K, p=pi0) + data[0] = self.observations.sample_x(z[0], data[:0], observation_input=observation_input[0], + with_noise=with_noise) + + # We only need to sample T-1 data points now + T = T - 1 + + else: + # Check that the prefix is of the right type + zpre, xpre = prefix + pad = len(zpre) + assert zpre.dtype == int and zpre.min() >= 0 and zpre.max() < K + assert xpre.shape == (pad,) + D + + # Construct the states, data, transition_input, observation_input and mask arrays + z = np.concatenate((zpre, np.zeros(T, dtype=int))) + data = np.concatenate((xpre, np.zeros((T,) + D, dtype))) + transition_input = np.zeros((T + pad,) + M_trans) if transition_input is None else np.concatenate( + (np.zeros((pad,) + M_trans), transition_input)) + observation_input = np.zeros((T + pad,) + M_obs) if observation_input is None else np.concatenate( + (np.zeros((pad,) + M_obs), observation_input)) + mask = np.ones((T + pad,) + D, dtype=bool) + + # Fill in the rest of the data + for t in range(pad, pad + T): + Pt = self.transitions.transition_matrices(data[t - 1:t + 1], transition_input[t - 1:t + 1], + mask=mask[t - 1:t + 1], tag=tag)[0] + z[t] = npr.choice(self.K, p=Pt[z[t - 1]]) + data[t] = self.observations.sample_x(z[t], data[:t], observation_input=observation_input[t], tag=tag, + with_noise=with_noise) + + # Return the whole data if no prefix is given. + # Otherwise, just return the simulated part. + if prefix is None: + return z, data + else: + return z[pad:], data[pad:] + + @ensure_args_not_none_modified + def expected_states(self, data, transition_input=None, observation_input=None, mask=None, tag=None): + pi0 = self.init_state_distn.initial_state_distn + Ps = self.transitions.transition_matrices(data, transition_input, mask, tag) + log_likes = self.observations.log_likelihoods(data, observation_input, mask, tag) + return hmm_expected_states(pi0, Ps, log_likes) + + def Ps_matrix(self, data, transition_input=None, observation_input=None, mask=None, tag=None): + Ps = self.transitions.transition_matrices(data, transition_input, mask, tag) + return Ps + + @ensure_args_not_none_modified + def most_likely_states(self, data, transition_input=None, observation_input=None, mask=None, tag=None): + pi0 = self.init_state_distn.initial_state_distn + Ps = self.transitions.transition_matrices(data, transition_input, mask, tag) + log_likes = self.observations.log_likelihoods(data, observation_input, mask, tag) + return viterbi(pi0, Ps, log_likes) + + @ensure_args_not_none_modified + def filter(self, data, transition_input=None, observation_input=None, mask=None, tag=None): + pi0 = self.init_state_distn.initial_state_distn + Ps = self.transitions.transition_matrices(data, transition_input, mask, tag) + log_likes = self.observations.log_likelihoods(data, observation_input, mask, tag) + return hmm_filter(pi0, Ps, log_likes) + + @ensure_args_not_none_modified + def smooth(self, data, transition_input=None, observation_input=None, mask=None, tag=None): + """ + Compute the mean observation under the posterior distribution + of latent discrete states. + """ + Ez, _, _ = self.expected_states(data, transition_input, observation_input, mask) + return self.observations.smooth(Ez, data, transition_input, observation_input, tag) + + def log_prior(self): + """ + Compute the log prior probability of the model parameters + """ + return self.init_state_distn.log_prior() + \ + self.transitions.log_prior() + \ + self.observations.log_prior() + + @ensure_args_are_lists_modified + def log_likelihood(self, datas, transition_input=None, observation_input=None, masks=None, tags=None): + """ + Compute the log probability of the data under the current + model parameters. + + :param datas: single array or list of arrays of data. + :return total log probability of the data. + """ + ll = 0 + for data, transition_input, observation_input, mask, tag in zip(datas, transition_input, observation_input, + masks, tags): + pi0 = self.init_state_distn.initial_state_distn + Ps = self.transitions.transition_matrices(data, transition_input, mask, tag) + log_likes = self.observations.log_likelihoods(data, observation_input, mask, tag) + ll += hmm_normalizer(pi0, Ps, log_likes) + assert np.isfinite(ll) + return ll + + @ensure_args_are_lists_modified + def log_probability(self, datas, transition_input=None, observation_input=None, masks=None, tags=None): + return self.log_likelihood(datas, transition_input, observation_input, masks, tags) + self.log_prior() + + def expected_log_likelihood(self, expectations, datas, transition_inputs=None, observation_inputs=None, masks=None, + tags=None): + """ + Compute log-likelihood given current model parameters. + + :param datas: single array or list of arrays of data. + :return total log probability of the data. + """ + ell = 0.0 + for (Ez, Ezzp1, _), data, transition_input, observation_input, mask, tag in \ + zip(expectations, datas, transition_inputs, observation_inputs, masks, tags): + pi0 = self.init_state_distn.initial_state_distn + log_Ps = self.transitions.log_transition_matrices(data, transition_input, mask, tag) + log_likes = self.observations.log_likelihoods(data, observation_input, mask, tag) + + ell += np.sum(Ez[0] * np.log(pi0)) + ell += np.sum(Ezzp1 * log_Ps) + ell += np.sum(Ez * log_likes) + assert np.isfinite(ell) + + return ell + + def expected_log_probability(self, expectations, datas, transition_inputs=None, observation_inputs=None, masks=None, + tags=None): + """ + Compute the log-probability of the data given current + model parameters. + """ + ell = self.expected_log_likelihood(expectations, datas, transition_inputs=transition_inputs, + observation_inputs=observation_inputs, masks=masks, tags=tags) + return ell + self.log_prior() + + def trans_weights_K(self, hmm_params, K): + """ + Standardize the GLM transition weights. + """ + trans_weight_append_zero = np.vstack((hmm_params[1][1], np.zeros((1, hmm_params[1][1].shape[1])))) + permutation = range(K) + trans_weight_append_zero_standard = trans_weight_append_zero + v1 = - np.mean(trans_weight_append_zero, axis=0) + trans_weight_append_zero_standard[-1, :] = v1 + for i in range(K - 1): + trans_weight_append_zero_standard[i, :] = v1 + trans_weight_append_zero[i, :] # vi = v1 + wi + weight_vectors_trans = trans_weight_append_zero_standard[permutation] + return weight_vectors_trans + + # Model fitting + def _fit_sgd(self, optimizer, datas, transition_input, observation_input, masks, tags, verbose=2, num_iters=1000, + **kwargs): + """ + Fit the model with maximum marginal likelihood. + """ + T = sum([data.shape[0] for data in datas]) + + def _objective(params, itr): + self.params = params + obj = self.log_probability(datas, transition_input, observation_input, masks, tags) + return -obj / T + + # Set up the progress bar + lls = [-_objective(self.params, 0) * T] + pbar = ssm_pbar(num_iters, verbose, "Epoch {} Itr {} LP: {:.1f}", [0, 0, lls[-1]]) + + # Run the optimizer + step = dict(sgd=sgd_step, rmsprop=rmsprop_step, adam=adam_step)[optimizer] + state = None + for itr in pbar: + self.params, val, g, state = step(value_and_grad(_objective), self.params, itr, state, **kwargs) + lls.append(-val * T) + if verbose == 2: + pbar.set_description("LP: {:.1f}".format(lls[-1])) + pbar.update(1) + return lls + + def _fit_stochastic_em(self, optimizer, datas, transition_input, observation_input, masks, tags, verbose=2, + num_epochs=100, **kwargs): + """ + Replace the M-step of EM with a stochastic gradient update using the ELBO computed + on a minibatch of data. + """ + M = len(datas) + T = sum([data.shape[0] for data in datas]) + + # A helper to grab a minibatch of data + perm = [np.random.permutation(M) for _ in range(num_epochs)] + + def _get_minibatch(itr): + epoch = itr // M + m = itr % M + i = perm[epoch][m] + return datas[i], transition_input[i], observation_input[i], masks[i], tags[i][i] + + # Define the objective (negative ELBO) + def _objective(params, itr): + # Grab a minibatch of data + data, transition_input, observation_input, mask, tag = _get_minibatch(itr) + Ti = data.shape[0] + + # E step: compute expected latent states with current parameters + Ez, Ezzp1, _ = self.expected_states(data, transition_input, observation_input, mask, tag) + + # M step: set the parameter and compute the (normalized) objective function + self.params = params + pi0 = self.init_state_distn.initial_state_distn + log_Ps = self.transitions.log_transition_matrices(data, transition_input, mask, tag) + log_likes = self.observations.log_likelihoods(data, observation_input, mask, tag) + + # Compute the expected log probability + # (Scale by number of length of this minibatch.) + obj = self.log_prior() + obj += np.sum(Ez[0] * np.log(pi0)) * M + obj += np.sum(Ezzp1 * log_Ps) * (T - M) / (Ti - 1) + obj += np.sum(Ez * log_likes) * T / Ti + assert np.isfinite(obj) + + return -obj / T + + # Set up the progress bar + lls = [-_objective(self.params, 0) * T] + pbar = ssm_pbar(num_epochs * M, verbose, "Epoch {} Itr {} LP: {:.1f}", [0, 0, lls[-1]]) + + # Run the optimizer + step = dict(sgd=sgd_step, rmsprop=rmsprop_step, adam=adam_step)[optimizer] + state = None + for itr in pbar: + self.params, val, _, state = step(value_and_grad(_objective), self.params, itr, state, **kwargs) + epoch = itr // M + m = itr % M + lls.append(-val * T) + if verbose == 2: + pbar.set_description("Epoch {} Itr {} LP: {:.1f}".format(epoch, m, lls[-1])) + pbar.update(1) + return lls + + def _fit_em(self, datas, transition_input, observation_input, masks, tags, verbose=2, num_iters=100, tolerance=0, + init_state_mstep_kwargs={}, + transitions_mstep_kwargs={}, + observations_mstep_kwargs={}, + **kwargs): + """ + Fit the parameters with expectation maximization. + + E step: compute E[z_t] and E[z_t, z_{t+1}] with message passing; + M-step: analytical maximization of E_{p(z | x)} [log p(x, z; theta)]. + """ + lls = [self.log_probability(datas, transition_input, observation_input, masks, tags)] + pbar = ssm_pbar(num_iters, verbose, "LP: {:.1f}", [lls[-1]]) + + for itr in pbar: + # E step: compute expected latent states with current parameters + expectations = [self.expected_states(data, transition_input, observation_input, mask, tag) + for data, transition_input, observation_input, mask, tag, + in zip(datas, transition_input, observation_input, masks, tags)] + + # M step: maximize expected log joint wrt parameters + self.init_state_distn.m_step_modified(expectations, datas, transition_input, observation_input, masks, tags, + **init_state_mstep_kwargs) + self.transitions.m_step(expectations, datas, transition_input, masks, tags, **transitions_mstep_kwargs) + self.observations.m_step(expectations, datas, observation_input, masks, tags, **observations_mstep_kwargs) + + # Store progress + lls.append(self.log_prior() + sum([ll for (_, _, ll) in expectations])) + + if verbose == 2: + pbar.set_description("LP: {:.1f}".format(lls[-1])) + + # Check for convergence + if itr > 0 and abs(lls[-1] - lls[-2]) < tolerance: + if verbose == 2: + pbar.set_description("Converged to LP: {:.1f}".format(lls[-1])) + break + + return lls + + @ensure_args_are_lists_modified + def fit(self, datas, transition_input=None, observation_input=None, masks=None, tags=None, + verbose=2, method="em", + initialize=True, + init_method="random", + **kwargs): + + _fitting_methods = \ + dict(sgd=partial(self._fit_sgd, "sgd"), + adam=partial(self._fit_sgd, "adam"), + em=self._fit_em, + stochastic_em=partial(self._fit_stochastic_em, "adam"), + stochastic_em_sgd=partial(self._fit_stochastic_em, "sgd"), + ) + + if method not in _fitting_methods: + raise Exception("Invalid method: {}. Options are {}". + format(method, _fitting_methods.keys())) + + if initialize: + self.initialize(datas, + transition_input=None, observation_input=None, + masks=masks, + tags=tags, + init_method=init_method) + + if isinstance(self.transitions, + trans.ConstrainedStationaryTransitions): + if method != "em": + raise Exception("Only EM is implemented for constrained transitions.") + + return _fitting_methods[method](datas, + transition_input=transition_input, + observation_input=observation_input, + masks=masks, + tags=tags, + verbose=verbose, + **kwargs) diff --git a/ssm/init_state_distns.py b/ssm/init_state_distns.py index 21c22c25..187cb369 100644 --- a/ssm/init_state_distns.py +++ b/ssm/init_state_distns.py @@ -47,6 +47,9 @@ def m_step(self, expectations, datas, inputs, masks, tags, **kwargs): pi0 = sum([Ez[0] for Ez, _, _ in expectations]) + 1e-8 self.log_pi0 = np.log(pi0 / pi0.sum()) + def m_step_modified(self, expectations, datas, transition_input, observation_input, masks, tags, **kwargs): + pi0 = sum([Ez[0] for Ez, _, _ in expectations]) + 1e-8 + self.log_pi0 = np.log(pi0 / pi0.sum()) class FixedInitialStateDistribution(InitialStateDistribution): def __init__(self, K, D, pi0=None, M=0): diff --git a/ssm/observations.py b/ssm/observations.py index a27824df..7d008984 100644 --- a/ssm/observations.py +++ b/ssm/observations.py @@ -819,6 +819,297 @@ def smooth(self, expectations, data, input, tag): """ raise NotImplementedError +class InputDrivenObservationsDiffInputs(Observations): + + def __init__(self, K, D, M_obs=0, C=2, prior_mean=0, prior_sigma=1000): + """ + @param K: number of states + @param D: dimensionality of output + @param C: number of distinct classes for each dimension of output + @param prior_sigma: parameter governing strength of prior. Prior on GLM weights is multivariate + normal distribution with mean 'prior_mean' and diagonal covariance matrix (prior_sigma is on diagonal) + """ + super(InputDrivenObservationsDiffInputs, self).__init__(K, D, M_obs) + self.C = C + self.M_obs = M_obs + self.D = D + self.K = K + self.prior_mean = prior_mean + self.prior_sigma = prior_sigma + # Parameters linking input to distribution over output classes + self.Wk = npr.randn(K, C - 1, M_obs) + + @property + def params(self): + return self.Wk + + @params.setter + def params(self, value): + self.Wk = value + + def permute(self, perm): + self.Wk = self.Wk[perm] + + def log_prior(self): + lp = 0 + for k in range(self.K): + for c in range(self.C - 1): + weights = self.Wk[k][c] + lp += stats.multivariate_normal_logpdf(weights, mus=np.repeat(self.prior_mean, (self.M_obs)), + Sigmas=((self.prior_sigma) ** 2) * np.identity(self.M_obs)) + return lp + + # Calculate time dependent logits - output is matrix of size TxKxC + # Input is size TxM + def calculate_logits(self, observation_input): + """ + Return array of size TxKxC containing log(pr(yt=C|zt=k)) + :param observation_input: observation_input array of covariates of size TxM_obs + :return: array of size TxKxC containing log(pr(yt=c|zt=k, ut)) for all c in {1, ..., C} and k in {1, ..., K} + """ + # Transpose array dimensions, so that array is now of shape ((C-1)xKx(M+1)) + Wk_tranpose = np.transpose(self.Wk, (1, 0, 2)) + # Stack column of zeros to transform array from size ((C-1)xKx(M_obs+1)) to ((C)xKx(M_obs+1)) and then transform shape back to (KxCx(M_obs+1)) + Wk = np.transpose(np.vstack([Wk_tranpose, np.zeros((1, Wk_tranpose.shape[1], Wk_tranpose.shape[2]))]), + (1, 0, 2)) + # Input effect; transpose so that output has dims TxKxC + time_dependent_logits = np.transpose(np.dot(Wk, observation_input.T), (2, 0, + 1)) # Note: this has an unexpected effect when both input (and thus Wk) are empty arrays and returns an array of zeros + time_dependent_logits = time_dependent_logits - logsumexp(time_dependent_logits, axis=2, keepdims=True) + return time_dependent_logits + + def log_likelihoods(self, data, observation_input, mask, tag): + if observation_input.ndim == 1 and observation_input.shape == ( + self.M_obs,): # if input is vector of size self.M_obs (one time point), expand dims to be (1, M_obs) + observation_input = np.expand_dims(observation_input, axis=0) + time_dependent_logits = self.calculate_logits(observation_input) + assert self.D == 1, "InputDrivenObservationsDiffInputs written for D = 1!" + mask = np.ones_like(data, dtype=bool) if mask is None else mask + return stats.categorical_logpdf(data[:, None, :], time_dependent_logits[:, :, None, :], mask=mask[:, None, :]) + + def sample_x(self, z, xhist, observation_input=None, tag=None, with_noise=True): + assert self.D == 1, "InputDrivenObservationsDiffInputs written for D = 1!" + if observation_input.ndim == 1 and observation_input.shape == (self.M_obs,): + observation_input = np.expand_dims(observation_input, axis=0) + time_dependent_logits = self.calculate_logits(observation_input) # size TxKxC + ps = np.exp(time_dependent_logits) + T = time_dependent_logits.shape[0] + + if T == 1: + sample = np.array([npr.choice(self.C, p=ps[t, z]) for t in range(T)]) + elif T > 1: + sample = np.array([npr.choice(self.C, p=ps[t, z[t]]) for t in range(T)]) + return sample + + def m_step(self, expectations, datas, observation_input, masks, tags, optimizer="bfgs", **kwargs): + + T = sum([data.shape[0] for data in datas]) # total number of data points: time_bins + + def _multisoftplus(X): + ''' + computes f(X) = log(1+sum(exp(X), axis =1)) and its first derivative + :param X: array of size Tx(C-1) + :return f(X) of size T and df of size (Tx(C-1)) + ''' + X_augmented = np.append(X, np.zeros((X.shape[0], 1)), + 1) # append a column of zeros to X for rowmax calculation + rowmax = np.max(X_augmented, axis=1, + keepdims=1) # get max along column for log-sum-exp trick, rowmax is size T + # compute f: + f = np.log(np.exp(-rowmax[:, 0]) + np.sum(np.exp(X - rowmax), axis=1)) + rowmax[:, 0] + # compute df + df = np.exp(X - rowmax) / np.expand_dims((np.exp(-rowmax[:, 0]) + np.sum(np.exp(X - rowmax), axis=1)), + axis=1) + return f, df + + def _objective(params, k): + ''' + computes term in negative expected complete loglikelihood that depends on weights for state k + :param params: vector of size (C-1)xM_obs + :return term in negative expected complete LL that depends on weights for state k; scalar value + ''' + W = np.reshape(params, (self.C - 1, self.M_obs)) + obj = 0 + for data, input, mask, tag, (expected_states, _, _) \ + in zip(datas, observation_input, masks, tags, expectations): + xproj = input @ W.T # projection of input onto weight matrix for particular state, size is Tx(C-1) + f, _ = _multisoftplus(xproj) + assert data.shape[1] == 1, "InputDrivenObservationsDiffInputs written for D = 1!" + data_one_hot = one_hot(data[:, 0], self.C) # convert to one-hot representation of size TxC + temp_obj = (-np.sum(data_one_hot[:, :-1] * xproj, axis=1) + f) @ expected_states[:, k] + obj += temp_obj + + # add contribution of prior: + if self.prior_sigma != 0: + obj += 1 / (2 * self.prior_sigma ** 2) * np.sum(W ** 2) + return obj / T + + def _gradient(params, k): + ''' + Explicit calculation of gradient of _objective w.r.t weight matrix for state k, W_{k} + :param params: vector of size (C-1)xM_obs + :param k: state whose parameters we are currently optimizing + :return gradient of objective with respect to parameters; vector of size (C-1)xM_obs + ''' + W = np.reshape(params, (self.C - 1, self.M_obs)) + grad = np.zeros((self.C - 1, self.M_obs)) + for data, input, mask, tag, (expected_states, _, _) \ + in zip(datas, observation_input, masks, tags, expectations): + xproj = input @ W.T # projection of input onto weight matrix for particular state, size is Tx(C-1) + _, df = _multisoftplus(xproj) + assert data.shape[1] == 1, "InputDrivenObservationsDiffInputs written for D = 1!" + data_one_hot = one_hot(data[:, 0], self.C) # convert to one-hot representation of size TxC + grad += (df - data_one_hot[:, :-1]).T @ ( + expected_states[:, [k]] * input) # gradient is shape (C-1,M_obs) + # Add contribution to gradient from prior: + if self.prior_sigma != 0: + grad += (1 / (self.prior_sigma) ** 2) * W + # Now flatten grad into a vector: + grad = grad.flatten() + return grad / T + + def _hess(params, k): + ''' + Explicit calculation of hessian of _objective w.r.t weight matrix for state k, W_{k} + :param params: vector of size (C-1)xM_obs + :param k: state whose parameters we are currently optimizing + :return hessian of objective with respect to parameters; matrix of size ((C-1)xM_obs) x ((C-1)xM_obs) + ''' + W = np.reshape(params, (self.C - 1, self.M_obs)) + hess = np.zeros(((self.C - 1) * self.M_obs, (self.C - 1) * self.M_obs)) + for data, input, mask, tag, (expected_states, _, _) \ + in zip(datas, observation_input, masks, tags, expectations): + xproj = input @ W.T # projection of input onto weight matrix for particular state + _, df = _multisoftplus(xproj) + # center blocks: + dftensor = np.expand_dims(df, axis=2) # dims are now (T, (C-1), 1) + Xdf = np.expand_dims(input, + axis=1) * dftensor # multiply every input covariate term with every class derivative term for a given time step; dims are now (T, (C-1), M) + # reshape Xdf to (T, (C-1)*M_obs) + Xdf = np.reshape(Xdf, (Xdf.shape[0], -1)) + # weight Xdf by posterior state probabilities + pXdf = expected_states[:, [k]] * Xdf # output is size (T, (C-1)*M_obs) + # outer product with input vector, size (M_obs, (C-1)*M_obs) + XXdf = input.T @ pXdf + # center blocks of hessian: + temp_hess = np.zeros(((self.C - 1) * self.M_obs, (self.C - 1) * self.M_obs)) + for c in range(1, self.C): + inds = range((c - 1) * self.M_obs, c * self.M_obs) + temp_hess[np.ix_(inds, inds)] = XXdf[:, inds] + # off diagonal entries: + hess += temp_hess - Xdf.T @ pXdf + # add contribution of prior to hessian + if self.prior_sigma != 0: + hess += (1 / self.prior_sigma ** 2) + return hess / T + + from scipy.optimize import minimize + # Optimize weights for each state separately: + for k in range(self.K): + def _objective_k(params): + return _objective(params, k) + + def _gradient_k(params): + return _gradient(params, k) + + def _hess_k(params): + return _hess(params, k) + + sol = minimize(_objective_k, self.params[k].reshape(((self.C - 1) * self.M_obs)), hess=_hess_k, + jac=_gradient_k, method="trust-ncg") + self.params[k] = np.reshape(sol.x, (self.C - 1, + self.M_obs)) # for InputDrivenObservationsDiffInputs class: comment out if you want to stop observation weights being updated + + def smooth(self, expectations, data, observation_input, tag): + """ + Compute the mean observation under the posterior distribution + of latent discrete states. + """ + raise NotImplementedError + + +class _AutoRegressiveObservationsBase(Observations): + """ + Base class for autoregressive observations of the form, + + E[x_t | x_{t-1}, z_t=k, u_t] + = \sum_{l=1}^{L} A_k^{(l)} x_{t-l} + b_k + V_k u_t. + + where L is the number of lags and u_t is the input. + """ + + def __init__(self, K, D, M=0, lags=1): + super(_AutoRegressiveObservationsBase, self).__init__(K, D, M) + + # Distribution over initial point + self.mu_init = np.zeros((K, D)) + + # AR parameters + assert lags > 0 + self.lags = lags + self.bs = npr.randn(K, D) + self.Vs = npr.randn(K, D, M) + + # Inheriting classes may treat _As differently + self._As = None + + @property + def As(self): + return self._As + + @As.setter + def As(self, value): + self._As = value + + @property + def params(self): + return self.As, self.bs, self.Vs + + @params.setter + def params(self, value): + self.As, self.bs, self.Vs = value + + def permute(self, perm): + self.mu_init = self.mu_init[perm] + self.As = self.As[perm] + self.bs = self.bs[perm] + self.Vs = self.Vs[perm] + + def _compute_mus(self, data, input, mask, tag): + # assert np.all(mask), "ARHMM cannot handle missing data" + K, M = self.K, self.M + T, D = data.shape + As, bs, Vs, mu0s = self.As, self.bs, self.Vs, self.mu_init + + # Instantaneous inputs + mus = np.empty((K, T, D)) + mus = [] + for k, (A, b, V, mu0) in enumerate(zip(As, bs, Vs, mu0s)): + # Initial condition + mus_k_init = mu0 * np.ones((self.lags, D)) + + # Subsequent means are determined by the AR process + mus_k_ar = np.dot(input[self.lags:, :M], V.T) + for l in range(self.lags): + Al = A[:, l * D:(l + 1) * D] + mus_k_ar = mus_k_ar + np.dot(data[self.lags - l - 1:-l - 1], Al.T) + mus_k_ar = mus_k_ar + b + + # Append concatenated mean + mus.append(np.vstack((mus_k_init, mus_k_ar))) + + return np.array(mus) + + def smooth(self, expectations, data, input, tag): + """ + Compute the mean observation under the posterior distribution + of latent discrete states. + """ + T = expectations.shape[0] + mask = np.ones((T, self.D), dtype=bool) + mus = np.swapaxes(self._compute_mus(data, input, mask, tag), 0, 1) + return (expectations[:, :, None] * mus).sum(1) + class _AutoRegressiveObservationsBase(Observations): """ diff --git a/ssm/transitions.py b/ssm/transitions.py index 79d5e432..23dc3b07 100644 --- a/ssm/transitions.py +++ b/ssm/transitions.py @@ -262,6 +262,99 @@ def neg_hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_j T, D = data.shape return np.zeros((T-1, D, D)) + +class InputDrivenTransitionsAlternativeFormulation(StickyTransitions): + # This class contains K-1 weight vectors so as to cope with degeneracy + """ + Hidden Markov Model whose transition probabilities are + determined by a generalized linear model applied to the + exogenous input. This has K-1 weight vectors so as to cope with degeneracy. + """ + def __init__(self, K, D, M, prior_sigma=1000, alpha=1, kappa=0): + """ + @param K: number of states + @param D: dimensionality of output + @param C: number of distinct classes for each dimension of output + @param prior_sigma: parameter governing strength of prior. Prior on GLM weights is multivariate + normal distribution with mean 'prior_mean' and diagonal covariance matrix (prior_sigma is on diagonal) + """ + + super(InputDrivenTransitionsAlternativeFormulation, self).__init__(K, D, M=M, alpha=alpha, kappa=kappa) + + # Parameters linking input to state distribution + self.Ws = npr.randn(K-1, M) + + # Regularization of Ws + # self.l2_penalty = l2_penalty + self.prior_sigma = prior_sigma + # self.global_fit = global_fit + + @property + def params(self): + return [self.log_Ps, self.Ws] + + @params.setter + def params(self, value): + [self.log_Ps, self.Ws] = value + + def permute(self, perm): + """ + Permute the discrete latent states. + """ + self.log_Ps = self.log_Ps[np.ix_(perm, perm)] + self.Ws = np.vstack([self.Ws, np.zeros((1, self.Ws.shape[1]))]) + self.Ws = self.Ws[perm] + + def log_prior(self): + lp = super(InputDrivenTransitionsAlternativeFormulation, self).log_prior() + lp = lp + np.sum(-0.5 * (1 / (self.prior_sigma ** 2)) * self.Ws**2) + return lp + + def log_transition_matrices(self, data, input, mask, tag): + T = np.array(data).shape[0] + assert np.array(input).shape[0] == T + # Previous state effect + log_Ps = np.tile(self.log_Ps[None, :, :], (T-1, 1, 1)) + # Append column of zeros so that Ws_with_zeros is now KxM + Ws_with_zeros = np.vstack([self.Ws, np.zeros((1, self.Ws.shape[1]))]) + if self.Ws.shape[0] > input[1:].shape[1]: # If it already has a column of zeros + Ws_with_zeros=self.Ws + # Input effect + log_Ps = log_Ps + np.dot(input[1:], Ws_with_zeros.T)[:, None, :] + normalized_Ps = log_Ps - logsumexp(log_Ps, axis=2, keepdims=True) + return normalized_Ps + + def m_step(self, expectations, datas, inputs, masks, tags, + optimizer="lbfgs", num_iters=1000, **kwargs): + optimizer = dict(sgd=sgd, adam=adam, rmsprop=rmsprop, bfgs=bfgs, lbfgs=lbfgs)[optimizer] + # Maximize the expected log joint + def _expected_log_joint(expectations): + elbo = self.log_prior() + for data, input, mask, tag, (expected_states, expected_joints, _) \ + in zip(datas, inputs, masks, tags, expectations): + log_Ps = self.log_transition_matrices(data, input, mask, tag) + K = np.array(log_Ps).shape[1] + elbo += np.sum(expected_joints * log_Ps) + return elbo + + T = sum([data.shape[0] for data in datas]) + + def _objective(params, itr): + self.params = params + obj = _expected_log_joint(expectations) + return -obj / T + + # Call the optimizer. Persist state (e.g. SGD momentum) across calls to m_step + optimizer_state = self.optimizer_state if hasattr(self, "optimizer_state") else None + self.params, self.optimizer_state = \ + optimizer(_objective, self.params, num_iters=num_iters, + state=optimizer_state, full_output=True, **kwargs) + + def neg_hessian_expected_log_trans_prob(self, data, input, mask, tag, expected_joints): + # Return (T-1, D, D) array of blocks for the diagonal of the Hessian + T, D = data.shape + return np.zeros((T-1, D, D)) + class RecurrentTransitions(InputDrivenTransitions): """ Generalization of the input driven HMM in which the observations serve as future inputs diff --git a/ssm/util.py b/ssm/util.py index 093b9cc6..3eb60eae 100644 --- a/ssm/util.py +++ b/ssm/util.py @@ -113,6 +113,42 @@ def wrapper(self, datas, inputs=None, masks=None, tags=None, **kwargs): return wrapper +def ensure_args_are_lists_modified(f): + def wrapper(self, datas, transition_input=None, observation_input=None, masks=None, tags=None, **kwargs): + + datas = [datas] if not isinstance(datas, (list, tuple)) else datas + + M_obs = (self.M_obs,) if isinstance(self.M_obs, int) else self.M_obs + assert isinstance(M_obs, tuple) + + M_trans = (self.M_trans,) if isinstance(self.M_trans, int) else self.M_trans + assert isinstance(M_trans, tuple) + + if transition_input is None: + transition_input = [np.zeros((data.shape[0],) + M_trans) for data in datas] + elif not isinstance(transition_input, (list, tuple)): + transition_input = [transition_input] + + if observation_input is None: + observation_input = [np.zeros((data.shape[0],) + M_obs) for data in datas] + elif not isinstance(observation_input, (list, tuple)): + observation_input = [observation_input] + + if masks is None: + masks = [np.ones_like(data, dtype=bool) for data in datas] + elif not isinstance(masks, (list, tuple)): + masks = [masks] + + if tags is None: + tags = [None] * len(datas) + elif not isinstance(tags, (list, tuple)): + tags = [tags] + + return f(self, datas, transition_input=transition_input, observation_input=observation_input, masks=masks, tags=tags, **kwargs) + + return wrapper + + def ensure_variational_args_are_lists(f): def wrapper(self, arg0, datas, inputs=None, masks=None, tags=None, **kwargs): datas = [datas] if not isinstance(datas, (list, tuple)) else datas @@ -158,6 +194,24 @@ def wrapper(self, data, input=None, mask=None, tag=None, **kwargs): return f(self, data, input=input, mask=mask, tag=tag, **kwargs) return wrapper +def ensure_args_not_none_modified(f): + + def wrapper(self, data, transition_input=None, observation_input=None, mask=None, tag=None, **kwargs): + assert data is not None + + M_obs = (self.M_obs,) if isinstance(self.M_obs, int) else self.M_obs + assert isinstance(M_obs, tuple) + + M_trans = (self.M_trans,) if isinstance(self.M_trans, int) else self.M_trans + assert isinstance(M_trans, tuple) + + transition_input = np.zeros((data.shape[0],) + M) if transition_input is None else transition_input + observation_input = np.zeros((data.shape[0],) + M) if observation_input is None else observation_input + + mask = np.ones_like(data, dtype=bool) if mask is None else mask + + return f(self, data, transition_input=transition_input, observation_input=observation_input, mask=mask, tag=tag, **kwargs) + return wrapper def ensure_slds_args_not_none(f): def wrapper(self, variational_mean, data, input=None, mask=None, tag=None, **kwargs):