From f9b2cbf69a74172d266de6dfff6df9fdaa564c1f Mon Sep 17 00:00:00 2001 From: stanislas Date: Wed, 4 Dec 2024 16:14:01 +0100 Subject: [PATCH] day 4 - part 2 --- src/aoc_2024/day4/part2.py | 75 +++++++++++++++++++++++++++++++ tests/aoc_2024/day4/test_part2.py | 33 ++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 src/aoc_2024/day4/part2.py create mode 100644 tests/aoc_2024/day4/test_part2.py diff --git a/src/aoc_2024/day4/part2.py b/src/aoc_2024/day4/part2.py new file mode 100644 index 0000000..aed5b56 --- /dev/null +++ b/src/aoc_2024/day4/part2.py @@ -0,0 +1,75 @@ +from pathlib import Path + +import numpy as np + + +def is_diagonal_forwards_downwards(data, i, j): + if data[i - 1, j - 1] == "M": + if data[i + 1, j + 1] == "S": + return True + return False + + +def is_diagonal_forwards_upwards(data, i, j): + if data[i + 1, j - 1] == "M": + if data[i - 1, j + 1] == "S": + return True + return False + + +def is_diagonal_backwards_upwards(data, i, j): + if data[i + 1, j + 1] == "M": + if data[i - 1, j - 1] == "S": + return True + return False + + +def is_diagonal_backwards_downwards(data, i, j): + if data[i - 1, j + 1] == "M": + if data[i + 1, j - 1] == "S": + return True + return False + + +def is_diagonal1(data, i, j): + return is_diagonal_forwards_downwards(data, i, j) or is_diagonal_backwards_upwards( + data, i, j + ) + + +def is_diagonal2(data, i, j): + return is_diagonal_backwards_downwards(data, i, j) or is_diagonal_forwards_upwards( + data, i, j + ) + + +def find_x_mas(data: np.array): + score = 0 + for i in range(1, data.shape[0] - 1): + for j in range(1, data.shape[1] - 1): + if data[i, j] == "A": + if is_diagonal1(data, i, j) and is_diagonal2(data, i, j): + score += 1 + + return score + + +def main(file: Path) -> int: + data_str = file.read_text() + x_axis_size = data_str.find("\n") + y_axis_size = data_str.count("\n") + 1 + data = np.full((x_axis_size, y_axis_size), "") + i, j = 0, 0 + for character in data_str: + if character != "\n": + data[i, j] = character + j += 1 + else: + i += 1 + j = 0 + return find_x_mas(data) + + +if __name__ == "__main__": + input_file = Path(__file__).parent / "input-data" + print(main(input_file)) diff --git a/tests/aoc_2024/day4/test_part2.py b/tests/aoc_2024/day4/test_part2.py new file mode 100644 index 0000000..b48876b --- /dev/null +++ b/tests/aoc_2024/day4/test_part2.py @@ -0,0 +1,33 @@ +from pathlib import Path + +from aoc_2024.day4 import part2 + +import numpy as np + + +def test_find_all(): + data_str = """MMMSXXMASM +MSAMXMSMSA +AMXSXMAAMM +MSAMASMSMX +XMASAMXAMM +XXAMMXXAMA +SMSMSASXSS +SAXAMASAAA +MAMMMXMMMM +MXMXAXMASX""" + data = np.full((10, 10), "") + i, j = 0, 0 + for character in data_str: + if character != "\n": + data[i, j] = character + j += 1 + else: + i += 1 + j = 0 + assert part2.find_x_mas(data) == 9 + + +def test_main(): + data_file = Path(__file__).parent / "test-data" + assert part2.main(data_file) == 9