from numpy import *
import ImageFilter
import numpy
import Image

#-- replace each element of a 2D numpy array with the median of its 
#   neighbors using the PIL median filter
def medfilt(img,smooth_size=5):
  filt = Image.fromarray(img.astype('uint8')).filter(ImageFilter.MedianFilter(smooth_size))
  return numpy.asarray(filt).astype('float32')


#-- dilates a 2D numpy array holding a binary image
def dilate(a,win=1):
    r = zeros(a.shape)
    [yy, xx] = where(a>0)

    # prepare neighboroods
    off = tile(range(-win,win+1),(2*win+1,1))
    x_off = off.flatten(1)
    y_off = off.T.flatten(1)

    # duplicate each neighborhood element for each index
    n = len(xx.flat)
    x_off = tile(x_off, (n,1)).flatten(1)
    y_off = tile(y_off, (n,1)).flatten(1)
    # round out offset
    ind = sqrt(x_off**2 + y_off**2) > win
    x_off[ind] = 0
    y_off[ind] = 0

    # duplicate each index for each neighborhood element
    xx = tile(xx, ((2*win+1)**2))
    yy = tile(yy, ((2*win+1)**2))

    nx = xx + x_off
    ny = yy + y_off

    # bounds checking
    ny[ny<0]=0
    ny[ny>a.shape[0]-1]=a.shape[0]-1
    nx[nx<0]=0
    nx[nx>a.shape[1]-1]=a.shape[1]-1
    r[ny,nx]=1;
    return r
