from math import floor
from collections import namedtuple

Rect = namedtuple('Rect', ('x1', 'y1', 'x2', 'y2'))

class SpatialHash(object):
    def __init__(self, cell_size=10.0):
        self.cell_size = float(cell_size)
        self.d = {}

    def _add(self, cell_coord, o):
        """Add the object o to the cell at cell_coord."""
        try:
            self.d.setdefault(cell_coord, set()).add(o)
        except KeyError:
            self.d[cell_coord] = set((o,))

    def _remove(self, cell_coord, o):
        """Remove the object o from the cell at cell_coord."""
        cell = self.d[cell_coord]
        cell.remove(o)

        # Delete the cell from the hash if it is empty.
        if not cell:
            del(self.d[cell_coord])

    def _cells_for_rect(self, r):
        """Return a set of the cells into which r extends."""
        cells = set()
        cy = floor(r.y1 / self.cell_size)
        while (cy * self.cell_size) <= r.y2:
            cx = floor(r.x1 / self.cell_size)
            while (cx * self.cell_size) <= r.x2:
                cells.add((int(cx), int(cy)))
                cx += 1.0
            cy += 1.0
        return cells

    def add_rect(self, r, obj):
        """Add an object obj with bounds r."""
        cells = self._cells_for_rect(r)
        for c in cells:
            self._add(c, obj)

    def remove_rect(self, r, obj):
        """Remove an object obj which had bounds r."""
        cells = self._cells_for_rect(r)
        for c in cells:
            self._remove(c, obj)

    def potential_collisions(self, r, obj):
        """Get a set of all objects that potentially intersect obj."""
        cells = self._cells_for_rect(r)
        potentials = set()
        for c in cells:
            potentials.update(self.d.get(c, set()))
        potentials.discard(obj) # obj cannot intersect itself
        return potentials


def test_cells_for_rect():
    h = SpatialHash()
    cells = h._cells_for_rect(Rect(1, 2, 9, 12))
    assert cells == set([(0, 0), (0, 1)]), cells
    
    r = Rect(7, 15, 13, 19)
    cells = h._cells_for_rect(r)
    assert cells == set([(0, 1), (1, 1)]), cells

def test_add():
    h = SpatialHash()
    h.add_rect(Rect(1, 2, 3, 4), 'foo')
    assert 'foo' in h.d[(0, 0)]

def test_add_spanning():
    h = SpatialHash()
    h.add_rect(Rect(-1, 9, 2, 12), 'foo')
    assert 'foo' in h.d[(0, 0)]
    assert 'foo' in h.d[(0, 1)]
    assert 'foo' in h.d[(-1, 0)]
    assert 'foo' in h.d[(-1, 1)]

def test_remove():
    h = SpatialHash()
    r = Rect(1, 2, 7, 12)
    h.add_rect(r, 'foo')
    h.remove_rect(r, 'foo')
    assert (0, 0) not in h.d
    assert (0, 1) not in h.d

def test_collide():
    h = SpatialHash()
    h.add_rect(Rect(3, 8, 4, 11), 'foo')

    r = Rect(7, 15, 13, 19)
    h.add_rect(r, 'bar')
    print h.d
    collisions = h.potential_collisions(r, 'bar') 
    assert collisions == set(['foo']), collisions

