Demo 6 (PS): Simulating Pose Segmentation Data

Written and last updated September 10, 2025 by Sedona Ewbank, snewbank@stanford.edu

The purpose of this demo is to demonstrate how to use the simulation functions of MARIPoSA, specifically for generating three different kinds of output data:

  1. Time-series categorical pose segmentation data (generated by measuring and applying transition frequency)

  2. Pose module usage data without labels

  3. Pose module usage data with labels from a regression model

import os
import sys
import numpy as np
import pandas as pd
import importlib
import matplotlib.pyplot as plt

#This block only important for running as script
#script_dir = os.path.dirname(os.path.abspath(__file__))
mariposa_dir = "/Users/snewbank/PycharmProjects/MARIPoSA"
utils_dir = os.path.join(mariposa_dir, 'utils')
sys.path.append(utils_dir)
sys.path.append(mariposa_dir)

#import utils
from utils import metadata, analyze, plot, simulate

importlib.reload(metadata)
importlib.reload(analyze)
importlib.reload(plot)
importlib.reload(simulate)
<module 'utils.simulate' from '/Users/snewbank/PycharmProjects/MARIPoSA/utils/simulate.py'>
demo_dir="/Users/snewbank/Behavior/MARIPoSA_demo_data/"
config=metadata.load_project(demo_dir+"250719_test/config_PS.yaml")
save=True
save_path = demo_dir+"demo_figures/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

6.1: Simulating time-series categorical pose segmentation data

The most granular form of simulation in MARIPoSA is to simulate time-series categorical pose segmentation data to generate new data emulating existing data. This is acheived in simulate.generate_sequence(), which outputs a pose module labels dataframe in the same format as analyze.get_module_labels (with pose module usage and transition probabilities based on the input labels_df). The simulated labels dataframe output will have the same group structure as the input data (if there are subgroups, it will have subgroups).

labels_df = analyze.get_module_labels(config,0,1200,subgroups=["sal"])
sim_labels_df = simulate.generate_sequence(config,labels_df,1200*int(config["fps"]),n_subjs=20,verbose=True)
Analyzing input labels dataframe
/Users/snewbank/PycharmProjects/MARIPoSA/utils/simulate.py:218: RuntimeWarning: invalid value encountered in divide
  transition_mat = transition_mat / np.sum(transition_mat, axis=1, keepdims=True)
Generating data for subject 1 of 20
Generating data for subject 2 of 20
Generating data for subject 3 of 20
Generating data for subject 4 of 20
Generating data for subject 5 of 20
Generating data for subject 6 of 20
Generating data for subject 7 of 20
Generating data for subject 8 of 20
Generating data for subject 9 of 20
Generating data for subject 10 of 20
Generating data for subject 11 of 20
Generating data for subject 12 of 20
Generating data for subject 13 of 20
Generating data for subject 14 of 20
Generating data for subject 15 of 20
Generating data for subject 16 of 20
Generating data for subject 17 of 20
Generating data for subject 18 of 20
Generating data for subject 19 of 20
Generating data for subject 20 of 20
real_usage = analyze.get_module_usage(config,labels_df)
plot.plot_module_usage(config,real_usage,cmap="Blues_r")
plt.title("Real data")
real_usage = analyze.get_module_usage(config,labels_df,binsize=60)
plot.module_usage_sandplot(config,real_usage)
plt.title("Real data")

sim_usage = analyze.get_module_usage(config,sim_labels_df)
plot.plot_module_usage(config,sim_usage,cmap="Reds_r")
plt.title("Simulated data")
sim_usage = analyze.get_module_usage(config,sim_labels_df,binsize=60)
plot.module_usage_sandplot(config,sim_usage)
plt.title("Simulated data")
Text(0.5, 1.0, 'Simulated data')
../_images/cbdfcb773064a359d630cfa0c3ae772d0a2e0f6283ddbce10da80d6a009f36cd.png ../_images/78f085edc185a5755dd6bf4cadbab40b18042ee9da64cb3ebb248ce6fce5fdb2.png ../_images/0187188f28af06f438b962b9397ced6017f14abfdd39d66344b729b39c9f0ed3.png ../_images/5e27414adb9b283d832830f539dd6092f688cc8f2672573baa7882c524ae8202.png
labels_df = analyze.get_module_labels(config,0,1200,subgroups=["sal","k1","k5","k10"])
sim_labels_df = simulate.generate_sequence(config,labels_df,1200*int(config["fps"]),n_subjs=5)
/Users/snewbank/PycharmProjects/MARIPoSA/utils/simulate.py:253: RuntimeWarning: invalid value encountered in divide
  transition_mat = transition_mat / np.sum(transition_mat, axis=1, keepdims=True)
