diff --git a/include/pluginplay/module_base.hpp b/include/pluginplay/module_base.hpp index cd7b3cebb..c57a36b20 100644 --- a/include/pluginplay/module_base.hpp +++ b/include/pluginplay/module_base.hpp @@ -290,6 +290,9 @@ class ModuleBase { */ runtime_type& get_runtime() const; + // Is this a Python module? + bool is_python() const { return m_is_python_; } + /** @brief Compares two ModuleBase instances for equality. * * Two ModuleBase instances are equivalent if their algorithm is diff --git a/src/pluginplay/detail_/module_manager_pimpl.hpp b/src/pluginplay/detail_/module_manager_pimpl.hpp index 8918c6e76..548174fd9 100644 --- a/src/pluginplay/detail_/module_manager_pimpl.hpp +++ b/src/pluginplay/detail_/module_manager_pimpl.hpp @@ -61,6 +61,10 @@ struct ModuleManagerPIMPL { /// A pointer to a runtime using runtime_ptr = std::shared_ptr; + /// Type of a map from key to Python implementation + // TODO: remove when a more elegant solution is determined + using py_base_map = std::map; + ///@} ModuleManagerPIMPL(runtime_ptr runtime) : m_runtime_(runtime) {} @@ -121,11 +125,21 @@ struct ModuleManagerPIMPL { base->set_cache(internal_cache); base->set_runtime(m_runtime_); base->set_uuid(uuid); - std::type_index type(base->type()); - if(!m_bases.count(type)) m_bases[type] = base; auto module_cache = m_caches.get_or_make_module_cache(key); - auto pimpl = std::make_unique(m_bases[type], module_cache); - auto ptr = std::make_shared(std::move(pimpl)); + std::unique_ptr pimpl; + if(base->is_python()) { + // This is a hacky patch to allow multiple python modules to be + // added while avoiding the type_index collisions. + // TODO: remove when a more elegant solution is determined + m_py_bases[key] = base; + pimpl = + std::make_unique(m_py_bases[key], module_cache); + } else { + std::type_index type(base->type()); + if(!m_bases.count(type)) m_bases[type] = base; + pimpl = std::make_unique(m_bases[type], module_cache); + } + auto ptr = std::make_shared(std::move(pimpl)); m_modules.emplace(std::move(key), ptr); } @@ -197,6 +211,13 @@ struct ModuleManagerPIMPL { if(m_modules.size() != rhs.m_modules.size()) return false; if(m_defaults.size() != rhs.m_defaults.size()) return false; + // TODO: Remove with the rest of the python hack + if(m_py_bases.size() != rhs.m_py_bases.size()) return false; + for(const auto& [k, v] : rhs.m_py_bases) { + if(!m_py_bases.count(k)) return false; + if(*m_py_bases.at(k) != *v) return false; + } + // Skip checking the values b/c implementations are compared by type for(const auto& [k, v] : rhs.m_bases) { if(!m_bases.count(k)) return false; @@ -241,6 +262,11 @@ struct ModuleManagerPIMPL { // These are the Modules in the state set by the user module_map m_modules; + // Part of the hacky patch to make multiple python modules work + // TODO: remove when a more elegant solution is determined + // These are the Python Modules in their developer state + py_base_map m_py_bases; + // These are the results of the modules running in the user's states cache::ModuleManagerCache m_caches; diff --git a/tests/python/doc_snippets/test_python_modules.py b/tests/python/doc_snippets/test_python_modules.py index 69269f891..d3def8709 100644 --- a/tests/python/doc_snippets/test_python_modules.py +++ b/tests/python/doc_snippets/test_python_modules.py @@ -28,15 +28,12 @@ def test_modules(self): a = [2.0, 0.0, 0.0] mm = pp.ModuleManager() - ppe.load_modules(mm) - # There's a bug here that's messing up the property types somehow. - # The property seems to get "stuck" on the first add_module call mm.add_module("My Coulomb's Law", clf.CoulombsLaw()) - # mm.add_module("My Force", clf.ClassicalForce()) - # mm.change_submod("My Force", "electric field", "Coulomb's Law") + mm.add_module("My Force", clf.ClassicalForce()) + mm.change_submod("My Force", "electric field", "My Coulomb's Law") field = mm.at("My Coulomb's Law").run_as(ppe.ElectricField(), r, pvc) self.assertTrue(field == [1.5, 0.0, 0.0]) - # cforce = mm.at("My Force").run_as(ppe.Force(), q, m, a, pvc) - # self.assertTrue(cforce == [5.5, 0.0, 0.0]) + cforce = mm.at("My Force").run_as(ppe.Force(), q, m, a, pvc) + self.assertTrue(cforce == [5.5, 0.0, 0.0])