#
# Solve sudoku puzzle
#

import unittest


class Cell(object):
    def __init__(self, row_idx, col_idx):
        self.row_idx = row_idx
        self.col_idx = col_idx
        self.row   = None
        self.col   = None
        self.box   = None
        self.value = None
        self.candidates = set(range(1,10))

    def __str__(self):
        return "." if self.value is None  else "%d"%self.value

    def description(self):
        return "cell(%d,%d)" % (self.row_idx+1,self.col_idx+1)


#
# Group stores 9 cells for a row, column, or a box
#
class Group(list):
    def __init__(self, name, row_idx, col_idx):
        list.__init__(self, [None]*9)
        self.name = name
        self.row_idx = row_idx
        self.col_idx = col_idx


class Row(Group):
    def __init__(self, row_idx):
        Group.__init__(self, "row", row_idx, 0)

    def description(self):
        return "row %d" % (self.row_idx+1)

class Column(Group):
    def __init__(self, col_idx):
        Group.__init__(self, "col", 0, col_idx)

    def description(self):
        return "col %d" % (self.col_idx+1)

class Box(Group):
    def __init__(self, row_idx, col_idx):
        Group.__init__(self, "box", row_idx, col_idx)

    def description(self):
        return "box(%d,%d)" % (self.row_idx+1, self.col_idx+1)



class FillIn():
    def __init__(self, cell, number):
        self.cell = cell
        self.number = number



