import sys
import os.path
import time
import matplotlib
matplotlib.use('Agg')
from astrometry.util.pyfits_utils import *
import pyfits
from pylab import *
from matplotlib.lines import Line2D
from astrometry.util.starutil_numpy import *
from astrometry.util.file import *
from astrometry.util.plotutils import *
#from astrometry.libkd.spherematch import spherematch_c
import numpy
import numpy.random as random
import colorsys
import cPickle as pickle
from matplotlib import rcParams
import scipy.linalg as linalg
from matplotlib.patches import Ellipse
from mwgc import parse_mwgc

rcParams.update({
	'image.interpolation':'nearest',
	'image.origin':'lower',
	})

def savefig(fn, *args, **kwargs):
	import pylab
	print 'saving', fn, '...',
	pylab.savefig(fn, *args, **kwargs)
	print 'done.'

def make_one_plot(xball, yball, color, alpha=None, ms=None, xtail=None, ytail=None, plotorder=None):
	(x, xl, xrange) = xball
	(y, yl, yrange) = yball
	if alpha is None:
		alpha = ones(len(x))
	if ms is None:
		ms = ones(len(x)) * 3.
	uc = unique(color)
	ua = unique(alpha)
	ums = unique(ms)

	arra = array(alpha)
	arrc = array(color)
	arrms = array(ms)

	clf()

	if xtail is not None:
		(Ntails,nil) = xtail.shape

	if plotorder is not None:
		I = argsort(plotorder)
		if xtail is not None and ytail is not None:
			for c,a,msi,xi,yi,xti,yti in zip(color[I], alpha[I], ms[I], x[I], y[I], xtail[:,I].T, ytail[:,I].T):
				plot([xi], [yi], '.', color=c, alpha=a, mec=c, ms=msi)
				#print 'shape1:', array([xi])[:,newaxis].repeat(Ntails).shape
				#print 'shape2:', xti.shape
				plot(vstack((array([xi])[:,newaxis].repeat(Ntails), xti)),
					 vstack((array([yi])[:,newaxis].repeat(Ntails), yti)),
					 '-', color=c, alpha=a)
		else:
			for c,a,msi,xi,yi in zip(color[I], alpha[I], ms[I], x[I], y[I]):
				plot([xi], [yi], '.', color=c, alpha=a, mec=c, ms=msi)

	else:
		for a in ua:
			for c in uc:
				for ms in ums:
					I = (c == arrc) * (a == arra) * (ms == arrms)
					if sum(I) == 0:
						continue
					print 'plotting', sum(I), 'in color', c
					plot(x[I], y[I], '.', color=c, alpha=a, mec=c, ms=ms)
					if xtail is not None and ytail is not None:
						for i in find(I):
							plot(vstack((x[i,newaxis].repeat(Ntails), xtail[:,i])),
								 vstack((y[i,newaxis].repeat(Ntails), ytail[:,i])),
								 '-', color=c, alpha=a)

	xlim(xrange)
	ylim(yrange)
	xlabel(xl)
	ylabel(yl)

def plot_range(x, qlo=0.025, qhi=0.975, quantile=None):
	if quantile is not None:
		assert(quantile >= 0)
		assert(quantile <= 1)
		qlo = (1. - quantile) / 2.
		qhi = 1. - qlo
	sx = x.copy()
	sx.sort()
	return (sx[floor(len(x) * qlo)], sx[ceil(len(x) * qhi)])

def dec_ticks(decrange, axis='y'):
	steps = [0.02, 0.05, 0.1, 0.2, 0.5, 1., 2., 5., 10., 20.]
	for s in steps:
		ns = int((decrange[1] - decrange[0]) / s)
		if 3 < ns < 10:
			break
	decticks = arange(s * int(floor(decrange[0] / s)),
					  s * int(ceil(decrange[1] / s) + 1),
					  s)
	if axis=='y':
		yticks(sin(deg2rad(decticks)), decticks)
	else:
		xticks(sin(deg2rad(decticks)), decticks)

def make_plots(ra, dec, dist, pmra, pmdec, pmraerr, pmdecerr, color=None, alpha=None, squidplot=False, pairs=None):
	distlabel = 'Rix distance (kpc)'
	rlabel = 'r (mag)'
	pmralabel = 'proper motion RA (mas/yr)'
	pmdeclabel = 'proper motion Dec (mas/yr)'
	ralabel = 'RA (deg)'
	declabel = 'Dec (deg)'

	decrange = plot_range(dec)
	dec_ticks(decrange)

	distrange = plot_range(dist)
	raball = (ra, ralabel, plot_range(ra))
	decball = (sin(deg2rad(dec)), declabel, sin(deg2rad(array( plot_range(dec)))))
	distball = (dist, distlabel, distrange)

	if alpha is None:
		alpha = 0.2 * ones(ra.shape)
	if color is None:
		color = array(['r'] * len(ra))

	if squidplot:
		snapdist = exp(0.15*floor(log(dist)/0.15))
		ms = 60./snapdist
		print 'snapdist range:', snapdist.min(), snapdist.max()
		#alpha2 = alpha # * snapdist
		Ntail = 5.
		#dx = random.normal(size=Ntail)[:,newaxis]
		#dy = random.normal(size=Ntail)[:,newaxis]
		N = len(ra)
		dx = random.normal(size=(Ntail,N))
		dy = random.normal(size=(Ntail,N))
		yrs = 1e5
		ratail  = ra  - (pmra/cos(deg2rad(dec)) + dx * pmraerr ) * yrs / (1000.*3600.)
		dec2    = dec - (pmdec+ dy * pmdecerr) * yrs / (1000.*3600.)
		print 'shape:', ratail.shape
		dectail = sin(deg2rad(dec2))
		alpha2 = 0.5 * ones(ra.shape)
		make_one_plot(raball, decball, color, alpha=alpha2, ms=ms, xtail=ratail, ytail=dectail, plotorder=alpha)

		a = axis()
		if pairs is not None:
			for (ra1,ra2,dec1,dec2) in zip(*pairs):
				plot([ra1, ra2], [sin(deg2rad(dec1)), sin(deg2rad(dec2))], 'b-')
		axis(a)
		
	else:
		make_one_plot(raball, decball, color, alpha)
	dec_ticks(decrange)


# returns a list of the group each point is a member of.
def friends_of_friends(I1, I2, N):
	assert(max(I1) < N)
	assert(max(I2) < N)

	# map from index to list of friend indices.
	friends = {}
	for i1,i2 in zip(I1,I2):
		if i1 in friends:
			if not i2 in friends[i1]:
				friends[i1].append(i2)
		else:
			friends[i1] = [i2]

		if i2 in friends:
			if not i1 in friends[i2]:
				friends[i2].append(i1)
		else:
			friends[i2] = [i1]


	# list of sets of indices.
	groups = []

	for i in range(N):
		alreadygrouped = False
		# check whether i is already in an existing group.
		for g in groups:
			if i in g:
				#print 'point',i,'is already in group:', g
				alreadygrouped = True
				break
		if alreadygrouped:
			continue

		if not i in friends:
			continue

		# friends that have been explored already
		friendsofi = set([i])
		# friends that have not been explored
		q = friends.get(i,[])[:]
		# "explored" means that their friends have been queued.
		while len(q):
			friend = q.pop(0)
			#print 'point',i,': looking at friend', friend
			friendsofi.add(friend)
			fof = friends.get(friend,[])
			for ff in fof:
				if ff in friendsofi:
					continue
				if ff in q:
					continue
				q.append(ff)
		groups.append(friendsofi)

		#print 'non-singleton groups:'
		#print 'point', i
		#for g in groups:
		#	if len(g) == 1:
		#		continue
		#	print '  ', g

	groupnumber = -1 * ones(N).astype(int)
	for gnum,g in enumerate(groups):
		for i in g:
			groupnumber[i] = gnum

	return (groups, groupnumber)

def plot_worm(x, y, lw, plot_pivots=True):
	if var(x) > var(y):
		I = argsort(x)
	else:
		I = argsort(y)
	uc = random.uniform()
	colorcode = colorsys.hsv_to_rgb(uc, 1, 1)
	highlightcolor = colorsys.hsv_to_rgb(uc, 0.5, 1)
	highlightedgecolor = colorsys.hsv_to_rgb(uc, 1, 0.8)
	plot(x, y, '-', color=highlightedgecolor, lw=lw+1, alpha=1,
		 zorder=1, solid_capstyle='round', solid_joinstyle='round')
	plot(x, y, '-', color=highlightcolor, lw=lw, alpha=1,
		 zorder=2, solid_capstyle='round', solid_joinstyle='round')
	if plot_pivots:
		plot(x, y, '.', color=colorcode)

def plotra(ra, dec):
	return 180. + (ra - 180.)*cos(deg2rad(dec))

