#! /Library/Frameworks/Python.framework/Versions/Current/bin/python

"""
Names:  James G. Malcolm
        Shawn M. Lankton
        Jesus H. Christ

Project Number: 2

Notes:

  To view intermediate representations (labelmaps, clusters, etc.)
  set the following variable to True
"""

INTERMEDIATE_SEQUENCES = False

"""
Summary of Algorithm:

  :First Frame:
    Detect Skyline & Ignore it
      - Find row in image where most bright pixels are above
      - Pixels above this row will be ignored from now on
    Detect Hot Pixels
      - Threshold to find very-bright pixels
      - Threshold to find nearby less-bright pixels
      - Hot pixels are less-bright pixels near very-bright pixels
    Find Clusters
      - Count hot pixels in local windows
      - Clusters are connected components with high density of hot pixels
    Initialize Planes
      - Initialize one plane per detected cluster
      - Planes inital position and bounding box match their cluster

  :For Each Subsequent Frame:
    Detect Hot Pixels (as before)
    Find Clusters (as before)
    Associate Planes with Clusters
      - Find distance from each precicted centroid to each cluster centroid
      - Associate closest nearby cluster to the plane
      - Planes with no corresponding clusters
         - are propogated based on velocity information only
         - have their "confidence" decremented (planes with 0 confidence are killed)
      - Planes that share a cluster
         - are prpopogated based on velocity inforation only
         - have their "confidence" incrimented
      - Planes that are associated to unique clusters
         - update velocity and window size estimates with a recursive average
         - are propogated based on new velocity estimate
         - have their "confidence" incrimented

Insights:
  Problem: Sky and Plane have similar intensities
  Solution: We were able to account for the sky by finding the majority 
            of bright pixels and ignoring them.  This could cause trouble
            if a plane soars too high (like Icarus who flew too close to 
            the sun and then drowed in the Agean Sea.).

  Problem: High intensity pixels are scattered throughout the image
  Solution: Planes should be a large cluster of high intensity pixels.  Filter to
            suppress isolated pixels of high intensity, and then look for remaining
            blobs.

  Problem: Planes sometimes occlude each other in flight.
  Solution: We rely on estimated dynamics to continue tracking until the planes
            separate and can be uniquely identified again.  If the planes alter their
            trajectory while they occlude each other, it can be difficult to maintain 
            track.  In some cases, tracks will "switch."  This is most likely in
            sequences like #9 where the occluding planes take on the motion 
            charactaristics of the other during the occlusion.

  Problem: Flight dynamics were not well modeled as constant velocity
  Solution: Added "wind friction" and "gravity" to account for planes slowing and
            dropping.
 
  Problem: Some spurious objects are detected as planes (bright clothes, body parts, etc.)
  Solution: Maintain a "confidence" measure for each plane based on how many times we've
            seen it.  If this drops too low, stop tracking the object.  The non-plane
            objects tend to be detected less reliably than real planes.  

  Problem: Because a lot of the algorithm is based on analysis of the first frame, it
           there is a tendancy to over-tune parameters.
  Solution: We ran tests on with the sequences starting at frame 1,2,3,4,5, etc. to
            ensure that the algorithm was robust.

Interesting Facts:
  12 newborns will be given to the wrong parents daily.
  A cockroach can live several weeks with its head cut off.
  Hummingbirds typically weigh less than a nickel

Bibliography:
  We do all our own stunts.

"""

from FrameWork import *
from numpy import *
import ImageFilter
import numpy
import Image

h_thresh = 240         # used to detect bright pixels
l_thresh = 200
hysteresis_radius = 7

cluster_density = .08  # used to identify clusters of pixels
cluster_window = 7

gravity_factor = .2    # helps model dynamics
friction_factor = .05 

velocity_gamma = .5    # controls how fast velocity can change
win_gamma = .4         # controls how fast window size can change

conf_interval = 10     # how long before we give up on a plane

