from xml.sax import make_parser, handler
from math import sqrt
from xml.sax.saxutils import escape
import osr

class OSMNode:
    def __init__( self, id=None, geom=None, tags=None):
        self.id = id
            
        if tags is None:
            self.tags = {}
        else:
            self.tags = tags
        
    def __str__( self ):
        return "%s, %s, %s" % (str(self.id), str(self.geom), str(self.tags))
        
class OSMWay:
    def __init__( self, id=None, nodes=None, tags=None ):
        self.id = id
        
        if nodes is None:
            self.nodes = []
        else:
            self.nodes = nodes
            
        if tags is None:
            self.tags = {}
        else:
            self.tags = tags
        
    def __str__( self ):
        return "%d, %s, %s" % ( self.id, str( self.nodes ), str(self.tags) )

class OSMFile:

    class OSMHandler(handler.ContentHandler):

        def __init__(self):
            self._parent = None
            self._curr_node = None
            self._curr_way = None
            self.nodes = {}
            self.ways = {}

        def startElement( self, name, attrs ):
            if name == 'node':
                self._parent = 'node'
                self._curr_node = OSMNode()
                self._curr_node.id = int( attrs['id'] )
                self._curr_node.geom = ( float(attrs['lon']), float(attrs['lat']) )
                
            if name == 'way':
                self._parent = 'way'
                self._curr_way = OSMWay()
                self._curr_way.id = int( attrs['id'] )
                
            if name == 'tag':
                if self._parent == 'node':
                    self._curr_node.tags[ str(attrs['k']) ] = str(attrs['v'])
                elif self._parent == 'way':
                    self._curr_way.tags[ str(attrs['k']) ] = str(attrs['v'])
                    
            if name == 'nd':
                self._curr_way.nodes.append( int(attrs['ref']) )
                    
        def endElement( self, name ):
            if name == 'node':
                self.nodes[ self._curr_node.id ] = self._curr_node
                self._curr_node == None
            
            if name == 'way':
                self.ways[ self._curr_way.id ] = self._curr_way
                self._curr_way == None

    def __init__( self, filename=None ):
        # Set up member collections
        self.nodes = {} #hash of id->(geom,tags)
        self.ways = {}  #hash of id->(nodes, tags)
        
        if filename is not None:
            self.parse_osmfile( filename )
        
    def parse_osmfile( self, filename ):
        osm_handler = OSMFile.OSMHandler()
                    
        parser = make_parser()
        parser.setContentHandler(osm_handler)
        parser.parse(filename)

        self.nodes = osm_handler.nodes
        self.ways  = osm_handler.ways
            
    def connected_nodes( self ):
        """Returns a list of all nodes that are referenced by more than one way"""
        
        once = {}
        twice = []
        
        for id, parsedway in self.ways.iteritems():
            for node in parsedway.nodes:
                if node in once:
                    twice.append( node )
                else:
                    once[node] = True
                    
        return twice
        
    def split_way( self, way, split_nodes, top=True ):
        
        if top:
            id = way.id*1000 # open up a space for split id, but only if this is the first recusion level
        else:
            id = way.id
        
        # Iterate down to the split point
        i=1
        for node in way.nodes[1:-1]:     #don't want to split at the endpoints
            if node in split_nodes:
                #Split the Way at the splitpoint
                left  = OSMWay(id, way.nodes[:i+1], way.tags)
                right = OSMWay(id+1, way.nodes[i:], way.tags)
                
                # Recurse!
                return [ left ] + self.split_way(right, split_nodes, top=False)
            i = i+1
            
        # Way was never split
        return [ OSMWay(id, way.nodes, way.tags) ]
        
    def split_ways( self ):
        """Returns osmfile object with split-up ways"""
        
        split_ways = {}
        
        split_nodes = self.connected_nodes()
        
        for id, way in self.ways.iteritems():
            xs = self.split_way( way, split_nodes )
            for way in xs:
                split_ways[way.id] = way
        
        ret = OSMFile()
        ret.nodes = self.nodes
        ret.ways = split_ways
        
        return ret

    from_proj = osr.SpatialReference()
    from_proj.SetWellKnownGeogCS( "EPSG:4326" )
    
    to_proj = osr.SpatialReference()
    to_proj.ImportFromProj4('+proj=merc')

    tr = osr.CoordinateTransformation( from_proj, to_proj )

    def project( self, point ):
        pt = self.tr.TransformPoint( point[0], point[1] )
        return (pt[1], pt[0])
    
    def way_length( self, way_id ):
        way = self.ways[way_id]
        
        geom = [self.project( self.nodes[id][0] ) for id in way.nodes]
            
        accum = 0
        for i in range(0, len(geom)-1):
            x1, y1 = geom[i]
            x2, y2 = geom[i+1]
            
            accum += sqrt( (x2-x1)**2 + (y2-y1)**2 )
            
        return accum
        
    def to_osm( self ):
        ret = []
        ret.append( "<?xml version='1.0' encoding='UTF-8'?>" )
        ret.append( "<osm version='0.5' generator='libosm'>" )
        
        for id, node in self.nodes.iteritems():
            ret.append("  <node id='%d' visible='true' lat='%f' lon='%f'>"%(id, node.geom[1], node.geom[0]))
            for k, v in node.tags.iteritems():
                ret.append( "    <tag k=\"%s\" v=\"%s\" />"%(k,escape(v)) )
            ret.append("  </node>")
                
        for id, way in self.ways.iteritems():
            ret.append( "  <way id='%d' visible='true'>"%id )
            for node in way.nodes:
                ret.append( "    <nd ref='%d' />"%node )
            for k, v in way.tags.iteritems():
                ret.append( "    <tag k=\"%s\" v=\"%s\" />"%(k,escape(v)) )
            ret.append("  </way>")
                
        ret.append( "</osm>" )
        
        return "\n".join(ret)
        
if __name__ == '__main__':
    
    print "Reading OSM file..."
    osmfile = OSMFile( '/home/brandon/osm/seattle.osm' )
    print "Done"
    
    print "%d nodes, %d ways" % (len(osmfile.nodes), len(osmfile.ways))
    
    print "Writing reconstituted osm file...",
    fp = open("/home/brandon/osm/seattle_reconstituted.osm", "w")
    fp.write( osmfile.to_osm() )
    fp.close()
    print "done"
    
    print "Splitting OSM file...",
    split_osmfile = osmfile.split_ways()
    print "done"
    
    print "Writing split OSM file...",
    fp = open("/home/brandon/osm/seattle_split.osm", "w" )
    fp.write( split_osmfile.to_osm() )
    fp.close()
    print "done"