Skip to content

Commit

Permalink
In sonata module, fully implement filtering in node sets based on nod…
Browse files Browse the repository at this point in the history
…e population attributes; support step current injection
  • Loading branch information
apdavison committed Mar 28, 2019
1 parent 1f98a6d commit a36629c
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 65 deletions.
20 changes: 20 additions & 0 deletions pyNN/common/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,3 +1451,23 @@ def describe(self, template='assembly_default.txt', engine='default'):
context = {"label": self.label,
"populations": [p.describe(template=None) for p in self.populations]}
return descriptions.render(engine, template, context)

def get_annotations(self, annotation_keys, simplify=True):
"""
Get the values of the given annotations for each population in the Assembly.
"""
if isinstance(annotation_keys, basestring):
annotation_keys = (annotation_keys,)
annotations = defaultdict(list)

for key in annotation_keys:
is_array_annotation = False
for p in self.populations:
annotation = p.annotations[key]
annotations[key].append(annotation)
is_array_annotation = isinstance(annotation, numpy.ndarray)
if is_array_annotation:
annotations[key] = numpy.hstack(annotations[key])
if simplify:
annotations[key] = simplify_parameter_array(numpy.array(annotations[key]))
return annotations
37 changes: 31 additions & 6 deletions pyNN/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import sys
import inspect
from itertools import chain
try:
basestring
Expand All @@ -22,6 +24,23 @@ def __init__(self, *components):
self.views = set([])
self.assemblies = set([])
self.projections = set([])
self.add(*components)

@property
def sim(self):
"""Figure out which PyNN backend module this Network is using."""
# we assume there is only one. Could be mixed if using multiple simulators
# at once.
populations_module = inspect.getmodule(list(self.populations)[0].__class__)
return sys.modules[".".join(populations_module.__name__.split(".")[:-1])]

def count_neurons(self):
return sum(population.size for population in chain(self.populations))

def count_connections(self):
return sum(projection.size() for projection in chain(self.projections))

def add(self, *components):
for component in components:
if isinstance(component, Population):
self.populations.add(component)
Expand All @@ -37,18 +56,24 @@ def __init__(self, *components):
else:
raise TypeError()

def count_neurons(self):
return sum(population.size for population in chain(self.populations))

def count_connections(self):
return sum(projection.size() for projection in chain(self.projections))

def get_component(self, label):
for obj in chain(self.populations, self.views, self.assemblies, self.projections):
if obj.label == label:
return obj
return None

def filter(self, cell_types=None):
"""Return an Assembly of all components that have a cell type in the list"""
if cell_types is None:
raise NotImplementedError()
else:
if cell_types == "all":
return self.sim.Assembly(*(pop for pop in self.populations
if pop.celltype.injectable)) # or could use len(receptor_types) > 0
else:
return self.sim.Assembly(*(pop for pop in self.populations
if pop.celltype.__class__ in cell_types))

