Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
'matplotlib',
'pandas',
'requests',
'numba',
'joblib'
],
extras_require={
'deepdish': ['deepdish'],
Expand Down
90 changes: 67 additions & 23 deletions spykes/plot/neurovis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import absolute_import

import numpy as np
from warnings import warn

import matplotlib.pyplot as plt
import numba
import numpy as np
from xarray import DataArray

from .. import utils
from ..config import DEFAULT_POPULATION_COLORS
Expand All @@ -17,12 +21,20 @@ class NeuroVis(object):
spiketimes (Numpy array): Array of spike times.
name (str): The name of the visualization.
'''
def __init__(self, spiketimes, name='neuron'):

def __init__(self, spiketimes=None, name='neuron', lfp: DataArray = None):
self.name = name
self.spiketimes = np.squeeze(np.sort(spiketimes))
n_seconds = (self.spiketimes[-1] - self.spiketimes[0])
n_spikes = np.size(spiketimes)
self.firingrate = (n_spikes / n_seconds)
if spiketimes is None and lfp is None:
raise ValueError('Either spiketimes or lfp should not be none!')
if spiketimes is None:
self.spiketimes = None
self.firingrate = None
else:
self.spiketimes = np.squeeze(np.sort(spiketimes))
n_seconds = (self.spiketimes[-1] - self.spiketimes[0])
n_spikes = np.size(spiketimes)
self.firingrate = (n_spikes / n_seconds)
self.lfp = lfp

def get_raster(self, event=None, conditions=None, df=None,
window=[-100, 500], binsize=10, plot=True,
Expand Down Expand Up @@ -55,6 +67,15 @@ def get_raster(self, event=None, conditions=None, df=None,
raster for each unique entry of :data:`df['conditions']`.
'''

@numba.jit(forceobj=True)
def searchsorted_jit(_a, _v):
return np.searchsorted(_a, _v)

@numba.jit()
def numba_histogram(v, b):
"""njit won't work since np.histogram does not convert types properly"""
return np.histogram(v, b)

if not type(df) is dict:
df = df.reset_index()

Expand All @@ -64,7 +85,7 @@ def get_raster(self, event=None, conditions=None, df=None,
# Get a set of binary indicators for trials of interest
if conditions:
trials = dict()
for cond_id in np.sort(df[conditions].unique()):
for cond_id in (df[conditions].unique()):
trials[cond_id] = \
np.where((df[conditions] == cond_id).apply(
lambda x: (0, 1)[x]).values)[0]
Expand All @@ -89,24 +110,46 @@ def get_raster(self, event=None, conditions=None, df=None,
raster = []

bin_template = 1e-3 * \
np.arange(window[0], window[1] + binsize, binsize)
for event_time in selected_events:
bins = event_time + bin_template
np.arange(window[0], window[1] + binsize, binsize)

# consider only spikes within window
if self.lfp is not None and (window[1]/1000 + selected_events.max()) > self.lfp['Time'].values.max():
raise ValueError()

# consider only spikes within window
searchsorted_idx = np.squeeze(np.searchsorted(self.spiketimes,
[event_time + 1e-3 *
window[0],
event_time + 1e-3 *
window[1]]))
for event_time in selected_events.dropna():

# bin the spikes into time bins
bins = event_time + bin_template

spike_counts = np.histogram(
self.spiketimes[searchsorted_idx[0]:searchsorted_idx[1]],
bins)[0]
# Skip histogram if this neuron has no spikes in window:
if (self.lfp is None and
(min(self.spiketimes) > (event_time + 1e-3 * window[1]) or # 1st spike after end of window
max(self.spiketimes) < (event_time + 1e-3 * window[0])) # last spike before start of window
):
signal = np.zeros(len(bins) - 1)
elif self.spiketimes is not None:
from numba.typed import List
searchsorted_idx = np.squeeze(searchsorted_jit(self.spiketimes, # mmyros original np.searchsorted
List([event_time + 1e-3 *
window[0],
event_time + 1e-3 *
window[1]])))

# bin the spikes into time bins
signal = numba_histogram(self.spiketimes[searchsorted_idx[0]:
searchsorted_idx[1]],
bins)[0]
else:
# raise error if sel hits edge
if event_time+(window[1]/1000) > self.lfp['Time'].max().values:
warn(f'Event time {event_time+(window[1]/1000)} is greater than recording span {self.lfp["Time"].max().values}')
else:
signal = self.lfp.sel({'Time': slice(event_time+window[0]/1000, # Convert "window" to seconds
event_time+window[1]/1000)}).data

raster.append(spike_counts)
raster.append(signal)
if raster and self.lfp is not None:
min_size = min([sig.shape[0] for sig in raster])
raster = [sig[:min_size] for sig in raster]

rasters['data'][cond_id] = np.array(raster)

Expand All @@ -118,7 +161,7 @@ def get_raster(self, event=None, conditions=None, df=None,
# Return all the rasters
return rasters

def plot_raster(self, rasters, cond_id=None, cond_name=None, sortby=None,
def plot_raster(self, rasters=None, cond_id=None, cond_name=None, sortby=None,
sortorder='descend', cmap='Greys', has_title=True):
'''Plot a single raster.

Expand Down Expand Up @@ -192,7 +235,7 @@ def plot_raster(self, rasters, cond_id=None, cond_name=None, sortby=None,
else:
print('No trials for this condition!')

def get_psth(self, event=None, df=None, conditions=None, cond_id=None,
def get_psth(self, event=None, df=None, conditions=None,
window=[-100, 500], binsize=10, plot=True, event_name=None,
conditions_names=None, ylim=None,
colors=DEFAULT_POPULATION_COLORS):
Expand Down Expand Up @@ -228,6 +271,7 @@ def get_psth(self, event=None, df=None, conditions=None, cond_id=None,

window = [np.floor(window[0] / binsize) * binsize,
np.ceil(window[1] / binsize) * binsize]

# Get all the rasters first
rasters = self.get_raster(event=event, df=df,
conditions=conditions,
Expand Down
40 changes: 29 additions & 11 deletions spykes/plot/popvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
from __future__ import division
from __future__ import print_function

import warnings

import numpy as np
import matplotlib.pyplot as plt
import copy
from collections import defaultdict

from fractions import gcd

try:
from fractions import gcd
except ImportError:
from math import gcd
from .neurovis import NeuroVis
from .. import utils
from ..config import DEFAULT_POPULATION_COLORS
Expand Down Expand Up @@ -44,7 +47,7 @@ def n_neurons(self):

def get_all_psth(self, event=None, df=None, conditions=None,
window=[-100, 500], binsize=10, conditions_names=None,
plot=True, colors=DEFAULT_PSTH_COLORS):
plot=True, colors=DEFAULT_PSTH_COLORS, use_parallel=True):
'''Iterates through all neurons and computes their PSTH's.

Args:
Expand All @@ -62,6 +65,7 @@ def get_all_psth(self, event=None, df=None, conditions=None,
Default are the unique values in :data:`df['conditions']`.
plot (bool): If set, automatically plot; otherwise, don't.
colors (list): List of colors for heatmap (only if plot is True).
use_parallel (bool): If set, parallelize PSTH computation

Returns:
dict: With keys :data:`event`, :data:`conditions`, :data:`binsize`,
Expand All @@ -70,6 +74,7 @@ def get_all_psth(self, event=None, df=None, conditions=None,
each :data:`cond_id` that correspond to the means for that
condition.
'''

all_psth = {
'window': window,
'binsize': binsize,
Expand All @@ -78,17 +83,30 @@ def get_all_psth(self, event=None, df=None, conditions=None,
'data': defaultdict(list),
}

for i, neuron in enumerate(self.neuron_list):
psth = neuron.get_psth(
if use_parallel:
from joblib import Parallel, delayed
psths = Parallel(n_jobs=-1)(delayed(neuron.get_psth)(
event=event,
df=df,
conditions=conditions,
window=window,
binsize=binsize,
plot=False,
)
plot=False) for neuron in self.neuron_list)
else:
psths=[neuron.get_psth(
event=event,
df=df,
conditions=conditions,
window=window,
binsize=binsize,
plot=False) for neuron in self.neuron_list]

for psth in psths:
for cond_id in np.sort(list(psth['data'].keys())):
all_psth['data'][cond_id].append(psth['data'][cond_id]['mean'])
all_psth['data'][cond_id].append(
psth['data'][cond_id]['mean']
)


for cond_id in np.sort(list(all_psth['data'].keys())):
all_psth['data'][cond_id] = np.stack(all_psth['data'][cond_id])
Expand Down Expand Up @@ -203,7 +221,7 @@ def plot_population_psth(self, all_psth=None, event=None, df=None,
conditions=None, cond_id=None, window=[-100, 500],
binsize=10, conditions_names=None,
event_name='event_onset', ylim=None,
colors=DEFAULT_POPULATION_COLORS, show=False):
colors=DEFAULT_POPULATION_COLORS, show=True):
'''Plots population PSTH's.

This involves two steps. First, it normalizes each neuron's PSTH across
Expand Down Expand Up @@ -315,6 +333,6 @@ def _get_normed_data(self, data, normalize):
elif normalize is None:
norm_factors = np.ones([data.shape[0], 1])
else:
raise ValueError('Invalid norm factors: {}'.format(norm_factors))
raise ValueError('Invalid norm factors: {}'.format(normalize))

return data / norm_factors
2 changes: 1 addition & 1 deletion tests/ml/test_strf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_strf():
strf_.visualize_gaussian_basis(spatial_basis)

# Design temporal basis
time_points = np.linspace(-100., 100., 10.)
time_points = np.linspace(-100., 100., 10)
centers = [-75., -50., -25., 0, 25., 50., 75.]
width = 10.
temporal_basis = strf_.make_raised_cosine_temporal_basis(
Expand Down
109 changes: 108 additions & 1 deletion tests/plot/test_popvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,115 @@ def test_popvis():
df[condition_num] = np.random.rand(num_trials)
df[condition_bool] = df[condition_num] < 0.5

df = pd.DataFrame()

event = 'anotherRealCueTime'
condition_num = 'responseNum'
condition_bool = 'responseBool'

start_times = rand_spiketimes[0::int(num_spikes/num_trials)]

df['trialStart'] = start_times

df[event] = df['trialStart'] + np.random.rand(num_trials)

event_times = ((start_times[:-1] + start_times[1:]) / 2).tolist()
event_times.append(start_times[-1] + np.random.rand())

df[event] = event_times

df[condition_num] = np.random.rand(num_trials)
df[condition_bool] = df[condition_num] < 0.5

all_psth = pop.get_all_psth(event=event, conditions=condition_bool, df=df,
plot=True, binsize=binsize, window=window,
use_parallel=False)

assert_equal(all_psth['window'], window)
assert_equal(all_psth['binsize'], binsize)
assert_equal(all_psth['event'], event)
assert_equal(all_psth['conditions'], condition_bool)

for cond_id in all_psth['data'].keys():

assert_true(cond_id in df[condition_bool])
assert_equal(all_psth['data'][cond_id].shape[0],
num_neurons)
assert_equal(all_psth['data'][cond_id].shape[1],
(window[1] - window[0]) / binsize)

assert_raises(ValueError, pop.plot_heat_map, all_psth,
sortby=list(range(num_trials-1)))

pop.plot_heat_map(all_psth, sortby=list(range(num_trials)))
pop.plot_heat_map(all_psth, sortby='rate')
pop.plot_heat_map(all_psth, sortby='latency')
pop.plot_heat_map(all_psth, sortorder='ascend')

pop.plot_population_psth(all_psth=all_psth)

def test_popvis_parallel():

np.random.seed()

num_spikes = 500
num_trials = 10

binsize = 100
window = [-500, 1500]

num_neurons = 10
neuron_list = list()

for i in range(num_neurons):
rand_spiketimes = num_trials * np.random.rand(num_spikes)
neuron_list.append(NeuroVis(rand_spiketimes))

pop = PopVis(neuron_list)

df = pd.DataFrame()

event = 'realCueTime'
condition_num = 'responseNum'
condition_bool = 'responseBool'

start_times = rand_spiketimes[0::int(num_spikes/num_trials)]

df['trialStart'] = start_times

df[event] = df['trialStart'] + np.random.rand(num_trials)

event_times = ((start_times[:-1] + start_times[1:]) / 2).tolist()
event_times.append(start_times[-1] + np.random.rand())

df[event] = event_times

df[condition_num] = np.random.rand(num_trials)
df[condition_bool] = df[condition_num] < 0.5

df = pd.DataFrame()

event = 'anotherRealCueTime'
condition_num = 'responseNum'
condition_bool = 'responseBool'

start_times = rand_spiketimes[0::int(num_spikes/num_trials)]

df['trialStart'] = start_times

df[event] = df['trialStart'] + np.random.rand(num_trials)

event_times = ((start_times[:-1] + start_times[1:]) / 2).tolist()
event_times.append(start_times[-1] + np.random.rand())

df[event] = event_times

df[condition_num] = np.random.rand(num_trials)
df[condition_bool] = df[condition_num] < 0.5

all_psth = pop.get_all_psth(event=event, conditions=condition_bool, df=df,
plot=True, binsize=binsize, window=window)
plot=True, binsize=binsize, window=window,
use_parallel=True)

assert_equal(all_psth['window'], window)
assert_equal(all_psth['binsize'], binsize)
Expand Down