diff --git a/src/aoc_2024/day4/part2.py b/src/aoc_2024/day4/part2.py new file mode 100644 index 0000000..aed5b56 --- /dev/null +++ b/src/aoc_2024/day4/part2.py @@ -0,0 +1,75 @@ +from pathlib import Path + +import numpy as np + + +def is_diagonal_forwards_downwards(data, i, j): + if data[i - 1, j - 1] == "M": + if data[i + 1, j + 1] == "S": + return True + return False + + +def is_diagonal_forwards_upwards(data, i, j): + if data[i + 1, j - 1] == "M": + if data[i - 1, j + 1] == "S": + return True + return False + + +def is_diagonal_backwards_upwards(data, i, j): + if data[i + 1, j + 1] == "M": + if data[i - 1, j - 1] == "S": + return True + return False + + +def is_diagonal_backwards_downwards(data, i, j): + if data[i - 1, j + 1] == "M": + if data[i + 1, j - 1] == "S": + return True + return False + + +def is_diagonal1(data, i, j): + return is_diagonal_forwards_downwards(data, i, j) or is_diagonal_backwards_upwards( + data, i, j + ) + + +def is_diagonal2(data, i, j): + return is_diagonal_backwards_downwards(data, i, j) or is_diagonal_forwards_upwards( + data, i, j + ) + + +def find_x_mas(data: np.array): + score = 0 + for i in range(1, data.shape[0] - 1): + for j in range(1, data.shape[1] - 1): + if data[i, j] == "A": + if is_diagonal1(data, i, j) and is_diagonal2(data, i, j): + score += 1 + + return score + + +def main(file: Path) -> int: + data_str = file.read_text() + x_axis_size = data_str.find("\n") + y_axis_size = data_str.count("\n") + 1 + data = np.full((x_axis_size, y_axis_size), "") + i, j = 0, 0 + for character in data_str: + if character != "\n": + data[i, j] = character + j += 1 + else: + i += 1 + j = 0 + return find_x_mas(data) + + +if __name__ == "__main__": + input_file = Path(__file__).parent / "input-data" + print(main(input_file)) diff --git a/tests/aoc_2024/day4/test_part2.py b/tests/aoc_2024/day4/test_part2.py new file mode 100644 index 0000000..b48876b --- /dev/null +++ b/tests/aoc_2024/day4/test_part2.py @@ -0,0 +1,33 @@ +from pathlib import Path + +from aoc_2024.day4 import part2 + +import numpy as np + + +def test_find_all(): + data_str = """MMMSXXMASM +MSAMXMSMSA +AMXSXMAAMM +MSAMASMSMX +XMASAMXAMM +XXAMMXXAMA +SMSMSASXSS +SAXAMASAAA +MAMMMXMMMM +MXMXAXMASX""" + data = np.full((10, 10), "") + i, j = 0, 0 + for character in data_str: + if character != "\n": + data[i, j] = character + j += 1 + else: + i += 1 + j = 0 + assert part2.find_x_mas(data) == 9 + + +def test_main(): + data_file = Path(__file__).parent / "test-data" + assert part2.main(data_file) == 9