import scipy.io
from scipy import stats
import matplotlib.pyplot as plt
from matplotlib import rcParams, gridspec
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
import numpy as np
import pandas as pd
import networkx as nx
rcParams['font.family'] = 'sans-serif'
rcParams['font.sans-serif'] = ['Arial']
from matplotlib.patches import Patch, Ellipse, Rectangle
from utils import analyze
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.decomposition import PCA
[docs]def is_nonnum(value):
try:
int(value)
return False
except (ValueError, TypeError):
return True
[docs]def fig_to_array(fig):
canvas = FigureCanvas(fig)
canvas.draw()
width, height = canvas.get_width_height()
image = np.frombuffer(canvas.tostring_rgb(), dtype='uint8')
image = image.reshape(height, width, 3)
return image
[docs]def moving_average(x, w):
return np.convolve(x, np.ones(w), 'valid') / w
[docs]def make_and_plot_ellipse(mean,
cov,
color,
label=None):
eigenvalues, eigenvectors = np.linalg.eig(cov)
angle = np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]) * 180 / np.pi
ell = Ellipse(mean, width=2 * np.sqrt(eigenvalues[0]), height=2 * np.sqrt(eigenvalues[1]),
angle=angle, facecolor=color, alpha=0.25, label=label,edgecolor="none")
plt.gca().add_patch(ell)
[docs]def plot_module_usage(config,
usage_feats,
figW=4,
figH=2,
style="bar_scatter",
cmap="jet",
legend_pos="outside_right",
remap=False,
legend=True,
long_legend=False,
alt_labels=None,
alt_xticks=None,
title=None,
plot_stats=False):
"""
Plot module usage
:param config: config
:param usage_feats: output of analyze.get_module_usage
:param figW: figure width (default: 4)
:param figH: figure height (default: 2)
:param style: plot style; "bar_scatter", "bar_error", "points", or "stacked"
:param cmap: colormap (default: jet)
:param legend_pos: legend position (default: outside)
:param remap: whether to remap modules acording to config["remappings"] (default: False)
:param legend: boolean to include or not include legend (default: False)
:param long_legend: long legend (default: False)
:param alt_labels: alternative group label dictionary (default: None)
:param alt_xticks: alternative xticklabels (list), for stacked plot style only (default: None)
:param title: plot title (default: None)
:param plot_stats: include stats on plot (default: False)
:return: fig
"""
if usage_feats.__class__.__name__!="ModuleUsage":
raise ValueError(f'usage_feats object class must be ModuleUsage, not {usage_feats.__class__}')
if ((style!="bar_scatter") and (style!="bar_error") and (style!="stacked") and (style!="points")):
raise ValueError(f'style must be one of "bar_scatter", "bar_error", "stacked", or "points", not {style}')
if len(np.unique(usage_feats.group_labels)) == 1:
data_subgrouped = False
else:
data_subgrouped = True
if not data_subgrouped:
usage_df = usage_feats.to_df()
usage_df.drop("group", axis=1, inplace=True)
modules = usage_feats.feat_names
n_modules = len(modules)
bar_heights = np.mean(usage_df, axis=0)
bar_sems = np.std(usage_df, axis=0) / np.sqrt(usage_df.shape[1])
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
cmap = plt.get_cmap(cmap)
if ((style == "bar_scatter") or (style == "bar_error")):
ax.bar(
x=np.arange(0, n_modules, 1),
height=bar_heights,
width=0.8,
alpha=0.5,
color=cmap([0.1])
)
if style == "bar_scatter":
for i in range(len(usage_df.index)):
ax.scatter(np.arange(0, n_modules, 1) + np.random.normal(0, 0.01, n_modules),
usage_df.iloc[i],
color="black",
s=0.5
)
elif style == "bar_error":
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights,
yerr=bar_sems,
linestyle="none", linewidth=0.6,
color="black", capsize=1, markeredgewidth=0.75
)
elif style == "points":
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights,
yerr=bar_sems,
color=cmap([0.1]),
linestyle="none",
marker="o", markersize=2.5, linewidth=0.75,
capsize=2, markeredgewidth=0.75
)
ax.set_xlabel(config["data_source"] + ' Pose Label')
ax.set_ylabel('Usage')
ax.set_xticks(np.arange(0, n_modules, 1), labels=[i.split("module")[1] for i in modules])
ax.tick_params(axis='x', rotation=90, labelsize=plt.rcParams['font.size'] * 0.5, pad=2)
plt.tight_layout()
else:
# To get groupnames in order
groupnames = list(usage_feats.group_dict.keys())
n_groups = len(groupnames)
usage_df = usage_feats.to_df()
modules = usage_feats.feat_names
n_modules = len(modules)
bar_heights = np.zeros([n_groups, n_modules])
bar_sems = np.zeros([n_groups, n_modules])
for g, group in enumerate(groupnames):
subgroup_usage = usage_df[usage_df["group"] == group].copy()
subgroup_usage.drop("group", axis=1, inplace=True)
bar_heights[g, :] = subgroup_usage.mean(axis=0)
bar_sems[g, :] = subgroup_usage.std(axis=0) / np.sqrt(subgroup_usage.shape[0])
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
scale = 1 / (n_groups + .7)
cmap = plt.get_cmap(cmap)
colors = [cmap([i]) for i in np.linspace(0, 1, n_groups)]
if ((style == "bar_scatter") or (style == "bar_error")):
if style == "bar_scatter":
for g, group in enumerate(groupnames):
subgroup_usage = usage_df[usage_df["group"] == group].copy()
subgroup_usage.drop("group", axis=1, inplace=True)
for i in subgroup_usage.index:
ax.scatter(
np.arange(0 + scale * g, n_modules + scale * g, 1) + np.random.normal(0, 0.1 * scale,
n_modules),
subgroup_usage.loc[i],
color="black",
s=0.5
)
bar_alpha = 0.5
elif style == "bar_error":
for g in range(n_groups):
ax.errorbar(
x=np.arange(0 + scale * g, n_modules + scale * g, 1),
y=bar_heights[g],
yerr=bar_sems[g],
linestyle="none", linewidth=0.6,
color="black", capsize=0.3, markeredgewidth=0.75, alpha=0.8
)
bar_alpha = 0.85
for g in range(n_groups):
if alt_labels is not None:
label_g = alt_labels[groupnames[g]]
else:
label_g = groupnames[g]
ax.bar(
x=np.arange(0 + scale * g, n_modules + scale * g, 1),
height=bar_heights[g],
width=scale,
alpha=bar_alpha,
color=colors[g],
label=label_g
)
elif style == "points":
for g in range(n_groups):
if alt_labels is not None:
label_g = alt_labels[groupnames[g]]
else:
label_g = groupnames[g]
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights[g],
yerr=bar_sems[g],
color=colors[g],
label=groupnames[g],
linestyle="none",
marker="o", markersize=4, linewidth=0.75,
capsize=2, markeredgewidth=0.75
)
if style != "stacked":
if legend:
if legend_pos == "inside":
ax.legend()
elif legend_pos == "outside_right":
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
elif legend_pos == "outside_above":
ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2), ncol=4)
else:
print("Warning - invalid legend position given (should be inside, outside_right, or outside_above), using default")
ax.legend()
ax.set_xlabel(config["data_source"] + ' Pose Label')
ax.set_ylabel('Usage')
all_num_modules = np.sum([isinstance(module, str) for module in modules]) == 0
if ((all_num_modules) and (n_modules >= 20)):
xticks = np.arange(0, n_modules, 5, dtype=int)
ax.set_xticks(xticks)
ax.set_xticklabels(modules[xticks])
else:
ax.set_xticks(np.arange(0, n_modules, 1))
try:
ax.set_xticklabels([i.split("module")[1] for i in modules])
except:
ax.set_xticklabels(modules)
ax.tick_params(axis='x', rotation=90, labelsize=plt.rcParams['font.size'] * 0.2 * n_groups, pad=2)
if plot_stats:
ylim = plt.ylim()
plt.ylim(ylim[0], ylim[1] * 1.2)
stat_result = usage_feats.f_oneway()
for m, module in enumerate(stat_result["module"]):
if stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.001:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "***", ha="center")
elif stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.01:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "**", ha="center")
elif stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.05:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "*", ha="center")
plt.tight_layout()
else:
any_nonint = False
if not remap:
for m, module in enumerate(modules):
if any_nonint:
module = m
if m == 0:
bar_bottom = np.zeros(n_groups)
ax.bar(np.arange(0, n_groups, 1), bar_heights[:, modules.index(module)], bottom=bar_bottom, align='center',
width=0.99)
ax.spines['top'].set_visible(False)
bar_bottom += bar_heights[:, modules.index(module)]
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
else:
modules = [int(i.split("module")[1]) for i in modules]
try:
behaviors = sorted(np.unique([remap[1] for remap in config["remappings"]]))
except TypeError:
raise TypeError("Issue remapping modules to behavior classes - is the config remapping section up-to-date?")
if "other" not in behaviors:
behaviors = behaviors + ["other"]
n_behaviors = len(behaviors)
leg_handles_col0 = []
leg_handles_col1 = []
leg_labels = [""] * n_behaviors
bottom = True
cm_list = ['magma', 'YlOrBr', 'Blues', 'copper', 'Greys_r', 'BuGn', 'Reds_r', 'magma_r', 'Purples', 'YlGn',
'YlGnBu', 'YlOrRd']
modules_plotted=[] #To account for modules not seen in remapping
for remap in range(n_behaviors):
sub_modules = [int(i[0][0]) for i in config["remappings"] if i[1] == behaviors[remap]]
if (behaviors[remap]=="other"):
print([mod for mod in modules if ((mod not in modules_plotted) and (mod not in sub_modules))])
cmap = plt.get_cmap(cm_list[remap])
colors = [cmap([0.4]), cmap([0.5])]
leg_handles_col0.extend([Patch(facecolor=colors[0], edgecolor='none')])
leg_handles_col1.extend([Patch(facecolor=colors[1], edgecolor='none')])
if long_legend:
leg_labels.append(behaviors[remap] + " (" + str(len(sub_modules)) + " modules)")
else:
leg_labels.append(behaviors[remap]) # +" "+str(sub_modules))
if len(sub_modules) > 0:
for c in sub_modules:
if c in modules:
if bottom == True:
bar_bottom = np.zeros(n_groups)
bottom = False
ax.bar(np.arange(0, n_groups, 1), bar_heights[:, modules.index(c)], bottom=bar_bottom, align='center',
width=0.97, color=colors[c % 2])
ax.spines['top'].set_visible(False)
bar_bottom += bar_heights[:, modules.index(c)]
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
modules_plotted.append(c)
ax.set_ylabel("Proportion of Time \nSpent in Pose Module")
ax.set_yticks([0, 0.2, 0.4, 0.6, 0.8, 1])
ax.set_xlim([-0.5, n_groups - 0.5])
ax.set_xticks(np.arange(n_groups))
if alt_xticks is None:
ax.set_xticklabels(groupnames)
else:
ax.set_xticklabels(alt_xticks)
if title is not None:
ax.set_title(title)
if legend == True:
if not remap:
modules = [int(i.split("module")[1]) for i in modules]
if len(modules) > 10:
leg_modules = []
for col in range(10):
items = [i for i in modules if int(i % 10) == col]
if len(items) > 1:
leg_modules.append("Modules " + ', '.join(map(str, items)))
else:
leg_modules.append("Module " + str(items[0]))
else:
leg_modules = modules
plt.legend(leg_modules, bbox_to_anchor=(1.05, 1))
else:
leg_handles = leg_handles_col0 + leg_handles_col1
ax.legend(handles=leg_handles,
labels=leg_labels,
ncol=2, handletextpad=0.5, handlelength=1.0, columnspacing=-0.5, bbox_to_anchor=(1.05, 1))
plt.tight_layout()
return fig
[docs]def plot_action_units(config,
action_units,
figW=4,
figH=2,
style="bar_scatter",
cmap="jet",
legend_pos="outside_right",
legend=True,
alt_labels=None,
alt_xticks=None,
title=None,
plot_stats=False):
"""
Plot mean action units
:param config: config
:param action_units: output of analyze.get_action_units
:param figW: figure width (default: 4)
:param figH: figure height (default: 2)
:param style: plot style; "bar_scatter", "bar_error", "points"
:param cmap: colormap (default: jet)
:param legend_pos: legend position (default: outside)
:param legend: boolean to include or not include legend (default: False)
:param alt_labels: alternative group label dictionary (default: None)
:param alt_xticks: alternative xticklabels (list), for stacked plot style only (default: None)
:param title: plot title (default: None)
:param plot_stats: include stats on plot (default: False)
:return: fig
"""
if action_units.__class__.__name__!="ActionUnits":
raise ValueError(f'action_units object class must be ActionUnits, not {action_units.__class__}')
if ((style!="bar_scatter") and (style!="bar_error") and (style!="points")):
raise ValueError(f'style must be one of "bar_scatter", "bar_error", or "points", not {style}')
if len(np.unique(action_units.group_labels)) == 1:
data_subgrouped = False
else:
data_subgrouped = True
if not data_subgrouped:
usage_df = action_units.to_df()
usage_df.drop("group", axis=1, inplace=True)
modules = action_units.feat_names
n_modules = len(modules)
bar_heights = np.mean(usage_df, axis=0)
bar_sems = np.std(usage_df, axis=0) / np.sqrt(usage_df.shape[1])
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
cmap = plt.get_cmap(cmap)
if ((style == "bar_scatter") or (style == "bar_error")):
ax.bar(
x=np.arange(0, n_modules, 1),
height=bar_heights,
width=0.8,
alpha=0.5,
color=cmap([0.1])
)
if style == "bar_scatter":
for i in range(len(usage_df.index)):
ax.scatter(np.arange(0, n_modules, 1) + np.random.normal(0, 0.01, n_modules),
usage_df.iloc[i],
color="black",
s=0.5
)
elif style == "bar_error":
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights,
yerr=bar_sems,
linestyle="none", linewidth=0.6,
color="black", capsize=1, markeredgewidth=0.75
)
elif style == "points":
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights,
yerr=bar_sems,
color=cmap([0.1]),
linestyle="none",
marker="o", markersize=2.5, linewidth=0.75,
capsize=2, markeredgewidth=0.75
)
ax.set_xlabel(config["data_source"] + ' Pose Label')
ax.set_ylabel('Usage')
ax.set_xticks(np.arange(0, n_modules, 1), labels=[i.split("module")[1] for i in modules])
ax.tick_params(axis='x', rotation=90, labelsize=plt.rcParams['font.size'] * 0.5, pad=2)
plt.tight_layout()
else:
# To get groupnames in order
groupnames = list(action_units.group_dict.keys())
n_groups = len(groupnames)
usage_df = action_units.to_df()
modules = action_units.feat_names
n_modules = len(modules)
bar_heights = np.zeros([n_groups, n_modules])
bar_sems = np.zeros([n_groups, n_modules])
for g, group in enumerate(groupnames):
subgroup_usage = usage_df[usage_df["group"] == group].copy()
subgroup_usage.drop("group", axis=1, inplace=True)
bar_heights[g, :] = subgroup_usage.mean(axis=0)
bar_sems[g, :] = subgroup_usage.std(axis=0) / np.sqrt(subgroup_usage.shape[0])
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
scale = 1 / (n_groups + .7)
cmap = plt.get_cmap(cmap)
colors = [cmap([i]) for i in np.linspace(0, 1, n_groups)]
if ((style == "bar_scatter") or (style == "bar_error")):
if style == "bar_scatter":
for g, group in enumerate(groupnames):
subgroup_usage = usage_df[usage_df["group"] == group].copy()
subgroup_usage.drop("group", axis=1, inplace=True)
for i in subgroup_usage.index:
ax.scatter(
np.arange(0 + scale * g, n_modules + scale * g, 1) + np.random.normal(0, 0.1 * scale,
n_modules),
subgroup_usage.loc[i],
color="black",
s=0.5
)
bar_alpha = 0.5
elif style == "bar_error":
for g in range(n_groups):
ax.errorbar(
x=np.arange(0 + scale * g, n_modules + scale * g, 1),
y=bar_heights[g],
yerr=bar_sems[g],
linestyle="none", linewidth=0.6,
color="black", capsize=0.3, markeredgewidth=0.75, alpha=0.8
)
bar_alpha = 0.85
for g in range(n_groups):
if alt_labels is not None:
label_g = alt_labels[groupnames[g]]
else:
label_g = groupnames[g]
ax.bar(
x=np.arange(0 + scale * g, n_modules + scale * g, 1),
height=bar_heights[g],
width=scale,
alpha=bar_alpha,
color=colors[g],
label=label_g
)
elif style == "points":
for g in range(n_groups):
if alt_labels is not None:
label_g = alt_labels[groupnames[g]]
else:
label_g = groupnames[g]
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights[g],
yerr=bar_sems[g],
color=colors[g],
label=groupnames[g],
linestyle="none",
marker="o", markersize=4, linewidth=0.75,
capsize=2, markeredgewidth=0.75
)
if legend:
if legend_pos == "inside":
ax.legend()
elif legend_pos == "outside_right":
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
elif legend_pos == "outside_above":
ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2), ncol=4)
else:
print("Warning - invalid legend position given (should be inside, outside_right, or outside_above), using default")
ax.legend()
ax.set_xlabel(config["data_source"] + ' Action Unit')
ax.set_ylabel('Mean Score')
all_num_modules = np.sum([isinstance(module, str) for module in modules]) == 0
if ((all_num_modules) and (n_modules >= 20)):
xticks = np.arange(0, n_modules, 5, dtype=int)
ax.set_xticks(xticks)
ax.set_xticklabels(modules[xticks])
else:
ax.set_xticks(np.arange(0, n_modules, 1))
try:
ax.set_xticklabels([i.split("_")[0].split("AU")[1] for i in modules])
except:
ax.set_xticklabels(modules)
ax.tick_params(axis='x', rotation=90, labelsize=plt.rcParams['font.size'] * 0.2 * n_groups, pad=2)
if plot_stats:
ylim = plt.ylim()
plt.ylim(ylim[0], ylim[1] * 1.2)
stat_result = action_units.f_oneway()
for m, module in enumerate(stat_result["module"]):
if stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.001:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "***", ha="center")
elif stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.01:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "**", ha="center")
elif stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.05:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "*", ha="center")
plt.tight_layout()
return fig
[docs]def plot_keypoint_kinematics(config,
kinematics,
figW=4,
figH=2,
style="bar_scatter",
cmap="jet",
legend_pos="outside_right",
legend=True,
alt_labels=None,
alt_xticks=None,
title=None,
plot_stats=False):
"""
Plot mean action units
:param config: config
:param action_units: output of analyze.get_keypoint_kinematics
:param figW: figure width (default: 4)
:param figH: figure height (default: 2)
:param style: plot style; "bar_scatter", "bar_error", "points"
:param cmap: colormap (default: jet)
:param legend_pos: legend position (default: outside)
:param legend: boolean to include or not include legend (default: False)
:param alt_labels: alternative group label dictionary (default: None)
:param alt_xticks: alternative xticklabels (list), for stacked plot style only (default: None)
:param title: plot title (default: None)
:param plot_stats: include stats on plot (default: False)
:return: fig
"""
if kinematics.__class__.__name__!="KeypointFeature":
raise ValueError(f'kinematics object class must be KeypointFeature, not {kinematics.__class__}')
if ((style!="bar_scatter") and (style!="bar_error") and (style!="points")):
raise ValueError(f'style must be one of "bar_scatter", "bar_error", or "points", not {style}')
if len(np.unique(kinematics.group_labels)) == 1:
data_subgrouped = False
else:
data_subgrouped = True
if not data_subgrouped:
usage_df = kinematics.to_df()
usage_df.drop("group", axis=1, inplace=True)
modules = kinematics.feat_names
n_modules = len(modules)
bar_heights = np.mean(usage_df, axis=0)
bar_sems = np.std(usage_df, axis=0) / np.sqrt(usage_df.shape[1])
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
cmap = plt.get_cmap(cmap)
if ((style == "bar_scatter") or (style == "bar_error")):
ax.bar(
x=np.arange(0, n_modules, 1),
height=bar_heights,
width=0.8,
alpha=0.5,
color=cmap([0.1])
)
if style == "bar_scatter":
for i in range(len(usage_df.index)):
ax.scatter(np.arange(0, n_modules, 1) + np.random.normal(0, 0.01, n_modules),
usage_df.iloc[i],
color="black",
s=0.5
)
elif style == "bar_error":
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights,
yerr=bar_sems,
linestyle="none", linewidth=0.6,
color="black", capsize=1, markeredgewidth=0.75
)
elif style == "points":
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights,
yerr=bar_sems,
color=cmap([0.1]),
linestyle="none",
marker="o", markersize=2.5, linewidth=0.75,
capsize=2, markeredgewidth=0.75
)
ax.set_xlabel(config["data_source"] + ' Pose Label')
ax.set_ylabel('Usage')
ax.set_xticks(np.arange(0, n_modules, 1), labels=[i.split("module")[1] for i in modules])
ax.tick_params(axis='x', rotation=90, labelsize=plt.rcParams['font.size'] * 0.5, pad=2)
plt.tight_layout()
else:
# To get groupnames in order
groupnames = list(kinematics.group_dict.keys())
n_groups = len(groupnames)
usage_df = kinematics.to_df()
modules = kinematics.feat_names
n_modules = len(modules)
bar_heights = np.zeros([n_groups, n_modules])
bar_sems = np.zeros([n_groups, n_modules])
for g, group in enumerate(groupnames):
subgroup_usage = usage_df[usage_df["group"] == group].copy()
subgroup_usage.drop("group", axis=1, inplace=True)
bar_heights[g, :] = subgroup_usage.mean(axis=0)
bar_sems[g, :] = subgroup_usage.std(axis=0) / np.sqrt(subgroup_usage.shape[0])
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
scale = 1 / (n_groups + .7)
cmap = plt.get_cmap(cmap)
colors = [cmap([i]) for i in np.linspace(0, 1, n_groups)]
if ((style == "bar_scatter") or (style == "bar_error")):
if style == "bar_scatter":
for g, group in enumerate(groupnames):
subgroup_usage = usage_df[usage_df["group"] == group].copy()
subgroup_usage.drop("group", axis=1, inplace=True)
for i in subgroup_usage.index:
ax.scatter(
np.arange(0 + scale * g, n_modules + scale * g, 1) + np.random.normal(0, 0.1 * scale,
n_modules),
subgroup_usage.loc[i],
color="black",
s=0.5
)
bar_alpha = 0.5
elif style == "bar_error":
for g in range(n_groups):
ax.errorbar(
x=np.arange(0 + scale * g, n_modules + scale * g, 1),
y=bar_heights[g],
yerr=bar_sems[g],
linestyle="none", linewidth=0.6,
color="black", capsize=0.3, markeredgewidth=0.75, alpha=0.8
)
bar_alpha = 0.85
for g in range(n_groups):
if alt_labels is not None:
label_g = alt_labels[groupnames[g]]
else:
label_g = groupnames[g]
ax.bar(
x=np.arange(0 + scale * g, n_modules + scale * g, 1),
height=bar_heights[g],
width=scale,
alpha=bar_alpha,
color=colors[g],
label=label_g
)
elif style == "points":
for g in range(n_groups):
if alt_labels is not None:
label_g = alt_labels[groupnames[g]]
else:
label_g = groupnames[g]
ax.errorbar(
x=np.arange(0, n_modules, 1),
y=bar_heights[g],
yerr=bar_sems[g],
color=colors[g],
label=groupnames[g],
linestyle="none",
marker="o", markersize=4, linewidth=0.75,
capsize=2, markeredgewidth=0.75
)
if legend:
if legend_pos == "inside":
ax.legend()
elif legend_pos == "outside_right":
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
elif legend_pos == "outside_above":
ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.2), ncol=4)
else:
print("Warning - invalid legend position given (should be inside, outside_right, or outside_above), using default")
ax.legend()
ax.set_xlabel(config["data_source"] + ' Keypoint')
ax.set_ylabel(kinematics.feature_type)
all_num_modules = np.sum([isinstance(module, str) for module in modules]) == 0
if ((all_num_modules) and (n_modules >= 20)):
xticks = np.arange(0, n_modules, 5, dtype=int)
ax.set_xticks(xticks)
ax.set_xticklabels(modules[xticks])
else:
ax.set_xticks(np.arange(0, n_modules, 1))
try:
ax.set_xticklabels([i.split("_")[0].split("AU")[1] for i in modules])
except:
ax.set_xticklabels(modules)
ax.tick_params(axis='x', rotation=90, labelsize=plt.rcParams['font.size'] * 0.2 * n_groups, pad=2)
if plot_stats:
ylim = plt.ylim()
plt.ylim(ylim[0], ylim[1] * 1.2)
stat_result = kinematics.f_oneway()
for m, module in enumerate(stat_result["module"]):
if stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.001:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "***", ha="center")
elif stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.01:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "**", ha="center")
elif stat_result[stat_result["module"] == module]["p_uncorr"].values[0] < 0.05:
plt.plot([m - 0.1, m + 0.7], [ylim[1] * 1.05] * 2, color="black", linewidth=0.5)
plt.text(m + 0.3, ylim[1] * 1.08, "*", ha="center")
plt.tight_layout()
return fig
[docs]def plot_distance_results(module_feature_object,
distance_results,
figW=3,
figH=3,
cmap="Blues",
title=None):
"""
Plot results from distance computation
:param module_feature_object:
:param distance_results:
:param figW:
:param figH:
:param cmap:
:param title:
:return:
"""
fig = plt.figure(figsize=(figW, figH))
dst = {}
for grp in sorted(np.unique(module_feature_object.group_labels)):
dst[grp] = distance_results[(((distance_results["grp_i"] == 0) & (distance_results["grp_j"] == grp)) | (
(distance_results["grp_i"] == grp) & (distance_results["grp_j"] == 0)))]["distance_ij"].values
cmap = plt.get_cmap(cmap)
bplot = plt.boxplot([dst[i] for i in sorted(np.unique(module_feature_object.group_labels))],
positions=sorted(np.unique(module_feature_object.group_labels)),
patch_artist=True,
medianprops={'color': cmap([0.7]), 'linewidth': 2},
boxprops={'color': cmap([0.7])},
whiskerprops={'color': cmap([0.7]), 'linewidth': 2},
capprops={'color': cmap([0.7]), 'linewidth': 2},
flierprops={'markeredgecolor': cmap([0.7]), 'marker': 'o'})
for patch in bplot['boxes']:
patch.set_facecolor(cmap([0.2]))
reverse_grp_dict = {v: k for k, v in module_feature_object.group_dict.items()}
plt.xticks(sorted(np.unique(module_feature_object.group_labels)),
labels=[reverse_grp_dict[i] for i in sorted(np.unique(module_feature_object.group_labels))],
rotation=90)
if title is not None:
plt.title(title)
plt.ylabel("Distance")
plt.tight_layout()
return fig
[docs]def network_plot(config,
labels_df=None,
module_usage=None,
module_transitions=None,
cmap="bwr",
include_labels=True,
scaling=1,
tscale=6,
figW=2.8,
figH=2.5,
alt_labels=None):
"""
Plot network comparison. You must provide EITHER labels_df OR module_usage and module_tsransitions
:param config: project config object
:param labels_df: labels_df for two groups to be compared; if provided, ModuleUsage and ModuleTransitions will be computed
:param module_usage: ModuleUsage object for comparison between two groups (not needed if labels_df is provided)
:param module_transitions: ModuleTransitions object for comparison between two groups (not needed if labels_df is provided)
:param cmap: color map
:param include_labels: True or False; default True
:param scaling: controls size of nodes in network plot; larger --> bigger; default 1
:param tscale: vmax and vmin for tscore; default 6
:param figW: figure width
:param figH: figure height
:param alt_labels: possible alt labels
:return:
"""
if labels_df is not None:
print("Defaulting to using labels_df to get module_usage and module_transitions")
print("Getting module usage")
module_usage = analyze.get_module_usage(config, labels_df)
print("Getting module transitions")
module_transitions = analyze.get_module_transitions(config, labels_df)
g1 = module_usage.label_counts[np.array(module_usage.group_labels) == 0, :]
g2 = module_usage.label_counts[np.array(module_usage.group_labels) == 1, :]
n_modules = g1.shape[1]
usage_stats = np.zeros(n_modules)
for module in range(g1.shape[1]):
usage_stats[module] = stats.ttest_ind(g2[:, module], g1[:, module]).statistic
transition_stats = stats.ttest_ind(
np.array(module_transitions.transition_count_matrices)[np.array(module_usage.group_labels) == 1],
np.array(module_transitions.transition_count_matrices)[np.array(module_usage.group_labels) == 0]).statistic
for i in range(n_modules):
transition_stats[i, i] = 0
fig, ax = plt.subplots(figsize=(figW, figH))
G = nx.from_numpy_array(transition_stats, parallel_edges=True)
edges = G.edges()
pos = nx.circular_layout(G)
weights = [G[u][v]['weight'] for u, v in edges]
colormap = plt.get_cmap(cmap)
colors = []
np.min(weights)
np.max(weights)
for w in range(len(weights)):
if weights[w] >= 0:
col_w = colormap([0.95])
elif weights[w] < 0:
col_w = colormap([0.05])
else:
col_w = colormap([0])
colors.append(col_w)
G = nx.from_numpy_array(np.abs(transition_stats) * 5000, parallel_edges=True)
edges = G.edges()
pos = nx.circular_layout(G)
scaling = scaling / n_modules
weights = [G[u][v]['weight'] * scaling * 5 for u, v in edges]
weights = list(np.array(weights) / 4000)
labels = {}
plotx = nx.draw_circular(G,
node_color=usage_stats,
cmap=plt.get_cmap(cmap),
vmin=-3,
vmax=3,
node_size=np.absolute(usage_stats) * 1000 * scaling,
edge_color=colors,
edgecolors=None,
with_labels=True, # Keep this as True to display node labels
labels=labels,
font_size=500 * scaling,
font_weight="bold",
font_color="white",
width=weights)
if include_labels == True:
for m in range(n_modules):
x, y = pos[m]
label = str(m)
label_font_size = 2 + 30 * scaling * (abs(usage_stats[m]) + 1) / 2
plt.text(x, y, label, color="white", fontsize=label_font_size, fontweight="bold", ha="center", va="center")
t_min = -tscale
t_max = tscale
smap = plt.cm.ScalarMappable(cmap=plt.get_cmap(cmap), norm=plt.Normalize(vmin=t_min, vmax=t_max))
smap.set_array([])
cbar_ax = fig.add_axes([0.85, 0.1, 0.05, 0.8])
if alt_labels == None:
label1 = list(module_usage.group_dict.keys())[0]
label2 = list(module_usage.group_dict.keys())[1]
else:
label1 = alt_labels[list(module_usage.group_dict.keys())[0]]
label2 = alt_labels[list(module_usage.group_dict.keys())[1]]
cbar_ax.text(0, t_max * 1.1, f"{label2}>{label1}", va="bottom", color=plt.get_cmap(cmap)([0.9]))
cbar_ax.text(0, t_min * 1.1, f"{label2}<{label1}", va="top", color=plt.get_cmap(cmap)([0.1]))
cbar = plt.colorbar(smap, cax=cbar_ax, shrink=0.5)
cbar.set_label('t-score')
plt.margins(x=0.4, y=0.4)
plt.subplots_adjust(right=0.85)
return fig
[docs]def module_usage_sandplot(config,
module_usage,
remap=False,
title=None,
legend=True,
long_legend=True,
figW=7,
figH=3,
convolve=False,
window=5):
"""
new sandplot function
:param config: the config object
:param module_usage:
:param BORIS_to_pose_mat: optional BORIS_to_pose_mat from analyze.boris_to_pose to re-align modules by their most overlapping manually scored behavior class
:param title:
:param legend: plot legend or not
:param long_legend:
:param convolve:
:param window:
:return:
"""
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
data = module_usage.usage_density(convolve=convolve, window=window)
data_mean = np.mean(data, axis=0).T
bottom = np.zeros(data_mean[:, 0].shape)
xpts = np.array(data[0].columns / 60, dtype=float)
if not remap:
leg_handles=[]
col = module_usage.to_df().columns
mods = np.unique([c.split("module")[1].split("_t")[0] for c in col if "module" in c])
if np.sum([analyze.is_nonnum(m) for m in mods]):
leg_handles=mods
else:
for col in range(10):
m_in_col=[i for i in range(data_mean.shape[1]) if i%10==col]
if len(m_in_col)>0:
leg_handles.append(f"Modules {m_in_col}")
for m in range(data_mean.shape[1]):
y_i = np.array(data_mean[:, m], dtype=float)
if m<len(leg_handles):
plt.fill_between(xpts, bottom, bottom + y_i, label=leg_handles[m])
else:
plt.fill_between(xpts, bottom, bottom + y_i)
bottom = y_i + bottom
if legend:
plt.legend(bbox_to_anchor=(1.5, 1))
else:
modules = list(data[0].index)
behaviors = sorted(np.unique([remap[1] for remap in config["remappings"]]))
if "other" not in behaviors:
behaviors = behaviors + ["other"]
n_behaviors = len(behaviors)
leg_handles_col0 = []
leg_handles_col1 = []
leg_labels = [""] * n_behaviors
cm_list = ['magma', 'YlOrBr', 'Blues', 'copper', 'Greys_r', 'BuGn', 'Reds_r', 'magma_r', 'Purples', 'YlGn',
'YlGnBu', 'YlOrRd']
modules_plotted = [] # To account for modules not seen in remapping
for remap in range(n_behaviors):
sub_modules = [int(i[0][0]) for i in config["remappings"] if i[1] == behaviors[remap]]
# if (behaviors[remap]=="other"):
# print([mod for mod in modules if ((mod not in modules_plotted) and (mod not in sub_modules))])
cmap = plt.get_cmap(cm_list[remap])
colors = [cmap([0.4]), cmap([0.5])]
leg_handles_col0.extend([Patch(facecolor=colors[0], edgecolor='none')])
leg_handles_col1.extend([Patch(facecolor=colors[1], edgecolor='none')])
if long_legend:
leg_labels.append(behaviors[remap] + " (" + str(len(sub_modules)) + " modules)")
else:
leg_labels.append(behaviors[remap]) # +" "+str(sub_modules))
if len(sub_modules) > 0:
for c in sub_modules:
if c in modules:
bottom = np.array(bottom, dtype=float)
y_i = np.array(data_mean[:, modules.index(c)], dtype=float)
plt.fill_between(xpts, bottom, bottom + y_i, color=colors[c % 2])
bottom = y_i + bottom
modules_plotted.append(c)
if legend:
leg_handles = leg_handles_col0 + leg_handles_col1
ax.legend(handles=leg_handles,
labels=leg_labels,
ncol=2, handletextpad=0.5, handlelength=1.0, columnspacing=-0.5, bbox_to_anchor=(1.5, 1))
plt.ylim([0, 1])
plt.xlim([xpts[0], xpts[-1]])
plt.xlabel("Time (m)")
plt.ylabel('Moving Average Proportion \nof Time Spent in Pose')
if title!=None:
plt.title(title)
plt.tight_layout()
return fig
[docs]def plot_keypoint_travel(keypoint_feature, cmap="viridis", plottype="band", figW=6, figH=3):
"""
Plots displacement of a keypoint either over time or in bins from dist_df (output of analyze.dist_df_subgroups)
:param dist_df: dist_df output from analyze.dist_df_subgroups
:param cmap: matplotlib colormap
:param plottype: type of plot ("band", "errorbar", or "bar" if no timebins)
:param figW: figure width
:param figH: figure height
:return:
"""
groups=list(keypoint_feature.group_dict.keys())
n_groups=len(groups)
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
if len(keypoint_feature.feat_names) == 1:
plottype = "bar"
cmap = plt.get_cmap(cmap)
colors = [cmap([i]) for i in np.linspace(0, 1, n_groups)]
xticks = keypoint_feature.feat_names
for g, group in enumerate(groups):
group_slice = np.array(keypoint_feature.group_labels)==keypoint_feature.group_dict[group]
group_data = keypoint_feature.keypoint_feature[group_slice,:]
group_mean = group_data.mean(axis=0)
group_sem = group_data.std(axis=0) / np.sqrt(group_data.shape[1])
if plottype == "errorbar":
ax.errorbar(xticks, group_mean, label=group, yerr=group_sem, marker="o", capsize=2, color=colors[g])
ax.legend()
ax.set_xlabel("Time (m)")
elif plottype == "band":
ax.plot(xticks, group_mean, label=group, marker="o", color=colors[g])
ax.fill_between(xticks, group_mean - group_sem, group_mean + group_sem, alpha=0.2, color=colors[g],
edgecolor="none")
ax.legend()
ax.set_xlabel("Time (m)")
elif plottype == "bar":
xticks = np.arange(0, n_groups, 1)
ax.bar(xticks[g], group_mean, color=colors[g], alpha=0.6)
ax.errorbar(xticks[g], group_mean, label=group, yerr=group_sem, capsize=2, color="black",
linestyle="none")
if g == len(groups) - 1:
ax.set_xticks(xticks)
ax.set_xticklabels(groups)
ax.set_ylabel("Keypoint travel (pix)")
return fig
[docs]def plot_embeddings(module_feature_object, embeddings_object, figW=3, figH=3, cmap="viridis",title=None,legend=False,draw_ellipse=True,alt_legend=None):
"""
Plot embeddings
:param module_feature_object: module feature object (ModuleUsage or ModuleTransitions) from analyze.get_module_{xx}
:param embeddings_object: embeddings object (LDA or PCA) from analyze.embed
:param figW: figure width
:param figH: figure height
:param cmap: matplotlib colormap
:param title: title string, or None
:param legend: True or False
:return: fig
"""
if module_feature_object.__class__.__name__=="ModuleUsage":
X_tfm = embeddings_object.transform(module_feature_object.label_counts)
elif module_feature_object.__class__.__name__=="ModuleTransitions":
X_tfm = embeddings_object.transform(module_feature_object.transition_counts)
elif module_feature_object.__class__.__name__=="KeypointFeature":
X_tfm = embeddings_object.transform(module_feature_object.keypoint_feature)
elif module_feature_object.__class__.__name__=="ActionUnits":
X_tfm = embeddings_object.transform(module_feature_object.action_units)
y=module_feature_object.group_labels
cmap = plt.get_cmap(cmap)
colors=[cmap([i]) for i in np.arange(0,len(np.unique(y)),1/(len(np.unique(y))-0.9))]
fig = plt.figure(figsize=(figW,figH))
if embeddings_object.__class__==LDA:
explained_variance = embeddings_object.explained_variance_ratio_
plt.xlabel(f'LD1 ({explained_variance[0]*100:.2f}% variance explained)')
plt.ylabel(f'LD2 ({explained_variance[1]*100:.2f}% variance explained)')
if embeddings_object.__class__==PCA:
explained_variance = embeddings_object.explained_variance_ratio_
plt.xlabel(f'PC1 ({explained_variance[0]*100:.2f}% variance explained)')
plt.ylabel(f'PC2 ({explained_variance[1]*100:.2f}% variance explained)')
if title is not None:
plt.title(title)
legend_elements = {}
for g in np.unique(y):
legend_elements[g]=0
reverse_group_dict = {v: k for k, v in module_feature_object.group_dict.items()}
for obs in range(X_tfm.shape[0]):
if legend_elements[module_feature_object.group_labels[obs]]==0:
if alt_legend is not None:
leg = alt_legend[reverse_group_dict[module_feature_object.group_labels[obs]]]
else:
leg = reverse_group_dict[module_feature_object.group_labels[obs]]
plt.scatter(X_tfm[obs,0],X_tfm[obs,1],color=colors[module_feature_object.group_labels[obs]],
label=leg)
legend_elements[module_feature_object.group_labels[obs]]=1
else:
plt.scatter(X_tfm[obs,0],X_tfm[obs,1],color=colors[module_feature_object.group_labels[obs]])
if draw_ellipse==True:
if embeddings_object.__class__==LDA:
for r in np.unique(y):
emb = X_tfm[y==r,:]
mean = np.mean(emb, axis=0)
cov = np.cov(emb, rowvar=False)
make_and_plot_ellipse(mean, cov, color=colors[r])
else:
print(f"Ellipse only drawn for embeddings_object of class LDA, not {embeddings_object.__class__}")
if legend:
plt.legend(bbox_to_anchor=[1.05,1])
plt.tight_layout()
return fig
[docs]def plot_distance_matrix(module_feature_object, dist_mat,cmap="Greens",figW=3,figH=3,alt_labels=None,title=None):
"""
Plot distance matrix
:param module_feature_object:
:param dist_mat:
:param cmap:
:param figW:
:param figH:
:param alt_labels:
:param title:
:return:
"""
fig = plt.figure(figsize=(figW,figH),dpi=100)
plt.imshow(dist_mat,aspect="auto",cmap=cmap)
reverse_grp_dict = {v: k for k, v in module_feature_object.group_dict.items()}
if alt_labels is not None:
reverse_grp_dict = {v: alt_labels[k] for k, v in module_feature_object.group_dict.items()}
xticks=[]
xticklabels=[]
for group in sorted(np.unique(module_feature_object.group_labels)):
xticks.append(group)
xticklabels.append(reverse_grp_dict[group])
if group!=0:
plt.axvline(group-0.5,color="black",linewidth=0.5)
plt.xticks(ticks=xticks,labels=xticklabels,rotation=90)
arr = np.array(module_feature_object.group_labels)
edges = np.where(arr[:-1] != arr[1:])[0]
for edge in edges:
plt.axhline(edge+0.5,color="black",linewidth=0.5)
ylabels = [reverse_grp_dict[i] for i in sorted(np.unique(module_feature_object.group_labels))]
plt.yticks(ticks=np.concatenate([np.zeros(1),edges+1]),labels=ylabels)
plt.ylabel("Observations")
plt.xlabel("Centroids")
if title is not None:
plt.title(title)
plt.tight_layout()
return fig
[docs]def plot_distance_box(module_feature_object, dist_mat, cmap="Blues",figW=3,figH=3,alt_labels=None,title=None):
"""
Plot distance boxplot
:param module_feature_object:
:param dist_mat:
:param cmap:
:param figW:
:param figH:
:param alt_labels:
:param title:
:return:
"""
fig = plt.figure(figsize=(figW, figH),dpi=100)
dst = {}
for grp in sorted(np.unique(module_feature_object.group_labels)):
dst[grp] = dist_mat[np.array(module_feature_object.group_labels)==grp,0]
cmap = plt.get_cmap(cmap)
bplot = plt.boxplot([dst[i] for i in sorted(np.unique(module_feature_object.group_labels))],
positions=sorted(np.unique(module_feature_object.group_labels)),
patch_artist=True,
medianprops={'color': cmap([0.7]), 'linewidth': 2},
boxprops={'color': cmap([0.7])},
whiskerprops={'color': cmap([0.7]), 'linewidth': 2},
capprops={'color': cmap([0.7]), 'linewidth': 2},
flierprops={'markeredgecolor': cmap([0.7]), 'marker': 'o'})
for patch in bplot['boxes']:
patch.set_facecolor(cmap([0.2]))
reverse_grp_dict = {v: k for k, v in module_feature_object.group_dict.items()}
if alt_labels:
reverse_grp_dict = {v: alt_labels[k] for k, v in module_feature_object.group_dict.items()}
plt.xticks(sorted(np.unique(module_feature_object.group_labels)),
labels=[reverse_grp_dict[i] for i in sorted(np.unique(module_feature_object.group_labels))],
rotation=90)
if title is not None:
plt.title(title)
plt.ylabel(f"Distance to {reverse_grp_dict[0]} Centroid")
plt.tight_layout()
return fig
# def plot_dist_boxplot(dists, dist_from, figW=5, figH=3,alt_labels=None):
# """
# Plot distance boxplots
#
# :param dists: dists returned from lda_result.get_mahalanobis_distance()
# :param dist_from: a group from the dataset from which to calculate all the distances
# :param figW:
# :param figH:
# :return:
# """
#
# centroid=np.sum(["CENTROID" in i for i in list(dists.keys())])>0
#
# if centroid:
# pairings = [i for i in list(dists.keys()) if dist_from+"-CENTROID" in i]
# centroid_lab = " Centroid"
# else:
# pairings = [i for i in list(dists.keys()) if dist_from in i]
# centroid_lab=""
#
# pairings_ticks = [i.replace("____vs____", "").replace(dist_from, "").replace("-CENTROID", "") for i in pairings]
# if alt_labels is not None:
# for p,tick in enumerate(pairings_ticks):
# if tick=="":
# pairings_ticks[p]=dist_from
# pairings_ticks = [alt_labels[p] for p in pairings_ticks]
#
# fig = plt.figure(figsize=(figW, figH))
# xticks = []
# for p, pair in enumerate(pairings):
# xticks.append(p)
# c1 = "#4198B5"
# c2 = "#D8EBF1"
#
# box = plt.boxplot(np.array(dists[pair]), positions=[p], widths=0.5,
# patch_artist=True,
# boxprops=dict(facecolor=c2, color=c1),
# whiskerprops=dict(color=c1),
# capprops=dict(color=c1),
# medianprops=dict(color=c1),
# flierprops=dict(markerfacecolor=c1, marker='o', markersize=4))
#
# dist_from_name=dist_from
# if alt_labels is not None:
# dist_from_name=alt_labels[dist_from]
#
# plt.ylabel('Mahalanoubis Distance\nfrom '+dist_from_name+centroid_lab)
# plt.yticks([])
#
# plt.xticks(ticks=xticks, labels=pairings_ticks, rotation=90)
# plt.tight_layout()
#
# return fig
#
# def plot_pc_weights(pca,cmap="PuOr"):
# """
# Plot PCA weights
#
# :param pca: pca object from sklearn
# :return: fig
# """
# components = pca.components_
# fig = plt.figure(figsize=(4,2),dpi=100)
# pc_labels=["PC"+str(i+1) for i in range(pca.components_.shape[0])]
# plt.imshow(components,cmap=cmap,vmin=-1,vmax=1)
# plt.title("Principle Component Weights")
# plt.xticks(np.arange(0,pca.components_.shape[1],1))
# plt.xlabel("Pose Modules")
# plt.yticks([0,1],labels=pc_labels)
# plt.colorbar(cmap=cmap)
# plt.tight_layout()
# return fig
#
[docs]def BORIS_to_pose_matrix_plot(config, boris_to_pose_output, figW=4, figH=2.5, cmap="Greens",outline_top_match=True):
fig, ax = plt.subplots(figsize=(figW, figH), dpi=100)
plt.imshow(boris_to_pose_output.to_numpy(dtype='float'),cmap=cmap,aspect="auto",interpolation="none")
data = boris_to_pose_output.to_numpy(dtype='float')
num_cols = data.shape[1]
if outline_top_match==True:
for col in range(num_cols):
max_row = np.argmax(data[:, col])
if np.sum(data[:, col]==data[max_row, col])==1:
rect = Rectangle((col - 0.5, max_row - 0.5), 1, 1, edgecolor='purple', facecolor='none', linewidth=0.5)
ax.add_patch(rect)
if len(boris_to_pose_output.columns>=20):
xticks=np.arange(0,len(boris_to_pose_output.columns),5,dtype=int)
xticklabels=np.arange(0,len(boris_to_pose_output.columns),1,dtype=int)[xticks]
else:
xticks=range(len(boris_to_pose_output.columns))
xticklabels=boris_to_pose_output.columns
plt.xticks(xticks,labels=xticklabels)
plt.yticks(range(len(boris_to_pose_output.index)), labels=boris_to_pose_output.index)
plt.xlabel(config["data_source"]+" Pose Module")
plt.ylabel("Manually Scored\nBehavior")
plt.tick_params(axis='x', rotation=90,
labelsize=plt.rcParams['font.size'] * 0.7, pad=2)
plt.tight_layout()
return fig