#!/usr/bin/python

# The OSM abstraction code in this module was gratuitously stolen from 
# graphserver (http://graphserver.sourceforge.net/)
# Copyright (c) 2007, Brandon Martin-Anderson

import xml.sax
import copy
import sys
from math import *
from xml.dom.minidom import Document


def latlng_dist(src_lat, src_lng, dest_lat, dest_lng):
    if src_lat == dest_lat and src_lng == dest_lng:
        return 0.0
    theta = float(src_lng) - float(dest_lng);
    src_lat_radians = radians(float(src_lat));
    dest_lat_radians = radians(float(dest_lat));
    dist = sin(src_lat_radians) * sin(dest_lat_radians) + \
        cos(src_lat_radians) * cos(dest_lat_radians) * \
        cos(radians(theta));
    dist = acos(dist);
    dist = degrees(dist);
    dist *= (60.0 * 1.1515 * 1.609344 * 1000.0);
    return dist;

def way_dist(orig_way, geobase_way, orig_osm, geobase_osm):
    dist = None

    for i in [0, -1]:
        j = i == 0 and -1 or 0
        n1 = orig_osm.nodes[orig_way.nds[i]]
        n2 = geobase_osm.nodes[geobase_way.nds[0]]
        n3 = orig_osm.nodes[orig_way.nds[j]]
        n4 = geobase_osm.nodes[geobase_way.nds[-1]]
        dist1 = latlng_dist(n1.lat, n1.lon, n2.lat, n2.lon)
        dist2 = latlng_dist(n3.lat, n3.lon, n4.lat, n4.lon)
        if dist == None or (dist1+dist2) < dist:
            dist = dist1+dist2

    return dist


class Node:
    def __init__(self, id, lon, lat):
        self.id = id
        self.lon = lon
        self.lat = lat
        self.tags = {}

class Way:
    def __init__(self, id, osm):
        self.osm = osm
        self.id = id
        self.nds = []
        self.tags = {}

class OSM:

    def __init__(self, filename_or_stream):
        """ File can be either a filename or stream/file object."""
        nodes = {}
        ways = {}

        superself = self

        class OSMHandler(xml.sax.ContentHandler):
            @classmethod
            def setDocumentLocator(self,loc):
                pass

            @classmethod
            def startDocument(self):
                pass

            @classmethod
            def endDocument(self):
                pass

            @classmethod
            def startElement(self, name, attrs):
                if name=='node':
                    self.currElem = Node(attrs['id'], float(attrs['lon']), float(attrs['lat']))
                elif name=='way':
                    self.currElem = Way(attrs['id'], superself)
                elif name=='tag':
                    self.currElem.tags[attrs['k']] = attrs['v']
                elif name=='nd':
                    self.currElem.nds.append( attrs['ref'] )

            @classmethod
            def endElement(self,name):
                if name=='node':
                    nodes[self.currElem.id] = self.currElem
                elif name=='way':
                    ways[self.currElem.id] = self.currElem

            @classmethod
            def characters(self, chars):
                pass

        xml.sax.parse(filename_or_stream, OSMHandler)

        self.nodes = nodes
        self.ways = ways

        #count times each node is used
        self.node_histogram = dict.fromkeys( self.nodes.keys(), 0 )
        for way in self.ways.values():
            for node in way.nds:
                # toss out any ways that don't have all nodes on map
                # FIXME: we cannot do this for the final version
                if not self.nodes.get(node) and self.ways.get(way.id):
                    del self.ways[way.id]
                elif self.ways.get(way.id):
                    self.node_histogram[node] += 1

if __name__ == "__main__":
    if len(sys.argv) < 3:
        print "Usage: %s: <original osm file> <geobase osm file>" \
            % sys.argv[0]
        exit(1)
    
    orig_osm = OSM(sys.argv[1])
    geobase_osm = OSM(sys.argv[2])
    
    # first, we split all OSM ways on junctions. There's no way we can do a 
    # proper geobase comparison otherwise
    new_ways = {}
    new_id = 0
    for way in orig_osm.ways.values():
        new_way = None 
        for node in way.nds:
            if new_way:
                new_way.nds.append(node)
                # add to new list once we have a way with at least two nodes
                if len(new_way.nds) == 2:
                    new_ways[new_id] = new_way
                    new_id+=1

            if not new_way or orig_osm.node_histogram[node] > 1:
                new_way = copy.copy(way)
                new_way.id = len(new_ways)
                new_way.nds = []
            new_way.nds.append(node)

    orig_osm.ways = new_ways

    # try to match closest segment, by comparing nodes on junctions
    for orig_way in orig_osm.ways.values():
        closest_way = None
        closest_dist = None
        for geobase_way in geobase_osm.ways.values():
            dist = way_dist(orig_way, geobase_way, orig_osm, geobase_osm)
            if not closest_way or dist < closest_dist:
                closest_way = geobase_way
                closest_dist = dist
        
        # less than 20 meters deviation on way endpoints, likely match
        deviation = way_dist(orig_way, closest_way, orig_osm, geobase_osm)
        if deviation < 20:
            print "Match between geobase + osm (deviation: %s)" % deviation
            print "OSM:"
            for tag in orig_way.tags.keys():
                print " - %s: %s" % (unicode(tag).encode( "utf-8" ), unicode(orig_way.tags[tag]).encode( "utf-8" ))
            print "GeoBase:"
            for tag in closest_way.tags.keys():
                print " - %s: %s" % (unicode(tag).encode( "utf-8" ), \
                                     unicode(closest_way.tags[tag]).encode( "utf-8" ))
        else:
            print "NO match between osm + geobase (deviation: %s)" % deviation
            print "OSM:"
            for tag in orig_way.tags.keys():
                print " - %s: %s" % (unicode(tag).encode("utf-8"), 
                                     unicode(orig_way.tags[tag]).encode("utf-8"))
