# Copyright 2011 Dustin Lang (Princeton) and David W. Hogg (NYU).
# All rights reserved.

# BUGS:
# -----
# - no known bugs (well, search the code!)

import matplotlib
matplotlib.use('Agg')
import multiprocessing
from glob import glob
import os
import datetime
import time
import sys
from math import pi
import numpy as np
from numpy.linalg import lstsq
import pylab as plt
if not hasattr(plt, 'tick_params'):
	print 'tick_params() is not available in this matplotlib version'
	def tick_params(**kwargs):
		pass
	plt.tick_params = tick_params
from matplotlib.patches import Ellipse, Polygon
from scipy import interpolate
from scipy.special import gammaln
import markovpy
from astrometry.util.file import *
from astrometry.util import jpl
from astrometry.util import EXIF
from astrometry.util.util import Tan
from astrometry.util.starutil_numpy import jdtomjd, radectoxyz, deg2distsq, datetomjd, mjdtodate, arcsec2distsq, xyztoradec, arcsec_between, ecliptic_basis
import astrometry.util.celestial_mechanics as cm

GM_sun = cm.GM_sun

# prior parameters
PRIOR_RADIUS = 1. # AU
PRIOR_ALPHA = 1.
PRIOR_BETA = 3.

def lnbeta(x, alpha, beta):
	return np.exp( gammaln(alpha + beta) - gammaln(alpha) - gammaln(beta)
		       + np.log(x) * (alpha-1.) + np.log(1.-x) * (beta-1.) )

def cosdeg(x):
	return np.cos(np.deg2rad(x))

# Returns mjd,E
def EMB_ephem():
	jd,E = jpl.parse_orbital_elements('''2454101.500000000 = A.D. 2007-Jan-01 00:00:00.0000 (CT)
 EC= 1.670361927937051E-02 QR= 9.832911829245575E-01 IN= 9.028642170169823E-04
 OM= 1.762399911457168E+02 W = 2.867172565215373E+02 Tp=  2454104.323526433203
 N = 9.856169820212966E-01 MA= 3.572170843984495E+02 TA= 3.571221738148128E+02
 A = 9.999947139070439E-01 AD= 1.016698244889530E+00 PR= 3.652534468934519E+02''',
									  needSystemGM=False)
	return jdtomjd(jd[0]),E[0]

def EMB_xyz_at_times(times):
	t0,E = EMB_ephem()
	(a,e,I,Omega,pomega,M0,nil) = E
	dMdt = np.sqrt(GM_sun / a**3)
	obsxyz = []
	for t in times:
		M = M0 + dMdt * (t - t0)
		(x,v) = cm.phase_space_coordinates_from_orbital_elements(a,e,I,Omega,pomega,M,GM_sun)
		obsxyz.append(x)
	obsxyz = np.array(obsxyz)
	return obsxyz

