#!/usr/bin/python
# -*- coding: utf8 -*-

from rtree import Rtree
import OsmSax, sys
from shapely.geometry import Point, LineString

DIST_MIN = 2.0e-6

def coords(n):
    return (n.lat, n.lon)



def procheWay(nd, p1, p2):
   #renvoie vrai si node est pres du segment formé par p1,p2
   no=Point(coords(nd))
   no1=Point(coords(p1))
   no2=Point(coords(p2))
   seg = LineString([coords(p1), coords(p2)])
   d = no.distance(seg)
   d1 = no.distance(no1)
   d2 = no.distance(no2)
   if (d < DIST_MIN) & (d < d1) & (d < d2)  :
         return True
   else:
      return False



class Node(object):
    def __init__(self, id=None, lon=None, lat=None, tags=None):
        self.id = id
        if lon != None: self.lon, self.lat = float(lon), float(lat)
        if tags:
            self.tags = tags
        else:
            self.tags = {}
        self.inWay = set()
        self.inRel = set()

class Way(object):
    def __init__(self, id=None, nodes=None, tags=None):
        self.id = id
        if nodes:
            self.nodes = nodes
        else:
            self.nodes = []
        if tags:
            self.tags = tags
        else:
            self.tags = {}

class Relation(object):
    def __init__(self, id, members=None, tags=None):
        self.id = id
        if members:
          self.members = members
        else:
          self.members = []
        if tags:
          self.tags = tags
        else:
          self.tags = {}
    def __repr__(self):
      return "Relation(id=%r, members=%r, tags=%r)" % (self.id, self.members, self.tags)

class Cache:
    def __init__(self):
        self.nods = {}
        self.ways = {}
        self.rels = {}
    def NodeCreate(self, data):
        self.nods[data["id"]] = Node(id=data["id"],lon=data["lon"],lat=data["lat"],tags=data["tag"])
    def WayCreate(self, data):
        self.ways[data["id"]] = Way(id=data["id"],nodes=data["nd"],tags=data["tag"])
    def RelationCreate(self, data):
        self.rels[data["id"]] = Relation(id=data["id"],tags=data["tag"],members=data["member"])

###########################################################################

fout = sys.argv[2]

data = OsmSax.OsmSaxReader(sys.argv[1])
cache = Cache()
print 'Parse du fichier...'
data.CopyTo(cache)


idxNode = Rtree()
tabindx = {}
print 'Indexation...'
i = 0
for k in cache.nods.keys():
    i += 1
    idxNode.insert(i, coords(cache.nods[k]))
    tabindx[i] = cache.nods[k]


print 'Snap...'
# balayage des ways
for w in cache.ways.values():
    i = 1
    # balayage des segments d'un way
    while i < len(w.nodes):
#        print w.id , ' ',len(w.nodes), ' ', i
        # recherche des noeuds proches du segment
        seg = LineString( [ coords(cache.nods[w.nodes[i-1]]), coords(cache.nods[w.nodes[i]]) ] )
        for idx in idxNode.nearest(seg.bounds,10):
            np = tabindx[idx]
            # si le noeud est déjà dans le way on ne fait rien
#            print 'id point: ', np.id
            if np.id in w.nodes: continue
#            print 'id point: ', np.id, ' ', no.distance(seg) 
            if  procheWay(np, cache.nods[w.nodes[i-1]], cache.nods[w.nodes[i]]):
                np.tags['fixme'] = 'Snaper'
                w.nodes.insert(i, np.id)
                i-=1
                break
        i += 1
                
 
print 'Ecriture...'
out = OsmSax.OsmSaxWriter(fout, "UTF-8")
out.startDocument()
out.startElement("osm", {'version':'0.6'})
for n in cache.nods.values():
    out.NodeCreate({'id':n.id,'lon':n.lon,'lat':n.lat,'tag':n.tags})
for w in cache.ways.values():
    out.WayCreate({'id':w.id,'nd':w.nodes,'tag':w.tags})
for r in cache.rels.values():
    out.RelationCreate({'id':r.id,'member':r.members,'tag':r.tags})

out.endElement("osm")