def plot_radec_grid_lines():
	a=axis()
	ticks = gca().get_xticks()
	(declo, dechi) = gca().get_ylim()
	decs = linspace(declo, dechi, 100)
	for ra in ticks:
		plot(plotra(ra, decs), decs, 'k-', zorder=1)
	axis(a)

def plot_2d_hist(x, y, weights=None, bins=(500,500), range=None, log=False, logoffset=1,
				 cmap=antigray, pctlo=0, pcthi=95, minmax=None):
	dolog = log
	log = numpy.log
	if range is None:
		range = [plot_range(x), plot_range(y)]
	(H,xe,ye) = histogram2d(x, y, bins=bins, range=range, weights=weights)
	## histogram2d produces an array that has to be transposed
	## to be interpreted as an image with the correct axes.
	H = H.T
	if dolog:
		H = log(logoffset+H)
	extent=(xe.min(), xe.max(), ye.min(), ye.max())
	imshow(H, extent=extent, aspect='auto', cmap=cmap)
	if minmax is not None:
		gci().set_clim(minmax)
	else:
		minmax = set_image_color_percentiles(H, pctlo, pcthi)
	return (H, minmax, extent, xe, ye)

# range is a 2-by-2 array-like: [[xlo,xhi],[ylo,yhi]]
def plot_mean_vector_field(x, y, vx, vy, bins=(20,20), range=None, scale=None, do_ellipses=False, vscale=0):
	import __builtin__
	datarange = range
	range = __builtin__.range
	if datarange is None:
		datarange = [plot_range(x), plot_range(y)]
	(xlo,xhi) = (datarange[0][0], datarange[0][1])
	(ylo,yhi) = (datarange[1][0], datarange[1][1])
	xedges = linspace(xlo, xhi, bins[0]+1)
	yedges = linspace(ylo, yhi, bins[1]+1)
	xbinw = xedges[1] - xedges[0]
	ybinw = yedges[1] - yedges[0]

	# bin centers
	xc = xedges[:-1] + xbinw / 2.
	yc = yedges[:-1] + ybinw / 2.

	mvx = zeros(bins)
	mvy = zeros(bins)
	ellipses = []

	for i in range(len(xedges)-1):
		for j in range(len(yedges)-1):
			I = find((x >= xedges[i]  ) *
					 (x <  xedges[i+1]) *
					 (y >= yedges[j]  ) *
					 (y <  yedges[j+1]))
			if len(I) == 0:
				continue
			mvx[j,i] = mean(vx[I])
			mvy[j,i] = mean(vy[I])

			if len(I) < 2:
				continue
			if not do_ellipses:
				continue
			c = cov((vx[I]-mvx[j,i]) * (vy[I]-mvy[j,i]))
			C = array([[var(vx[I]), c], [c, var(vy[I])]])
			(eigenvecs, eigenvals) = linalg.eigh(C)
			angle = arctan(-eigenvals[0,1] / eigenvals[1,1]) * 180./pi
			#print 'ellipse: w', escale * 2.*sqrt(eigenvecs[0])
			#print '         h', escale * 2.*sqrt(eigenvecs[1])
			#print '         angle', angle
			#thisellipse = Ellipse(array([xc[i],yc[j]]),
			#					  escale * 2.*sqrt(eigenvecs[0]),
			#					  escale * 2.*sqrt(eigenvecs[1]),
			#				  angle)
			#ellipses.append(thisellipse)

	print 'mean vector length:', mean(sqrt(mvx**2 + mvy**2).ravel())
	if vscale == 0:
		(nil,vscale) = plot_range(sqrt(mvx**2 + mvy**2).ravel())
	(figw,figh) = gcf().get_size_inches()
	scale = 1/3. * bins[0] * vscale / figw #(datarange[0][1] - datarange[0][0])

	(XC,YC) = meshgrid(xc, yc)
	#plot(XC.ravel(), YC.ravel(), 'b.')
	if scale is not None:
		# scale: length of vector that should fill a grid cell.
		scale *= bins[0]
	Q = quiver(XC, YC, mvx, mvy, pivot='middle', scale=scale)
	if do_ellipses:
		for e in ellipses:
			gca().add_artist(e)
			e.set_facecolor('none')
			# e.set_facecolor('r')
			# e.set_alpha(0.2)
	return (Q,vscale)

def read_cache(fn, func, args, kwargs={}, write_cache=True):
	gotit = False
	if os.path.exists(fn):
		try:
			print 'Reading pickle', fn
			X = pickle.loads(read_file(fn))
			gotit = True
		except:
			pass
	if not gotit:
		X = func(*args, **kwargs)
		if write_cache:
			print 'Writing pickle', fn
			write_file(pickle.dumps(X, pickle.HIGHEST_PROTOCOL), fn)
	return X

def plotdec(dec):
	return sin(deg2rad(dec))

def all_pairs_plots(x):
	# produce sanity-check plots
	balls = [
		('Rix distance (kpc)', x.dist, None, 'dist'),
		#('proper motion RA, no correction (mas/yr)', old_pmra, plot_range(old_pmra), 'origpmra'),
		#('proper motion Dec, no correction (mas/yr)', old_pmdec, plot_range(old_pmdec), 'origpmdec'),
		#('proper motion RA, v_sun correction (mas/yr)', x.pmra, plot_range(x.pmra), 'pmra'),
		#('proper motion Dec, v_sun correction (mas/yr)', x.pmdec, plot_range(x.pmdec), 'pmdec'),
		('mu_l*cos(b), no correction (mas/yr)', x.old_pml, None, 'origpml'),
		('mu_b, no correction (mas/yr)', x.old_pmb, None, 'origpmb'),
		('mu_l*cos(b), v_sun correction (mas/yr)', x.pml, None, 'pml'),
		('mu_b, v_sun correction (mas/yr)', x.pmb, None, 'pmb'),
		('v_l, no correction (km/s)', pmdisttovelocity(x.old_pml, x.dist), None, 'origvl'),
		('v_b, no correction (km/s)', pmdisttovelocity(x.old_pmb, x.dist), None, 'origvb'),
		('v_l, v_sun correction (km/s)', pmdisttovelocity(x.pml, x.dist), None, 'vl'),
		('v_b, v_sun correction (km/s)', pmdisttovelocity(x.pmb, x.dist), None, 'vb'),
		('l (deg)', x.l, None, 'l'),
		('b (deg)', x.b, None, 'b'),
		('RA (deg)', x.ra, None, 'ra'),
		('Dec (deg)', x.dec, None, 'dec'),
		('r (mag)', x.r, None, 'r'),
		('Fe/H (dex)', x.feh, None, 'feh'),
		]

	figure(figsize=(60,60))
	clf()
	for i,bi in enumerate(balls):
		for j,bj in enumerate(balls):
			(xl,xi,xr,xn) = bi
			(yl,yi,yr,yn) = bj
			if yn != 'feh':
				continue
			if xr is None:
				xr = plot_range(xi)
				balls[i] = (xl,xi,xr,xn)
			if yr is None:
				yr = plot_range(yi)
				balls[j] = (yl,yi,yr,yn)
			subplot(len(balls),len(balls), j*len(balls) + i + 1)
			if i == j:
				hist(xi, 50, range=xr)
				xlabel(xl)
				fn = prefix + '-allsky-%s.png' % xn
			else:
				plot_2d_hist(xi, yi, range=[xr,yr], pcthi=99)
				xlabel(xl)
				ylabel(yl)
				xlim(xr)
				ylim(yr)
				fn = prefix + '-allsky-%s-%s.png' % (xn,yn)
	fn = prefix + '-all.pdf'
	savefig(fn)
	figure()

def density_plots(x, density, prefix):
	mwgc = parse_mwgc()
	print 'total of %i GCs' % len(mwgc)
	nearmwgc = [xx for xx in mwgc if xx[3] < 5.0]
	farmwgc  = [xx for xx in mwgc if xx[3] >= 5.0]
	print '  %i closer than 5 kpc' % len(nearmwgc)
	gcra = array([xx[1] for xx in nearmwgc])
	gcdec = array([xx[2] for xx in nearmwgc])
	gcnames = array([xx[0] for xx in nearmwgc])
	fargcra = array([xx[1] for xx in farmwgc])
	fargcdec = array([xx[2] for xx in farmwgc])
	fargcnames = array([xx[0] for xx in farmwgc])

	rar  = plot_range(x.ra)
	pdecr = plot_range(x.plotdec)
	decr = plot_range(x.dec)

	# Make RA,Dec plots of density.
	bins = (100,100)
	clf()
	(H,minmax,nil2) = plot_2d_hist(x.ra, x.plotdec, weights=density, bins=bins)
	plot(gcra, plotdec(gcdec), 'rx')
	#for (r,pd,n) in zip(gcra,plotdec(gcdec),gcnames):
	#	text(r,pd,n, horizontalalignment='left', verticalalignment='bottom')
	xlabel('RA (deg)')
	ylabel('Dec (deg)')
	dec_ticks(decr)
	xlim(rar)
	ylim(pdecr)
	savefig(prefix + '-density.png')
	
	clf()
	(H2,nil,nil2) = plot_2d_hist(x.ra, x.plotdec, minmax=minmax, bins=bins)
	minmax = (0, minmax[1] * median(H2.ravel()) / median(H.ravel()))
	clf()
	(H2,nil,ext) = plot_2d_hist(x.ra, x.plotdec, minmax=minmax, bins=bins)
	plot(gcra, plotdec(gcdec), 'rx')
	xlabel('RA (deg)')
	ylabel('Dec (deg)')
	dec_ticks(decr)
	xlim(rar)
	ylim(pdecr)
	savefig(prefix + '-density-unweighted.png')

	clf()
	R = H.copy()
	R[H2 == 0] = 0
	R /= maximum(H2,1)
	print 'R range:', R.min(), R.max()
	imshow(R, extent=ext, aspect='auto', cmap=antigray)
	plot(gcra, plotdec(gcdec), 'rx')
	plot(fargcra, plotdec(fargcdec), 'ro', mec='r', mfc='none')
	xlabel('RA (deg)')
	ylabel('Dec (deg)')
	dec_ticks(decr)
	xlim(rar)
	ylim(pdecr)
	savefig(prefix + '-density-divided.png')

