Skip to content

Commit

Permalink
Merge pull request #638 from NeuralEnsemble/sonata
Browse files Browse the repository at this point in the history
Merge SONATA branch
  • Loading branch information
apdavison authored Dec 3, 2019
2 parents ed1cddd + 6d1e062 commit f740854
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 65 deletions.
20 changes: 20 additions & 0 deletions pyNN/common/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1470,3 +1470,23 @@ def describe(self, template='assembly_default.txt', engine='default'):
context = {"label": self.label,
"populations": [p.describe(template=None) for p in self.populations]}
return descriptions.render(engine, template, context)

def get_annotations(self, annotation_keys, simplify=True):
"""
Get the values of the given annotations for each population in the Assembly.
"""
if isinstance(annotation_keys, basestring):
annotation_keys = (annotation_keys,)
annotations = defaultdict(list)

for key in annotation_keys:
is_array_annotation = False
for p in self.populations:
annotation = p.annotations[key]
annotations[key].append(annotation)
is_array_annotation = isinstance(annotation, numpy.ndarray)
if is_array_annotation:
annotations[key] = numpy.hstack(annotations[key])
if simplify:
annotations[key] = simplify_parameter_array(numpy.array(annotations[key]))
return annotations
37 changes: 31 additions & 6 deletions pyNN/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

import sys
import inspect
from itertools import chain
try:
basestring
Expand All @@ -22,6 +24,23 @@ def __init__(self, *components):
self.views = set([])
self.assemblies = set([])
self.projections = set([])
self.add(*components)

@property
def sim(self):
"""Figure out which PyNN backend module this Network is using."""
# we assume there is only one. Could be mixed if using multiple simulators
# at once.
populations_module = inspect.getmodule(list(self.populations)[0].__class__)
return sys.modules[".".join(populations_module.__name__.split(".")[:-1])]

def count_neurons(self):
return sum(population.size for population in chain(self.populations))

def count_connections(self):
return sum(projection.size() for projection in chain(self.projections))

def add(self, *components):
for component in components:
if isinstance(component, Population):
self.populations.add(component)
Expand All @@ -37,18 +56,24 @@ def __init__(self, *components):
else:
raise TypeError()

def count_neurons(self):
return sum(population.size for population in chain(self.populations))

def count_connections(self):
return sum(projection.size() for projection in chain(self.projections))

def get_component(self, label):
for obj in chain(self.populations, self.views, self.assemblies, self.projections):
if obj.label == label:
return obj
return None

def filter(self, cell_types=None):
"""Return an Assembly of all components that have a cell type in the list"""
if cell_types is None:
raise NotImplementedError()
else:
if cell_types == "all":
return self.sim.Assembly(*(pop for pop in self.populations
if pop.celltype.injectable)) # or could use len(receptor_types) > 0
else:
return self.sim.Assembly(*(pop for pop in self.populations
if pop.celltype.__class__ in cell_types))

