#!/usr/bin/python

import sys
import string
from math import sqrt,pi,sin,cos,atan2

base_picture_filename = sys.argv[1]
other_pictures = sys.argv[2:]
timing = 0

class PPM:
	def __init__(self,filename):
		file = open(filename)
		header = file.readline()
		if header[0:2] != 'P6': raise (header)
		everythingelse = file.read()
		[self.width,self.height,self.max,self.data] = string.split(everythingelse,' ',3)
		self.width = string.atoi(self.width)
		self.height = string.atoi(self.height)
		self.max = string.atoi(self.max)
		if self.max != 255: raise "can't handle max != 255"
	def __getitem__(self,pos):
		(x,y) = pos
		start = 3*(x + self.width*y)
		data = self.data[start:start+3]
		return (ord(data[0]),ord(data[1]),ord(data[2]))
	def dump_colouronly(self,output):
		output.write('P6\n' + `self.width` + ' ' + `self.height` + ' 255 ')
		for i in range(0,len(self.data),3):
			r = ord(self.data[i])
			g = ord(self.data[i+1])
			b = ord(self.data[i+2])
			norm = sqrt(r*r+g*g+b*b) + 0.001 #+ 10
			output.write(chr(int(255*r/norm)))
			output.write(chr(int(255*g/norm)))
			output.write(chr(int(255*b/norm)))

class ScalarImage:
	def __init__(self,ppm,colour):
		self.thresholded = 0
		self.data = [0] * (ppm.height*ppm.width)
		self.width = ppm.width
		self.height = ppm.height
		(c_r,c_g,c_b) = colour
		if colour == (0,0,0): raise "can't look for black"
		c_norm = sqrt(c_r*c_r + c_g * c_g + c_b * c_b) + 0.001  # tiny fudge factor
		self.colour = (colour[0]/c_norm,colour[1]/c_norm,colour[2]/c_norm)
		# yeuk, pull out the insides of PPM
		for i in range(0,len(ppm.data),3):
			rgb = ppm.data[i:i+3]
			r = ord(rgb[0])
			g = ord(rgb[1])
			b = ord(rgb[2])
			norm=sqrt(r*r+g*g+b*b) + 10 # make nearly black into black
			pixel = int(255*(1-(r*c_r + g*c_g + b*c_b)/(norm*c_norm)))
			self.data[i/3] = pixel
	def __getitem__(self,pos):
		(x,y) = pos
		return self.data[x + y*self.width]
	def dump(self,output,applycolour=0):
		output.write('P6\n' + `self.width` + ' ' + `self.height` + ' 255 ')
		if applycolour:
			for blob in self.data: output.write(chr(blob*self.colour[0])+chr(blob*self.colour[1])+chr(blob*self.colour[2]))
		else:
			for blob in self.data: output.write(chr(blob)+chr(blob)+chr(blob))
	def threshold(self,unitile=0.1,percentile=None):
		if percentile is not None: unitile = percentile / 100.0
		datacopy = self.data[:]
		datacopy.sort()
		thresh = datacopy[int(len(datacopy)*(1.0-unitile))]
		sys.stderr.write('Threshold is ' + `thresh` + '\n')
		for i in range(len(self.data)): self.data[i] = 255 * (self.data[i] > thresh)
		self.thresholded = 1