def compute_density(inds, N):
	print 'Computing density...'
	I1 = inds[:,0]
	I2 = inds[:,1]
	leftindex = I1[find(I1 < I2)]
	rightindex = I2[find(I1 < I2)]
	density = zeros(N)
	for i in I1:
		density[i] += 1
	return (I1,I2,leftindex,rightindex,density)


def mean_and_covar(x, y):
	mx = mean(x)
	my = mean(y)
	c = cov(x, y)
	return (mx, my, c)

# not clearly correct.
def plot_ellipse_from_mean_and_covar(mx, my, c):
	(eigvals,eigvecs) = linalg.eig(c)
	angle = rad2deg(arctan(-eigvecs[0,1]/eigvecs[1,1]))
	e = Ellipse(array([mx, my]), 2 * sqrt(eigvals[0]),
				2 * sqrt(eigvals[1]), angle)
	a=gca()
	e.set_clip_box(a.bbox)
	a.add_artist(e)
	return e

# x is (N,2)
# mx is (N,2)
# varx is (N,2,2)
# returns ln of gaussian probs, shape (N,)
def ln_gaussian_2d(x, mx, varx):
	(N,two) = x.shape
	assert(two == 2)
	(N2,two) = mx.shape
	assert(two == 2)
	assert(N == N2)
	(N2,two,twob) = varx.shape
	assert(two == 2)
	assert(twob == 2)
	assert(N2 == N)

	detvar = varx[:,0,0] * varx[:,1,1] - varx[:,0,1] * varx[:,1,0]
	#print 'detvar', detvar.shape
	invvar = zeros_like(varx)
	invvar[:,0,0] =  varx[:,1,1]/detvar
	invvar[:,1,1] =  varx[:,0,0]/detvar
	invvar[:,0,1] = -varx[:,0,1]/detvar
	invvar[:,1,0] = -varx[:,1,0]/detvar
	#print 'invvar', invvar.shape
	dx = x-mx
	#print 'x-mx', (x-mx).shape
	d2 = (dx[:,0] * invvar[:,0,0] * dx[:,0] +
		  dx[:,1] * invvar[:,1,0] * dx[:,0] +
		  dx[:,0] * invvar[:,0,1] * dx[:,1] +
		  dx[:,1] * invvar[:,1,1] * dx[:,1])
	#print 'd2', d2.shape
	return -0.5 * d2 - log(2.*pi) - 0.5 * log(detvar)

# x is scalar or (N,)
# mx is scalar or (N,)
# varx is scalar or (N,)
# returns ln of gaussian probs, shape "max(x.shape, mx.shape, varx.shape)"
def ln_gaussian_1d(x, mx, varx):
	detvar = varx
	invvar = 1./varx
	dx = x-mx
	d2 = (dx * invvar * dx)
	return -0.5 * d2 - 0.5 * log(2.*pi*detvar)

def flat_background_pm_lnprob(x, pmrarange, pmdecrange):
	return -log((pmrarange[1]-pmrarange[0]) * (pmdecrange[1]-pmdecrange[0]))

def gaussian_background_pm_lnprob(x, components, pmmeans, pmvars):
	return ln_gaussian_2d(vstack((x.pmra, x.pmdec)).T, pmmeans[components,:], pmvars[components,:,:])

def background_radecdistfeh_lnprob(x, components, radec_area, distfeh_amplitudes):
	lnp = zeros_like(x.ra)
	if radec_area is not None:
		lnp += -log(radec_area)
	if distfeh_amplitudes is not None:
		lnp += log(distfeh_amplitudes[components])
	return lnp

# returns log(exp(x) + exp(y)), avoiding underflow.
def logsum(x, y):
    ref = maximum(x, y)
    return log(exp(x - ref) + exp(y - ref)) + ref

# x: structure with x.ra,x.dec, ...
# components: gaussian component indices (indicators)
# radec_area: solid angle in deg^2 (ie, deal with cos(Dec))
# distfeh_amplitudes: shape (# components)
def background_data_lnprob(x, components, radec_area, distfeh_amplitudes,
			   pmmeans, pmvars, pmrarange, pmdecrange):
	lnp = background_radecdistfeh_lnprob(x, components, radec_area, distfeh_amplitudes)
	if pmmeans is not None and pmvars is not None and pmrarange is not None and pmdecrange is not None:
		beta = 0.01
		lnp_flat = flat_background_pm_lnprob(x, pmrarange, pmdecrange)
		lnp_gauss = gaussian_background_pm_lnprob(x, components, pmmeans, pmvars)
		lnp += logsum(log(beta) + lnp_flat, log(1.-beta) + lnp_gauss)
	return lnp

# approximate!  fails at poles, etc.
# pmra = dRA/dt * cos(Dec)
# all in degrees
def distance_from_radec_line(ra, dec, racenter, deccenter, angle):
#pmra, pmdec):
	duv = vstack(((ra - racenter) * cos(deg2rad(deccenter)),
				  dec - deccenter)).T
	# Vector orthogonal to the proper motion.
	#vorth = array([-pmdec, pmra])
	vorth = array([-sin(angle), cos(angle)])
	vorth /= norm(vorth)
	return absolute(dot(duv, vorth))

# streamwidth: Gaussian stddev of width in deg.
# streamlength: tophat length in deg.
def foreground_data_lnprob(x, ra, dec, dist, feh, pmra, pmdec, streamwidth, streamlength, angle_offset=0., includepm=True):
	fehvar = (0.2)**2
	lnp = zeros_like(x.ra)
	if ra is not None and dec is not None:
		angle = arctan2(pmdec, pmra) + angle_offset
		lnp += ln_gaussian_1d(distance_from_radec_line(x.ra, x.dec, ra, dec, angle),
				      0., streamwidth**2) - log(streamlength)
	if dist is not None:
		lnp += ln_gaussian_1d(x.dist, dist, x.dd_in_kpc**2)
	if feh is not None:
		lnp += ln_gaussian_1d(x.feh, feh, fehvar)
	#if pmra is not None:
	#if pmdec is not None:
	if includepm:
		lnp += ln_gaussian_1d(x.pmra, pmra, x.dpmra**2)
		lnp += ln_gaussian_1d(x.pmdec, pmdec, x.dpmdec**2)
	return lnp

# returns:
#  cut-down of x
#  solid angle in deg^2
#  distrange in kpc
#  fehrange in dex
#  pmrarange, pmdecrange in mas/yr
def make_6d_cut(ra, dec, x, r=10.):
	fehmin = -2
	fehmax = -1.4
	I = points_within_radius(ra, dec, r, x.ra, x.dec)
	print '%i within RA,Dec range' % sum(I)

	y = x[I]
	distrange = plot_range(y.dist, quantile=0.99)
	fehrange  = plot_range(y.feh,  quantile=0.99)
	fehrange = (max(fehrange[0], fehmin), min(fehrange[1], fehmax))
	I = ((y.dist >= distrange[0]) * (y.dist <= distrange[1]) *
		 (y.feh >= fehrange[0]) * (y.feh <= fehrange[1]))
	# points within the ra,dec circle and dist,FeH ranges.
	y = y[I]
	print '%i within dist,FeH range' % len(y)
	pmrarange = plot_range(y.pmra, quantile=0.99)
	pmdecrange = plot_range(y.pmdec, quantile=0.99)
	I = ((y.pmra >= pmrarange[0]) * (y.pmra <= pmrarange[1]) *
		 (y.pmdec >= pmdecrange[0]) * (y.pmdec <= pmdecrange[1]))
	y = y[I]
	print '%i within pm range' % len(y)
	return (y, pi*r**2, distrange, fehrange, pmrarange, pmdecrange)