class CometMCMCxv:
	#			  AU,  AU/day
	paramnames = ['x', 'v']
	def __init__(self):
		self.x = np.zeros(3)
		self.v = np.zeros(3)
		self.xstep = 0.0003
		self.vstep = 0.00004
		# epoch of the orbital elements (ie, M)
		# copy-n-pasted from JPL below
		self.epoch = jdtomjd(2454418.5)
		self.times = None
		self.earthxyz = None
		# spline subsampling
		self.Nspline = 200
		self.pgood = 0.85
		self.pexif = 0.75
		self.centering = 2.5
		self.set_times(datetomjd(datetime.datetime(2007, 7, 1)),
					   datetomjd(datetime.datetime(2008, 5, 1)),
					   0.05)

	def set_times(self, tmin, tmax, dtdays=None):
		self.tmax = tmax
		self.tmin = tmin
		if dtdays is None:
			dtdays = self.dtdays
			if dtdays is None:
				# default
				dtdays = 1.
		self.dtdays = dtdays
		self.times = np.arange(tmin, tmax, self.dtdays)
		self.set_ptime(np.ones_like(self.times))
		self.earthxyz = EMB_xyz_at_times(self.times)

	def set_ptime(self, pt):
		assert(pt.shape == self.times.shape)
		assert(np.sum(pt) > 0)
		self.ptime = pt.astype(float)
		self.ptime /= (self.ptime.sum() * self.dtdays)

	def set_ptime_from_samples(self, times, binwidth):
		bins = np.arange(self.tmin, self.tmax + binwidth, binwidth)
		H,xe = np.histogram(times, bins=bins)
		# np.digitize is 1-indexed
		inds = np.clip(np.digitize(self.times, bins) - 1, 0, len(bins)-1)
		# MAGIC +1: regularization
		self.set_ptime(H[inds].astype(float) + 1.)

	def set_x(self, x):
		self.x[:] = x[:]

	def set_v(self, v):
		self.v[:] = v[:]

	def set_epoch(self, mjd):
		self.epoch = mjd

	def set_params(self, p):
		self.x[:] = p[:3]
		self.v[:] = p[3:6]
		self.pgood = p[6]
		self.pexif = p[7]
		self.centering = np.exp(p[8])

	def get_params(self):
		return np.hstack((self.get_x(), self.get_v(),
				  self.pgood, self.pexif, np.log(self.centering)))

	def get_orbital_elements(self):
		return cm.orbital_elements_from_phase_space_coordinates(self.get_x(), self.get_v(), GM_sun)

	def get_x(self):
		return self.x.copy()

	def get_v(self):
		return self.v.copy()

	def set_params_from_jpl_string(self, s):
		x,v,jd = jpl.parse_phase_space(s)
		x = x[0]
		v = v[0]
		jd = jd[0]
		self.set_x(x)
		self.set_v(v)
		self.set_epoch(jdtomjd(jd))

	# year ~2000 visit
	def set_last_pass_params(self):
		s = '''2455638.500000000 = A.D. 2011-Mar-18 00:00:00.0000 (CT)
		-5.048281449175005E+00  4.365257402143395E-02 -9.429924038754340E-01
		9.663748139985810E-04 -5.542825718109530E-03 -1.422501670129140E-03
	    2.966181942514266E-02  5.135784829124947E+00 -7.358334729180353E-04
		'''
		self.set_params_from_jpl_string(s)
		
	def set_true_params(self):
		# from test-holmes-1.txt
		s = '''2454418.500000000 = A.D. 2007-Nov-14 00:00:00.0000 (CT)
		1.248613009072901E+00  2.025389080777020E+00  8.242272661173784E-01
		-7.973747689948190E-03  9.391385050634651E-03  1.214741189900433E-03
		1.454305558379576E-02  2.518052017168866E+00  3.997656275386785E-03
		'''
		self.set_params_from_jpl_string(s)

	def lnposterior(self, data):
		return self.lnprior() + self.lnlikelihood(data)

	# the prior is of the form
	# p(x,v,pgood,pexif) = p(x) p(v|x) p(pgood) p(pexif)
	# p(x) is a gaussian
	# p(v|x) is a beta distribution in v^2 between 0 and the unbinding v^2
	# p(pgood) = p(pexif) = uniform(0,1)
	def lnprior(self):
		lnp = 0.
		if self.tmax <= self.tmin:
			print 'lnprior(): tmax <= tmin -- punishing prior'
			lnp += -1000.
		if self.pgood <= 0 or self.pgood >= 1:
			print 'lnprior(): pgood=', self.pgood, '-- punishing prior'
			lnp += -1000.
		if self.pexif <= 0 or self.pexif >= 1:
			print 'lnprior(): pexif=', self.pexif, '-- punishing prior'
			lnp += -1000.
		if self.centering < 1.:
			print 'lnprior(): centering=', self.centering, '-- punishing prior'
			lnp += -1000.
		if self.get_energy() > 0:
			print 'lnprior(): unbound orbit -- punishing'
			lnp += -1000.
			return lnp # this must be here or else death and destruction
		x = self.get_x()
		lnp += -0.5 * np.dot(x,x) / PRIOR_RADIUS**2
		v2max = -2.0 * cm.potential_energy_from_position(x, GM_sun)
		v = self.get_v()
		lnp += lnbeta(np.dot(v,v) / v2max, PRIOR_ALPHA, PRIOR_BETA)
		lnp += -1. * np.log(2. * pi * cm.norm(v)) # Jacobian factor converts [d^3v] to [dv^2]
		return lnp

	def get_energy(self):
		return cm.energy_from_phase_space_coordinates(self.get_x(), self.get_v(), GM_sun)

	def xyz_at_times(self, times=None, light_travel=True):
		C = list(self.get_orbital_elements()) + [GM_sun]
		CM0 = C[5]
		CdMdt = np.sqrt(GM_sun / C[0]**3)
		if times is None:
			times = self.times
			earthxyz = self.earthxyz
			Nsp = self.Nspline
		else:
			Nsp = 1
			earthxyz = EMB_xyz_at_times(times)
		while len(times)/Nsp < 6 and Nsp > 1:
			Nsp = max(1, int(Nsp / 2))

		SS = slice(0, len(times), Nsp)
		N = len(times[SS])

		XYZ = np.empty((N,3))
		for i,(ex,t) in enumerate(zip(earthxyz[SS], times[SS])):
			C[5] = CM0 + CdMdt * (t - self.epoch)
			XYZ[i,:] = cm.orbital_elements_to_xyz(C, ex)
		
		if Nsp > 1:
			(S,u) = interpolate.splprep([XYZ[:,0], XYZ[:,1], XYZ[:,2]], u=times[SS], s=0)
			X,Y,Z = interpolate.splev(times, S)
			XYZ = np.vstack((X,Y,Z)).T

		# just in case any of our shih is hucked up.
		XYZ /= np.sqrt(np.sum(XYZ**2, axis=1))[:,np.newaxis]
		return XYZ

	def radec_at_times(self, times=None, light_travel=True):
		xyz = self.xyz_at_times(times, light_travel)
		return xyztoradec(xyz)

	# POSSIBLE BUG: what is the zero-indexing or one-indexing?
	def find_points_in_wcs(self, wcs, xyz,
						   xlo=None, xhi=None, ylo=None, yhi=None):
		if xlo is None:
			xlo = 0.5
		if xhi is None:
			xhi = wcs.imagew + 0.5
		if ylo is None:
			ylo = 0.5
		if yhi is None:
			yhi = wcs.imageh + 0.5
		xyzc = np.array(wcs.xyzcenter())
		r2 = arcsec2distsq(wcs.pixel_scale() *
						   np.hypot(xhi-xlo, yhi-ylo)/2.)
		R2 = ((xyz - xyzc[np.newaxis,:])**2).sum(axis=1)
		I = np.flatnonzero(R2 < r2)
		if len(I) == 0:
			return np.array([])
		x,y = wcs.xyz2pixelxy(xyz[I,:])
		x,y = np.array(x), np.array(y)
		J = (x >= xlo) * (x <= xhi) * (y >= ylo) * (y <= yhi)
		return I[J]

	# God's pwn function
	# MAGIC: 0.5 days = 12 hrs = span of time zones
	def find_times_slice_within_12_hours_of_mjd(self, time):
		return (self.times > (time - 0.5)) * (self.times < (time + 0.5))

	# NOTE: this likelihood assumes that self.times are evenly spaced.
	def lnlikelihood(self, data):
		try:
			cxyz = self.xyz_at_times()
		except:
			return -100. * len(data)
		pgood = np.clip(self.pgood, 0.001, 0.999)
		pexif = np.clip(self.pexif, 0.001, 0.999)
		# MAGIC 1/(4*pi..yadda) = the sky in arcsec^2
		bg = (1. - pgood) / (4. * pi * 206265.**2)
		if self.centering > 1:
			sfc = np.sqrt(float(self.centering))
			f1 = 0.5 - 0.5 / sfc
			f2 = 0.5 + 0.5 / sfc
		else:
			f1 = 0.
			f2 = 1.
		lnp = 0.
		wcs = Tan()
		for i, (wcsvals, date) in enumerate(data):
			# MAGIC: we set bad dates to zero
			if date > 1:
				H = self.find_times_slice_within_12_hours_of_mjd(date)
				if H.sum() > 0:
					ptime = (1. - pexif) * self.ptime
					# BUG / MAGIC: the following lines rely on dtdays being in days!
					ptime[H] += pexif
					ptime /= (np.sum(ptime) * self.dtdays)
			else:
				ptime = self.ptime
			wcs.set(*wcsvals)
			xlo,xhi = f1 * wcs.imagew, f2 * wcs.imagew
			ylo,yhi = f1 * wcs.imageh, f2 * wcs.imageh
			J = self.find_points_in_wcs(wcs, cxyz, xlo=xlo, xhi=xhi, ylo=ylo, yhi=yhi)
			if len(J) == 0:
				totalp = 0.
			else:
				totalp = ptime[J].sum() * self.dtdays
			fg = totalp * pgood / (wcs.pixel_scale()**2 * (yhi-ylo)*(xhi-xlo))
			# print 'lnlikelihood(): sum(J), 'in image, totalp=', totalp, 'fg', fg, 'bg', bg
			lnp += np.log(fg + bg)
		return lnp

