#!/usr/bin/env python2

import random
import subprocess
import binascii
import sys
import bsddb
import os

r = random.SystemRandom()

domain = "org"
if len(sys.argv) >= 2:
  domain = sys.argv[1]

servers = []

if len(sys.argv) >= 3:
  for i in range(2,len(sys.argv)):
    servers.append(sys.argv[i])
else:
  x = subprocess.Popen(["dig","+short","ns",domain],stdout=subprocess.PIPE).stdout
  for s in x:
    y = subprocess.Popen(["dig","+short",s.strip()],stdout=subprocess.PIPE).stdout
    for i in y:
      servers.append(i.strip())

print "domain",domain
for i in range(len(servers)):
  print "server",servers[i]

fntodo = "/tmp/collect.todo."+str(os.getpid())
todo = bsddb.btopen(fntodo,"n")
os.remove(fntodo)

salt = 0
numhashes = 0
numqueries = 0

nexthash = {}

while len(todo) > 0 or len(nexthash) == 0:
  if salt:
    x = hashes.readline().split()
    h = x[0]
    guess = x[1]
    numhashes += 1
    knownguess = 1
    if h in todo:
      print "stillneed",h,todo[h]
      knownguess = 0
    elif len(todo) > 0:
      x = todo.last()
      if x[1] < x[0] and (x[0] <= h or h <= x[1]):
        print "stillneed",x[0],x[1]
        knownguess = 0
      elif x[0] <= h and h <= x[1]:
        print "stillneed",x[0],x[1]
        knownguess = 0
      elif todo.first()[0] <= h and h <= todo.last()[0]:
        todo.set_location(h)
	x = todo.previous()
	if x[0] <= h and h <= x[1]:
          print "stillneed",x[0],x[1]
          knownguess = 0
  else:
    guess = str(r.getrandbits(40))+"."+domain
    h = "unknownsalt"
    knownguess = 0
  if not knownguess:
    print "querying",guess,h
    numqueries += 1
    server = servers[r.randrange(len(servers))]
    query = subprocess.Popen(["./query",guess,server],stdout=subprocess.PIPE).stdout
    for x in query:
      y = x.strip().split(' ')
      if y[0] == "ns":
        print "ns",y[1],y[2]
      # if y[0] == "nsec3" and y[1] in nexthash:
      #   print "again",y[1],y[2]
      if y[0] == "nsec3" and not y[1] in nexthash:
        if not salt:
          print "salt",y[3]
          print "iterations",y[4]
          salt = binascii.a2b_hex(y[3])
          iterations = int(y[4])
	  hashprocess = subprocess.Popen(["./randomhashes",domain,y[4],y[3]],stdout=subprocess.PIPE)
	  hashes = hashprocess.stdout
	if salt != binascii.a2b_hex(y[3]):
	  print "newsalt",binascii.a2b_hex(y[3])
        print "nexthash "+y[1]+" "+y[2]
        nexthash[y[1]] = y[2]
        if len(todo) == 0:
          todo[y[2]] = y[1]
        else:
	  if y[1] in todo:
	    if y[1] < y[2] and y[2] < todo[y[1]]:
	      todo[y[2]] = todo[y[1]]
	      del todo[y[1]]
	    elif todo[y[1]] < y[1] and y[1] < y[2]:
	      todo[y[2]] = todo[y[1]]
	      del todo[y[1]]
	    elif y[2] == todo[y[1]]:
	      del todo[y[1]]
	    elif y[2] == nexthash[todo[y[1]]]:
	      del todo[y[1]]
	  else:
            x = todo.last()
	    if x[1] < x[0] and ((x[0] <= y[1] and y[1] <= y[2]) or (x[0] <= y[1] and y[2] <= x[1]) or (y[1] <= y[2] and y[2] <= x[1])):
              if y[2] != x[1]:
                todo[y[2]] = x[1]
              todo[x[0]] = y[1]
	    elif x[1] < x[0] and ((x[0] <= y[1] and y[2] == nexthash[x[1]]) or (y[1] <= x[1] and y[2] == nexthash[x[1]])):
              todo[x[0]] = y[1]
            elif x[0] <= y[1] and y[1] < y[2] and y[2] <= x[1]:
              if y[2] != x[1]:
                todo[y[2]] = x[1]
              todo[x[0]] = y[1]
	    elif y[2] == nexthash[x[1]] and x[0] <= y[1] and y[1] <= x[1]:
	      del todo[x[0]]
	    elif todo.first()[0] <= y[1] and y[1] <= todo.last()[0]:
	      todo.set_location(y[1])
	      x = todo.previous()
              if x[0] <= y[1] and y[1] < y[2] and y[2] <= x[1]:
                if y[2] != x[1]:
                  todo[y[2]] = x[1]
                todo[x[0]] = y[1]
	      elif y[2] == nexthash[x[1]] and x[0] <= y[1] and y[1] <= x[1]:
	        del todo[x[0]]
	if len(nexthash) > len(todo):
	  print "stats",len(todo),numhashes,numqueries,len(nexthash),"maybe",(len(nexthash)*len(nexthash))/(len(nexthash)-len(todo))
	else:
	  print "stats",len(todo),numhashes,numqueries,len(nexthash)
        sys.stdout.flush()

print "numqueries",str(numqueries)
print "numhashes",str(numhashes)

if salt:
  hashprocess.terminate()