def is_edge_of_survey(ra, dec, x):
	# stars within the annulus of 8-10 degrees
	ring = x[points_within_radius_range(ra, dec, 8, 10, x.ra, x.dec)]
	nazbins = 20.
	anglebin = floor((pi + arctan2(dec - ring.dec,
								   (ra - ring.ra) * cos(deg2rad(dec))))
					 / (2.*pi/float(nazbins))).astype(int)
	return (len(unique(anglebin)) < nazbins)

# returns components, distfeh_amplitudes, pmmeans, pmvars
def build_component_model(y, distrange, fehrange):
	distbins = 4
	fehbins = 3
	xe = linspace(distrange[0], distrange[1], endpoint=True, num=distbins+1)
	ye = linspace(fehrange[0], fehrange[1], endpoint=True, num=fehbins+1)

	NC = (len(ye)-1) * (len(xe)-1)
	components = ones_like(y.ra).astype(int) * -1
	pmmeans = zeros((NC,2))
	pmvars = zeros((NC,2,2))
	distfeh_amplitudes = zeros((NC))

	c = 0
	for ynum,(ylo,yhi) in enumerate(zip(ye[:-1],ye[1:])):
		for xnum,(xlo,xhi) in enumerate(zip(xe[:-1],xe[1:])):
			I = (y.dist > xlo) * (y.dist <= xhi) * (y.feh > ylo) * (y.feh <= yhi)
			yi = y[I]
			MC = mean_and_covar(yi.pmra, yi.pmdec)
			if len(yi) < 2:
				MC = list(MC)
				MC[2] = array([[3,0],[0,3]])
			components[I] = c
			pmmeans[c,0] = MC[0]
			pmmeans[c,1] = MC[1]
			pmvars[c,:,:] = MC[2]
			# quasi-safe...
			distfeh_amplitudes[c] = len(yi) / float(len(y)) / ((xhi-xlo)*(yhi-ylo))
			c += 1
	return (components, distfeh_amplitudes, pmmeans, pmvars)

def brutishly_find(x, overall_prefix):
	ralo = floor(min(x.ra))
	rahi = ceil(max(x.ra))
	declo = floor(min(x.dec))
	dechi = ceil(max(x.dec))

	x.deltalogprob = zeros_like(x.ra)
	x.bestalpha = zeros_like(x.ra)
	x.edgeofsurvey = zeros_like(x.ra).astype(bool)

	picklefn = overall_prefix + '-brutish.pickle'

	declo = 8.

	# number of azimuth bins for detecting edge-of-survey
	nazbins = 20.

	for dlo in arange(declo, dechi):
		D = (x.dec > dlo) * (x.dec <= dlo + 1.)
		for rlo in arange(ralo, rahi):
			# points within the 1-square-degree box.
			R = D * (x.ra > rlo) * (x.ra <= rlo + 1.)
			if sum(R) < 2:
				continue
			print 'rlo', rlo, 'dlo', dlo, 'Nstars', sum(R)
			eos = is_edge_of_survey(rlo+0.5, dlo+0.5, x)
			x[R].edgeofsurvey = eos
			if eos:
				print 'edge of survey.'
				continue
			(y, solidangle, distrange, fehrange, pmrarange, pmdecrange) = (
				make_6d_cut(rlo+0.5, dlo+0.5, x))

			(components, distfeh_amplitudes, pmmeans, pmvars) = (
				build_component_model(y, distrange, fehrange))

			bg = background_data_lnprob(y, components, solidangle, distfeh_amplitudes, pmmeans, pmvars, pmrarange, pmdecrange)
			print 'bg', min(bg), median(bg), max(bg)
			for jboss in find(R):
				boss = x[int(jboss)]
				streamw = 1. # deg
				streamlen = 2. * sqrt(solidangle/pi)
				fg = foreground_data_lnprob(y, boss.ra, boss.dec, boss.dist,
											boss.feh, boss.pmra, boss.pmdec,
											streamw, streamlen)
				print 'fg', min(fg), median(fg), max(fg)
				alphas = 10.**arange(-4, -1+0.1, 0.5)
				ps = array([sum(log(alpha*exp(fg) + (1.-alpha)*exp(bg)))
							for alpha in alphas])
				
				besti = argmax(ps)
				strawman = sum(bg)
				x.deltalogprob[jboss] = ps[besti] - strawman
				x.bestalpha   [jboss] = alphas[besti]
				if x.deltalogprob[jboss] < 4:
					continue

				# check a finer grid.
				loalpha = max(log10(1e-5), log10(min(alphas[ps > (ps[besti] - 10.)])) - 1)
				hialpha = log10(max(alphas[ps > (ps[besti] - 10.)])) + 1
				alphas = 10.**(arange(loalpha, hialpha+0.0001, 0.01))
				ps = [sum(log(alpha*exp(fg) + (1.-alpha)*exp(bg)))
					  for alpha in alphas]
				besti = argmax(ps)
				besta = alphas[besti]

				x.deltalogprob[jboss] = ps[besti] - strawman
				x.bestalpha   [jboss] = alphas[besti]

				prefix = overall_prefix + '-%.5f-%06i' % (x.bestalpha[jboss], jboss)

				clf()
				semilogx(alphas, ps, 'r-')
				xlabel('alpha')
				ylabel('log(p(data))')
				tit = 'dist=%.1f, [FeH]=%.1f' % (boss.dist, boss.feh)
				title(tit)
				ylim(max(ps)-10, max(ps)+1)
				xlim(10.**(loalpha),10.**(hialpha))
				savefig(prefix + '-alpha.png')

				clf()
				plot(y.ra, y.dec, 'k,', alpha=0.3)
				dyr = 0.1
				plot([boss.ra, boss.ra+boss.pmra/cos(deg2rad(boss.dec))*dyr],
					 [boss.dec, boss.dec+boss.pmdec*dyr], 'k-', lw=5)
				plot([boss.ra], [boss.dec], 'ko', ms=10, mfc='none')
				xlabel('RA (deg)')
				ylabel('Dec (deg)')
				# assign stars to the stream.
				logfrac = (log(besta)+fg) - (log(1.-besta)+bg)
				NFG = int(0.5 + round(besta * len(y)))
				print 'best alpha:', besta
				print 'N fg:', NFG
				IN = argsort(-logfrac)[:NFG]
				plot(y.ra[IN], y.dec[IN], 'k.')
				title(tit)
				savefig(prefix + '-radec.png')

				print 'Saving result to', picklefn
				write_file(pickle.dumps(x, pickle.HIGHEST_PROTOCOL), picklefn)

	print 'Saving result to', picklefn
	write_file(pickle.dumps(x, pickle.HIGHEST_PROTOCOL), picklefn)


