Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 25 additions & 23 deletions neo/io/klustakwikio.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import glob
import logging
import os.path
from pathlib import Path
import shutil

# note neo.core need only numpy and quantitie
# note neo.core need only numpy and quantities
import numpy as np


Expand Down Expand Up @@ -87,27 +87,31 @@ class KlustaKwikIO(BaseIO):
extensions = ['fet', 'clu', 'res', 'spk']

# Operates on directories
mode = 'file'
mode = 'dir'

def __init__(self, filename, sampling_rate=30000.):
def __init__(self, dirname, sampling_rate=30000.):
"""Create a new IO to operate on a directory

filename : the directory to contain the files
basename : string, basename of KlustaKwik format, or None
dirname : the directory to contain the files
sampling_rate : in Hz, necessary because the KlustaKwik files
stores data in samples.
"""
BaseIO.__init__(self)
# self.filename = os.path.normpath(filename)
self.filename, self.basename = os.path.split(os.path.abspath(filename))
self.dirname = Path(dirname)
# in case no basename is provided
if self.dirname.is_dir():
self.session_dir = self.dirname
else:
self.session_dir = self.dirname.parent
self.basename = self.dirname.name
self.sampling_rate = float(sampling_rate)

# error check
if not os.path.isdir(self.filename):
raise ValueError("filename must be a directory")
if not self.session_dir.is_dir():
raise ValueError("dirname must be in an existing directory")

# initialize a helper object to parse filenames
self._fp = FilenameParser(dirname=self.filename, basename=self.basename)
self._fp = FilenameParser(dirname=self.session_dir, basename=self.basename)

def read_block(self, lazy=False):
"""Returns a Block containing spike information.
Expand All @@ -130,7 +134,7 @@ def read_block(self, lazy=False):
return block

# Create a single segment to hold all of the data
seg = Segment(name='seg0', index=0, file_origin=self.filename)
seg = Segment(name='seg0', index=0, file_origin=str(self.session_dir / self.basename))
block.segments.append(seg)

# Load spike times from each group and store in a dict, keyed
Expand Down Expand Up @@ -367,15 +371,13 @@ def _make_all_file_handles(self, block):

def _new_group(self, id_group, nbClusters):
# generate filenames
fetfilename = os.path.join(self.filename,
self.basename + ('.fet.%d' % id_group))
clufilename = os.path.join(self.filename,
self.basename + ('.clu.%d' % id_group))
fetfilename = self.session_dir / (self.basename + ('.fet.%d' % id_group))
clufilename = self.session_dir / (self.basename + ('.clu.%d' % id_group))

# back up before overwriting
if os.path.exists(fetfilename):
if fetfilename.exists():
shutil.copyfile(fetfilename, fetfilename + '~')
if os.path.exists(clufilename):
if clufilename.exists():
shutil.copyfile(clufilename, clufilename + '~')

# create file handles
Expand Down Expand Up @@ -406,12 +408,12 @@ def __init__(self, dirname, basename=None):
will be used. An error is raised if files with multiple basenames
exist in the directory.
"""
self.dirname = os.path.normpath(dirname)
self.dirname = Path(dirname).absolute()
self.basename = basename

# error check
if not os.path.isdir(self.dirname):
raise ValueError("filename must be a directory")
if not self.dirname.is_dir():
raise ValueError("dirname must be a directory")

def read_filenames(self, typestring='fet'):
"""Returns filenames in the data directory matching the type.
Expand All @@ -430,13 +432,13 @@ def read_filenames(self, typestring='fet'):
a sequence of digits are valid. The digits are converted to an integer
and used as the group number.
"""
all_filenames = glob.glob(os.path.join(self.dirname, '*'))
all_filenames = self.dirname.glob('*')

# Fill the dict with valid filenames
d = {}
for v in all_filenames:
# Test whether matches format, ie ends with digits
split_fn = os.path.split(v)[1]
split_fn = v.name
m = glob.re.search((r'^(\w+)\.%s\.(\d+)$' % typestring), split_fn)
if m is not None:
# get basename from first hit if not specified
Expand Down
13 changes: 8 additions & 5 deletions neo/io/neuroshareapiio.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,26 +77,29 @@ class NeuroshareapiIO(BaseIO):
# This object operates on neuroshare files
mode = "file"

def __init__(self, filename=None, dllpath=None):
def __init__(self, filename=None, dllname=None):
"""
Arguments:
filename : the filename
dllname: the path of the library to use for reading
The init function will run automatically upon calling of the class, as
in: test = MultichannelIO(filename = filetoberead.mcd), therefore the first
operations with the file are set here, so that the user doesn't have to
remember to use another method, than the ones defined in the NEO library

"""
BaseIO.__init__(self)
self.filename = filename
self.filename = str(filename)
# set the flags for each event type
eventID = 1
analogID = 2
epochID = 3
# if a filename was given, create a dictionary with information that will
# be needed later on.
if self.filename is not None:
if dllpath is not None:
if dllname is not None:
# converting to string to also accept pathlib objects
dllpath = str(dllname)
name = os.path.splitext(os.path.basename(dllpath))[0]
library = ns.Library(name, dllpath)
else:
Expand Down Expand Up @@ -330,13 +333,13 @@ def read_spiketrain(self,
numIndx = endat - startat
# get the end point using segment duration
# create a numpy empty array to store the waveforms
waveforms = np.array(np.zeros([numIndx, tempSpks.max_sample_count]))
waveforms = np.array(np.zeros([numIndx, 1, tempSpks.max_sample_count]))
# loop through the data from the specific channel index
for i in range(startat, endat, 1):
# get cutout, timestamp, cutout duration, and spike unit
tempCuts, timeStamp, duration, unit = tempSpks.get_data(i)
# save the cutout in the waveform matrix
waveforms[i] = tempCuts[0]
waveforms[i, 0, :] = tempCuts[0]
# append time stamp to list
times.append(timeStamp)

Expand Down
2 changes: 1 addition & 1 deletion neo/io/neurosharectypesio.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, filename='', dllname=''):
"""
BaseIO.__init__(self)
self.dllname = dllname
self.filename = filename
self.filename = str(filename)

def read_segment(self, import_neuroshare_segment=True,
lazy=False):
Expand Down
Loading