import one.alf.io as alfio
from one.alf.spec import is_session_path
from one.alf.exceptions import ALFObjectNotFound
from ibllib.io.raw_data_loaders import load_data
from ibllib.oneibl.registration import _read_settings_json_compatibility_enforced, _get_session_times
from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type
from ibllib.plots.snapshot import ReportSnapshot
from iblutil.numerical import ismember
from brainbox.behavior import training
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.lines import Line2D
from datetime import datetime
import seaborn as sns
TRAINING_STATUS = {'not_computed': (-2, (0, 0, 0, 0)),
'habituation': (-1, (0, 0, 0, 0)),
'in training': (0, (0, 0, 0, 0)),
'trained 1a': (1, (195, 90, 80, 255)),
'trained 1b': (2, (255, 153, 20, 255)),
'ready4ephysrig': (3, (28, 20, 255, 255)),
'ready4delay': (4, (117, 117, 117, 255)),
'ready4recording': (5, (20, 255, 91, 255))}
[docs]def get_trials_task(session_path, one):
# TODO this eventually needs to be updated for dynamic pipeline tasks
pipeline = get_pipeline(session_path)
if pipeline == 'training':
from ibllib.pipes.training_preprocessing import TrainingTrials
task = TrainingTrials(session_path, one=one)
elif pipeline == 'ephys':
from ibllib.pipes.ephys_preprocessing import EphysTrials
task = EphysTrials(session_path, one=one)
else:
try:
# try and look if there is a custom extractor in the personal projects extraction class
import projects.base
task_type = get_session_extractor_type(session_path)
PipelineClass = projects.base.get_pipeline(task_type)
pipeline = PipelineClass(session_path, one)
trials_task_name = next(task for task in pipeline.tasks if 'Trials' in task)
task = pipeline.tasks.get(trials_task_name)
except Exception:
task = None
return task
[docs]def save_path(subj_path):
return Path(subj_path).joinpath('training.csv')
[docs]def save_dataframe(df, subj_path):
"""
Save training dataframe to disk
:param df: dataframe to save
:param subj_path: path to subject folder
:return:
"""
df.to_csv(save_path(subj_path), index=False)
[docs]def load_existing_dataframe(subj_path):
"""
Load training dataframe from disk, if dataframe doesn't exist returns None
:param subj_path: path to subject folder
:return:
"""
df_location = save_path(subj_path)
if df_location.exists():
return pd.read_csv(df_location)
else:
df_location.parent.mkdir(exist_ok=True, parents=True)
return None
[docs]def load_trials(sess_path, one):
"""
Load trials data for session. First attempts to load from local session path, if this fails will attempt to download via ONE,
if this also fails, will then attempt to re-extraxt locally
:param sess_path: session path
:param one: ONE instance
:return:
"""
# try and load trials locally
try:
trials = alfio.load_object(sess_path.joinpath('alf'), 'trials')
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
except ALFObjectNotFound:
try:
# attempt to download trials using ONE
trials = one.load_object(one.path2eid(sess_path), 'trials')
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
except Exception:
try:
task = get_trials_task(sess_path, one=one)
if task is not None:
task.run()
trials = alfio.load_object(sess_path.joinpath('alf'), 'trials')
if 'probabilityLeft' not in trials.keys():
raise ALFObjectNotFound
else:
trials = None
except Exception: # TODO how can i make this more specific
trials = None
return trials
[docs]def load_combined_trials(sess_paths, one):
"""
Load and concatenate trials for multiple sessions. Used when we want to concatenate trials for two sessions on the same day
:param sess_paths: list of paths to sessions
:param one: ONE instance
:return:
"""
trials_dict = {}
for sess_path in sess_paths:
trials = load_trials(Path(sess_path), one)
if trials is not None:
trials_dict[Path(sess_path).stem] = load_trials(Path(sess_path), one)
return training.concatenate_trials(trials_dict)
[docs]def find_earliest_recompute_date(df):
"""
Find the earliest date that we need to compute the training status from. Training status depends on previous sessions
so if a session was missing and now has been added we need to recompute everything from that date onwards
:param df:
:return:
"""
missing_df = df[df['training_status'] == 'not_computed']
if len(missing_df) == 0:
return []
missing_df = missing_df.sort_values('date')
first_index = missing_df.index[0]
return df[first_index:].date.values
[docs]def compute_training_status(df, compute_date, one):
"""
Compute the training status for compute date based on training from that session and two previous days
:param df: training dataframe
:param compute_date: date to compute training on
:param one: ONE instance
:return:
"""
# compute_date = str(one.path2ref(session_path)['date'])
df_temp = df[df['date'] <= compute_date]
df_temp = df_temp.drop_duplicates('session_path')
df_temp.sort_values('date')
dates = df_temp.date.values
n_dates = np.min([3, len(dates)]).astype(int)
compute_dates = dates[(-1 * n_dates):]
assert compute_dates[-1] == compute_date
df_temp_group = df_temp.groupby('date')
trials = {}
n_delay = 0
ephys_sessions = []
protocol = []
status = []
for date in compute_dates:
df_date = df_temp_group.get_group(date)
# If habituation skip
if df_date.iloc[-1]['task_protocol'] == 'habituation':
continue
trials[df_date.iloc[-1]['date']] = load_combined_trials(df_date.session_path.values, one)
protocol.append(df_date.iloc[-1]['task_protocol'])
status.append(df_date.iloc[-1]['training_status'])
if df_date.iloc[-1]['combined_n_delay'] >= 900: # delay of 15 mins
n_delay += 1
if df_date.iloc[-1]['location'] == 'ephys_rig':
ephys_sessions.append(df_date.iloc[-1]['date'])
n_status = np.max([-2, -1 * len(status)])
training_status, _ = training.get_training_status(trials, protocol, ephys_sessions, n_delay)
training_status = pass_through_training_hierachy(training_status, status[n_status])
df.loc[df['date'] == compute_date, 'training_status'] = training_status
return df
[docs]def pass_through_training_hierachy(status_new, status_old):
"""
Makes sure that the new training status is not less than the one from the previous day. e.g Subject cannot regress in
performance
:param status_new: latest training status
:param status_old: previous training status
:return:
"""
if TRAINING_STATUS[status_old][0] > TRAINING_STATUS[status_new][0]:
return status_old
else:
return status_new
[docs]def compute_session_duration_delay_location(sess_path):
"""
Get meta information about task. Extracts session duration, delay before session start and location of session
:param sess_path: session path
:return:
"""
settings_file = list(sess_path.glob('**/raw_behavior_data/_iblrig_taskSettings.raw*.json'))
md = _read_settings_json_compatibility_enforced(settings_file[0])
sess_data = load_data(sess_path)
start_time, end_time = _get_session_times(sess_path, md, sess_data)
session_duration = int((end_time - start_time).total_seconds() / 60)
session_delay = md.get('SESSION_START_DELAY_SEC', 0)
if 'ephys' in md.get('PYBPOD_BOARD', None):
session_location = 'ephys_rig'
else:
session_location = 'training_rig'
return session_duration, session_delay, session_location
[docs]def get_training_info_for_session(session_paths, one):
"""
Extract the training information needed for plots for each session
:param session_paths: list of session paths on same date
:param one: ONE instance
:return:
"""
# return list of dicts to add
sess_dicts = []
for session_path in session_paths:
session_path = Path(session_path)
sess_dict = {}
sess_dict['date'] = str(one.path2ref(session_path)['date'])
sess_dict['session_path'] = str(session_path)
sess_dict['task_protocol'] = get_session_extractor_type(session_path)
if sess_dict['task_protocol'] == 'habituation':
nan_array = np.array([np.nan])
sess_dict['performance'], sess_dict['contrasts'], _ = (nan_array, nan_array, np.nan)
sess_dict['performance_easy'] = np.nan
sess_dict['reaction_time'] = np.nan
sess_dict['n_trials'] = np.nan
sess_dict['sess_duration'] = np.nan
sess_dict['n_delay'] = np.nan
sess_dict['location'] = np.nan
sess_dict['training_status'] = 'habituation'
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
(np.nan, np.nan, np.nan, np.nan)
else:
# if we can't compute trials then we need to pass
trials = load_trials(session_path, one)
if trials is None:
continue
sess_dict['performance'], sess_dict['contrasts'], _ = training.compute_performance(trials, prob_right=True)
if sess_dict['task_protocol'] == 'training':
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
training.compute_psychometric(trials)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
(np.nan, np.nan, np.nan, np.nan)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
(np.nan, np.nan, np.nan, np.nan)
else:
sess_dict['bias_50'], sess_dict['thres_50'], sess_dict['lapsehigh_50'], sess_dict['lapselow_50'] = \
training.compute_psychometric(trials, block=0.5)
sess_dict['bias_20'], sess_dict['thres_20'], sess_dict['lapsehigh_20'], sess_dict['lapselow_20'] = \
training.compute_psychometric(trials, block=0.2)
sess_dict['bias_80'], sess_dict['thres_80'], sess_dict['lapsehigh_80'], sess_dict['lapselow_80'] = \
training.compute_psychometric(trials, block=0.8)
sess_dict['performance_easy'] = training.compute_performance_easy(trials)
sess_dict['reaction_time'] = training.compute_median_reaction_time(trials)
sess_dict['n_trials'] = training.compute_n_trials(trials)
sess_dict['sess_duration'], sess_dict['n_delay'], sess_dict['location'] = \
compute_session_duration_delay_location(session_path)
sess_dict['training_status'] = 'not_computed'
sess_dicts.append(sess_dict)
protocols = [s['task_protocol'] for s in sess_dicts]
if len(protocols) > 0 and len(set(protocols)) != 1:
print(f'Different protocols on same date {sess_dicts[0]["date"]} : {protocols}')
if len(sess_dicts) > 1 and len(set(protocols)) == 1: # Only if all protocols are the same
print(f'{len(sess_dicts)} sessions being combined for date {sess_dicts[0]["date"]}')
combined_trials = load_combined_trials(session_paths, one)
performance, contrasts, _ = training.compute_performance(combined_trials, prob_right=True)
psychs = {}
psychs['50'] = training.compute_psychometric(trials, block=0.5)
psychs['20'] = training.compute_psychometric(trials, block=0.2)
psychs['80'] = training.compute_psychometric(trials, block=0.8)
performance_easy = training.compute_performance_easy(combined_trials)
reaction_time = training.compute_median_reaction_time(combined_trials)
n_trials = training.compute_n_trials(combined_trials)
sess_duration = np.nansum([s['sess_duration'] for s in sess_dicts])
n_delay = np.nanmax([s['n_delay'] for s in sess_dicts])
for sess_dict in sess_dicts:
sess_dict['combined_performance'] = performance
sess_dict['combined_contrasts'] = contrasts
sess_dict['combined_performance_easy'] = performance_easy
sess_dict['combined_reaction_time'] = reaction_time
sess_dict['combined_n_trials'] = n_trials
sess_dict['combined_sess_duration'] = sess_duration
sess_dict['combined_n_delay'] = n_delay
for bias in [50, 20, 80]:
sess_dict[f'combined_bias_{bias}'] = psychs[f'{bias}'][0]
sess_dict[f'combined_thres_{bias}'] = psychs[f'{bias}'][1]
sess_dict[f'combined_lapsehigh_{bias}'] = psychs[f'{bias}'][2]
sess_dict[f'combined_lapselow_{bias}'] = psychs[f'{bias}'][3]
# Case where two sessions on same day with different number of contrasts! Oh boy
if sess_dict['combined_performance'].size != sess_dict['performance'].size:
sess_dict['performance'] = \
np.r_[sess_dict['performance'],
np.full(sess_dict['combined_performance'].size - sess_dict['performance'].size, np.nan)]
sess_dict['contrasts'] = \
np.r_[sess_dict['contrasts'],
np.full(sess_dict['combined_contrasts'].size - sess_dict['contrasts'].size, np.nan)]
else:
for sess_dict in sess_dicts:
sess_dict['combined_performance'] = sess_dict['performance']
sess_dict['combined_contrasts'] = sess_dict['contrasts']
sess_dict['combined_performance_easy'] = sess_dict['performance_easy']
sess_dict['combined_reaction_time'] = sess_dict['reaction_time']
sess_dict['combined_n_trials'] = sess_dict['n_trials']
sess_dict['combined_sess_duration'] = sess_dict['sess_duration']
sess_dict['combined_n_delay'] = sess_dict['n_delay']
for bias in [50, 20, 80]:
sess_dict[f'combined_bias_{bias}'] = sess_dict[f'bias_{bias}']
sess_dict[f'combined_thres_{bias}'] = sess_dict[f'thres_{bias}']
sess_dict[f'combined_lapsehigh_{bias}'] = sess_dict[f'lapsehigh_{bias}']
sess_dict[f'combined_lapselow_{bias}'] = sess_dict[f'lapselow_{bias}']
return sess_dicts
[docs]def check_up_to_date(subj_path, df):
"""
Check which sessions on local file system are missing from the computed training table
:param subj_path:
:return:
"""
session_dates = subj_path.glob('*')
df_session = pd.DataFrame()
for sess_date in session_dates:
sess_paths = list(sess_date.glob('00*'))
date = sess_date.stem
if len(sess_paths) > 0:
for sess in sess_paths:
if is_session_path(sess):
df_session = pd.concat([df_session, pd.DataFrame({'date': date, 'session_path': str(sess)}, index=[0])],
ignore_index=True)
if df is None or 'combined_thres_50' not in df.columns:
return df_session
else:
# recorded_session_paths = df['session_path'].values
isin, _ = ismember(df_session.date.unique(), df.date.unique())
missing_dates = df_session.date.unique()[~isin]
return df_session[df_session['date'].isin(missing_dates)].sort_values('date')
[docs]def plot_trial_count_and_session_duration(df, subject):
df = df.drop_duplicates('date').reset_index(drop=True)
y1 = {'column': 'combined_n_trials',
'title': 'Trial counts',
'lim': None,
'color': 'k',
'join': True}
y2 = {'column': 'combined_sess_duration',
'title': 'Session duration (mins)',
'lim': None,
'color': 'r',
'log': False,
'join': True}
ax = plot_over_days(df, subject, y1, y2)
return ax
[docs]def plot_fit_params(df, subject):
fig, axs = plt.subplots(2, 2, figsize=(12, 6))
axs = axs.ravel()
df = df.drop_duplicates('date').reset_index(drop=True)
cmap = sns.diverging_palette(20, 220, n=3, center="dark")
y50 = {'column': 'combined_bias_50',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[1],
'join': False}
y80 = {'column': 'combined_bias_80',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[2],
'join': False}
y20 = {'column': 'combined_bias_20',
'title': 'Bias',
'lim': [-100, 100],
'color': cmap[0],
'join': False}
plot_over_days(df, subject, y50, ax=axs[0], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[0], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[0], legend=False, title=False)
axs[0].axhline(16, linewidth=2, linestyle='--', color='k')
axs[0].axhline(-16, linewidth=2, linestyle='--', color='k')
y50['column'] = 'combined_thres_50'
y50['title'] = 'Threshold'
y50['lim'] = [0, 100]
y80['column'] = 'combined_thres_20'
y80['title'] = 'Threshold'
y20['lim'] = [0, 100]
y20['column'] = 'combined_thres_80'
y20['title'] = 'Threshold'
y80['lim'] = [0, 100]
plot_over_days(df, subject, y50, ax=axs[1], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[1], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[1], legend=False, title=False)
axs[1].axhline(19, linewidth=2, linestyle='--', color='k')
y50['column'] = 'combined_lapselow_50'
y50['title'] = 'Lapse Low'
y50['lim'] = [0, 1]
y80['column'] = 'combined_lapselow_20'
y80['title'] = 'Lapse Low'
y80['lim'] = [0, 1]
y20['column'] = 'combined_lapselow_80'
y20['title'] = 'Lapse Low'
y20['lim'] = [0, 1]
plot_over_days(df, subject, y50, ax=axs[2], legend=False, title=False)
plot_over_days(df, subject, y80, ax=axs[2], legend=False, title=False)
plot_over_days(df, subject, y20, ax=axs[2], legend=False, title=False)
axs[2].axhline(0.2, linewidth=2, linestyle='--', color='k')
y50['column'] = 'combined_lapsehigh_50'
y50['title'] = 'Lapse High'
y50['lim'] = [0, 1]
y80['column'] = 'combined_lapsehigh_20'
y80['title'] = 'Lapse High'
y80['lim'] = [0, 1]
y20['column'] = 'combined_lapsehigh_80'
y20['title'] = 'Lapse High'
y20['lim'] = [0, 1]
plot_over_days(df, subject, y50, ax=axs[3], legend=False, title=False, training_lines=True)
plot_over_days(df, subject, y80, ax=axs[3], legend=False, title=False, training_lines=False)
plot_over_days(df, subject, y20, ax=axs[3], legend=False, title=False, training_lines=False)
axs[3].axhline(0.2, linewidth=2, linestyle='--', color='k')
fig.suptitle(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
lines, labels = axs[3].get_legend_handles_labels()
fig.legend(lines, labels, loc='upper center', bbox_to_anchor=(0.5, 0.1), fancybox=True, shadow=True, ncol=5)
legend_elements = [Line2D([0], [0], marker='o', color='w', label='p=0.5', markerfacecolor=cmap[1], markersize=8),
Line2D([0], [0], marker='o', color='w', label='p=0.2', markerfacecolor=cmap[0], markersize=8),
Line2D([0], [0], marker='o', color='w', label='p=0.8', markerfacecolor=cmap[2], markersize=8)]
legend2 = plt.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.1, -0.2), fancybox=True, shadow=True)
fig.add_artist(legend2)
return axs
[docs]def plot_psychometric_curve(df, subject, one):
df = df.drop_duplicates('date').reset_index(drop=True)
sess_path = Path(df.iloc[-1]["session_path"])
trials = load_trials(sess_path, one)
fig, ax1 = plt.subplots(figsize=(8, 6))
training.plot_psychometric(trials, ax=ax1, title=f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
return ax1
[docs]def plot_over_days(df, subject, y1, y2=None, ax=None, legend=True, title=True, training_lines=True):
if ax is None:
fig, ax1 = plt.subplots(figsize=(12, 6))
else:
ax1 = ax
dates = [datetime.strptime(dat, '%Y-%m-%d') for dat in df['date']]
if y1['join']:
ax1.plot(dates, df[y1['column']], color=y1['color'])
ax1.scatter(dates, df[y1['column']], color=y1['color'])
ax1.set_ylabel(y1['title'])
ax1.set_ylim(y1['lim'])
if y2 is not None:
ax2 = ax1.twinx()
if y2['join']:
ax2.plot(dates, df[y2['column']], color=y2['color'])
ax2.scatter(dates, df[y2['column']], color=y2['color'])
ax2.set_ylabel(y2['title'])
ax2.yaxis.label.set_color(y2['color'])
ax2.tick_params(axis='y', colors=y2['color'])
ax2.set_ylim(y2['lim'])
if y2['log']:
ax2.set_yscale('log')
ax2.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['left'].set_visible(False)
month_format = mdates.DateFormatter('%b %Y')
month_locator = mdates.MonthLocator()
ax1.xaxis.set_major_locator(month_locator)
ax1.xaxis.set_major_formatter(month_format)
week_locator = mdates.WeekdayLocator(byweekday=mdates.MO, interval=1)
ax1.xaxis.set_minor_locator(week_locator)
ax1.grid(True, which='minor', axis='x', linestyle='--')
ax1.spines['left'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['top'].set_visible(False)
if training_lines:
ax1 = add_training_lines(df, ax1)
if title:
ax1.set_title(f'{subject} {df.iloc[-1]["date"]}: {df.iloc[-1]["training_status"]}')
# Put a legend below current axis
box = ax1.get_position()
ax1.set_position([box.x0, box.y0 + box.height * 0.1,
box.width, box.height * 0.9])
if legend:
ax1.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
fancybox=True, shadow=True, ncol=5)
return ax1
[docs]def add_training_lines(df, ax):
status = df.drop_duplicates(subset='training_status', keep='first')
for _, st in status.iterrows():
if TRAINING_STATUS[st['training_status']][0] <= 0:
continue
ax.axvline(datetime.strptime(st['date'], '%Y-%m-%d'), linewidth=2,
color=np.array(TRAINING_STATUS[st['training_status']][1]) / 255, label=st['training_status'])
return ax
[docs]def make_plots(session_path, one, df=None, save=False, upload=False):
subject = one.path2ref(session_path)['subject']
subj_path = session_path.parent.parent
df = load_existing_dataframe(subj_path) if df is None else df
df = df[df['task_protocol'] != 'habituation']
if len(df) == 0:
return
ax1 = plot_trial_count_and_session_duration(df, subject)
ax2 = plot_performance_easy_median_reaction_time(df, subject)
ax3 = plot_heatmap_performance_over_days(df, subject)
ax4 = plot_fit_params(df, subject)
ax5 = plot_psychometric_curve(df, subject, one)
outputs = []
if save:
save_path = Path(subj_path)
save_name = save_path.joinpath('subj_trial_count_session_duration.png')
outputs.append(save_name)
ax1.get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_performance_easy_reaction_time.png')
outputs.append(save_name)
ax2.get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_performance_heatmap.png')
outputs.append(save_name)
ax3.get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_psychometric_fit_params.png')
outputs.append(save_name)
ax4[0].get_figure().savefig(save_name, bbox_inches='tight')
save_name = save_path.joinpath('subj_psychometric_curve.png')
outputs.append(save_name)
ax5.get_figure().savefig(save_name, bbox_inches='tight')
if upload:
subj = one.alyx.rest('subjects', 'list', nickname=subject)[0]
snp = ReportSnapshot(session_path, subj['id'], content_type='subject', one=one)
snp.outputs = outputs
snp.register_images(widths=['orig'])