class Plane(object) :
    def __init__(self, pposn=None, bbox=None, state=None) :
        self.pposn = pposn # predicted posn
        self.bbox = bbox   # bounding box
        self.state = state # filter state
class State(object) :
    def __init__(self, pos=None, win=None) :
        self.pos = pos
        self.win = win
        self.vel = (0,0)
        self.conf = conf_interval/2

def Main(SeqName='20061129-01', start=1850, end=1859) :
    MyTrial = Trial(Update=False)
    # print_parameters(MyTrial) # Extra outputs for debugging
    sin = Seq(MyTrial, SeqName)
    sout = AirplaneSeq(MyTrial, 'result')
    sz = sin.Load_Frame(start).size
    
    # interesting intermediate outputs
    if INTERMEDIATE_SEQUENCES :
        labelmap = Seq(MyTrial,'INTERMEDIATE_labelmap')
        
    # determine skyline and number of planes from first frame
    img = numpy.asarray(sin.Load_Frame(start)).astype('float32')
    row = skyline_row(img)
    img = skyline_eliminate(img,row)
    img = mean(img,axis=2)
    m = detect_white(img, l_thresh, h_thresh, hysteresis_radius)
    m = cluster_filter(m, cluster_window, cluster_density)
    cc,lmap = connected_components(m);
    ps = list() # Plane objects: predicted position, box, state
    for c in cc :
        (x,y),(xwin,ywin) = c
        b = (x-xwin,y-ywin,x+xwin,y+ywin)
        s = State((x,y), (xwin,ywin))
        ps.append(Plane((x,y), b, s))

    for fn in range(start, end+1) :
        # print 'on %(current)d of %(total)d.' %{"current": fn-start, "total": end-start}

        # kill skyline
        img = numpy.asarray(sin.Load_Frame(fn)).astype('float32')
        img_ = skyline_eliminate(img,row)
        gimg = mean(img_,axis=2)

        # find planes
        m = detect_white(gimg,l_thresh,h_thresh,hysteresis_radius)
        m = cluster_filter(m,cluster_window,cluster_density)
        cc,lmap = connected_components(m);
        cc = correlate_components(cc, ps)
        cc = mark_duplicates(cc)
        ps = update_state(ps,cc)

        # output
        if INTERMEDIATE_SEQUENCES :
            labelmap.Store_Frame(fn,Image.fromarray((g2rgb(lmap*75)).astype('uint8')))
        bb,pp = get_track_data(ps)
        output = Image.fromarray(img.astype('uint8'))
        sout.Mark_Airplane_Positions(fn, output, bb, pp)
        sout.Store_Frame(fn,output)

    MyTrial.Done()

def print_parameters(t) :
    t.Print('')
    t.Print('===== PARAMETERS ===========================')
    t.Print('h_thresh = ' + str(h_thresh))
    t.Print('l_thresh = ' + str(l_thresh))
    t.Print('hysteresis_radius = ' + str(hysteresis_radius))
    t.Print('cluster_density = ' + str(cluster_density))
    t.Print('cluster_window = ' + str(cluster_window))
    t.Print('gravity_factor = ' + str(gravity_factor))
    t.Print('friction_factor = ' + str(friction_factor))
    t.Print('velocity_gamma = ' + str(velocity_gamma))
    t.Print('win_gamma = ' + str(win_gamma))
    t.Print('conf_interval = ' + str(conf_interval))
    t.Print('============================================')
    

def mark_duplicates(cc) :
    cc_ = list()
    for i in range(len(cc)) :
        is_duplicate = False
        for j in range(len(cc)) :
            is_duplicate |= (i != j) & (cc[i] == cc[j])
            
        c,win = cc[i]
        cc_.append((c,win,is_duplicate))
    return cc_
    

def filter_window(s,(xwin_cur,ywin_cur)) :
    (xwin,ywin) = s.win # unpack state

    gamma = win_gamma
    xwin = gamma*xwin + (1-gamma)*xwin_cur
    ywin = gamma*ywin + (1-gamma)*ywin_cur

    return s,(xwin,ywin)

def filter_position(s,(x_cur,y_cur)) :
   (x,y),(dx,dy) = s.pos,s.vel  # unpack state

   gamma = velocity_gamma

   # recursive average of velocity
   dx = (gamma*dx + (1-gamma)*(x_cur - x))*(1-friction_factor)
   dy = gamma*dy + (1-gamma)*(y_cur - y) + gravity_factor

   # predict location in next frame
   x = x + dx #was x_cur
   y = y + dy
 
   s.pos = (x,y)
   s.vel = (dx,dy)
   x_next = (x+dx, y+dy)
   return s, x_next

def get_track_data(planes) :
    bb = list()
    pp = list()
    for p in planes :
        bb.append(p.bbox)
        pp.append(p.pposn)
    return bb,pp

def update_state(planes, cc) :
    for i in range(len(planes)) :
        s = planes[i].state
        (cx,cy),win,isduplicate = cc[i]

        # if duplicate, then retain old info
        if isduplicate :
            (cx,cy) = planes[i].pposn
            win = s.win

#         print cx,cy

        # filter state
        s,x_next      = filter_position(s, (cx,cy))
        s,(xwin,ywin) = filter_window(s, win)
        
        b = (cx - xwin, cy - ywin, cx + xwin, cy + ywin)

        if(planes[i].state.conf == 0) :
            x_next = (-1, -1)
            b = (-1,-1,-1,-1)
            s.pos = (-1,-1)
            s.win = (0,0)
            s.vel = (0,0)

        planes[i].state = s
        planes[i].pposn = x_next
        planes[i].bbox = b
    return planes
        
def tuple_dist((x,y),(p,q)) :
    return sqrt((x-p)**2 + (y-q)**2)
def tuple_norm(x) :
    return tuple_dist(x,(0,0))
                   

def correlate_components(cc, planes) :
    cc_ = list()
    for p in planes :
        c,isfound = find_best_component(p, cc)
        cc_.append(c)
        if isfound :
            p.state.conf = min(p.state.conf + 1, conf_interval)
        else :
            p.state.conf = max(p.state.conf - 1, 0)

    return cc_

# cc and bb should correspond, pp may be in a different order
def find_best_component(p,cc) :
    min_d = tuple_norm(p.state.win)*2
    c_best = ( p.pposn, p.state.win ) # default: use last
    isfound = False
    for c in cc :
        px,py=p.pposn
        (cx,cy),win = c
        d = sqrt( (cx-px)**2 + (cy-py)**2 )
        if(d<min_d) : 
            min_d = d
            c_best = (cx,cy),win
            isfound = True
    return c_best,isfound

# get connected components, bounding boxes, and centroids
def connected_components(m) :
    lmap = zeros(m.shape)
    currentlabel = 1
    cc = list()
    ol = list()
    [yy,xx] = where(m)

    for i in range(len(xx)) :
        x,y = xx[i],yy[i]
        if lmap[y,x] : continue

        ymin,ymax = m.shape[0],0
        xmin,xmax = m.shape[1],0
        area,cx,cy = 0,0,0

        ol.append((x,y))
        while(len(ol)) :
            x,y = ol.pop()
            m[y,x] = 0
            if lmap[y,x] : continue
            lmap[y,x]=currentlabel

            #get XXX-treme points for bounding box
            if y < ymin : ymin = y
            if y > ymax : ymax = y
            if x < xmin : xmin = x
            if x > xmax : xmax = x
            
            #compute means for centroid
            cx+=x
            cy+=y
            area+=1
            
            #update list of open pixels
            if x>0 and m[y,x-1] : ol.append((x-1,y))
            if y>0 and m[y-1,x] : ol.append((x,y-1))
            if x<m.shape[1]-1 and m[y,x+1] : ol.append((x+1,y))
            if y<m.shape[0]-1 and m[y+1,x] : ol.append((x,y+1))

        currentlabel += 1
        cc.append(( (cx/area,cy/area), ((xmax-xmin)/2-1, (ymax-ymin)/2-1)))

    return cc,lmap