def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True):
for obj in chain(self.populations, self.assemblies):
if include_spike_source or obj.injectable: # spike sources are not injectable
Expand Down
187 changes: 128 additions & 59 deletions pyNN/serialization/sonata.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def write(self, blocks):
Write a list of Blocks to SONATA HDF5 files.
"""
if not os.path.isdir(self.base_dir):
os.makedirs(self.base_dir)
# Write spikes
spike_file_path = join(self.base_dir, self.spike_file)
spikes_file = h5py.File(spike_file_path, 'w')
Expand All @@ -131,36 +133,38 @@ def write(self, blocks):
file_path = join(self.base_dir, file_name)

signal_file = h5py.File(file_path, 'w')
population_name = self.node_sets[report_metadata["cells"]]["population"]
node_ids = self.node_sets[report_metadata["cells"]]["node_id"]
targets = self.node_sets[report_metadata["cells"]]
for block in blocks:
if block.name == population_name:
if len(block.segments) > 1:
raise NotImplementedError()
signal = block.segments[0].filter(name=report_metadata["variable_name"])
if len(signal) != 1:
raise NotImplementedError()

report_group = signal_file.create_group("report")
population_group = report_group.create_group(population_name)
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
dataset.attrs["units"] = signal[0].units.dimensionality.string
dataset.attrs["variable_name"] = report_metadata["variable_name"]
n = dataset.shape[1]
mapping_group = population_group.create_group("mapping")
mapping_group.create_dataset("node_ids", data=node_ids)
# "gids" not in the spec, but expected by some bmtk utils
mapping_group.create_dataset("gids", data=node_ids)
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
time_ds = mapping_group.create_dataset("time",
data=(float(signal[0].t_start),
float(signal[0].t_stop),
float(signal[0].sampling_period)))
time_ds.attrs["units"] = "ms"
logger.info("Wrote block {} to {}".format(block.name, file_path))
for (assembly, mask) in targets:
if block.name == assembly.label:
if len(block.segments) > 1:
raise NotImplementedError()
signal = block.segments[0].filter(name=report_metadata["variable_name"])
if len(signal) != 1:
raise NotImplementedError()

node_ids = np.arange(assembly.size)[mask]

report_group = signal_file.create_group("report")
population_group = report_group.create_group(assembly.label)
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
dataset.attrs["units"] = signal[0].units.dimensionality.string
dataset.attrs["variable_name"] = report_metadata["variable_name"]
n = dataset.shape[1]
mapping_group = population_group.create_group("mapping")
mapping_group.create_dataset("node_ids", data=node_ids)
# "gids" not in the spec, but expected by some bmtk utils
mapping_group.create_dataset("gids", data=node_ids)
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
time_ds = mapping_group.create_dataset("time",
data=(float(signal[0].t_start.rescale('ms')),
float(signal[0].t_stop.rescale('ms')),
float(signal[0].sampling_period.rescale('ms'))))
time_ds.attrs["units"] = "ms"
logger.info("Wrote block {} to {}".format(block.name, file_path))
signal_file.close()


Expand Down Expand Up @@ -232,6 +236,7 @@ def condense(value, types_array):
from "/edges/<population_name>/edge_type_id" that applies to this group.
Needed to construct parameter arrays.
"""
# todo: use lazyarray
if isinstance(value, np.ndarray):
return value
elif isinstance(value, dict):
Expand All @@ -240,7 +245,12 @@ def condense(value, types_array):
if np.all(value_array == value_array[0]):
return value_array[0]
else:
new_value = np.ones_like(types_array) * np.nan
if np.issubdtype(value_array.dtype, np.number):
new_value = np.ones_like(types_array) * np.nan
elif np.issubdtype(value_array.dtype, np.str_):
new_value = np.array(["UNDEFINED"] * types_array.size)
else:
raise TypeError("Cannot handle annotations that are neither numbers or strings")
for node_type_id, val in value.items():
new_value[types_array == node_type_id] = val
return new_value
Expand Down Expand Up @@ -584,10 +594,10 @@ def import_from_sonata(config_file, sim):
net = Network()
for node_population in sonata_node_populations:
assembly = node_population.to_assembly(sim)
net.assemblies.add(assembly)
net.add(assembly)
for edge_population in sonata_edge_populations:
projections = edge_population.to_projections(net, sim)
net.projections.update(projections)
net.add(*projections)

return net

Expand Down Expand Up @@ -777,7 +787,7 @@ def to_population(self, sim):
if name in cell_type_cls.default_parameters:
parameters[name] = condense(value, self.node_types_array)
else:
annotations[name] = value
annotations[name] = condense(value, self.node_types_array)
# todo: handle spatial structure - nodes_file["nodes"][np_label][ng_label]['x'], etc.

# temporary hack to work around problem with 300 Intfire cell example
Expand Down Expand Up @@ -1072,28 +1082,21 @@ def setup(self, sim):
self.sim = sim
sim.setup(timestep=self.run_config["dt"])

def _get_target(self, config, node_sets, net):
def _get_target(self, config, net):
if "node_set" in config: # input config
target = node_sets[config["node_set"]]
elif "cells" in config: # recording config
targets = self.node_set_map[config["node_set"]]
elif "cells" in config: # recording config
# inconsistency in SONATA spec? Why not call this "node_set" also?
target = node_sets[config["cells"]]
if "model_type" in target:
raise NotImplementedError()
if "location" in target:
raise NotImplementedError()
if "gids" in target:
raise NotImplementedError()
if "population" in target:
assembly = net.get_component(target["population"])
if "node_id" in target:
indices = target["node_id"]
assembly = assembly[indices]
return assembly
targets = self.node_set_map[config["cells"]]
return targets

