Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch for Multiple Python Modules #335

Merged
merged 2 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/pluginplay/module_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ class ModuleBase {
*/
runtime_type& get_runtime() const;

// Is this a Python module?
bool is_python() const { return m_is_python_; }

/** @brief Compares two ModuleBase instances for equality.
*
* Two ModuleBase instances are equivalent if their algorithm is
Expand Down
34 changes: 30 additions & 4 deletions src/pluginplay/detail_/module_manager_pimpl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ struct ModuleManagerPIMPL {
/// A pointer to a runtime
using runtime_ptr = std::shared_ptr<runtime_type>;

/// Type of a map from key to Python implementation
// TODO: remove when a more elegant solution is determined
using py_base_map = std::map<type::key, const_module_base_ptr>;

///@}

ModuleManagerPIMPL(runtime_ptr runtime) : m_runtime_(runtime) {}
Expand Down Expand Up @@ -121,11 +125,21 @@ struct ModuleManagerPIMPL {
base->set_cache(internal_cache);
base->set_runtime(m_runtime_);
base->set_uuid(uuid);
std::type_index type(base->type());
if(!m_bases.count(type)) m_bases[type] = base;
auto module_cache = m_caches.get_or_make_module_cache(key);
auto pimpl = std::make_unique<ModulePIMPL>(m_bases[type], module_cache);
auto ptr = std::make_shared<Module>(std::move(pimpl));
std::unique_ptr<ModulePIMPL> pimpl;
if(base->is_python()) {
// This is a hacky patch to allow multiple python modules to be
// added while avoiding the type_index collisions.
// TODO: remove when a more elegant solution is determined
m_py_bases[key] = base;
pimpl =
std::make_unique<ModulePIMPL>(m_py_bases[key], module_cache);
} else {
std::type_index type(base->type());
if(!m_bases.count(type)) m_bases[type] = base;
pimpl = std::make_unique<ModulePIMPL>(m_bases[type], module_cache);
}
auto ptr = std::make_shared<Module>(std::move(pimpl));
m_modules.emplace(std::move(key), ptr);
}

Expand Down Expand Up @@ -197,6 +211,13 @@ struct ModuleManagerPIMPL {
if(m_modules.size() != rhs.m_modules.size()) return false;
if(m_defaults.size() != rhs.m_defaults.size()) return false;

// TODO: Remove with the rest of the python hack
if(m_py_bases.size() != rhs.m_py_bases.size()) return false;
for(const auto& [k, v] : rhs.m_py_bases) {
if(!m_py_bases.count(k)) return false;
if(*m_py_bases.at(k) != *v) return false;
}

// Skip checking the values b/c implementations are compared by type
for(const auto& [k, v] : rhs.m_bases) {
if(!m_bases.count(k)) return false;
Expand Down Expand Up @@ -241,6 +262,11 @@ struct ModuleManagerPIMPL {
// These are the Modules in the state set by the user
module_map m_modules;

// Part of the hacky patch to make multiple python modules work
// TODO: remove when a more elegant solution is determined
// These are the Python Modules in their developer state
py_base_map m_py_bases;

// These are the results of the modules running in the user's states
cache::ModuleManagerCache m_caches;

Expand Down
11 changes: 4 additions & 7 deletions tests/python/doc_snippets/test_python_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,12 @@ def test_modules(self):
a = [2.0, 0.0, 0.0]

mm = pp.ModuleManager()
ppe.load_modules(mm)

# There's a bug here that's messing up the property types somehow.
# The property seems to get "stuck" on the first add_module call
mm.add_module("My Coulomb's Law", clf.CoulombsLaw())
# mm.add_module("My Force", clf.ClassicalForce())
# mm.change_submod("My Force", "electric field", "Coulomb's Law")
mm.add_module("My Force", clf.ClassicalForce())
mm.change_submod("My Force", "electric field", "My Coulomb's Law")

field = mm.at("My Coulomb's Law").run_as(ppe.ElectricField(), r, pvc)
self.assertTrue(field == [1.5, 0.0, 0.0])
# cforce = mm.at("My Force").run_as(ppe.Force(), q, m, a, pvc)
# self.assertTrue(cforce == [5.5, 0.0, 0.0])
cforce = mm.at("My Force").run_as(ppe.Force(), q, m, a, pvc)
self.assertTrue(cforce == [5.5, 0.0, 0.0])