def candidate_plot(x, prefix, rac, decc, r):
	print '%i objects' % len(x)

	if False:
		fehrange = plot_range(x.feh)
		C = bluegrayred(minimum(1, maximum(0, (x.feh - fehrange[0])/(fehrange[1]-fehrange[0]))))
		X = x.ra
		Y = x.dec
		U = x.pmra
		V = x.pmdec
		(nil,vscale) = plot_range(sqrt(U**2 + V**2).ravel())
		(figw,nil) = gcf().get_size_inches()
		scale = 1000/3. * vscale / figw
		rar = plot_range(x.ra)
		decr = plot_range(x.dec)

	I = points_within_radius(rac, decc, r, x.ra, x.dec)
	print '%i within RA,Dec range' % sum(I)
	y = x[I]
	distrange = plot_range(y.dist, quantile=0.99)
	fehrange  = plot_range(y.feh,  quantile=0.99)
	fehrange = (max(fehrange[0], -2), min(fehrange[1], -1.4))
	I = ((y.dist >= distrange[0]) * (y.dist <= distrange[1]) *
		 (y.feh >= fehrange[0]) * (y.feh <= fehrange[1]))
	y = y[I]
	print '%i within dist,FeH range' % len(y)

	clf()
	(H,nil1,nil2,xe,ye) = plot_2d_hist(y.dist, y.feh, bins=(4,3),
									   range=[distrange, fehrange])
	xlabel('dist (kpc)')
	ylabel('[Fe/H] (dex)')
	(xc,yc) = meshgrid((xe[1:] + xe[:-1])/2., (ye[1:] + ye[:-1])/2.)
	colorbar()
	a = axis()
	(xpp,ypp) = get_pixel_scales()
	for xi,yi,n in zip(xc.ravel(), yc.ravel(), H.ravel()):
		for dx,dy in [(-1,-1),(-1,1),(1,1),(1,-1)]:
			text(xi + dx*xpp, yi + dy*ypp, int(n), color='k', fontsize=20, horizontalalignment='center')
		text(xi, yi, int(n), color=(0,1,0), fontsize=20, horizontalalignment='center')
	axis(a)
	savefig(prefix + '.png')

	NC = (len(ye)-1) * (len(xe)-1)
	components = ones_like(y.ra).astype(int) * -1
	pmmeans = zeros((NC,2))
	pmvars = zeros((NC,2,2))
	distfeh_amplitudes = zeros((NC))

	clf()
	pmrar = plot_range(y.pmra, qlo=0.005, qhi=0.995)
	pmdecr = plot_range(y.pmdec, qlo=0.005, qhi=0.995)
	c = 0
	for ynum,(ylo,yhi) in enumerate(zip(ye[:-1],ye[1:])):
		for xnum,(xlo,xhi) in enumerate(zip(xe[:-1],xe[1:])):
			I = (y.dist > xlo) * (y.dist <= xhi) * (y.feh > ylo) * (y.feh <= yhi)
			yi = y[I]
			subplot(len(ye)-1, len(xe)-1, (len(ye)-2-ynum)*(len(xe)-1) + xnum + 1)
			plot(yi.pmra, yi.pmdec, 'k,')
			MC = mean_and_covar(yi.pmra, yi.pmdec)
			if len(yi) < 2:
				MC = list(MC)
				MC[2] = array([[3,0],[0,3]])
			e = plot_ellipse_from_mean_and_covar(*MC)
			e.set_facecolor('none')
			e.set_edgecolor('b')
			e.set_zorder(10)
			text((pmrar[0]+pmrar[1])/2., (pmdecr[0]+pmdecr[1])/2., len(yi), color=(1,0,0), fontsize=20, horizontalalignment='center')
			xlim(pmrar)
			ylim(pmdecr)
			if not (ynum == 0 and xnum == 0):
				xticks([],[])
				yticks([],[])

			components[I] = c
			pmmeans[c,0] = MC[0]
			pmmeans[c,1] = MC[1]
			pmvars[c,:,:] = MC[2]
			# quasi-safe...
			distfeh_amplitudes[c] = len(yi) / float(len(y)) / ((xhi-xlo)*(yhi-ylo))
			c += 1
	subplots_adjust(wspace=0.05, hspace=0.05)
	savefig(prefix + '-gridpmra.png')

	boss = uniform(len(y))

	bg = background_data_lnprob(y, components, pi*r**2, distfeh_amplitudes, pmmeans, pmvars)
	print 'bg', min(bg), median(bg), max(bg)

	if False:
		test1d = exp(ln_gaussian_1d(linspace(0, 100, 100), 50, 10**2))
		print 'test1d:', sum(test1d)
		NX = 100
		NY = 100
		mx = ones((NX*NY,2)) * 50
		varx = zeros((NX*NY,2,2))
		varx[:,0,0] = 10**2
		varx[:,1,1] = 10**2
		X,Y = meshgrid(linspace(0,NX,NX), linspace(0,NY,NY))
		test2d = exp(ln_gaussian_2d(vstack((X.ravel(),Y.ravel())).T, mx, varx))
		print 'test2d:', sum(test2d)

	J = argsort(-radecdotproducts(rac, decc, y.ra, y.dec))
	for j,jboss in enumerate(J[:50]):
		boss = y[int(jboss)]
		fg = foreground_data_lnprob(y, boss.ra, boss.dec, boss.dist,
									boss.feh, boss.pmra, boss.pmdec,
									1., 2.*r)
		print 'fg', min(fg), median(fg), max(fg)

		alphas = 10.**arange(-4, -1+0.1, 0.5)
		ps = array([sum(log(alpha*exp(fg) + (1.-alpha)*exp(bg)))
					for alpha in alphas])
		clf()
		#print 'alphas', alphas
		#print 'ps', ps
		semilogx(alphas, ps, 'r-')
		xlabel('alpha')
		ylabel('log(p(data))')
		tit = 'dist=%.1f, [FeH]=%.1f' % (boss.dist, boss.feh)
		title(tit)
		ylim(max(ps)-10, max(ps)+1)
		savefig(prefix + '-boss%03i'%j + '-alpha.png')

		besti = argmax(ps)
		if besti != 0 and ps[besti] > ps[0] + 4:

			clf()
			loalpha = max(log10(1e-5), log10(min(alphas[ps > (ps[besti] - 10.)])) - 1)
			hialpha = log10(max(alphas[ps > (ps[besti] - 10.)])) + 1
			alphas = 10.**(arange(loalpha, hialpha+0.0001, 0.01))
			ps = [sum(log(alpha*exp(fg) + (1.-alpha)*exp(bg)))
				  for alpha in alphas]
			besti = argmax(ps)
			besta = alphas[besti]
			semilogx(alphas, ps, 'r-')
			xlabel('alpha')
			ylabel('log(p(data))')
			title(tit)
			ylim(max(ps)-10, max(ps)+1)
			xlim(10.**(loalpha),10.**(hialpha))
			savefig(prefix + '-boss%03i'%j + '-alpha.png')

			clf()
			plot(y.ra, y.dec, 'k,', alpha=0.3)
			dyr = 0.1
			plot([boss.ra, boss.ra+boss.pmra/cos(deg2rad(boss.dec))*dyr],
				 [boss.dec, boss.dec+boss.pmdec*dyr], 'k-', lw=5)
			plot([boss.ra], [boss.dec], 'ko', ms=10, mfc='none')
			xlabel('RA (deg)')
			ylabel('Dec (deg)')
			# assign stars to the stream.
			logfrac = (log(besta)+fg) - (log(1.-besta)+bg)
			NFG = int(0.5 + round(besta * len(y)))
			print 'best alpha:', besta
			print 'N fg:', NFG
			IN = argsort(-logfrac)[:NFG]
			plot(y.ra[IN], y.dec[IN], 'k.')
			title(tit)
			savefig(prefix + '-boss%03i'%j + '-radec.png')

			clf()
			plot(y.g - y.r, y.r, 'k,', alpha=0.3)
			plot([boss.g - boss.r], [boss.r], 'ko', ms=10, mfc='none')
			xlabel('g - r (mag)')
			ylabel('r (mag)')
			plot([y.g[IN] - y.r[IN]], [y.r[IN]], 'k.')
			title(tit)
			savefig(prefix + '-boss%03i'%j + '-colormag.png')


	return

	dstep = 0.5
	for dlo in arange(0, 3.99, dstep):
		I = (x.dist > dlo) * (x.dist <= (dlo+dstep))
		print '%i with dlo=%f' % (sum(I),dlo)
		if sum(I) == 0:
			continue
		clf()
		quiver(X[I], Y[I], U[I], V[I], C[I], pivot='middle', scale=scale, alpha=0.5)
		xlim(rar)
		ylim(decr)
		xlabel('RA (deg)')
		ylabel('Dec (deg)')
		savefig(prefix + '-%.1f' % dlo + '.png')


def read_data_file(datafn, prefix, scramble=False):
	prefix = prefix + datafn
	x = read_cache(prefix + '.pickle', text_table_fields, (datafn,))
	cols = x.columns()
	print 'data file columns:', cols
	# Rix changed his column names...
	if not 'd_kpc' in cols:
		x.d_kpc = x.d_in_kpc
		x.delete_column('d_in_kpc')
	if scramble:
		print 'Scrambling proper motions and distances...'
		random.seed(42)
		I = random.permutation(len(x))
		# we don't want to scramble ra,dec
		x.pmra  = x.pmra [I]
		x.pmdec = x.pmdec[I]
		x.d_kpc = x.d_kpc[I]
		x.feh   = x.feh  [I]
		prefix += '-scramble'

	(l,b,x.old_pml,x.old_pmb) = pm_radectolb(x.ra, x.dec, x.pmra, x.pmdec)
	(x.old_pmra, x.old_pmdec) = (x.pmra, x.pmdec)
	(x.pmra, x.pmdec) = remove_solar_motion(x.ra, x.dec, x.d_kpc, x.pmra, x.pmdec)
	(x.l, x.b, x.pml, x.pmb) = pm_radectolb(x.ra, x.dec, x.pmra, x.pmdec)
	# (pml, pmb) are proper motions in the galactic coordinates l,b
	#     in mas/yr.  pml is d(l*cos(b))/dt.
	x.dist = x.d_kpc
	x.delete_column('d_kpc')
	x.plotdec = plotdec(x.dec)
	x.plotra = plotra(x.ra, x.dec)
	return (x, prefix)

def analyze_data_file(datafn, prefix, scramble=False):
	(x,prefix) = read_data_file(datafn, prefix, scramble)
	#candidate_plot(x, prefix + '-candy', 170., 16.5, 10.)
	#candidate_plot(x, prefix + '-steffi', 190., 30., 10.)
	brutishly_find(x, prefix)