def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True):
for obj in chain(self.populations, self.assemblies):
if include_spike_source or obj.injectable: # spike sources are not injectable
Expand Down
187 changes: 128 additions & 59 deletions pyNN/serialization/sonata.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def write(self, blocks):
Write a list of Blocks to SONATA HDF5 files.
"""
if not os.path.isdir(self.base_dir):
os.makedirs(self.base_dir)
# Write spikes
spike_file_path = join(self.base_dir, self.spike_file)
spikes_file = h5py.File(spike_file_path, 'w')
Expand All @@ -131,36 +133,38 @@ def write(self, blocks):
file_path = join(self.base_dir, file_name)

signal_file = h5py.File(file_path, 'w')
population_name = self.node_sets[report_metadata["cells"]]["population"]
node_ids = self.node_sets[report_metadata["cells"]]["node_id"]
targets = self.node_sets[report_metadata["cells"]]
for block in blocks:
if block.name == population_name:
if len(block.segments) > 1:
raise NotImplementedError()
signal = block.segments[0].filter(name=report_metadata["variable_name"])
if len(signal) != 1:
raise NotImplementedError()

report_group = signal_file.create_group("report")
population_group = report_group.create_group(population_name)
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
dataset.attrs["units"] = signal[0].units.dimensionality.string
dataset.attrs["variable_name"] = report_metadata["variable_name"]
n = dataset.shape[1]
mapping_group = population_group.create_group("mapping")
mapping_group.create_dataset("node_ids", data=node_ids)
# "gids" not in the spec, but expected by some bmtk utils
mapping_group.create_dataset("gids", data=node_ids)
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
time_ds = mapping_group.create_dataset("time",
data=(float(signal[0].t_start),
float(signal[0].t_stop),
float(signal[0].sampling_period)))
time_ds.attrs["units"] = "ms"
logger.info("Wrote block {} to {}".format(block.name, file_path))
for (assembly, mask) in targets:
if block.name == assembly.label:
if len(block.segments) > 1:
raise NotImplementedError()
signal = block.segments[0].filter(name=report_metadata["variable_name"])
if len(signal) != 1:
raise NotImplementedError()

node_ids = np.arange(assembly.size)[mask]

report_group = signal_file.create_group("report")
population_group = report_group.create_group(assembly.label)
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
dataset.attrs["units"] = signal[0].units.dimensionality.string
dataset.attrs["variable_name"] = report_metadata["variable_name"]
n = dataset.shape[1]
mapping_group = population_group.create_group("mapping")
mapping_group.create_dataset("node_ids", data=node_ids)
# "gids" not in the spec, but expected by some bmtk utils
mapping_group.create_dataset("gids", data=node_ids)
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
time_ds = mapping_group.create_dataset("time",
data=(float(signal[0].t_start.rescale('ms')),
float(signal[0].t_stop.rescale('ms')),
float(signal[0].sampling_period.rescale('ms'))))
time_ds.attrs["units"] = "ms"
logger.info("Wrote block {} to {}".format(block.name, file_path))
signal_file.close()


Expand Down Expand Up @@ -232,6 +236,7 @@ def condense(value, types_array):
from "/edges/<population_name>/edge_type_id" that applies to this group.
Needed to construct parameter arrays.
"""
# todo: use lazyarray
if isinstance(value, np.ndarray):
return value
elif isinstance(value, dict):
Expand All @@ -240,7 +245,12 @@ def condense(value, types_array):
if np.all(value_array == value_array[0]):
return value_array[0]
else:
new_value = np.ones_like(types_array) * np.nan
if np.issubdtype(value_array.dtype, np.number):
new_value = np.ones_like(types_array) * np.nan
elif np.issubdtype(value_array.dtype, np.str_):
new_value = np.array(["UNDEFINED"] * types_array.size)
else:
raise TypeError("Cannot handle annotations that are neither numbers or strings")
for node_type_id, val in value.items():
new_value[types_array == node_type_id] = val
return new_value
Expand Down Expand Up @@ -584,10 +594,10 @@ def import_from_sonata(config_file, sim):
net = Network()
for node_population in sonata_node_populations:
assembly = node_population.to_assembly(sim)
net.assemblies.add(assembly)
net.add(assembly)
for edge_population in sonata_edge_populations:
projections = edge_population.to_projections(net, sim)
net.projections.update(projections)
net.add(*projections)

return net

Expand Down Expand Up @@ -777,7 +787,7 @@ def to_population(self, sim):
if name in cell_type_cls.default_parameters:
parameters[name] = condense(value, self.node_types_array)
else:
annotations[name] = value
annotations[name] = condense(value, self.node_types_array)
# todo: handle spatial structure - nodes_file["nodes"][np_label][ng_label]['x'], etc.

# temporary hack to work around problem with 300 Intfire cell example
Expand Down Expand Up @@ -1072,28 +1082,21 @@ def setup(self, sim):
self.sim = sim
sim.setup(timestep=self.run_config["dt"])

def _get_target(self, config, node_sets, net):
def _get_target(self, config, net):
if "node_set" in config: # input config
target = node_sets[config["node_set"]]
elif "cells" in config: # recording config
targets = self.node_set_map[config["node_set"]]
elif "cells" in config: # recording config
# inconsistency in SONATA spec? Why not call this "node_set" also?
target = node_sets[config["cells"]]
if "model_type" in target:
raise NotImplementedError()
if "location" in target:
raise NotImplementedError()
if "gids" in target:
raise NotImplementedError()
if "population" in target:
assembly = net.get_component(target["population"])
if "node_id" in target:
indices = target["node_id"]
assembly = assembly[indices]
return assembly
targets = self.node_set_map[config["cells"]]
return targets