# outside main for multiprocessing (needs to be picklable -- it's deep)
def evallnprob(args):
	(c, params, data) = args
	c.set_params(params)
	return c.lnposterior(data)

def main():
	import optparse
	parser = optparse.OptionParser()
	parser.add_option('--small-only', dest='smallonly', default=False, action='store_true', help='Use only (angularly) small images')
	parser.add_option('--hack-out-crap', '--hoc', dest='hoc', default=False, action='store_true', help='Remove images that might be screwing us')
	parser.add_option('--threads', dest='threads', default=16, type=int, help='Use this many concurrent processors')
	parser.add_option('--walkers', dest='walkers', default=64, type=int, help='Use this many MCMC walkers')
	parser.add_option('--binwidth', dest='binwidth', default=8, type=int, help='Make cheater prior with this bin width (days)')
	parser.add_option('--maxiter', dest='maxiter', default=500, type=int, help='Do no more than this number of iterations')
	parser.add_option('--exif-plots', dest='exifplots', default=False, action='store_true', help='Make EXIF plots and exit.')
	parser.add_option('--emp-r', dest='emprad', type=float, default=None, help='Radius for empirical initialization')
	opt,args = parser.parse_args()
	np.random.seed(42)

	print 'main(): Munging data...'
	print '  reading WCS...'
	wcsfns = glob('2010-04-16-holmes-dedup/holmes-*.wcs')
	wcsfns.sort()
	wcs = [Tan(fn,0) for fn in wcsfns]

	# ad-hockery!
	if opt.hoc:
		keep = []
		for i,w in enumerate(wcs):
			rmin,rmax,decmin,decmax = w.radec_bounds()
			#print 'RA', rmin, rmax, 'Dec', decmin, decmax
			if rmin > 40. and rmax < 44. and decmin > 30. and decmax < 35.:
				print 'main(): removing image', wcsfns[i].replace('.wcs','.jpg')
			else:
				keep.append(i)
		wcsfns = [wcsfns[i] for i in keep]
		wcs = [wcs[i] for i in keep]

	# if asked, keep only (angularly) smallest quartile
	if opt.smallonly:
		scales = np.array([w.pixel_scale()*np.sqrt(w.imagew*w.imageh) for w in wcs])
		I = np.argsort(scales)[:len(wcs)/2]
		wcs = [wcs[i] for i in I]

	# parse dates from EXIF data in jpegs
	print '  reading EXIF...'
	dates = []
	exifobjs = []
	for fn in wcsfns:
		imgfn = fn.replace('.wcs', '.jpg')
		exif = EXIF.process_file(open(imgfn), details=False)
		exifobjs.append((imgfn, exif))
		imgdate = exif.get('EXIF DateTimeOriginal')
		if not imgdate:
			imgdate = exif.get('Image DateTime')
		if imgdate:
			dates.append(datetime.datetime.strptime(str(imgdate),
													'%Y:%m:%d %H:%M:%S'))
		else:
			dates.append(None)

	gooddates = [d for d in dates if d is not None]
	Igooddates = np.array([i for i,d in enumerate(dates) if d is not None])
	print 'main(): Got', len(gooddates), 'dates from EXIF'
	goodmjds = [datetomjd(d) for d in gooddates]

	data = zip(wcs, [datetomjd(d) if d is not None else 0 for d in dates])
	# unpack WCS structure into a list of doubles for multiprocessing
	data2 = [((wcs.crval[0], wcs.crval[1], wcs.crpix[0], wcs.crpix[1],
			   wcs.cd[0], wcs.cd[1], wcs.cd[2], wcs.cd[3],
			   wcs.imagew, wcs.imageh), mjd) for (wcs, mjd) in data]
	print 'main(): burp.'
	print 'main(): Number of images:', len(data)

	# make comet
	C = CometMCMCxv()
	C.set_true_params()
	ptrue = C.get_params()
	C.set_ptime_from_samples(goodmjds, opt.binwidth)

	if False:
		for centering in np.arange(1.,4.01,0.25):
			C.set_ptime_from_samples(goodmjds, 8.)
			C.centering = centering
			print centering, C.lnposterior(data2)
		sys.exit(0)

	if False:
		for binwidth in np.arange(5.,13.01):
			C.centering = 2.5
			C.set_ptime_from_samples(goodmjds, binwidth)
			print binwidth, C.lnposterior(data2)
		sys.exit(0)

	if False:
		C.centering = 2.5
		C.set_ptime_from_samples(goodmjds, 8.)
		for pgood in np.arange(0.05,1.,0.1):
			C.pexif = 0.75
			C.pgood = pgood
			print C.pgood, C.pexif, C.lnposterior(data2)
		for pexif in np.arange(0.05,1.,0.1):
			C.pgood = 0.85
			C.pexif = pexif
			print C.pgood, C.pexif, C.lnposterior(data2)
		sys.exit(0)

	if False:
		ra,dec = C.radec_at_times()
		D = np.sqrt((np.diff(ra*np.cos(np.deg2rad(dec))))**2 + np.diff(dec)**2)
		print 'max angular velocity', max(D)/C.dtdays, 'deg/day'
		print 'min image radius:', min(data[2])
		print 'fastest traversal:', min(data[2]) / (max(D)/C.dtdays), 'days'

	if opt.exifplots:
		from exifplots import exifplots
		exifplots(C, data2, exifobjs)
		sys.exit(0)

	epoch = None
	if opt.emprad:
		print 'Initializing from the data with radius', opt.emprad
		# Initialize orbit using images close to the median date.
		medmjd = np.median(goodmjds)
		nearby = (np.abs(goodmjds - medmjd) < 7)
		print 'number of images within 7 days:', sum(nearby)
		# Move the epoch to the nearest integer MJD.
		epoch = medmjd
		epoch = float(round(epoch))
		print 'epoch', epoch

		# True radius
		embxyz0 = EMB_xyz_at_times(np.array([epoch]))
		embxyz0 = embxyz0[0,:]
		print 'emb xyz0:', embxyz0
		E = list(C.get_orbital_elements()) + [GM_sun]
		(cxyz0,v0) = cm.phase_space_coordinates_from_orbital_elements(*E)
		print 'comet xyz0:', cxyz0
		print 'R:', np.sqrt(np.sum((embxyz0 - cxyz0)**2))

		Inear = Igooddates[nearby]
		nearwcs = [data[i][0] for i in Inear]
		nrd = np.array([wcs.radec_center() for wcs in nearwcs])
		nra  = nrd[:,0]
		ndec = nrd[:,1]
		nrad = np.array([wcs.pixel_scale() * np.hypot(wcs.imagew, wcs.imageh)
						 for wcs in nearwcs]) / 3600.
		# cut on RA,Dec
		mra = np.median(nra)
		mdec = np.median(ndec)
		Jnear = np.hypot(nra - mra, ndec - mdec) < 5.
		Inear = Inear[Jnear]
		nra = nra[Jnear]
		ndec = ndec[Jnear]
		nrad = nrad[Jnear]
		nt = np.array(goodmjds)[nearby][Jnear]

		N = len(nra)
		sqrtW = 1./nrad
		X = np.zeros((N,2))
		X[:,0] = 1. * sqrtW
		X[:,1] = (nt-epoch) * sqrtW
		y = np.zeros((N,2))
		y[:,0] = nra * sqrtW
		y[:,1] = ndec * sqrtW
		(b,resid,rank,eigs) = lstsq(X, y)
		print 'b', b
		ra0,dec0 = b[0,:]
		dradt,ddecdt = b[1,:]

		plt.clf()
		plt.subplot(2,1,1)
		plt.errorbar(nt-epoch, nra, yerr=nrad, fmt=None)
		a = plt.axis()
		tt = np.array(a[:2])
		plt.plot(tt, ra0 + dradt * tt, 'b-')
		plt.xlim(a[0],a[1])
		#plt.ylim(ra0-5, ra0+5)
		plt.ylabel('RA (deg)')
		plt.subplot(2,1,2)
		plt.errorbar(nt-epoch, ndec, yerr=nrad, fmt=None)
		plt.plot(tt, dec0 + ddecdt * tt, 'b-')
		plt.xlim(a[0],a[1])
		#plt.ylim(dec0-5, dec0+5)
		plt.ylabel('Dec (deg)')
		plt.xlabel('dt (days)')
		plt.savefig('ravst.png')

		# Find velocity and compute orbital elements from x,v.
		dt = 1.
		ra1  = ra0  + dradt  * dt
		dec1 = dec0 + ddecdt * dt
		xyz0 = radectoxyz(ra0, dec0)[0]
		xyz1 = radectoxyz(ra1, dec1)[0]
		#print 'xyz:', xyz0, xyz1
		# xyz are in celestial coords; convert to solar-system coords
		(antieq, antisol, antipole) = ecliptic_basis(eclipticangle = -23.439281)
		xyz0 = xyz0[0] * antieq + xyz0[1] * antisol + xyz0[2] * antipole
		xyz1 = xyz1[0] * antieq + xyz1[1] * antisol + xyz1[2] * antipole

		# Find observer (EMB) positions at epoch
		if False:
			t0_emb,E_emb = EMB_ephem()
			E_emb = list(E_emb)
			(a,e,I,Omega,pomega,Mx,nil) = E_emb
			dMdt_emb = np.sqrt(GM_sun / a**3)
			M0 = Mx + dMdt_emb * (epoch - t0_emb)
			(embx0,v) = cm.phase_space_coordinates_from_orbital_elements(a,e,I,Omega,pomega,M0,GM_sun)
			M1 = Mx + dMdt_emb * ((epoch+dt) - t0_emb)
			(embx1,v) = cm.phase_space_coordinates_from_orbital_elements(a,e,I,Omega,pomega,M1,GM_sun)

			#print 'R = ', R
			#print 'angle between cx0, cv0:', rad2deg(arccos(dot(cx0, cv0) / norm(cx0) / norm(cv0)))
			#print 'kinetic energy:', 0.5 * dot(cv0,cv0)
			#print 'potential energy:', GM_sun / norm(cx0)
			#E = 0.5 * dot(cv0,cv0) - GM_sun / norm(cx0)

		embxyz = EMB_xyz_at_times(np.array([epoch, epoch + dt]))
		embx0 = embxyz[0,:]
		embx1 = embxyz[1,:]

		rtru,dtru = C.radec_at_times()

		C.set_epoch(epoch)
		plt.clf()
		# Comet distance from EMB
		N = 30
		RR = np.linspace(0.5, 4, N)
		CC = np.linspace(0, 1, N)
		for R,cc in zip(RR, CC):
			#(0.5, (0,0,1)),
			#		(1,   (0,0.5,0.5)),
			#		(2,(0,1,0))]: # <- between 1 and 2 is best
			c = (0, 1-cc, cc)
			# Comet position and velocity...
			cx0 = embx0 + R * xyz0
			cx1 = embx1 + R * xyz1
			cv0 = (cx1 - cx0) / dt
			C.set_x(cx0)
			C.set_v(cv0)
			r,d = C.radec_at_times()
			plt.plot(r, d, '-', color=c, alpha=0.5)
		plt.plot(rtru, dtru, 'r-')
		plt.savefig('c2.png')

		RR = np.linspace(0.5, 4, 40)
		lnp = []
		for R in RR:
			# Comet position and velocity...
			cx0 = embx0 + R * xyz0
			cx1 = embx1 + R * xyz1
			cv0 = (cx1 - cx0) / dt
			C.set_x(cx0)
			C.set_v(cv0)
			lnp.append(C.lnposterior(data2))
		plt.clf()
		plt.plot(RR, lnp, 'k-')
		plt.savefig('R.png')

		R = opt.emprad
		cx0 = embx0 + R * xyz0
		cx1 = embx1 + R * xyz1
		cv0 = (cx1 - cx0) / dt
		C.set_x(cx0)
		C.set_v(cv0)
		#p2 = c2.get_params()
		#print 'c2 params', p2
		#print 'true params', C.get_params()
		#lnptrue = C.lnposterior(data2)
		#print 'true lnp:', lnptrue
		#C.set_params(p2)
		#C.set_epoch(epoch)
		#lnp2 = C.lnposterior(data2)
		#print 'empirical lnp:', lnp2

	nwalkers = opt.walkers
	ndim	 = len(C.get_params())
	# the following step sizes were set by Hogg on 2011-03-18
	stepsizes = np.array([1e-5]*3 + [1e-6]*3 + [0.001]*3)

	initpos = []
	for i in range(nwalkers):
		p = C.get_params()
		p += np.random.normal(size=p.shape) * stepsizes
		initpos.append(p)
	pos = np.array(initpos)

	pool = None
	if opt.threads > 1:
		pool = multiprocessing.Pool(opt.threads)

	def manylnprob(manyparams, data):
		if pool:
			M = pool.map
		else:
			M = map
		lnp = np.array(M(evallnprob, [(C, p, data) for p in manyparams]))
		print 'manylnprob(): lnp =', lnp
		print 'manylnprob(): p[0] =', manyparams[0]
		return lnp

	sampler = markovpy.EnsembleSampler(nwalkers, ndim, None,
									   manylnprob, postargs=data2)

	scale = np.array([1,1,1,365.25,365.25,365.25])
	postrue = ptrue[:6]
	lims = dict([(i,[postrue[i]*scale[i] - 3e-1,
					 postrue[i]*scale[i] + 3e-1])
				  for i in range(len(scale))])
	lims2 = dict([(0,[0., 1.]), (1, [0., 1.]), (2, [0., 2.])])
	labels2 = ['pgood', 'pexif', 'ln(centering)']

	state = None
	matplotlib.rc('font', family='computer modern roman')
	matplotlib.rc('font', size=14)
	matplotlib.rc('text', usetex=True)

	#dpi = 300
	dpi = 100
	matplotlib.rc('savefig', dpi=dpi)
	# figure(dpi=) does nothing.  Use savefig(dpi=) instead.
	plt.figure(figsize=(6,6), dpi=dpi)
	#spargs = dict(bottom=0.15, left=0.15, right=0.95, top=0.95)
	spargs = dict(bottom=0.1, left=0.1, right=0.98, top=0.98)
	
	suffix = ''
	if opt.smallonly: suffix += '-so'
	if opt.hoc: suffix += '-hoc'
	if suffix == '': suffix += '-default'
	suffix += '-nw%03d' % nwalkers
	suffix += '-bw%03d' % opt.binwidth
	if opt.emprad == 1.:
		suffix += '-emp'
	elif opt.emprad is not None:
		suffix += '-emp%.2f' % opt.emprad
	nproposed = np.zeros(nwalkers)
	naccepted = np.zeros(nwalkers)
	lnprob = None
	for k in range(opt.maxiter):
		fn = 'traj%s-%03d.png' % (suffix, k)
		if (k%10 == 0) and (not os.path.exists(fn)):
			print 'Plotting trajectory', fn
			plt.clf()
			plt.subplots_adjust(**spargs)
			plot_trajectory(pos, ptrue, data2, epoch=epoch)
			# Check -- plot the JPL ephemeris in RA,Dec
			# this is from Earth (399) not EMB
			# Looks good!
			#(ra,dec,jd) = jpl.parse_radec(open('holmes-ephem.txt').read())
			#plt.plot(ra, dec, 'b-', alpha=0.5)
			#plt.axis([70, 35, 30, 60])
			plt.savefig(fn)
			if (k%110 == 0):
				fn2 = 'traj%s-%03d.pdf' % (suffix, k)
				plt.savefig(fn2)

		if False:
			fn = 'manyd%s-%03d.png' % (suffix, k)
			if (k%10 == 0) and (not os.path.exists(fn)):
				print 'Plotting distribution', fn
				plt.clf()
				plt.subplots_adjust(**spargs)
				manyd_plot(6,pos[:,:6]*scale,ptrue[:6]*scale,lims)
				plt.savefig(fn)

			fn = 'manyd2%s-%03d.png' % (suffix, k)
			if (k%10 == 0) and (not os.path.exists(fn)):
				print 'Plotting distribution', fn
				plt.clf()
				plt.subplots_adjust(**spargs)
				manyd_plot(3,pos[:,6:9],ptrue[6:9],lims2,labels=labels2)
				plt.savefig(fn)

		fn = 'mcmc%s-%03i.pickle' % (suffix, k)
		if os.path.exists(fn):
			print 'Reading pickle', fn
			(nil,pos,lnprob,state) = unpickle_from_file(fn)
		else:
			print 'Running MCMC', fn
			t0 = datetime.datetime.now()
			niter = 1
			pos,lnprob,state = sampler.run_mcmc(pos, state, niter, lnprobinit=lnprob)
			nproposed[:] += float(niter)
			print 'main(): current acceptance fraction:', 100. * sampler.naccepted.sum() / nproposed.sum()
			t1 = datetime.datetime.now()
			dt = t1-t0
			dt = (dt.microseconds + (dt.seconds + dt.days * 24 * 3600.) * 1e6) / 1e6
			print 'MCMC took', dt, 'sec'
			pickle_to_file((k,pos,lnprob,state), fn)