def OLD_DELETE_ME_analyze_data_file():
	from astrometry.libkd.spherematch import spherematch_c

	#rar = plot_range(x.ra)
	rar = (130,260)
	decr = array([-3, 63])
	pdecr = plotdec(decr)

	kdfn = prefix + '.kd'
	if not os.path.exists(kdfn):
		rar  = plot_range(x.ra)
		pdecr = plot_range(x.plotdec)
		decr = plot_range(x.dec)
		
		if False:
			# check that the proper motions in ra,dec and l,b
			# have the same magnitudes.
			clf()
			I = argmax(abs(sqrt(x.old_pml**2 + x.old_pmb**2) - sqrt(x.old_pmra**2 + x.old_pmdec**2)))
			print 'largest |pm| difference:', max(abs(sqrt(x.old_pml**2 + x.old_pmb**2) - sqrt(x.old_pmra**2 + x.old_pmdec**2)))
			H = plot_2d_hist(sqrt(x.old_pml**2 + x.old_pmb**2),
							 sqrt(x.old_pmra**2 + x.old_pmdec**2))
			set_image_color_percentiles(H, 0, 100)
			plot([0,30],[0,30], 'r--')
			savefig('magnitudes-old.png')
			clf()
			I = argmax(abs(sqrt(x.pml**2 + x.pmb**2) - sqrt(x.pmra**2 + x.pmdec**2)))
			print 'largest |pm| difference:', max(abs(sqrt(x.pml**2 + x.pmb**2) - sqrt(x.pmra**2 + x.pmdec**2)))
			H = plot_2d_hist(sqrt(x.pml**2 + x.pmb**2),
							 sqrt(x.pmra**2 + x.pmdec**2))
			set_image_color_percentiles(H, 0, 100)
			plot([0,30],[0,30], 'r--')
			savefig('magnitudes.png')

			# check that the transformation is conformal.
			# scramble proper motions
			clf()
			J = random.permutation(len(x))
			pmra2 = x.pmra[J]
			pmdec2 = x.pmdec[J]
			(nil1,nil2,pml2,pmb2) = pm_radectolb(x.ra, x.dec, pmra2, pmdec2)
			plot_2d_hist(x.pmra * pmra2 + x.pmdec * pmdec2,
						 x.pml * pml2 + x.pmb * pmb2)
			set_image_color_percentiles(H, 0, 100)
			savefig('dotproducts.png')

		if False:
			xl = plot_range(x.l)
			yl = plot_range(x.b)
			vscale = 0
			t0 = time.clock()
			for fehcut in arange(-0.7, -2.1, -0.1):
				infix = '%.1f' % -fehcut
				clf()
				I = x.feh < fehcut
				thisx = x[I]
				(Q,v) = plot_mean_vector_field(thisx.l, thisx.b, thisx.pml, thisx.pmb, vscale=vscale)
				if vscale is 0:
					vscale = v
				xlabel('$\ell$ (deg)')
				ylabel('b (deg)')
				title('Proper motions, corrected for solar motion, [Fe/H] < %.1f' % fehcut)
				xlim(xl)
				ylim(yl)
				savefig(prefix + '-allsky-grid-pmlb-' + infix + '.png')
				clf()
				plot_mean_vector_field(thisx.l, thisx.b, thisx.old_pml, thisx.old_pmb, vscale=vscale)
				xlabel('$\ell$ (deg)')
				ylabel('b (deg)')
				title('Proper motions, not corrected for solar motion, [Fe/H] < %.1f' % fehcut)
				xlim(xl)
				ylim(yl)
				savefig(prefix + '-allsky-grid-oldpmlb-' + infix + '.png')
			t1 = time.clock()
			print 'time:', (t1-t0)

		if False:
			print 'Plotting vector fields...'

			clf()
			(Q,vscale) = plot_mean_vector_field(x.ra, x.plotdec, old_pmra, old_pmdec)
			dec_ticks(decr)
			xlabel('RA (deg)')
			ylabel('Dec (deg)')
			title('Proper motions, no solar motion correction')
			xlim(rar)
			ylim(pdecr)
			savefig(prefix + '-allsky-grid-oldpm.png')

			clf()
			plot_mean_vector_field(x.ra, x.plotdec, x.pmra, x.pmdec, vscale=vscale)
			dec_ticks(decr)
			xlabel('RA (deg)')
			ylabel('Dec (deg)')
			title('Proper motions, solar motion correction')
			xlim(rar)
			ylim(pdecr)
			savefig(prefix + '-allsky-grid-pm.png')

			clf()
			(Q,vscale) = plot_mean_vector_field(x.l, x.b, old_pml, old_pmb)
			xlabel('$\ell$ (deg)')
			ylabel('b (deg)')
			title('Proper motions, no solar motion correction')
			xlim(plot_range(x.l))
			ylim(plot_range(x.b))
			savefig(prefix + '-allsky-grid-oldpmlb.png')

		print 'Making summary plots...'

		for fehcut in [-0.8, -1.4]:
			I = x.feh < fehcut
			thisx = x[I]
			infix = '%.1f' % -fehcut
			(w,h) = get_axes_pixel_size()
			w += 1
			h += 1
			clf()
			plot_2d_hist(thisx.ra, thisx.plotdec, bins=(w/3,h/3))
			xlabel('RA (deg)')
			ylabel('Dec (deg)')
			dec_ticks(decr)
			xlim(rar)
			ylim(pdecr)
			savefig(prefix + '-allsky-ra-dec-' + infix + '.png')

		# generating all pairs of plots takes a long time...
		all_pairs_plots(x)

		# prep data for kdtree insertion
		rascale = 4.0 # in degrees
		decscale = rascale
		distscale = 0.45 # in kpc
		pmrascale = 4. # in mas/yr
		pmdecscale = pmrascale
		fehscale = 0.2 # in dex

		X = vstack((x.ra * cos(deg2rad(x.dec)) / rascale,
					x.dec / decscale,
					x.dist / distscale,
					x.pmra / pmrascale,
					x.pmdec / pmdecscale,
					x.feh / fehscale)).transpose()
		print 'Building kdtree...'
		kd = spherematch_c.kdtree_build(X)
		print kd
		print 'Writing kdtree to', kdfn
		rtn = spherematch_c.kdtree_write(kd, kdfn)
		print 'rtn=', rtn
	else:
		print 'Reading kdtree from', kdfn
		kd = spherematch_c.kdtree_open(kdfn)


	if False:
		kd2fn = prefix + '-xyz.kd'
		if not os.path.exists(kd2fn):
			X = vstack((x.ra * cos(deg2rad(x.dec)), x.dec)).T
			print 'Building kdtree...'
			kd2 = spherematch_c.kdtree_build(X)
			print kd2
			print 'Writing kdtree to', kd2fn
			rtn = spherematch_c.kdtree_write(kd2, kd2fn)
		def match_radec(kd, r):
			print 'Matching...'
			(inds,nil) = spherematch_c.match(kd, kd, r)
			print 'Done!  Got %i pairs', len(inds)
			return inds
		M = read_cache(prefix + '-radecmatch.pickle', match_radec, (kd2, 2))
		print 'Got %i matches' % len(M)

	for radius in [1.]:
		fn = prefix + '-match-%.2f.pickle' % radius
		def match_6d(kd, r):
			print 'Matching with radius', r, '...'
			(inds,nil) = spherematch_c.match(kd, kd, r)
			print 'Done!'
			return inds
		inds = read_cache(fn, match_6d, (kd, radius))
		print 'Got %i matches' % len(inds)

		(I1,I2,leftindex,rightindex,density) = compute_density(inds, len(x))
		#density_plots(x, density, prefix)
		print 'Max density:', density.max()

		state = {'x':x,
				 'density':density,
				 'I1':I1,
				 'I2':I2,
				 }
		statefn = 'state.pickle'
		print 'pickling...'
		write_file(pickle.dumps(state, pickle.HIGHEST_PROTOCOL), statefn)
		print 'done.'
		sys.exit(0)