/Users/snewbank/PycharmProjects/MARIPoSA/utils/simulate.py:253: RuntimeWarning: invalid value encountered in divide
  transition_mat = transition_mat / np.sum(transition_mat, axis=1, keepdims=True)
/Users/snewbank/PycharmProjects/MARIPoSA/utils/simulate.py:253: RuntimeWarning: invalid value encountered in divide
  transition_mat = transition_mat / np.sum(transition_mat, axis=1, keepdims=True)
/Users/snewbank/PycharmProjects/MARIPoSA/utils/simulate.py:253: RuntimeWarning: invalid value encountered in divide
  transition_mat = transition_mat / np.sum(transition_mat, axis=1, keepdims=True)
real_usage = analyze.get_module_usage(config,labels_df)
plot.plot_module_usage(config,real_usage,cmap="viridis",figH=4,figW=8)
plt.title("Real data")


sim_usage = analyze.get_module_usage(config,sim_labels_df)
plot.plot_module_usage(config,sim_usage,cmap="viridis",figH=4,figW=8)
plt.title("Simulated data")
Text(0.5, 1.0, 'Simulated data')
../_images/a558821f68c682988724eb6884db64c96b046f06e5a8515b197a51b6a6d78e93.png ../_images/2b402acbddd2def39fb085270409b3dcbe20dc81fdf765cba930d1a3dbe6b9cf.png

6.2: Simulating module usage

Since many analyses can be performed on pose module usage summary data without the full time-series categorical sequence of pose modules, it is also possible to just simulate pose module usage based on input data, as shown below, using simulate.generate_usage (to generate module usage with no group labels) or simulate.generate_usage_labeled (to generate module usage with continuous labels based on a regression model).

labels_df = analyze.get_module_labels(config, 0, 1200, subgroups=["sal","k1","k5","k10"])
module_usage = analyze.get_module_usage(config, labels_df)
plot.plot_module_usage(config,module_usage,style="stacked",legend=False,alt_xticks=[0,1,5,10])

dose_dict={'sal': 0, 'k1': 1, 'k5': 5, 'k10': 10}
reg, dose_labels = analyze.regress(module_usage, dose_dict)

bin_low = np.arange(0,10,0.5)
bin_high = np.arange(0,10,0.5)+0.5
bins = np.array([bin_low,bin_high]).T

#sim_module_usage_sal = simulate.generate_usage(sal_module_usage, 1000, random_state=rs, mode="multivariate_gaussian")
sim_module_usage = simulate.generate_usage_labeled(module_usage, 30, bins, reg, max_iters = 1e100, mode="multivariate_gaussian",verbosity="medium")

ticks=[]
for i in list(sim_module_usage.group_dict.keys()): 
    if i%2==0:
        ticks.append(str(int(i)))
    else:
        ticks.append("")

plot.plot_module_usage(config,sim_module_usage,style="stacked",legend=False,alt_xticks=ticks)
Iter 500: 69.3% samples generated (18 of 20 bins not filled)
Iter 1000: 95.3% samples generated (4 of 20 bins not filled)
Iter 1500: 99.8% samples generated (1 of 20 bins not filled)
../_images/dcb372cf636bccf5ef3c62b07bb86a67f4719c5a7bcfeda7db690b7cdecbad90.png ../_images/01f3624a0841d791cd123cfee386f63210c861d68cda8a3f2185b616ba8d0438.png ../_images/dcb372cf636bccf5ef3c62b07bb86a67f4719c5a7bcfeda7db690b7cdecbad90.png