def plot_trajectory(pos, ptrue, data, npts=16,
				   S=[], epoch=None):
	c = CometMCMCxv()
	c.set_true_params()
	if epoch is not None:
		c.set_epoch(epoch)
	c.dtdays = 1
	c.Nspline = 10

	wcs = [Tan(*d) for (d, foo) in data]
	scales = np.array([w.pixel_scale()*np.sqrt(w.imagew*w.imageh) for w in wcs])
	I = np.argsort(-scales)

	rds = []
	alphas = []
	plt.clf()
	for s,w in zip(scales[I], [wcs[i] for i in I]):
		W,H = w.imagew, w.imageh
		r,d = w.pixelxy2radec([0, 0, W, W, 0], [0, H, H, 0, 0])
		rd = np.vstack((r,d)).T
		radius = s / 3600.
		alphas.append(0.2/radius**2)
		rds.append(rd)
	for rd,alpha in zip(rds,alphas):
		plt.gca().add_artist(Polygon(rd, fc='0.5', ec='none',
									 alpha=np.clip(alpha,0.,0.5)))
	for rd,alpha in zip(rds,alphas):
		plt.gca().add_artist(Polygon(rd, fc='none', ec='k', lw=0.05,
									 alpha=0.5))

	p1 = None
	samplealpha = 0.5
	for posi in pos[:npts]:
		c.set_params(posi)
		(ras,decs) = c.radec_at_times()
		p1 = plt.plot(ras, decs, 'r-', alpha=samplealpha, linewidth=0.5)
		plt.plot([ras[0],ras[-1]], [decs[0],decs[-1]], 'r.', alpha=samplealpha)

	#print 'ptrue', ptrue
	c.set_params(ptrue)
	(ras,decs) = c.radec_at_times()
	p2 = plt.plot(ras, decs, 'r:', alpha=1, linewidth=2.0)

	# label some dates
	if False:
		ltimes = [datetime.datetime(x/12, x%12+1, 1)
				  for x in range(2007*12+5, 2008*12+4, 2)]
		c.Nspline = 1
		(lras,ldecs) = c.radec_at_times(np.array([datetomjd(d) for d in ltimes]))
		plt.plot(lras, ldecs, 'o', mec='r', mfc='none')
		for t,ra,dec in zip(ltimes,lras,ldecs):
			plt.text(ra-0.5, dec-0.5, str(t.date()), color='r',
					 size=10,
					 horizontalalignment='left',
					 verticalalignment='top',
					 bbox=dict(facecolor='w', ec='r', alpha=0.25))

	for day,ha,va in [ ((2007,  8, 1), 'left',  'bottom'),
					   ((2007, 10, 1), 'right', 'top'),
					   ((2007, 12, 1), 'left',  'bottom'),
					   ((2008,  2, 1), 'left',  'top'),
					   ((2008,  4, 1), 'left',  'top') ]:
		D = datetime.datetime(*day)
		d = datetomjd(D)
		ra,dec = c.radec_at_times(np.array([d]))
		ra,dec = ra[0],dec[0]
		plt.plot(ra, dec, 'o', mec='r', mfc='none')
		# offset
		dr = 0.5 * (1. if ha == 'right' else -1.)
		dd = 0.5 * (1. if va == 'bottom' else -1.)
		
		plt.text(ra+dr, dec+dd, str(D.date()), color='r', size=10,
				 ha=ha, va=va, 
				 bbox=dict(facecolor='w', ec='r', alpha=0.25))

	# plot the last pass
	#c.set_last_pass_params()
	#c.set_times(datetomjd(datetime.datetime(2000, 7, 1)),
	#datetomjd(datetime.datetime(2000, 12, 1)),
	#				1)

	for s,w in zip([scales[i] for i in S], [wcs[i] for i in S]):
		W,H = w.imagew, w.imageh
		r,d = w.pixelxy2radec([0, 0, W, W, 0], [0, H, H, 0, 0])
		rd = np.vstack((r,d)).T
		radius = s / 3600.
		plt.gca().add_artist(Polygon(rd, fc='none', ec='r'))

	if p1 is not None:
		plt.legend((p2,p1), ('JPL', 'samples'), prop=dict(size=12))
	plt.xlabel('RA (deg)')
	plt.ylabel('Dec (deg)')
	plt.axis([70,35,31,59])
	xt, foo = plt.xlim()
	yt, foo = plt.ylim()
	# MAGIC 0.5s etc
	plt.text(xt-0.5, yt + 0.5 + 2. * 0.70,
			 r'($p_{\mathrm{good}}$ %s)' % (' '.join(['%.2f' % pos[i,6] for i in range(npts)])),
			 color='r', alpha=samplealpha, fontsize=8)
	plt.text(xt-0.5, yt + 0.5 + 1. * 0.70,
			 r'($p_{\mathrm{EXIF}}$ %s)' % (' '.join(['%.2f' % pos[i,7] for i in range(npts)])),
			 color='r', alpha=samplealpha, fontsize=8)
	plt.text(xt-0.5, yt + 0.5 + 0. * 0.70,
			 r'($\eta$ %s)' % (' '.join(['%.2f' % np.exp(pos[i,8]) for i in range(npts)])),
			 color='r', alpha=samplealpha, fontsize=8)