def _set_input_spikes(self, input_config, node_sets, net):
def _set_input_spikes(self, input_config, net):
# determine which assembly the spikes are for
assembly = self._get_target(input_config, node_sets, net)
targets = self._get_target(input_config, net)
if len(targets) != 1:
raise NotImplementedError()
base_assembly, mask = targets[0]
assembly = base_assembly[mask]
assert isinstance(assembly, self.sim.Assembly)

# load spike data from file
Expand All @@ -1111,22 +1114,88 @@ def _set_input_spikes(self, input_config, node_sets, net):
if len(spiketrains) != assembly.size:
raise NotImplementedError()
# todo: map cell ids in spikes file to ids/index in the population
#logger.info("SETTING SPIKETIMES")
#logger.info(spiketrains)
assembly.set(spike_times=[Sequence(st.times.rescale('ms').magnitude) for st in spiketrains])

def _set_input_currents(self, input_config, net):
# determine which assembly the currents are for
if "input_file" in input_config:
raise NotImplementedError("Current clamp from source file not yet supported.")
targets = self._get_target(input_config, net)
if len(targets) != 1:
raise NotImplementedError()
base_assembly, mask = targets[0]
assembly = base_assembly[mask]
assert isinstance(assembly, self.sim.Assembly)
amplitude = input_config["amp"] # nA
if self.target_simulator == "NEST":
amplitude = input_config["amp"]/1000.0 # pA

current_source = self.sim.DCSource(amplitude=amplitude,
start=input_config["delay"],
stop=input_config["delay"] + input_config["duration"])
assembly.inject(current_source)

def _calculate_node_set_map(self, net):
# for each "node set" in the config, determine which populations
# and node_ids it corresponds to
self.node_set_map = {}

# first handle implicit node sets - i.e. each node population is an implicit node set
for assembly in net.assemblies:
self.node_set_map[assembly.label] = [(assembly, slice(None))]

# now handle explictly-declared node sets
# todo: handle compound node sets
for node_set_name, node_set_definition in self.node_sets.items():
if isinstance(node_set_definition, dict): # basic node set
filters = node_set_definition
if "population" in filters:
assemblies = [net.get_component(filters["population"])]
else:
assemblies = list(net.assemblies)

self.node_set_map[node_set_name] = []
for assembly in assemblies:
mask = True
for attr_name, attr_value in filters.items():
print(attr_name, attr_value, "____")
if attr_name == "population":
continue
elif attr_name == "node_id":
# convert integer mask to boolean mask
node_mask = np.zeros(assembly.size, dtype=bool)
node_mask[attr_value] = True
mask = np.logical_and(mask, node_mask)
else:
values = assembly.get_annotations(attr_name)[attr_name]
mask = np.logical_and(mask, values == attr_value)
if isinstance(mask, (bool, np.bool_)) and mask == True:
mask = slice(None)
self.node_set_map[node_set_name].append((assembly, mask))
elif isinstance(node_set_definition, list): # compound node set
raise NotImplementedError("Compound node sets not yet supported")
else:
raise TypeError("Expecting node set definition to be a list or dict")

def execute(self, net):
self._calculate_node_set_map(net)

# create/configure inputs
for input_name, input_config in self.inputs.items():
if input_config["input_type"] != "spikes":
raise NotImplementedError()
self._set_input_spikes(input_config, self.node_sets, net)
if input_config["input_type"] == "spikes":
self._set_input_spikes(input_config, net)
elif input_config["input_type"] == "current_clamp":
self._set_input_currents(input_config, net)
else:
raise NotImplementedError("Only 'spikes' and 'current_clamp' supported")

# configure recording
net.record('spikes', include_spike_source=False) # SONATA requires that we record spikes from all non-virtual nodes
for report_name, report_config in self.reports.items():
assembly = self._get_target(report_config, self.node_sets, net)
assembly.record(report_config["variable_name"])
targets = self._get_target(report_config, net)
for (base_assembly, mask) in targets:
assembly = base_assembly[mask]
assembly.record(report_config["variable_name"])

# run simulation
self.sim.run(self.run_config["tstop"])
Expand All @@ -1141,7 +1210,7 @@ def execute(self, net):
spikes_file=self.output.get("spikes_file", "spikes.h5"),
spikes_sort_order=self.output["spikes_sort_order"],
report_config=self.reports,
node_sets=self.node_sets)
node_sets=self.node_set_map)
# todo: handle reports
net.write_data(io)

Expand Down

0 comments on commit a36629c

Please sign in to comment.