class Board(object):
    def __init__(self):
        self.verbose = False
        self.rows    = [ Row(i)    for i in range(9) ]
        self.cols    = [ Column(j) for j in range(9) ]
        self.boxes = Group("boxes",0,0)

        for i in range(3):
            for j in range (3):
                self.boxes[3*i+j] = Box(i,j)

        for i in range(9):
            for j in range(9):
                box = self.boxes[(i/3)*3+(j/3)]
                cell = Cell(i,j)
                cell.box = box
                cell.row = self.rows[i]
                cell.col = self.cols[j]
                self.rows[i][j] = self.cols[j][i] = box[(i%3)*3+(j%3)] = cell
        
        # self.init_boxes()
        self.groups = [ box for box in self.boxes ] + self.rows + self.cols

        self.rules = [
            ("cell has 1 candidate",        self.cell_has_1_candidate),
            ("candidate appears once",      self.candidate_appears_once),
            ("box candidates in row/col",   self.box_candidates_in_single_row_or_column),
            ("N candidates in N cells",     self.full_multi_cells )
        ]

    #
    # Run all rules in order. The rules are ordered from the simplest to
    # the most complex.
    # If a rule makes a change to the board (fills a cell or removes candidates),
    # restart the rules from the start.
    #
    def solve(self):
        self.remove_candidates()
        self.changed = True
        while self.changed:
            self.changed = False
            self.log(self)
            self.log("new round of rules")
            for (rule_name, rule) in self.rules:
                self.log("applying rule: %s" % rule_name )
                fill_ins = rule()
                self.log("solved %d cells" % len(fill_ins))
                if len(fill_ins) > 0:
                    for fill_in in fill_ins:
                        self.solve_cell(fill_in.cell, fill_in.number)
                if self.changed:
                    break
        self.log("done")


    #
    # Rule: if a cell has only one remaining candidate,
    #       then that candidate is the solution for the cell
    #
    def cell_has_1_candidate(self):
        fill_ins = []
        for row in self.rows:
            for cell in row:
                if cell.value is not None : continue
                if len(cell.candidates) == 1:
                    number = cell.candidates.pop()
                    self.log("%s has only one candidate #%d" %
                             (cell.description(),number))
                    fill_ins.append( FillIn(cell=cell, number=number) )
        return fill_ins

    #
    # Rule: if a candidate appears only once in a row (or column or box),
    #       then that candidate is the solution for its cell
    #
    def candidate_appears_once(self):
        fill_ins = []
        for group in self.groups:
            fills = self.candidate_appears_once_in_group(group)
            # fill_ins.extend( fills )
            if len(fills) > 0:      # stop searching to apply simpler rules
                return fills;
        return fill_ins

    def candidate_appears_once_in_group(self, group):
        map = self.candidates_to_cells_map(group)
        fill_ins = [ FillIn(cell=map[i][0], number=i)
                        for i in range(1,10)
                        if len(map[i]) == 1 ]
        if self.verbose:
            for fillin in fill_ins:
                self.log("Candidate %d at %s appears only once in %s" %
                         (fillin.number, fillin.cell.description(), group.description()) )
        return fill_ins

    def candidates_to_cells_map(self, group):    
        map = [ [] for i in range(10) ] # item 0 not used
        for cell in group:
            if cell.value is None :
                for candidate in cell.candidates:
                    # candidates are 1-based
                    map[candidate].append(cell)  
        return map

    #
    # Rule: If a candidate number appears more than once in one box,
    #       but all the appearances belong to the same row,
    #       then that number cannot appear anywere else along the same row
    #       except in the same box.
    #       The same holds for candidates in the same column
    def box_candidates_in_single_row_or_column(self):
        for box in self.boxes:
             self.box_candidates(box)
        return []

    def box_candidates(self,box):
        map = self.candidates_to_cells_map(box)
        for num in range(1,10) :
            cells = map[num]
            if len(cells) < 2 : continue
            row = self.in_same_row( cells )
            if row is not None :
                removed = self.remove_other_candidates( number=num, remove=row, keep=cells )
                self.log_row_description( box, num, len(cells), row, removed )
                continue
            col = self.in_same_col( cells )
            if col is not None :
                removed = self.remove_other_candidates( number=num, remove=col, keep=cells )
                self.log_col_description( box, num, len(cells), col, removed )

    def in_same_row(self,cells):
        row = cells[0].row
        for cell in cells:
            if cell.row != row:
                return None
        return row

    def in_same_col(self,cells):
        col = cells[0].col
        for cell in cells:
            if cell.col != col:
                return None
        return col

    def remove_other_candidates(self, number, remove, keep ):
        removed = []
        for cell in remove:
            if cell.value is None  and  cell not in keep:
                if number in cell.candidates :
                    cell.candidates.remove( number )
                    removed.append( cell )
                    self.changed = True
        return removed

    def log_row_description(self, box, num, times, row, removed ):
        if( self.verbose and len(removed) > 0 ):
            self.log( ( "Candidate %d appears %d times in %s, but only in %s. " +
                        "Removed it from column(s) %s in the row." ) %
                      ( num, times, box.description(), row.description(),
                        ",".join([str(cell.col_idx+1) for cell in removed]) ) )

    def log_col_description(self, box, num, times, col, removed ):
        if( self.verbose and len(removed) > 0 ):
            self.log( ( "Candidate %d appears %d times in %s, but only in %s. " +
                        "Removed it from rows(s) %s in the column." ) %
                      ( num, times, box.description(), col.description(),
                        ",".join([str(cell.row_idx+1) for cell in removed]) ) )

    #
    # Rule: Find two numbers that appear in exactly two cells within a row,
    # column, or a box. No other numbers have space to share these two cells.
    # Eliminate them (the "stragglers").
    # Continue finding tree numbers in three cells, and so on.
    #
    def full_multi_cells(self):
        for n in range(2,10):
            for group in self.groups:
                (cells, numbers) = self.find_n_numbers_in_n_cells( group, n )
                if len(cells) > 0:
                    removed = self.remove_stragglers( numbers=numbers, cells=cells )
                    self.log_multi_cells_description(n,numbers,cells,removed,group)
                    if self.changed:
                        return  [] # Stop searching to re-run easier rules
        return []

    # Remove all candidates except the ones listed in 'numbers'
    def remove_stragglers(self, numbers, cells ):
        removed = {} # hash of cell:[array of removed candidates from the cell]
        for cell in cells:
            for candidate in cell.candidates.copy():
                if candidate not in numbers:
                    if cell not in removed:
                        removed[cell] = []
                    removed[cell].append(candidate)
                    cell.candidates.remove( candidate )
                    self.changed = True
        return removed

    def find_n_numbers_in_n_cells(self, group, n):
        map = self.candidates_to_cells_map(group)
        numbers_appearing_n_times = filter( lambda i : len(map[i]) == n, range(1,10) )
        if len( numbers_appearing_n_times ) < n :
            return ( [], [] )
        if (     len( numbers_appearing_n_times ) == n
             and self.all_cell_sets_equal( numbers_appearing_n_times, map ) ):
            return ( map[numbers_appearing_n_times[0]],
                     numbers_appearing_n_times )
        #
        # If there are four numbers that appear in two cells each,
        # We will check all possible pairs
        for subset_of_n in self.all_subsets( set(numbers_appearing_n_times), n ):
            if self.all_cell_sets_equal( subset_of_n, map ):
                for i in subset_of_n:   # Just get a random number from subset_of_n
                    return (map[i],subset_of_n)
        return ( [], [] )

    def all_subsets(self,a_set,n):    # Yields all subsets of length n
        if len(a_set) < n:
            return
        if len(a_set) ==n:
            yield a_set
            return
        if n == 1:
            for elt in a_set:
                yield set([elt])
        else:
            c = a_set.copy()
            while len(c) > 0 :
                elt = c.pop()
                for s in self.all_subsets( c, n-1 ):
                    yield s.union([elt])
            
    def all_cell_sets_equal(self, numbers, numbers_to_cells_map):
        set0 = None
        for i in numbers:
            if set0 is None:
                set0 = set(numbers_to_cells_map[i])
            elif set0 != set(numbers_to_cells_map[i]):
                return False
        return True

    def log_multi_cells_description(self,n,numbers,cells,removed,group):
        if( len(removed)>0  and  self.verbose ):
            self.log(self.full_multi_cells_description(n,numbers,cells,removed,group))

    def full_multi_cells_description(self,n,numbers,cells,removed,group):
        numbers_text = ",".join( str(number) for number in numbers )
        cells_text = ",".join( cell.description() for cell in cells )
        removed_text=", ".join([ "%s from %s" %
                                ( ",".join([str(num) for num in removed[cell]]) ,
                                  cell.description())
                                        for cell in removed ])
        return(("In %s, %d candidates, %s, appear in %d cells, %s. "+
                "Removed stragglers %s.") %
               (group.description(),n,numbers_text, n, cells_text, removed_text) )

    #
    #     Utilities
    #

    #
    # Fill a number into a cell, then remove that number from possible
    # candidates in the same row, column, and box
    #
    def solve_cell(self, cell, value):
        self.log("solve %s = %d" % (cell.description(),value))
        self.changed = True
        cell.value = value
        self.remove_candidates_for_cell( cell )

    def remove_candidates(self):
        fill_ins = []
        for row in self.rows:
            for cell in row:
                if cell.value is None : continue
                fill_ins.extend( self.remove_candidates_for_cell(cell) )
        return fill_ins

    def remove_candidates_for_cell(self, cell):
        self.remove_candidates_in_group( cell.value, cell.row )
        self.remove_candidates_in_group( cell.value, cell.col )
        self.remove_candidates_in_group( cell.value, cell.box )
        return []

    def remove_candidates_in_group( self, value, group ):
        for cell in group:
            if cell.value is not None : continue
            if value in cell.candidates:
                cell.candidates.remove(value)
                self.changed = True

    def __str__(self):
        return "\n".join([ self.row_to_string(row) + ("","\n")[row[0].row_idx%3==2] for row in self.rows ])

    def row_to_string(self,row):
        return "".join([ cell.__str__() + (""," ")[cell.col_idx%3==2] for cell in row ])


    def detailed_description(self):
        box_line  = " +" + ("-"*17 + "+")*3
        cell_line = " |" + (" "*17 + "|")*3
        lines = []

        for box_x in range(3):
            for cell_x in range(3*box_x, 3*box_x+3):
                lines.append((cell_line, box_line)[cell_x % 3 == 0])
                for x in range(3):
                    line = ""
                    for box_y in range(3):
                        for cell_y in range(3*box_y, 3*box_y+3):
                            line+=("   "," | ")[cell_y%3 == 0]
                            for y in range(3):
                                cell = self.rows[cell_x][cell_y]
                                if cell.value is not None:
                                    line+= (" ",str(cell.value))[x == 1 and y == 1]
                                else:
                                    candidate = 3*x+y+1
                                    line+=(".",str(candidate))[candidate in cell.candidates]
                    line +=" |"
                    lines.append(line)
        lines.append(box_line)
        return "\n".join(lines)
                

    def init_from_string(self, str):
        row = 0
        col = 0
        for char in str:
            if( char == "." ):
                col+=1
                continue
            if( char >= "0" and char <= "9" ):
                if col >= 9:
                    raise Exception("Expected 9 elements in row %d, but got %d" % (row+1,col+1))
                number = ord(char)-ord("0")
                self.solve_cell( self.rows[row][col], number )
                col+=1
                continue
            if( char == "\n" ):
                if col == 0:    # Ignore empty lines
                    continue
                if col != 9:
                    raise Exception("Expected 9 elements in row %d, but got %d" % (row+1,col+1))
                col=0
                row+=1
                if( row == 9 ):
                    break
        if( row != 9 or col != 0 ):
            raise Exception("Incomplete board: %d rows, %d columns" % (row+1,col+1))

    def log(self, str):
        if self.verbose:
            print str