class Trellis:
	def __init__(self,x0,y0,scalarimage,radii_step_along=1.0,number_of_radii=100,max=None,min=1.0):
		self.largest_radii_jump_for_algorithm = 2
		#if not(scalarimage.thresholded): scalarimage.threshold()
		if max is None: max=scalarimage.width
		if x0+max > scalarimage.width: max = scalarimage.width - x0
		if y0+max > scalarimage.height: max = scalarimage.height - y0
		if max > x0: max = x0
		if max > y0: max = y0
		self.data = {}
		theta = 0
		two_pi = 2 * pi
		delta_theta = two_pi / number_of_radii
		self.theta_used = []
		while theta < two_pi:
			self.theta_used.append(theta)
			r = min
			self.r_used = []
			while r < max:
				self.r_used.append(r)
				(x,y) = (int(round(x0 + r * cos(theta))), int(round(y0 + r * sin(theta))))
				#sys.stderr.write(`(x,y)` + ' ')
				self.data[(theta,r)]=scalarimage[(x,y)]
				r = r + radii_step_along
			theta = theta + delta_theta
			#sys.stderr.write('\n')
	def __getitem__(self,pos): self.data[pos]
	def thetae(self): return self.theta_used[:]
	def rs_used(self): return self.r_used[:]
	def dump(self,output):
		output.write('P6\n' + `len(self.theta_used)` + ' ' + `len(self.r_used)` + ' 255 ')
		for r in self.r_used:
			for theta in self.theta_used:
				val = self.data[(theta,r)]
				if type(val) == type(0):
					blob = chr(255-self.data[(theta,r)])
					output.write(blob+blob+blob)
				elif type(val) == type((0,)) and len(val)==3:
					output.write(chr(val[0])+chr(val[1])+chr(val[2]))
	def shortest_path(self):
		# calculate minima for everything
		distance = {}
		backtrack = {}
		# do the first line
		for r in self.rs_used():
			distance[(0,r)] = self.data[(0,r)]
			backtrack[(0,r)] = None
		prev_theta = 0
		lrjfa = self.largest_radii_jump_for_algorithm
		r_l = len(self.r_used)
		for theta in self.thetae()[1:]:
			for i in range(r_l):
				r = self.r_used[i]
				useful_previous_values = []
				for offset in range(-lrjfa,lrjfa+1):
					pos = i + offset
					if pos >= 0 and pos < r_l:
						useful_previous_values.append((self.data[(prev_theta,self.r_used[pos])],pos))
				(distance[(theta,r)],backtrack[(theta,r)]) = min(useful_previous_values)
			prev_theta = theta
		# scan the final column and find the least value
		last_theta = self.thetae()[-1]
		final_column_minima = []
		for i in range(r_l): final_column_minima.append((distance[last_theta,self.r_used[i]],i))
		idx = min(final_column_minima)[1]
		path = []
		#path = [(last_theta,self.r_used[idx])]
		thetae = self.thetae()
		thetae.reverse()
		for theta in thetae:
			next_idx = backtrack[(theta,self.r_used[idx])]
			path = [(theta,self.r_used[idx])] + path
			idx = next_idx
		return path

class MaskOfSegmentedObject:
	def __init__(self,shortest_path,image,point_in_object_centre):
		(self.x0,self.y0) = point_in_object_centre
		self.height = image.height
		self.width = image.width
		self.path = shortest_path[:]
		maximum_radius = max(map(lambda (a,b): b,self.path))
		self.maximum_radius_squared = maximum_radius * maximum_radius
		minimum_radius = min(map(lambda (a,b): b,self.path))
		self.minimum_radius_squared = minimum_radius * minimum_radius
		self.in_points = None
	def radius_at_angle(self,angle):
		"""This function relies on self.path being sorted smallest theta up to the biggest, starting at zero"""
		if angle < 0: return self.radius_at_angle(2*pi + angle)
		if angle > 2*pi: return self.radius_at_angle(angle - 2*pi)
		for i in range(len(self.path)):
			(theta,r) = self.path[i]
			if theta == angle: return r
			if theta < angle: (lower_theta,lower_r) = (theta,r)
			if theta > angle:
				arc_width = theta - lower_theta
				proportion = angle - lower_theta
				r_difference = r - lower_r
				return lower_r + r_difference * proportion
		# too big...
		(last_theta,last_r) = self.path[-1]
		(next_theta,next_r) = (2*pi,self.path[0][0])
		arc_width = next_theta - last_theta
		proportion = angle - last_theta
		r_difference = next_r - last_r
		return last_r + r_difference * proportion
	def is_point_inside(self,x,y):
		x0 = self.x0
		y0 = self.y0
		if x==x0 and y==y0: return 1
		xdiff = x0 - x
		ydiff = y0 - y
		radius_sq = xdiff * xdiff + ydiff * ydiff
		if radius_sq > self.maximum_radius_squared: return 0
		if radius_sq < self.minimum_radius_squared: return 1
		radius = sqrt(radius_sq)
		angle = atan2(y-y0,x-x0)
		radius_there = self.radius_at_angle(angle)
		if radius_there >= radius: return 1
		return 0
	def __rmul__(self,otherimage):
		if otherimage.height != self.height or otherimage.width != self.width: 
			raise ValueError,'incompatible image sizes'
		if self.in_points is None:
			self.in_points = []
			new_image_data = []
			# optimisation trick... use a range for y and x, because you know y > y0 - maximum_radius, y > 0,  etc
			for y in range(self.height):
				for x in range(self.width):
					if self.is_point_inside(x,y):
						self.in_points.append((x,y))
						(r,g,b) = otherimage[(x,y)]
						new_image_data.append((x,y,r,g,b))
			return SparseImage(self.height,self.width,new_image_data)
		new_image_data = []
		for (x,y) in self.in_points:
			(r,g,b) = otherimage[(x,y)]
			new_image_data.append((x,y,r,g,b))
		return SparseImage(self.height,self.width,new_image_data)
