#!/usr/bin/python
# -*- coding: utf8 -*-

from rtree import Rtree
import OsmSax, sys

DIST_MIN = 1.0e-12

def coords(n):
    return (n.lat, n.lon)

def distance2(a,b):
    xa, ya = coords(a)
    xb, yb = coords(b)
    return (xa-xb)**2 + (ya-yb)**2


class Node(object):
    def __init__(self, id=None, lon=None, lat=None, tags=None):
        self.id = id
        if lon != None: self.lon, self.lat = float(lon), float(lat)
        if tags:
            self.tags = tags
        else:
            self.tags = {}
        self.inWay = set()
        self.inRel = set()

class Way(object):
    def __init__(self, id=None, nodes=None, tags=None):
        self.id = id
        if nodes:
            self.nodes = nodes
        else:
            self.nodes = []
        if tags:
            self.tags = tags
        else:
            self.tags = {}

class Relation(object):
    def __init__(self, id, members=None, tags=None):
        self.id = id
        if members:
          self.members = members
        else:
          self.members = []
        if tags:
          self.tags = tags
        else:
          self.tags = {}
    def __repr__(self):
      return "Relation(id=%r, members=%r, tags=%r)" % (self.id, self.members, self.tags)

class Cache:
    def __init__(self):
        self.nods = {}
        self.ways = {}
        self.rels = {}
    def NodeCreate(self, data):
        self.nods[data["id"]] = Node(id=data["id"],lon=data["lon"],lat=data["lat"],tags=data["tag"])
    def WayCreate(self, data):
        self.ways[data["id"]] = Way(id=data["id"],nodes=data["nd"],tags=data["tag"])
    def RelationCreate(self, data):
        self.rels[data["id"]] = Relation(id=data["id"],tags=data["tag"],members=data["member"])

###########################################################################

fout = sys.argv[2]
data = OsmSax.OsmSaxReader(sys.argv[1])
cache = Cache()
print 'Parse du fichier...'
data.CopyTo(cache)


idxNode = Rtree()
tabindx = {}
print 'Indexation...'
i = 0
for k in cache.nods.keys():
    i += 1
    idxNode.insert(i, coords(cache.nods[k]))
    tabindx[i] = cache.nods[k]

# set des chemins utilisant un noeud
for w in cache.ways.values():
    for nid in w.nodes: cache.nods[nid].inWay.add(w)
# set des relations utilisant un noeud
for r in cache.rels.values():
    for m in r.members:
        if m['type'] == 'node': cache.nodes[m['ref']].inRel.add(r)

print 'Simplification des noeuds...'
# balayage des noeuds à simplifier
for noeud in cache.nods.values():
#    print 'traitment', noeud.id
    # le noeud a-t-il déjà été supprimé
    if not cache.nods.has_key(noeud.id): continue
    # recherche des noeuds proches
    for i in idxNode.nearest(coords(noeud),4):
        np = tabindx[i]
        if np == noeud: continue
        if distance2(noeud, np) < DIST_MIN:
            noeud.tags['fixme']='simplify'
            #remplacement du np par noeud dans les ways
            for w in np.inWay:
                while np.id in w.nodes :
                    ind = w.nodes.index(np.id)
                    w.nodes[ind]=noeud.id
            #suppression np de l'index
            idxNode.delete(i,coords(np))
            #suppression de la liste des noeuds
            del cache.nods[np.id]
 
print 'Nettoyage des chemins...'
for w in cache.ways.values():
    i = 1
    # balayage des segments d'un way
    while  (len(w.nodes) > 1) & (i < len(w.nodes)):
        if w.nodes[i-1] == w.nodes[i]:
            w.nodes.pop(i)
            continue
        i += 1

print 'Ecriture...'
out = OsmSax.OsmSaxWriter(fout, "UTF-8")
out.startDocument()
out.startElement("osm", {'version':'0.6'})
for n in cache.nods.values():
    out.NodeCreate({'id':n.id,'lon':n.lon,'lat':n.lat,'tag':n.tags})
for w in cache.ways.values():
    out.WayCreate({'id':w.id,'nd':w.nodes,'tag':w.tags})
for r in cache.rels.values():
    out.RelationCreate({'id':r.id,'member':r.members,'tag':r.tags})

out.endElement("osm")

