#!/usr/bin/env python
# arch-tag: 43b684ec-eb85-41d9-9b9a-c879f5b8b732
# Author: Sascha Silbe -- http://sascha.silbe.org/
# "License": Public Domain
import csv
import xml.dom.minidom
import xml.xpath
import sys


generatorName="osmcsvdiff"
elemTypes=['node', 'way', 'relation']
nonTagAttrs=['type', 'id', 'lat', 'lon']


def extractTags(entry) :
  tags = dict(entry)
  for name in nonTagAttrs :
    if name in tags :
      del tags[name]

  return tags

def readCSVs(fNameOld, fNameNew) :
  oldLines = csv.DictReader(file(fNameOld, "rb"))
  newLines = csv.DictReader(file(fNameNew, "rb"))
  oldByType = dict([(type_,{}) for type_ in elemTypes])
  newByType = dict([(type_,{}) for type_ in elemTypes])

  for entry in oldLines :
    oldByType[entry['type']][int(entry['id'])] = {'tags': extractTags(entry)}

  for entry in newLines :
    newByType[entry['type']][int(entry['id'])] = {'tags': extractTags(entry)}

  return (oldByType, newByType)

def uniq(l) :
  return dict.fromkeys(l).keys()

def genChangesType(oldVals, newVals) :
  diffs = {}
  for id in uniq(oldVals.keys()+newVals.keys()) :
    if id not in oldVals :
      raise NotImplementedError("Creation not yet supported")

    elif id not in newVals :
      diffs[id] = {'action': 'delete'}

    else :
      if (oldVals[id]['tags'] == newVals[id]['tags']) :
        continue

      diff = {
        'action': 'modify',
        'setTags': dict([(name, newValue) for (name, newValue) in newVals[id]['tags'].items() if (oldVals[id]['tags'].get(name) != newValue)]),
        'delTags': [name for name in oldVals[id]['tags'].keys() if (name not in newVals[id]['tags'])],
      }
      if diff['setTags'] or diff['delTags'] :
        diffs[id] = diff

  return diffs

def genChanges(oldVals, newVals) :
  diffs = dict([(t,{}) for t in elemTypes])
  for type_ in elemTypes :
    diffs[type_] = genChangesType(oldVals[type_], newVals[type_])

  return diffs


def applyDiff(oldNode, diff) :
  newNode=oldNode.cloneNode(deep=True)
  for (name, value) in diff['setTags'].items() :
    xml.xpath.Evaluate('tag[@k="%s"]' % (name,), newNode)[0].setAttribute("v", value)

  for name in diff['delTags'] :
    oldTag = xml.xpath.Evaluate('tag[@k="%s"]' % (name,), newNode)
    newNode.removeChild(oldTag)
    oldTag.unlink()

  return newNode

def formatChangeOldNew(type_, id, xmlNode, diff) :
  lines = ['<%s id="%d" action="%s">' % (type_, id, diff['action'])]

  if (diff['action'] == 'modify') :
    lines += ['<old>', xmlNode.toxml(encoding='utf-8'), '</old>']
    lines += ['<new>', applyDiff(xmlNode, diff).toxml(encoding='utf-8'), '</old>']

  lines.append('</%s>' % (type_,))
  return lines

def formatChangesOldNew(oldXml, diffsByType) :
  lines=["<?xml version='1.0' encoding='UTF-8'?>", '<osmOldNew version="1.0" generator="%s">' % (generatorName,)]
  for type_ in elemTypes :
    for (id, diff) in diffsByType[type_].items() :
      elemNode = xml.xpath.Evaluate("/osm/%s[@id='%d']" % (type_,id), oldXml)[0]
      lines += formatChangeOldNew(type_, id, elemNode, diff)

  lines.append('</osmOldNew>')
  return "\n".join(lines)

def formatChangesNew(oldXml, diffsByType) :
  lines=["<?xml version='1.0' encoding='UTF-8'?>", '<osm version="0.5" generator="%s">' % (generatorName,)]
  for type_ in elemTypes :
    for (id, diff) in diffsByType[type_].items() :
      elemNode = xml.xpath.Evaluate("/osm/%s[@id='%d']" % (type_,id), oldXml)[0]
      lines.append(applyDiff(elemNode, diff).toxml(encoding='utf-8'))

  lines.append('</osm>')
  return "\n".join(lines)


def printSyntax(myName) :
  print "Syntax: "+myName+" <outputFormat> <old.xml> <old.csv> <new.csv>"
  print "Available output formats: oldNew, new"
  return 100

def main(myName, args) :
  if len(args) != 4 :
    return printSyntax(myName)

  outputFormat, oldXmlFName, oldCsvFName, newCsvFName = args
  oldXml = xml.dom.minidom.parse(oldXmlFName)
  if outputFormat == "oldNew" :
    print formatChangesOldNew(oldXml, genChanges(*readCSVs(oldCsvFName, newCsvFName)))
  elif outputFormat == "new" :
    print formatChangesNew(oldXml, genChanges(*readCSVs(oldCsvFName, newCsvFName)))
  else :
    raise ValueError("Unknown output format: "+`outputFormat`)


sys.exit(main(sys.argv[0], sys.argv[1:]))