def manyd_plot(ndim,pos,ptrue,lims={},labels=None):
	if labels is None:
		labels = ['p%02d' % i for i in range(ndim)]
	for i in range(ndim):
		for j in range(ndim):
			plt.subplot(ndim, ndim, i*ndim + j + 1)
			plt.axvline(ptrue[j], color='g', alpha=0.5, lw=0.5)
			if i == j:
				plt.hist(pos[:,i], 20, range=lims.get(i, None))
			else:
				plt.axhline(ptrue[i], color='g', alpha=0.5, lw=0.5)
				plt.plot(pos[:,j], pos[:,i], 'k.', alpha=0.25)
			if j in lims:
				plt.xlim(*lims[j])
			else:
				lims[j] = plt.xlim()
			if i != j:
				if i in lims:
					plt.ylim(*lims[i])
				else:
					lims[i] = plt.ylim()
			plt.xlabel('')
			plt.ylabel('')
			plt.tick_params(labelbottom=False, labelleft=False,
							labelsize=8)
			if i == 0:
				plt.tick_params(labeltop=True)
			if i == ndim-1:
				plt.tick_params(labelbottom=True)
				plt.xlabel(labels[j])
			if j == 0:
				plt.tick_params(labelleft=True)
				plt.ylabel(labels[i])
			if j == ndim-1:
				plt.tick_params(labelright=True)
			if i == j:
				plt.tick_params(labelleft= False, labelright=False)
				plt.ylabel('')

if __name__ == '__main__':
	main()

