Tuesday, 15 March 2022

Slow dnf to cnf in pycosat

Question in short

To have a proper input for pycosat, is there a way to speed up calculation from dnf to cnf, or to circumvent it altogether?

Question in detail

I have been watching this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers in it. Below I share the code to do so.

Example Tower puzzle (solved):

    3 3 2 1    
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
    1 2 3 2    

The problem I encounter is that the conversion from dnf to cnf takes forever. Let's say that you know there are 3 towers visible from a certain line of sight. This leads to 35 possible permutations 1-5 in that row.

[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
 ('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
 ...
 ('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
 ('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]

This is a disjunctive normal form: an OR of several AND statements. This needs to be converted into a conjunctive normal form: an AND of several OR statements. This is however very slow. On my Macbook Pro, it didn't finish calculating this cnf after 5 minutes for a single row. For the entire puzzle, this should be done up to 20 times (for a 5x5 grid).

What would be the best way to optimize this code, in order to make the computer able to solve this Towers puzzle?

This code is also available from this Github repository.

import string

import itertools
from sys import intern
from typing import Collection, Dict, List

from sat_utils import basic_fact, from_dnf, one_of, solve_one

Point = str


def comb(point: Point, value: int) -> str:
    """
    Format a fact (a value assigned to a given point), and store it into the interned strings table

    :param point: Point on the grid, characterized by two letters, e.g. AB
    :param value: Value of the cell on that point, e.g. 2
    :return: Fact string 'AB 2'
    """

    return intern(f'{point} {value}')


def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
    """
    Return how many towers are visible from the given line

    >>> visible_from_line([1, 2, 3, 4])
    4
    >>> visible_from_line([1, 4, 3, 2])
    2
    """

    visible = 0
    highest_seen = 0
    for number in reversed(line) if reverse else line:
        if number > highest_seen:
            visible += 1
            highest_seen = number
    return visible


class TowersPuzzle:
    def __init__(self):
        self.visible_from_top = [3, 3, 2, 1]
        self.visible_from_bottom = [1, 2, 3, 2]
        self.visible_from_left = [3, 3, 2, 1]
        self.visible_from_right = [1, 2, 3, 2]
        self.given_numbers = {'AC': 3}

        # self.visible_from_top = [3, 2, 1, 4, 2]
        # self.visible_from_bottom = [2, 2, 4, 1, 2]
        # self.visible_from_left = [3, 2, 3, 1, 3]
        # self.visible_from_right = [2, 2, 1, 3, 2]

        self._cnf = None
        self._solution = None

    def display_puzzle(self):
        print('*** Puzzle ***')
        self._display(self.given_numbers)

    def display_solution(self):
        print('*** Solution ***')
        point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
        self._display(point_to_value)

    @property
    def n(self) -> int:
        """
        :return: Size of the grid
        """

        return len(self.visible_from_top)

    @property
    def points(self) -> List[Point]:
        return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]

    @property
    def rows(self) -> List[List[Point]]:
        """
        :return: Points, grouped per row
        """

        return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]

    @property
    def cols(self) -> List[List[Point]]:
        """
        :return: Points, grouped per column
        """

        return [self.points[i::self.n] for i in range(self.n)]

    @property
    def values(self) -> List[int]:
        return list(range(1, self.n + 1))

    @property
    def cnf(self):
        if self._cnf is None:
            cnf = []

            # Each point assigned exactly one value
            for point in self.points:
                cnf += one_of(comb(point, value) for value in self.values)

            # Each value gets assigned to exactly one point in each row
            for row in self.rows:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in row)

            # Each value gets assigned to exactly one point in each col
            for col in self.cols:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in col)

            # Set visible from left
            if self.visible_from_left:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_left[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from right
            if self.visible_from_right:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_right[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from top
            if self.visible_from_top:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_top[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from bottom
            if self.visible_from_bottom:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_bottom[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set given numbers
            for point, value in self.given_numbers.items():
                cnf += basic_fact(comb(point, value))

            self._cnf = cnf

        return self._cnf

    @property
    def solution(self):
        if self._solution is None:
            self._solution = solve_one(self.cnf)
        return self._solution

    def _display(self, facts: Dict[Point, int]):
        top_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + '    '
        print(top_line)
        print('-' * len(top_line))
        for index, row in enumerate(self.rows):
            elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
                    [str(facts.get(point, ' ')) for point in row] + \
                    ['|', str(self.visible_from_right[index]) or ' ']
            print(' '.join(elems))
        print('-' * len(top_line))
        bottom_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + '    '
        print(bottom_line)
        print()


if __name__ == '__main__':
    puzzle = TowersPuzzle()
    puzzle.display_puzzle()
    puzzle.display_solution()


from Slow dnf to cnf in pycosat

No comments:

Post a Comment