diff --git a/backend/spellbook/variants.py b/backend/spellbook/variants.py index 696c2e4d..52341c3a 100644 --- a/backend/spellbook/variants.py +++ b/backend/spellbook/variants.py @@ -29,6 +29,8 @@ class Node: @dataclass class CardNode(Node): card: Card + features: list['FeatureNode'] + combos: list['ComboNode'] def __hash__(self): return hash(self.card) + 7 * hash('card') @@ -37,6 +39,7 @@ def __hash__(self): @dataclass class TemplateNode(Node): template: Template + combos: list['ComboNode'] def __hash__(self): return hash(self.template) + 7 * hash('template') @@ -226,9 +229,17 @@ class Graph: def __init__(self, data: Data): if data is not None: self.data = data - self.cnodes = dict[int, CardNode]((card.id, CardNode(card)) for card in data.cards) - self.tnodes = dict[int, TemplateNode]((template.id, TemplateNode(template)) for template in data.templates) - self.fnodes = dict[int, FeatureNode]((feature.id, FeatureNode(feature, [self.cnodes[i.id] for i in feature.cards.all()], [], [])) for feature in data.features) + self.cnodes = dict[int, CardNode]((card.id, CardNode(card, [], [])) for card in data.cards) + self.tnodes = dict[int, TemplateNode]((template.id, TemplateNode(template, [])) for template in data.templates) + self.fnodes = dict[int, FeatureNode]() + for feature in data.features: + node = FeatureNode(feature, + cards=[self.cnodes[i.id] for i in feature.cards.all()], + produced_by_combos=[], + needed_by_combos=[]) + self.fnodes[feature.id] = node + for card in feature.cards.all(): + self.cnodes[card.id].features.append(node) self.bnodes = dict[int, ComboNode]() for combo in data.combos: node = ComboNode(combo, @@ -243,6 +254,10 @@ def __init__(self, data: Data): for feature in combo.needs.all(): featureNode = self.fnodes[feature.id] featureNode.needed_by_combos.append(node) + for card in combo.uses.all(): + self.cnodes[card.id].combos.append(node) + for template in combo.requires.all(): + self.tnodes[template.id].combos.append(node) else: raise Exception('Invalid arguments') @@ -275,6 +290,10 @@ def variants(self, combo_id: int) -> Iterable[VariantIngredients]: for node in nodes.copy(): ups = set() match node: + case CardNode(_, _, _): + ups = self._card_nodes_up(node) + case TemplateNode(_, _): + ups = self._template_nodes_up(node) case FeatureNode(_, _, _): ups = self._feature_nodes_up(node) case ComboNode(_, _, _, _): @@ -372,6 +391,34 @@ def _feature_nodes_up(self, feature: FeatureNode) -> set[Node]: other.update(self._combo_nodes_up(c)) c.state = NodeState.VISITED return combos | other + + def _card_nodes_up(self, card: CardNode) -> set[Node]: + card.state = NodeState.VISITING + features: set[FeatureNode] = set() + combos: set[ComboNode] = set() + other: set[Node] = set() + for f in card.features: + if f.state == NodeState.NOT_VISITED: + features.add(f) + other.update(self._feature_nodes_up(f)) + f.state = NodeState.VISITED + for c in card.combos: + if c.state == NodeState.NOT_VISITED: + combos.add(c) + other.update(self._combo_nodes_up(c)) + c.state = NodeState.VISITED + return features | combos | other + + def _template_nodes_up(self, template: TemplateNode) -> set[Node]: + template.state = NodeState.VISITING + combos: set[ComboNode] = set() + other: set[Node] = set() + for c in template.combos: + if c.state == NodeState.NOT_VISITED: + combos.add(c) + other.update(self._combo_nodes_up(c)) + c.state = NodeState.VISITED + return combos | other def unique_id_from_cards_and_templates_ids(cards: list[int], templates: list[int]) -> str: