from astrometry.util.defaults import *
from astrometry.util.pyfits_utils import *
from astrometry.util.file import *
from astrometry.util.starutil_numpy import *
import sanitycheckrix as scr
import os.path
from numpy import *
from astromemtry.util.mcmc import *
import numpy
from pylab import *
from astrometry.util.defaults import *

dinnerplate_r = 20.

'''
Notes:

rm *.pickle
python sanitycheckrix.py brute > brute.log 2>&1 &
mv 1_4kpc_metal_poor_with_PM_July23_09.txt-brutish.pickle 1_4kpc_metal_poor_with_PM_July23_09.txt-brutish3.pickle
python sanitycheckrix.py
# produces cands.fits
fitscopy cands.fits"[RA<220 && RA>200 && DEC > 10 && DEC < 20]" greeny.fits


python fitstream.py
#-- ran 100 x 1000 steps: data in step1/*
cp param* stream.pickle step1/
#-- continued with another 1000 x 1000 steps: data in step2/*
cp stream.pickle param* radec.png radec2.png lnp.png distfeh.png pm.png step2/
# ~rev 13686

fitscopy cands.fits"[RA>215 && RA < 220 && DEC > 29 && DEC < 30]" bluey.fits

'''

class Stream():
    def __init__(self, logalpha=None, ra=None, dec=None, dist=None, feh=None, pmra=None, pmdec=None, width=None, angle=0):
        '''
        logalpha (dimensionless)
        (RA,Dec) in deg
        dist in kpc
        feh in dex
        (pmra, pmdec) in mas/yr?
        width in deg
        angle in deg
        '''
        self.logalpha = logalpha
        self.ra = ra
        self.dec = dec
        self.dist = dist
        self.feh = feh
        self.pmra = pmra
        self.pmdec = pmdec
        self.width = width
        self.angle = angle
        self.set_burnin(True)
        self.bg = None

    def set_burnin(self, yesno):
        self.burnin = yesno

    def get_burnin(self):
        return self.burnin

    def set_background(self, data, distrange, fehrange, pmrarange, pmdecrange, solidangle):
        (components, distfeh_amplitudes, pmmeans, pmvars) = (
            scr.build_component_model(data, distrange, fehrange))
        self.bg = scr.background_data_lnprob(data, components, solidangle, distfeh_amplitudes, pmmeans, pmvars, pmrarange, pmdecrange)
        self.streamlength = 2. * sqrt(solidangle/pi)

    def fg_lnlikelihood(self, data):
        return scr.foreground_data_lnprob(data, self.ra, self.dec, self.dist,
                                          self.feh, self.pmra, self.pmdec,
                                          self.width, self.streamlength,
                                          angle_offset=deg2rad(self.angle))

    def lnlikelihoods(self, data):
        alpha = exp(self.logalpha)
        #return sum(scr.logsum(log(alpha) + self.fg_lnlikelihood(data),
        #                      log(1.-alpha) + self.bg))
        return scr.logsum(log(alpha) + self.fg_lnlikelihood(data) - self.bg,
                          log(1.-alpha))
        
    def lnlikelihood(self, data):
        return sum(self.lnlikelihoods(data))

    def get_mcmc_params(self):
        p = [self.logalpha, self.ra, self.dec, self.dist, self.feh,
             self.pmra, self.pmdec, self.width, self.angle]
        if self.burnin:
            return p[:-1]
        return p

    def set_mcmc_params(self, p):
        if self.burnin:
            (self.logalpha, self.ra, self.dec, self.dist, self.feh,
             self.pmra, self.pmdec, self.width) = p
        else:
            (self.logalpha, self.ra, self.dec, self.dist, self.feh,
             self.pmra, self.pmdec, self.width, self.angle) = p

    def get_mcmc_stepsizes(self):
        s = [0.1, 0.1, 0.1, 0.02, 0.02, 0.1, 0.1, 0.01, 1.]
        if self.burnin:
            return s[:-1]
        else:
            return s

    def set_mcmc_stepsizes(self, s):
        raise 'unimplemented'



class mcmcI:
    pass

def lnlikelihood(data, params, I):
    I.stream.set_mcmc_params(params)
    return I.stream.lnlikelihood(data)

def lnprior(params, I):
    return 0

def lnposterior(data, params, I):
    return lnprior(params, I) + lnlikelihood(data, params, I)

def proposal(oldparams, I):
    I.paramnum = numpy.random.randint(len(oldparams))
    #p = oldparams.copy()
    p = oldparams[:]
    p[I.paramnum] += I.stepsizes[I.paramnum] * numpy.random.normal()
    return (p, I)

def record(oldparams, oldlnp, newparams, lnp, randnum,
           accept, link, nlinks, I):
    verbose=False
    if verbose:
        print
        print 'oldp', oldparams
        print 'newp', newparams
        print 'oldlnp', oldlnp
        print 'newlnp', lnp
        print 'dlnp:', (lnp - oldlnp), 'logrand:', log(randnum), 'accept:', accept

    if lnp > I.bestlnp:
        I.bestparams = newparams
        I.bestlnp = lnp
    if link % 100 == 0 or link == nlinks-1:
        p = I.bestparams
        print 'link %i/%i' % (link,nlinks), 'best lnp %g' % I.bestlnp, 'best params', #['%.3g'%x for x in params]
        print 'alpha %.5g,' % exp(p[0]), '(%.5g, %.5g),'% (p[1],p[2]), 'dist %.5g,' % p[3], 'FeH %.5g,' % p[4], 'PM (%.5g,%.5g),'% (p[5],p[6]), 'width %.5g'%p[7],
        if len(p) == 9:
            print 'angle %.5g deg' % p[8]
        else:
            print
    # Record number of tries and accepts.
    # 'I.paramnum' is set by 'proposal()'
    I.tries[I.paramnum] += 1
    if lnp > oldlnp:
        I.autoaccepts[I.paramnum] += 1
    if accept:
        I.accepts[I.paramnum] += 1
        I.lnps.append(lnp)
    else:
        I.lnps.append(oldlnp)
    I.proplnps.append(lnp)
    return I

def plotstream(stream, data, bgparams, prefix):
    (components, radec_area, pmrarange, pmdecrange, distfeh_amplitudes, pmmeans, pmvars) = bgparams

    bg0 = scr.background_data_lnprob(data, components, radec_area, distfeh_amplitudes, pmmeans, pmvars, pmrarange, pmdecrange)
    fg0 = scr.foreground_data_lnprob(data, stream.ra, stream.dec, stream.dist, stream.feh, stream.pmra, stream.pmdec, stream.width, stream.streamlength, deg2rad(stream.angle))
    fgi = exp(fg0) / (exp(fg0) + exp(bg0))
    bgi = 1. - fgi

    print '%i stars in the dinner-plate.' % len(fgi)
    print '%i are definitely background' % sum(bgi == 1.)
    print '%i are definitely foreground' % sum(fgi == 1.)
    print '%i are >50pct foreground' % sum(fgi >= 0.5)

    # Omit RA,Dec from model.
    bg1 = scr.background_data_lnprob(data, components, None, distfeh_amplitudes, pmmeans, pmvars, pmrarange, pmdecrange)
    fg1 = scr.foreground_data_lnprob(data, None, None, stream.dist, stream.feh, stream.pmra, stream.pmdec, stream.width, stream.streamlength, deg2rad(stream.angle))
    # Omit dist,FeH from model.
    bg2 = scr.background_data_lnprob(data, components, radec_area, None, pmmeans, pmvars, pmrarange, pmdecrange)
    fg2 = scr.foreground_data_lnprob(data, stream.ra, stream.dec, None, None, stream.pmra, stream.pmdec, stream.width, stream.streamlength, deg2rad(stream.angle))
    # Omit pmra,pmdec from model.
    bg3 = scr.background_data_lnprob(data, components, radec_area, distfeh_amplitudes, None, None, None, None)
    fg3 = scr.foreground_data_lnprob(data, stream.ra, stream.dec, stream.dist, stream.feh, stream.pmra, stream.pmdec, stream.width, stream.streamlength, deg2rad(stream.angle), includepm=False)

    fgi0 = exp(fg0) / (exp(fg0) + exp(bg0))
    fgi1 = exp(fg1) / (exp(fg1) + exp(bg1))
    fgi2 = exp(fg2) / (exp(fg2) + exp(bg2))
    fgi3 = exp(fg3) / (exp(fg3) + exp(bg3))

    clf()
    subplot(2,2,1)
    hist(fgi0, 100, log=True)
    title('Whole model')
    subplot(2,2,2)
    hist(fgi1, 100, log=True)
    title('No RA,Dec')
    subplot(2,2,3)
    hist(fgi2, 100, log=True)
    title('No dist,FeH')
    subplot(2,2,4)
    hist(fgi3, 100, log=True)
    title('No PM')
    savefig(prefix + 'fghists.png')
    

    a = 0.2
    pargs = {'alpha':a, 'markeredgecolor':None, 'markeredgewidth':0}

    clf()
    fgi = fgi1
    bgi = 1. - fgi
    print
    print 'Without RA,Dec:'
    print '%i are definitely background' % sum(bgi == 1.)
    print '%i are definitely foreground' % sum(fgi == 1.)
    print '%i are >50pct foreground' % sum(fgi >= 0.5)
    B = (bgi >= 0.99)
    plot(data.ra[B], data.dec[B], 'b.', **pargs)
    F = (fgi >= 0.99)
    I = logical_and(logical_not(B), logical_not(F))
    for i in find(I):
        plot([data.ra[i]], [data.dec[i]], '.', color=(fgi[i], 0, bgi[i]), **pargs)
    plot(data.ra[F], data.dec[F], 'r.', **pargs)
    axis('equal')
    xlabel('RA (deg)')
    ylabel('Dec (deg)')
    savefig(prefix + 'radec.png')

    clf()
    fgi = fgi0
    bgi = 1. - fgi
    B = (bgi >= 0.99)
    plot(data.ra[B], data.dec[B], 'b.', **pargs)
    F = (fgi >= 0.99)
    I = logical_and(logical_not(B), logical_not(F))
    for i in find(I):
        plot([data.ra[i]], [data.dec[i]], '.', color=(fgi[i], 0, bgi[i]), **pargs)
    plot(data.ra[F], data.dec[F], 'r.', **pargs)
    axis('equal')
    xlabel('RA (deg)')
    ylabel('Dec (deg)')
    savefig(prefix + 'radec2.png')

    clf()
    fgi = fgi2
    bgi = 1. - fgi
    print
    print 'Without dist,FeH:'
    print '%i are definitely background' % sum(bgi == 1.)
    print '%i are definitely foreground' % sum(fgi == 1.)
    print '%i are >50pct foreground' % sum(fgi >= 0.5)
    B = (bgi >= 0.99)
    plot(data.dist[B], data.feh[B], 'b.', **pargs)
    F = (fgi >= 0.99)
    I = logical_and(logical_not(B), logical_not(F))
    for i in find(I):
        plot([data.dist[i]], [data.feh[i]], '.', color=(fgi[i], 0, bgi[i]), **pargs)
    plot(data.dist[F], data.feh[F], 'r.', **pargs)
    xlabel('Distance (kpc)')
    ylabel('FeH (dex)')
    #ylim(-2, -1.4)
    savefig(prefix + 'distfeh.png')

    clf()
    fgi = fgi3
    bgi = 1. - fgi
    print
    print 'Without pmRA,pmDec:'
    print '%i are definitely background' % sum(bgi == 1.)
    print '%i are definitely foreground' % sum(fgi == 1.)
    print '%i are >50pct foreground' % sum(fgi >= 0.5)
    B = (bgi >= 0.99)
    plot(data.pmra[B], data.pmdec[B], 'b.', **pargs)
    F = (fgi >= 0.99)
    I = logical_and(logical_not(B), logical_not(F))
    for i in find(I):
        plot([data.pmra[i]], [data.pmdec[i]], '.', color=(fgi[i], 0, bgi[i]), **pargs)
    plot(data.pmra[F], data.pmdec[F], 'r.', **pargs)
    axis('equal')
    axhline(0, color='b')
    axvline(0, color='b')
    xlabel('Proper motion RA')
    ylabel('Proper motion Dec')
    savefig(prefix + 'pm.png')


def print_targets(data, stream):
    lnl = stream.lnlikelihoods(y)
    I = argsort(-lnl)
    #fg = stream.fg_lnlikelihood(y)
    #bg = stream.bg
    #J = argsort(-(fg - bg))
    #print 'I == J?', all(I == J)
    print '%i stars are more likely FG than BG.' % sum(lnl > 0)
    targets = y[I[:100]]
    for i in range(len(targets)):
        t = targets[i]
        print i+1,
        print 'dLnl: %.3g,' % lnl[I[i]],
        print 'RA,Dec (%8.5g,% -8.5g),' % (t.ra,t.dec),
        print 'dist %-5.5g kpc,' % t.dist,
        print 'Fe/H %-6.5g dex,' % t.feh,
        print 'u %-7.5g, g %-7.5g, r %-7.5g' % (t.u, t.g, t.r)


if __name__ == '__main__':

    #prefix = 'greeny'
    prefix = 'bluey'

    cands = table_fields(prefix + '.fits')
    print 'Read %i candidate streams' % len(cands)

    candnum = 0
    cand = cands[candnum]

    pfn = prefix + '.pickle'
    if os.path.exists(pfn):
        print 'Reading from pickle', pfn
        (pcandnum, y, solidangle, distrange, fehrange, pmrarange, pmdecrange) = unpickle_from_file(pfn)
        assert(pcandnum == candnum)
         
    else:
        datafn = '1_4kpc_metal_poor_with_PM_July23_09.txt'
        (x,nil) = scr.read_data_file(datafn, '', False)
        print 'Read %i stars' % len(x)
        (y, solidangle, distrange, fehrange, pmrarange, pmdecrange) = (
            scr.make_6d_cut(cand.ra, cand.dec, x, dinnerplate_r))

        print 'Saving to pickle', pfn
        pickle_to_file((candnum, y, solidangle, distrange, fehrange, pmrarange, pmdecrange), pfn)
        

    (components, distfeh_amplitudes, pmmeans, pmvars) = (
        scr.build_component_model(y, distrange, fehrange))

    bg = scr.background_data_lnprob(y, components, solidangle, distfeh_amplitudes, pmmeans, pmvars, pmrarange, pmdecrange)

    streamw = 1. # deg
    streamangle = 0.
    logalpha = log(1e-2)

    prefix += '-'

    pfn = prefix + 'stream.pickle'
    if os.path.exists(pfn):
        print 'Reading stream parameters from pickle', pfn
        (pcandnum, burnin, params) = unpickle_from_file(pfn)
        assert(pcandnum == candnum)
    else:
        burnin = True
        params = (logalpha, cand.ra, cand.dec, cand.dist, cand.feh,
                  cand.pmra, cand.pmdec, streamw) #, streamangle)

    stream = Stream()
    stream.set_burnin(burnin)
    stream.set_mcmc_params(params)
    stream.set_background(y, distrange, fehrange, pmrarange, pmdecrange, solidangle)

    print_targets(y, stream)
    plotstream(stream, y, (components, solidangle, pmrarange, pmdecrange, distfeh_amplitudes, pmmeans, pmvars), prefix)

    I = mcmcI()
    I.stepsizes = stream.get_mcmc_stepsizes()
    I.stream = stream

    I.lnps = []
    I.proplnps = []
    I.bestlnp = -1e100

    bigchain = []

    for steps in range(1000):
        p = stream.get_mcmc_params()
        I.tries = zeros(len(p)).astype(int)
        I.accepts = zeros_like(I.tries)
        I.autoaccepts = zeros_like(I.tries)

        (bestparams, bestlnp, chain, step_info, naccept) = (
            mcmc(y, p, proposal, lnposterior, I, I, 1000, record=record, record_info=I,
                 keepchain=True, verbose=False))

        print 'Acceptance percentages:',
        print '  '.join(['%.1f (%i/%i)' % (100*a/t, a, t) for (a,t) in zip(I.accepts, I.tries.astype(float))])
        print 'Automatic acceptance percentages:',
        print '  '.join(['%.1f (%i/%i)' % (100*a/t, a, t) for (a,t) in zip(I.autoaccepts, I.tries.astype(float))])

        stream.set_mcmc_params(bestparams)

        clf()
        prop = array(I.proplnps)
        lnps = array(I.lnps)
        i = find(prop != lnps)
        plot(i, prop[i], 'b.')
        plot(lnps, 'r-')
        ylabel('ln(p)')
        xlabel('iteration')
        ylim(bestlnp-10, bestlnp+1)
        savefig(prefix + 'lnp.png')

        print 'Saving parameters to', pfn
        print 'Best params:', bestparams
        pickle_to_file((candnum, stream.get_burnin(), bestparams), pfn)

        stream.set_burnin(False)
        I.stepsizes = stream.get_mcmc_stepsizes()

        bigchain += chain

        clf()
        allparams = [p for (lnp,p) in bigchain]
        for i,name in enumerate(['log(alpha)', 'RA (deg)', 'Dec (deg)', 'dist (kpc)', 'Fe/H (dex)', 'pmRA', 'pmDec', 'stream width (deg)', 'angle offset (deg)']):
            if i >= len(allparams[-1]):
                break
            x = array([p[i] for p in allparams])
            clf()
            hist(x, 25)
            xlabel(name)
            title('%i x 1000 steps' % (steps+1))
            savefig(prefix + 'param%i.png' % i)


    #print_targets(y, stream)