def tube_density(x, I1, I2, density):
	streamw = 1.
	tubedensity = zeros(len(density))
	I = argsort(-density)
	for ii,left in enumerate(I):
		J = (I1 == left) * (I2 != left)
		right = I2[J]
		#print 'Density:', density[left], ', left point', left, ' number of right points:', len(right)
		# compute difference in the local (equal-area) ra,dec space
		u0 = x.ra [left]  * cos(deg2rad(x.dec[left]))
		v0 = x.dec[left]
		u =  x.ra [right] * cos(deg2rad(x.dec[right]))
		v =  x.dec[right]
		#pmlen = 0.1
		#clf()
		#plot(u, v, 'ro', mfc='none', mec='r')
		#plot([u0], [v0], 'bo')
		#plot([u0, u0 + pmlen * x.pmra[left]], [v0, v0 + pmlen * x.pmdec[left]], 'b-')
		#circle(x=u0, y=v0, radius=2., facecolor='none', edgecolor='b')
		#pmvec = array([x.pmra[left], x.pmdec[left]])
		#pmunit = pmvec / norm(pmvec)
		#pmorth = array([-pmvec[1], pmvec[0]])
		# Vector orthogonal to the proper motion.
		pmorth = array([-x.pmdec[left], x.pmra[left]])
		pmorth /= norm(pmorth)
		#p0 = array([u0,v0])
		#sA = p0 + streamw/2. * pmorth - 10 * pmunit
		#sB = p0 + streamw/2. * pmorth + 10 * pmunit
		#sC = p0 - streamw/2. * pmorth - 10 * pmunit
		#sD = p0 - streamw/2. * pmorth + 10 * pmunit
		#plot([[sA[0],sC[0]],[sB[0],sD[0]]], [[sA[1],sC[1]],[sB[1],sD[1]]], 'b-')

		uv = vstack((u,v)).T
		uv0 = vstack(([u0,v0])).T
		#print 'uv shape:', uv.shape
		#print 'dot prod shape:', dot(uv-uv0, pmorth)
		intube = (absolute(dot(uv-uv0, pmorth)) < streamw/2.)
		tubedensity[left] = sum(intube)
		#print left,density[left],tubedensity[left]

		#plot(u[intube], v[intube], 'k.')
		#rai = int(round(x.ra[left]))
		#xt = arange(rai-2, rai+3)
		#xticks(xt * cos(deg2rad(x.dec[left])), xt)
		#xlabel('RA (deg)')
		#ylabel('Dec (deg)')
		#xlim(u0-2, u0+2)
		#ylim(v0-2, v0+2)
		#axis('scaled')
		#xlim(u0-2, u0+2)
		#ylim(v0-2, v0+2)
		#savefig('uv-%03i.png' % ii)

		if ii % 1000 == 0:
			density_plots(x, tubedensity, 'tube-%06i' % ii)
			print ii

	density_plots(x, tubedensity, 'tube')

	return tubedensity

def oldjunk():
	if True:
		fn = prefix + '-fof-%f.pickle' % radius
		FOF = None
		if os.path.exists(fn):
			try:
				print 'Trying to read pickle', fn
				FOF = pickle.loads(read_file(fn))
			except:
				pass
		if FOF is None:
			print 'starting fof...'
			FOF = friends_of_friends(leftindex, rightindex, len(x.ra))
			print 'done fof'
			print 'Writing pickle', fn
			write_file(pickle.dumps(FOF, pickle.HIGHEST_PROTOCOL), fn)
			
		(fofgroups,groupnums) = FOF

		# all-sky worms.
		clf()
		for g in fofgroups:
			if len(g) < 5:
				continue
			LG = list(g)
			plot_worm(x.ra[LG], x.plotdec[LG],
					  10. / mean(x.dist[LG]), False)
		xlabel('RA (deg)')
		ylabel('Dec (deg)')
		decrange = plot_range(x.dec)
		dec_ticks(decrange)
		xlim(plot_range(x.ra))
		ylim(plot_range(x.plotdec))
		savefig(prefix + '-allsky-worms.png')

		clf()
		for g in fofgroups:
			if len(g) < 5:
				continue
			LG = list(g)
			plot_worm(x.plotra[LG], x.dec[LG],
					  10. / mean(x.dist[LG]), False)
		xlabel('RA*cos(Dec) (deg)')
		ylabel('Dec (deg)')
		decrange = plot_range(x.dec)
		xlim(plot_range(x.plotra))
		ylim(plot_range(x.dec))
		plot_radec_grid_lines()
		savefig(prefix + '-allsky-worms-2.png')

		sys.exit(0)

		# spatial cut
		Z = (x.ra > 230) * (x.ra < 250) * (x.dec > 51) * (x.dec < 61)
		#Z = (x.ra > 152) * (x.ra < 172) * (x.dec > 38) * (x.dec < 48)

		print 'Largest group touching window:', max([len(fofgroups[gi]) for gi in unique(groupnums[Z])])

		largestgroup = fofgroups[unique(groupnums[Z])[argmax([len(fofgroups[gi]) for gi in unique(groupnums[Z])])]]
		print 'largest group:', largestgroup
		print 'size:',len(largestgroup)

		print 'full-sky density min:', density.min(), 'max:', density.max()
		reldensity = (density[Z] - density[Z].min()) / (density[Z].max()-density[Z].min())
		print 'spatial cut density min:', density[Z].min(), 'max:', density[Z].max()

		leftinside = leftindex[Z[leftindex]]
		rightinside = rightindex[Z[leftindex]]
		
		colorlo = array([[0.5,0.5,0.5]]).T
		colorhi = array([[0.5,0,0]]).T
		colors = colorlo + (colorhi - colorlo) * reldensity
		colors = array(['#%02x%02x%02x' % (int(c[0]*255.999), int(c[1]*255.999), int(c[2]*255.999)) for c in colors.T])

		alphas = 0.1 + 0.5 * reldensity

		print 'number of points in spatial cut:', sum(Z)
		print 'number of unique colors in cut:', len(unique(colors))

		#radecpairs = (x.ra[leftinside], x.ra[rightinside], x.dec[leftinside], x.dec[rightinside])
		radecpairs = None
		print 'making plots...'
		clf()
		#plot(x.ra[Z], x.plotdec[Z], '.', ms=10, color='0.5')
		make_plots(x.ra[Z], x.dec[Z], x.dist[Z], x.pmra[Z], x.pmdec[Z], x.dpmra[Z], x.dpmdec[Z], colors, alphas,
				   squidplot=True, pairs=radecpairs)
		fn = prefix + '-worm-%f' % radius + '.png'
		savefig(fn)
		print 'done plot', fn


		a = axis()
		for g in [fofgroups[i] for i in unique(groupnums[Z])]:
			if len(g) < 5:
				continue
			print 'Highlighting group of length', len(g)
			LG = list(g)
			plot_worm(x.ra[LG], x.plotdec[LG],
					  50. / mean(x.dist[LG]))

		if False:
			for indx in largestgroup:
				I = (leftindex == indx)
				L = leftindex[I]
				R = rightindex[I]
				plot(vstack((x.ra[L], x.ra[R])),
					 vstack((x.plotdec[L], x.plotdec[R])),
					 '-', color='y', lw=10., alpha=0.5)
		axis(a)
		savefig(prefix + '-highlight.png')




		if False:
			make_plots(x.ra, x.dec, x.dist, x.pmra, x.pmdec, x.dpmra, x.dpmdec, suffix='-allsky%.1f' % radius)
			Z = (x.ra > 230) * (x.ra < 250) * (x.dec > 40) * (x.dec < 60)
			make_plots(x.ra[Z], x.dec[Z], x.dist[Z], x.pmra[Z], x.pmdec[Z], x.dpmra[Z], x.dpmdec[Z], colors[Z], alphas[Z], '-zA-%.1f' % radius)

bgrmap = LinearSegmentedColormap('bgr', {
				  'red': ((0., -1,   0),
					  (0.5,0, 0),
					  (1., 1, -1)),
				  'green': ((0., -1,   0),
					    (0.5,0.5, 0.5),
					    (1., 0, -1)),
				  'blue': ((0., -1,   1),
					   (0.5,0,  0),
					   (1., 0, -1))})
class ColorbarMappable(object):
	def __init__(self, cmap):
		self.cmap = cmap
		self.norm = None
	def autoscale_None(self):
		pass
	def get_alpha(self):
		return 1.0
	def add_observer(self, x):
		pass
	def set_colorbar(self, x, y):
		pass

from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize

def amber_plot(x, stream, xalpha=0.02, xcolor='k', lwscale=10., dotms=5., dyr=0.2,
			   streamalpha=0.5):
	plot(x.ra, x.dec, ',', color=xcolor, alpha=xalpha, zorder=1000)
	for i in range(len(stream)):
		s = stream[i]
		#if s.feh > -1.6:
		#	c = 'r'
		#elif s.feh < -1.8:
		#	c = 'b'
		#else:
		#	c = 'g'
		c = bgrmap(min(1, max(0, (s.feh - -2.0)/(-1.4 - -2.0))))
		plot([s.ra, s.ra + dyr * s.pmra/cos(deg2rad(s.dec))],
			 [s.dec, s.dec + dyr * s.pmdec], '-', lw=lwscale/s.dist,
			 color = c, alpha=streamalpha, zorder=2000)
		plot([s.ra], [s.dec], 'o', mfc=c, alpha=streamalpha, ms=dotms, zorder=3000)
		#arrow(s.ra, s.dec, dyr * s.pmra/cos(deg2rad(s.dec)), dyr * s.pmdec,
		#	  lw=10./s.dist, ec=c, fc='none', alpha=0.5, head_width=0.5)
	xlim(plot_range(x.ra))
	ylim(plot_range(x.dec))