class SparseImage:
	def __init__(self,height,width,data):
		"""data is a list of (x,y,r,g,b); height is total image height (for dumps), likewise width"""
		self.height = height
		self.width = width
		self.data = {}
		for (x,y,r,g,b) in data:
			self.data[(x,y)] = (r,g,b)
	def __getitem__(self,pos):
		if self.data.has_key(pos): return self.data[pos]
		else: return (0,0,0)
	def dump(self,output):
		output.write('P6\n' + `self.width` + ' ' + `self.height` + ' 255 ')
		for y in range(self.height):
			for x in range(self.width):
				(r,g,b) = self[(x,y)]
				output.write(chr(r))
				output.write(chr(g))
				output.write(chr(b))
	def mean_brightness_inside(self):
		sum_so_far = 0.0
		for point in self.data.keys():
			(r,g,b) = self.data[point]
			sum_so_far = sum_so_far + r+g+b
		return sum_so_far

def sum_of_squares(image1,image2):
	sum = 0.0
	brightness1 = image1.mean_brightness_inside()
	brightness2 = image2.mean_brightness_inside()
	for pos in image1.data.keys():
			(r1,g1,b1) = image1[pos]
			(r2,g2,b2) = image2[pos]
			(rd,gd,bd) = (r1/brightness1-r2/brightness2,g1/brightness1-g2/brightness2,b1/brightness1-b2/brightness2)
			sum = sum + rd*rd + gd*gd + bd * bd
	return sum



base_image = PPM(base_picture_filename)
if timing: sys.stderr.write('Image loaded\n')
#image.dump_colouronly(sys.stdout)
mirror1colour = (0xc3, 0x8d, 0x54)
scalar = ScalarImage(base_image,mirror1colour)
#if timing: sys.stderr.write('About to do threshold\n')
#scalar.threshold(percentile=10)   # ouch.  it's because of the yellow walls
obj_x = 180
obj_y = 180
trellis = Trellis(obj_x,obj_y,scalar,max=100)
mask = MaskOfSegmentedObject(trellis.shortest_path(),base_image,(obj_x,obj_y))

inside_the_mirror = base_image * mask

for filename in other_pictures:
	this_image = PPM(filename)
	this_sparse = this_image * mask
	print filename,sum_of_squares(inside_the_mirror,this_sparse)

#sys.stderr.write("Thetae:"+`trellis.thetae()`+'\n')
#sys.stderr.write("Rs: "+`trellis.rs_used()`+'\n')
#scalar.dump(sys.stdout)