#------------------------------------------------------------------------------
#
#    Unit Tests
#
#------------------------------------------------------------------------------

class TestBoard(unittest.TestCase):

    def setUp(self):
        self.board = Board()

    def test_remove_candidates_in_group(self):
        group = [ Cell(0,1), Cell(0,2), Cell(0,3) ]
        self.assertEqual( group[0].candidates, set(range(1,10)) )
        self.assertEqual( group[1].candidates, set(range(1,10)) )
        self.assertEqual( group[2].candidates, set(range(1,10)) )

        self.board.remove_candidates_in_group( 3, group )

        self.assertEqual( group[0].candidates, set([1,2,  4,5,6,7,8,9]) )
        self.assertEqual( group[1].candidates, set([1,2,  4,5,6,7,8,9]) )
        self.assertEqual( group[2].candidates, set([1,2,  4,5,6,7,8,9]) )

        self.board.remove_candidates_in_group( 3, group )
        self.assertEqual( group[0].candidates, set([1,2,  4,5,6,7,8,9]) )

        self.board.remove_candidates_in_group( 5, group )
        self.assertEqual( group[0].candidates, set([1,2,  4,  6,7,8,9]) )

    def test_candidate_appears_once_in_group(self):
        group = [ Cell(0,1), Cell(0,2), Cell(0,3), Cell(0,4) ]
        group[3].value = 7
        group[0].candidates &= set([1,  3,4])
        group[1].candidates &= set([  2,3  ]) # <- 2 is the only value that appears only once
        group[2].candidates &= set([1,  3,4])
        group[3].candidates &= set([1,2,3,4]) # <- this does not count because value is set

        fill_ins = self.board.candidate_appears_once_in_group( group )

        self.assertEqual( len(fill_ins), 1 )
        fill_in = fill_ins[0]
        self.assertEqual( (fill_in.cell, fill_in.number), ( group[1], 2 ) )

    def test_in_same_row(self):
        board = self.board
        rows  = board.rows
        row = self.board.in_same_row( [ rows[3][1], rows[3][5], rows[3][8] ])
        self.assertEqual( row, rows[3] )
        row = self.board.in_same_row( [ rows[3][1], rows[6][5], rows[3][8] ])
        self.assertEqual( row, None )
        row = self.board.in_same_row( [ rows[4][1] ])
        self.assertEqual( row, rows[4] )

    def test_all_cell_sets_equal(self):
        rows = self.board.rows
        c1, c2, c3 = rows[3][1], rows[7][5], rows[3][8]
        numbers = [1,2,3]
        numbers_to_cells_map = { 1: [c1,c2,c3], 2: [c2,c3,c1], 3: [c3,c1,c2] }
        eq = self.board.all_cell_sets_equal( numbers, numbers_to_cells_map )
        self.assertEqual( eq, True )

        numbers_to_cells_map = { 1: [c1,c2,c3], 2: [c2,c3,c1], 3: [c3,c2] }
        eq = self.board.all_cell_sets_equal( numbers, numbers_to_cells_map )
        self.assertEqual( eq, False )

    def test_candidates_to_cells_map(self):
        cell1 = Cell(0,1); cell1.candidates = [1]
        cell2 = Cell(0,2); cell2.candidates = [5,6,7]
        cell3 = Cell(0,3); cell3.candidates = [1,2,3]
        cell4 = Cell(0,4); cell4.candidates = [2,4]
        cell5 = Cell(0,5); cell5.candidates = [1,2,5,6,7]
        cell6 = Cell(0,6); cell6.candidates = [7]
        group = [cell1,cell2,cell3,cell4,cell5,cell6]

        expected_map = [[],
                        [cell1,cell3,cell5], #1
                        [cell3,cell4,cell5], #2
                        [cell3], #3
                        [cell4], #4
                        [cell2,cell5], #5
                        [cell2,cell5], #6
                        [cell2,cell5,cell6], #7
                        [], #8
                        [], #9
                        ]

        map = self.board.candidates_to_cells_map( group )
        self.assertEqual( expected_map, map )        


    def test_find_n_numbers_in_n_cells_1(self):
        cell1 = Cell(0,1); cell1.candidates = [1]
        cell2 = Cell(0,2); cell2.candidates = [5,6,7]       # <-- 5 and 6
        cell3 = Cell(0,3); cell3.candidates = [1,2,3]
        cell4 = Cell(0,4); cell4.candidates = [2,4]
        cell5 = Cell(0,5); cell5.candidates = [1,2,5,6,7]   # <-- 5 and 6
        cell6 = Cell(0,6); cell6.candidates = [7]
        group = [cell1,cell2,cell3,cell4,cell5,cell6]

        cells, numbers = self.board.find_n_numbers_in_n_cells( group, 2 )

        self.assertEqual( set(numbers), set([5,6]) )
        self.assertEqual( set(cells),   set([cell2,cell5]) )

    def test_find_n_numbers_in_n_cells_2(self):
        cell1 = Cell(0,1); cell1.candidates = [1]           # #1 appears twice
        cell2 = Cell(0,2); cell2.candidates = [5,6,7]       # <-- 5 and 6
        cell3 = Cell(0,3); cell3.candidates = [2,3]
        cell4 = Cell(0,4); cell4.candidates = [2,4]
        cell5 = Cell(0,5); cell5.candidates = [1,2,5,6,7]   # <-- 5 and 6
        cell6 = Cell(0,6); cell6.candidates = [7]
        group = [cell1,cell2,cell3,cell4,cell5,cell6]

        cells, numbers = self.board.find_n_numbers_in_n_cells( group, 2 )

        self.assertEqual( set(numbers), set([5,6]) )
        self.assertEqual( set(cells),   set([cell2,cell5]) )

    def test_find_n_numbers_in_n_cells_3(self):
        cell1 = Cell(0,1); cell1.candidates = [1]           # 1 appears twice
        cell2 = Cell(0,2); cell2.candidates = [5,7]         # 5 appears twice
        cell3 = Cell(0,3); cell3.candidates = [2,3]
        cell4 = Cell(0,4); cell4.candidates = [2,4]
        cell5 = Cell(0,5); cell5.candidates = [1,2,5,7]     # 5 appears twice
        cell6 = Cell(0,6); cell6.candidates = [7]
        group = [cell1,cell2,cell3,cell4,cell5,cell6]

        cells, numbers = self.board.find_n_numbers_in_n_cells( group, 2 )

        self.assertEqual( numbers, [] )
        self.assertEqual( cells,   [] )

    def test_all_subsets(self):
        subsets = [ subset for subset in self.board.all_subsets( set([99]), 1 ) ]
        self.assertEqual( subsets, [set([99])] )

        subsets = [ subset for subset in self.board.all_subsets( set([1,2,3]), 1 ) ]
        self.assertEqual( subsets, [set([1]), set([2]), set([3])] )

        subsets = [ subset for subset in self.board.all_subsets( set([1,2,3]), 3 ) ]
        self.assertEqual( subsets, [set([1,2,3])] )

        subsets = [ subset for subset in self.board.all_subsets( set([1,2,3]), 2 ) ]
        self.assertEqual( subsets, [set([1,2]), set([1,3]), set([2,3])] )

    def test_remove_stragglers(self):
        cell1 = Cell(0,1); cell1.candidates = set([1])           # #1 appears twice
        cell2 = Cell(0,2); cell2.candidates = set([5,6,7])       # <-- 5 and 6
        cell3 = Cell(0,3); cell3.candidates = set([2,3])
        cell4 = Cell(0,4); cell4.candidates = set([2,4])
        cell5 = Cell(0,5); cell5.candidates = set([1,2,5,6,7])   # <-- 5 and 6
        cell6 = Cell(0,6); cell6.candidates = set([7])         
        group = [cell1,cell2,cell3,cell4,cell5,cell6]
        cells = [cell2, cell5]
        numbers = [5,6]
        self.board.remove_stragglers( numbers=numbers, cells=cells )
        self.assertEqual(cell1.candidates, set([1]))
        self.assertEqual(cell2.candidates, set([5,6]))           # <--- 7 removed
        self.assertEqual(cell3.candidates, set([2,3]))
        self.assertEqual(cell4.candidates, set([2,4]))
        self.assertEqual(cell5.candidates, set([5,6]))           # <--- 1,2,7 removed
        self.assertEqual(cell6.candidates, set([7]))

    def test_remove_other_candidates(self):
        cell1 = Cell(0,1); cell1.candidates = set([1])           
        cell2 = Cell(0,2); cell2.candidates = set([5,6,7])       
        cell3 = Cell(0,3); cell3.candidates = set([2,3])        # 2, keep this
        cell4 = Cell(0,4); cell4.candidates = set([2,4])        # 2, keep this
        cell5 = Cell(0,5); cell5.candidates = set([1,2,5,6,7])  # 2, remove this
        cell6 = Cell(0,6); cell6.candidates = set([7])         
        group = [cell1,cell2,cell3,cell4,cell5,cell6]
        cells = [cell3, cell4]
        self.board.remove_other_candidates( number=2, remove=group, keep=cells )
        self.assertEqual(cell1.candidates, set([1]))
        self.assertEqual(cell2.candidates, set([5,6,7]))
        self.assertEqual(cell3.candidates, set([2,3]))          # kept 2
        self.assertEqual(cell4.candidates, set([2,4]))          # kept 2
        self.assertEqual(cell5.candidates, set([1,5,6,7]))      # <--- 2 removed
        self.assertEqual(cell6.candidates, set([7]))

    def test_full_multi_cells_description(self):
        n = 2
        numbers = [2,5]
        cell1 = self.board.rows[2][4]
        cell2 = self.board.rows[2][5]
        cells = [ cell1, cell2 ]
        removed = { cell1 : [1,3,4], cell2 : [7] }
        group = self.board.boxes[1]
        result = self.board.full_multi_cells_description(n,numbers,cells,removed,group)
        expected = ( "In box(1,2), 2 candidates, 2,5, appear in 2 cells, cell(3,5),cell(3,6). " +
                     "Removed stragglers 1,3,4 from cell(3,5), 7 from cell(3,6)." )
        self.assertEqual(expected, result)