def _set_input_spikes(self, input_config, node_sets, net):
def _set_input_spikes(self, input_config, net):
# determine which assembly the spikes are for
assembly = self._get_target(input_config, node_sets, net)
targets = self._get_target(input_config, net)
if len(targets) != 1:
raise NotImplementedError()
base_assembly, mask = targets[0]
assembly = base_assembly[mask]
assert isinstance(assembly, self.sim.Assembly)

# load spike data from file
Expand All @@ -1111,22 +1114,88 @@ def _set_input_spikes(self, input_config, node_sets, net):
if len(spiketrains) != assembly.size:
raise NotImplementedError()
# todo: map cell ids in spikes file to ids/index in the population
#logger.info("SETTING SPIKETIMES")
#logger.info(spiketrains)
assembly.set(spike_times=[Sequence(st.times.rescale('ms').magnitude) for st in spiketrains])

def _set_input_currents(self, input_config, net):
# determine which assembly the currents are for
if "input_file" in input_config:
raise NotImplementedError("Current clamp from source file not yet supported.")
targets = self._get_target(input_config, net)
if len(targets) != 1:
raise NotImplementedError()
base_assembly, mask = targets[0]
assembly = base_assembly[mask]
assert isinstance(assembly, self.sim.Assembly)
amplitude = input_config["amp"] # nA
if self.target_simulator == "NEST":
amplitude = input_config["amp"]/1000.0 # pA

current_source = self.sim.DCSource(amplitude=amplitude,
start=input_config["delay"],
stop=input_config["delay"] + input_config["duration"])
assembly.inject(current_source)

def _calculate_node_set_map(self, net):
# for each "node set" in the config, determine which populations
# and node_ids it corresponds to
self.node_set_map = {}

# first handle implicit node sets - i.e. each node population is an implicit node set
for assembly in net.assemblies:
self.node_set_map[assembly.label] = [(assembly, slice(None))]

# now handle explictly-declared node sets
# todo: handle compound node sets
for node_set_name, node_set_definition in self.node_sets.items():
if isinstance(node_set_definition, dict): # basic node set
filters = node_set_definition
if "population" in filters:
assemblies = [net.get_component(filters["population"])]
else:
assemblies = list(net.assemblies)

self.node_set_map[node_set_name] = []
for assembly in assemblies:
mask = True
for attr_name, attr_value in filters.items():
print(attr_name, attr_value, "____")
if attr_name == "population":
continue
elif attr_name == "node_id":
# convert integer mask to boolean mask
node_mask = np.zeros(assembly.size, dtype=bool)
node_mask[attr_value] = True
mask = np.logical_and(mask, node_mask)
else:
values = assembly.get_annotations(attr_name)[attr_name]
mask = np.logical_and(mask, values == attr_value)
if isinstance(mask, (bool, np.bool_)) and mask == True:
mask = slice(None)
self.node_set_map[node_set_name].append((assembly, mask))
elif isinstance(node_set_definition, list): # compound node set
raise NotImplementedError("Compound node sets not yet supported")
else:
raise TypeError("Expecting node set definition to be a list or dict")

def execute(self, net):
self._calculate_node_set_map(net)

# create/configure inputs
for input_name, input_config in self.inputs.items():
if input_config["input_type"] != "spikes":
raise NotImplementedError()
self._set_input_spikes(input_config, self.node_sets, net)
if input_config["input_type"] == "spikes":
self._set_input_spikes(input_config, net)
elif input_config["input_type"] == "current_clamp":
self._set_input_currents(input_config, net)
else:
raise NotImplementedError("Only 'spikes' and 'current_clamp' supported")

# configure recording
net.record('spikes', include_spike_source=False) # SONATA requires that we record spikes from all non-virtual nodes
for report_name, report_config in self.reports.items():
assembly = self._get_target(report_config, self.node_sets, net)
assembly.record(report_config["variable_name"])
targets = self._get_target(report_config, net)
for (base_assembly, mask) in targets:
assembly = base_assembly[mask]
assembly.record(report_config["variable_name"])

# run simulation
self.sim.run(self.run_config["tstop"])
Expand All @@ -1141,7 +1210,7 @@ def execute(self, net):
spikes_file=self.output.get("spikes_file", "spikes.h5"),
spikes_sort_order=self.output["spikes_sort_order"],
report_config=self.reports,
node_sets=self.node_sets)
node_sets=self.node_set_map)
# todo: handle reports
net.write_data(io)

Expand Down

0 comments on commit f740854

Please sign in to comment.