Skip to content

Commit

Permalink
Day 05 update
Browse files Browse the repository at this point in the history
  • Loading branch information
vsedov committed Dec 5, 2024
1 parent 56f6112 commit c5dc3c1
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 0 deletions.
109 changes: 109 additions & 0 deletions src/aoc/aoc2024/day_05.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Tuple

import numpy as np
import numpy.typing as npt
from numba import njit

from src.aoc.aoc2024 import YEAR, get_day
from src.aoc.aoc_helper import Aoc


@njit
def parse_rules(rules_array: npt.NDArray) -> npt.NDArray:
n = np.max(rules_array) + 1
dep_matrix = np.zeros((n, n), dtype=np.int8)
for rule in rules_array:
dep_matrix[rule[1], rule[0]] = 1
return dep_matrix


# @njit
# def check_order(deps: npt.NDArray, sequence: npt.NDArray) -> bool:
# n = len(sequence)
# for i in range(n):
# curr = sequence[i]
# for j in range(i + 1, n):
# if deps[curr, sequence[j]] == 1:
# return False
# return True
# @njit
# def check_order(deps: npt.NDArray, sequence: npt.NDArray) -> bool:
# seen = np.zeros_like(deps[0], dtype=np.bool_)
# for num in sequence:
# if np.any(deps[num] & seen):
# return False
# seen[num] = True
# return True
@njit
def check_order(deps: npt.NDArray, sequence: npt.NDArray) -> bool:
n = len(sequence)
seq_deps = deps[sequence]
for i in range(n - 1):
if np.any(seq_deps[i, sequence[i + 1 :]]):
return False
return True


@njit
def find_valid_order(deps: npt.NDArray, numbers: npt.NDArray) -> npt.NDArray:
n = len(numbers)
result = np.empty(n, dtype=np.int64)
used = np.zeros(n, dtype=np.bool_)
in_degree = np.zeros(n, dtype=np.int64)
for i in range(n):
in_degree[i] = np.sum(deps[numbers, numbers[i]])

# O(n) topological sort
for pos in range(n):
next_idx = np.where(~used & (in_degree == 0))[0][0]
result[pos] = numbers[next_idx]
used[next_idx] = True
# amortisation
for j in range(n):
if not used[j] and deps[numbers[next_idx], numbers[j]]:
in_degree[j] -= 1

return result


def parse(txt: str) -> Tuple[npt.NDArray, npt.NDArray]:
rules_txt, sequences_txt = txt.split("\n\n")
rules = np.array(
[[int(x) for x in line.split("|")] for line in rules_txt.splitlines()],
dtype=np.int64,
)

sequences = []
for line in sequences_txt.splitlines():
seq = np.fromstring(line, sep=",", dtype=np.int64)
sequences.append(seq)
return rules, np.array(sequences, dtype=object)


def part_a(txt: str) -> int:
rules_array, sequences = parse(txt)
deps = parse_rules(rules_array)

return sum(seq[len(seq) // 2] for seq in sequences if check_order(deps, seq))


def part_b(txt: str) -> int:
rules_array, sequences = parse(txt)
deps = parse_rules(rules_array)

total = 0
for seq in sequences:
if not check_order(deps, seq):
fixed_order = find_valid_order(deps, seq)
total += fixed_order[len(fixed_order) // 2]
return total


def main(txt: str) -> None:
print("part_a: ", part_a(txt))
print("part_b: ", part_b(txt))


if __name__ == "__main__":
aoc = Aoc(day=get_day(), years=YEAR)
aoc.run(main, submit=True, part="both", readme_update=True)
42 changes: 42 additions & 0 deletions tests/aoc2024/2024_day_05_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from src.aoc.aoc2024 import day_05 as d

TEST_INPUT = """
47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13
75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47
""".strip()


def test_a() -> None:
assert d.part_a(TEST_INPUT) == 143


def test_b() -> None:
assert d.part_b(TEST_INPUT) == 123

0 comments on commit c5dc3c1

Please sign in to comment.