#
#
#
def solveSudoku( puzzle ):
    board = Board()
    board.verbose = True
    board.init_from_string( puzzle )
    print
    print "Puzzle is:\n", board
    
    board.solve()

    print "Solution is:\n", board
    return board



if __name__ == "__main__":

    puzzle = """
    .5...12.9
    .6.29....
    4...6....
    8.....4.1
    .73......
    ..4......
    ......6.4
    ..5......
    7...8...2
    """
    hard = """
    ..9 6.. ..8
    ... ..7 .53
    3.2 8.5 ..9
    .2. 3.. 5..
    9.. ... ..2
    ..8 ..2 .4.
    8.. 2.6 9.4
    65. 1.. ...
    2.. ..3 6..
    """


    easy = """
    .1. 5.. .47
    ... ... 98.
    ... 2.6 ..3
    
    .38 .2. 6.1
    2.. .5. ..9
    9.1 .3. 47.
    
    1.. 8.2 ...
    .97 ... ...
    32. ..5 .6.
    """

    hard2 = """
    .9. .26 ...
    183 9.. ...
    7.2 1.. 4..

    ... ..5 ..4
    .18 ... 65.
    2.. 8.. ...

    ..1 ..3 7.2
    ... ..9 361
    ... 27. .4.
    """
    # hard2="_9__26___+1839____6+7621_84__+_7___5__4+_18__265_+2__8_____+__1__37_2+_27__9361+_3_271_4_"

    run_unit_tests = False

    if( run_unit_tests ):  
        unittest.main()
    else:
        solution  = solveSudoku( hard2 )