if __name__ == '__main__':
	args = sys.argv[1:]
	figure(dpi=100)

	scramble = 'scramble' in args
	eos = 'eos' in args
	rerun = 'rerun' in args
	fn = '1_4kpc_metal_poor_with_PM_July23_09.txt'
	prefix = ''
	if scramble:
		infix = '-scramble'
	else:
		infix = ''
	if 'brute' in args:
		analyze_data_file(fn, prefix, scramble)
		sys.exit(0)

		
	picklefn = prefix + fn + infix + '-brutish3.pickle'
	x = pickle.loads(read_file(picklefn))

	if eos:
		print 'Recomputing edge-of-survey flags...'
		I = (x.bestalpha > 1e-3) * (x.deltalogprob > 4)
		if not 'edgeofsurvey' in x:
			x.edgeofsurvey = array([True] * len(x))
		print 'fixing %i' % sum(I)
		for i in find(I):
			x.edgeofsurvey[i] = is_edge_of_survey(x.ra[i], x.dec[i], x)
			print '.', x.edgeofsurvey[i]
		x.edgeofsurvey = x.edgeofsurvey.astype(bool)
		write_file(pickle.dumps(x, pickle.HIGHEST_PROTOCOL), picklefn)
		sys.exit(0)


	I = (x.bestalpha > 1e-3) * (x.deltalogprob > 4) * (x.edgeofsurvey == 0)
	print '%i streams' % sum(I)

	if rerun:
		for i in find(I):
			boss = x[i]
			(y, solidangle, distrange, fehrange, pmrarange, pmdecrange) = (
				make_6d_cut(boss.ra, boss.dec, x, 20.))

			(components, distfeh_amplitudes, pmmeans, pmvars) = (
				build_component_model(y, distrange, fehrange))

			bg = background_data_lnprob(y, components, solidangle,
										distfeh_amplitudes, pmmeans, pmvars, pmrarange, pmdecrange)
			print 'bg', min(bg), median(bg), max(bg)
			strawman = sum(bg)

			streamw = 1. # deg
			streamlen = 2. * sqrt(solidangle/pi)
			fg = foreground_data_lnprob(y, boss.ra, boss.dec, boss.dist,
										boss.feh, boss.pmra, boss.pmdec,
										streamw, streamlen)
			print 'fg', min(fg), median(fg), max(fg)

			loalpha = log10(boss.bestalpha / 30.)
			hialpha = log10(boss.bestalpha * 30.)
			alphas = 10.**(arange(loalpha, hialpha+0.0001, 0.01))
			ps = [sum(log(alpha*exp(fg) + (1.-alpha)*exp(bg)))
				  for alpha in alphas]
			besti = argmax(ps)
			besta = alphas[besti]

			x.deltalogprob[i] = ps[besti] - strawman
			x.bestalpha   [i] = alphas[besti]

			if ps[besti] - strawman < 4:
				continue

			clf()
			semilogx(alphas, ps, 'r-')
			axhline(strawman)
			xlabel('alpha')
			ylabel('log(p(data))')
			tit = '(RA,Dec)=(%.1f,%.1f), dist=%.1f, [FeH]=%.1f' % (boss.ra, boss.dec, boss.dist, boss.feh)
			title(tit)
			ylim(max(ps)-10, max(ps)+1)
			xlim(min(alphas), max(alphas))
			savefig('rerun' + infix + '-%06i-alpha.png' % i)

		outpicklefn = prefix + fn + infix + '-brutish4.pickle'
		write_file(pickle.dumps(x, pickle.HIGHEST_PROTOCOL), outpicklefn)




	I = (x.bestalpha > 1e-3) * (x.deltalogprob > 10) * (x.edgeofsurvey == 0)
	print '%i streams' % sum(I)
	stream = x[I]

	clf()
	plot(stream.bestalpha, stream.deltalogprob, 'k.', alpha=0.5)
	xlabel('Best alpha')
	ylabel('Delta log prob')
	savefig('alphaprob' + infix + '.png')

	clf()
	amber_plot(x, stream)
	cm = ColorbarMappable(bgrmap)
	cm.norm = Normalize(vmin=-2.0, vmax=-1.4)
	#colorbar(cm, ticks=arange(-2.0, -1.399, 0.2))
	#colorbar(ScalarMappable(cmap=bgrmap, norm=Normalize(vmin=-2.0, vmax=-1.4)))
	#colorbar(cmap=bgrmap, norm=Normalize(vmin=-2.0, vmax=-1.4))
	savefig('amber' + infix + '.png')

	bosses = x[I]
	J = argsort(-bosses.deltalogprob)
	bosses = bosses[J]
	t = pyfits.new_table([pyfits.Column(name='ra', format='D', array=bosses.ra),
			      pyfits.Column(name='dec', format='D', array=bosses.dec),
			      pyfits.Column(name='dist', format='D', array=bosses.dist),
			      pyfits.Column(name='feh', format='D', array=bosses.feh),
			      pyfits.Column(name='pmra', format='D', array=bosses.pmra),
			      pyfits.Column(name='pmdec', format='D', array=bosses.pmdec),
			      pyfits.Column(name='dlogprob', format='D', array=bosses.deltalogprob),
			      pyfits.Column(name='amplitude', format='D', array=bosses.bestalpha)])

	t.writeto('cands.fits', clobber=True)


	for i in find(I):
		boss = x[i]
		(y, solidangle, distrange, fehrange, pmrarange, pmdecrange) = (
			make_6d_cut(boss.ra, boss.dec, x))
		(components, distfeh_amplitudes, pmmeans, pmvars) = (
			build_component_model(y, distrange, fehrange))
		bg = background_data_lnprob(y, components, solidangle,
									distfeh_amplitudes, pmmeans, pmvars,
									pmrarange, pmdecrange)
		strawman = sum(bg)
		streamw = 1. # deg
		streamlen = 2. * sqrt(solidangle/pi)
		fg = foreground_data_lnprob(y, boss.ra, boss.dec, boss.dist,
									boss.feh, boss.pmra, boss.pmdec,
									streamw, streamlen)

		besta = boss.bestalpha

		tit = '(RA,Dec)=(%.1f,%.1f), pm=(%.0f,%.0f), dist=%.1f, [FeH]=%.1f' % (boss.ra, boss.dec, boss.pmra, boss.pmdec, boss.dist, boss.feh)
		clf()
		plot(y.ra, y.dec, 'k,', alpha=0.3)
		dyr = 0.1
		plot([boss.ra, boss.ra+boss.pmra/cos(deg2rad(boss.dec))*dyr],
			 [boss.dec, boss.dec+boss.pmdec*dyr], 'k-', lw=5)
		plot([boss.ra], [boss.dec], 'ko', ms=10, mfc='none')
		xlabel('RA (deg)')
		ylabel('Dec (deg)')
		# assign stars to the stream.
		logfrac = (log(besta)+fg) - (log(1.-besta)+bg)
		NFG = int(0.5 + round(besta * len(y)))
		print 'best alpha:', besta
		print 'N fg:', NFG
		IN = argsort(-logfrac)[:NFG]
		plot(y.ra[IN], y.dec[IN], 'k.')
		title(tit)

		suba = axes([0.74,0.11,0.25,0.25])
		amber_plot(x, stream, xalpha=0.01, xcolor='0.4', lwscale=4.,
				   dotms=2., dyr=0.4, streamalpha=1.)
		a = axis()
		axvline(boss.ra, color='k', alpha=0.5, zorder=1500)
		axhline(boss.dec, color='k', alpha=0.5, zorder=1500)
		xticks([],[])
		yticks([],[])
		axis(a)

		subb = axes([0.74,0.9-0.01-0.25,0.25,0.25])
		amber_plot(x, stream, xalpha=0.02, xcolor='0.4', lwscale=6.,
				   dotms=4., dyr=0.4, streamalpha=0.7)
		axvline(boss.ra, color='k', alpha=0.5, zorder=1500)
		axhline(boss.dec, color='k', alpha=0.5, zorder=1500)
		xticks([],[])
		yticks([],[])
		radius=20.
		dra = radius/cos(deg2rad(boss.dec))
		axis([boss.ra - dra, boss.ra + dra, boss.dec - radius, boss.dec + radius])

		savefig('radec' + infix + '-%06i'%i + '.png')

		clf()
		plot(y.dist, y.feh, 'k,', alpha=0.3)
		plot([boss.dist], [boss.feh], 'ko', ms=10, mfc='none')
		plot(y.dist[IN], y.feh[IN], 'k.')
		title(tit)
		xlabel('Dist (kpc)')
		ylabel('[Fe/H] (dex)')
		title(tit)
		savefig('distfeh' + infix + '-%06i'%i + '.png')
		
		clf()
		plot(y.pmra, y.pmdec, 'k,', alpha=0.3)
		plot([boss.pmra], [boss.pmdec], 'ko', ms=10, mfc='none')
		plot(y.pmra[IN], y.pmdec[IN], 'k.')
		title(tit)
		xlabel('pm RA (mas/yr)')
		ylabel('pm Dec (mas/yr)')
		axhline(0, color='0.5')
		axvline(0, color='0.5')
		#xlim(plot_range(x.pmra,  quantile=0.99))
		#ylim(plot_range(x.pmdec, quantile=0.99))
		xlim(-50,50)
		ylim(-50,50)
		title(tit)
		savefig('pm' + infix + '-%06i'%i + '.png')