# thresholding with hysterisis on grayscale image
def detect_white(gimg,tlow,thigh,hysteresis_win) :
    m=zeros(gimg.shape)
    m[gimg>thigh]=1
    m = medfilt(m,3)
    m = dilate(m,hysteresis_win)
    m[gimg<tlow]=0
    return m

# finds row where all (bright) sky is above
def skyline_row(img) :
    eps = 6.6e4
    m = sum( (img - 100)**2 ,2) > eps
    S = sum(cumsum(m,1),0)
    row = where(S > .12 * S[-1])[0][0]
    return row

# blacks out the sky portion of an image
def skyline_eliminate(img, row) :
    img = img.copy()
    img[0:row,:,:] = 0
    return img

# converts a grayscale image to RGB
def g2rgb(a):
    sz = a.shape
    r = zeros((sz[0],sz[1],3))
    r[:,:,0]=a
    r[:,:,1]=a
    r[:,:,2]=a
    return r

#dilates 2D binary image
def dilate(a,win):
    r = zeros(a.shape)
    [yy, xx] = where(a>0)
    # prepare neighboroods
    off = tile(range(-win,win+1),(2*win+1,1))
    x_off = off.flatten(1)
    y_off = off.T.flatten(1)

    # duplicate each neighborhood element for each index
    n = len(xx.flat)
    x_off = tile(x_off, (n,1)).flatten(1)
    y_off = tile(y_off, (n,1)).flatten(1)
    # round out offset
    ind = sqrt(x_off**2 + y_off**2) > win
    x_off[ind] = 0
    y_off[ind] = 0

    # duplicate each index for each neighborhood element
    xx = tile(xx, ((2*win+1)**2))
    yy = tile(yy, ((2*win+1)**2))

    nx = xx + x_off
    ny = yy + y_off

    # bounds checking
    ny[ny<0]=0
    ny[ny>479]=479
    nx[nx<0]=0
    nx[nx>639]=639
    r[ny,nx]=1;
    return r

#dilates 2D binary image
def cluster_filter(a,win,thresh):
    r = zeros(a.shape)
    [yy, xx] = where(a>0)
    # prepare neighboroods
    off = tile(range(-win,win+1),(2*win+1,1))
    x_off = off.flatten(1)
    y_off = off.T.flatten(1)

    # duplicate each neighborhood element for each index
    n = len(xx.flat)
    x_off = tile(x_off, (n,1)).flatten(1)
    y_off = tile(y_off, (n,1)).flatten(1)
    # round out offset
    ind = sqrt(x_off**2 + y_off**2) > win
    x_off[ind] = 0
    y_off[ind] = 0

    # duplicate each index for each neighborhood element
    xx = tile(xx, ((2*win+1)**2))
    yy = tile(yy, ((2*win+1)**2))

    nx = xx + x_off
    ny = yy + y_off

    # bounds checking
    ny[ny<0]=0
    ny[ny>479]=479
    nx[nx<0]=0
    nx[nx>639]=639

    r[ny,nx]=r[ny,nx]+1;

    t = thresh*(2*win+1)**2
    r[r>t] = 1;
    return r
    
def medfilt(a,win_size):
  img = Image.fromarray(a.astype('uint8')).filter(ImageFilter.MedianFilter(win_size))
  return numpy.asarray(img).astype('float32')

# these were my calls to make the program run
# Main('20080217-01',1,100) # girlL,1plane
# Main('20080217-07',1,100) # girlR,1plane
# Main('20080217-08',1,100) # boyL,1plane
# Main('20080217-09',1,100) # boygirlR,2planes
Main('20080217-11',1,100) # boygirlLR,2planes





