import json
import random
from collections import defaultdict, deque
random.seed(67)

###GLOBALS###
#stop after this many ACCEPTED 
TARGETACCEPTS = 10
#Population tolerance of ideal district population
POPTOLERANCE = 0.01

#Helper funcs for recom
def getedges(data):
    edges = set()
    #build edge list from adjacencies in the json
    for node in data:
        for neighbor in data[node]["adjacencies"]:
            edge = tuple(sorted([node, neighbor]))
            edges.add(edge)

    return list(edges)

#edges that cross district boarders
def findbordercrossings(data, edges):
    crossings = []

    for u, v in edges:
        if data[u]["district"] != data[v]["district"]:
            crossings.append((u, v))

    return crossings

#find which district labels are adj
def findadjacentdistricts(borderedges):
    adjdists = set()
    #each border edge connects two districts
    #add that pair to the set of adj districts
    for u, v in borderedges:
        d1 = data[u]["district"]
        d2 = data[v]["district"]
        adjdists.add(tuple(sorted([d1, d2])))

    return adjdists

#merge two districts into one subgraph for recom
def com(a, b):
    vertices = []
    edges = []
    #build vertex and edge list of the merged district
    for node in data:
        if data[node]["district"] in (a, b):
            vertices.append(node)

    vertexset = set(vertices)
    #only keep edges where both endpoints are in the merged district
    for node in vertices:
        for neighbor in data[node]["adjacencies"]:
            if neighbor in vertexset:
                edges.append((node, neighbor))

    return vertices, edges


def isconnected(vertices, edges):
    if not vertices:
        return False

    graph = defaultdict(list)
    #build graph from edges
    for u, v in edges:
        graph[u].append(v)
        graph[v].append(u)

    start = vertices[0]
    visited = {start}
    queue = deque([start])
    #bfs to check connectivity
    while queue:
        node = queue.popleft()
        for nbr in graph[node]:
            if nbr not in visited:
                visited.add(nbr)
                queue.append(nbr)

    return len(visited) == len(vertices)

#all nodes in one district
def districtvertices(label):
    return [node for node in data if data[node]["district"] == label]

#internal edges of a district
def districtinternaledges(label):
    verts = districtvertices(label)
    vertset = set(verts)
    edges = []

    for node in verts:
        for nbr in data[node]["adjacencies"]:
            if nbr in vertset:
                edges.append((node, nbr))

    return verts, edges

#return total population of one district
def districtpopulation(label):
    return sum(
        data[node]["population"]
        for node in data
        if data[node]["district"] == label
    )


#Wilsons algorithm for generating a random spanning tree on the mega district
def randomspanningtree(vertices, edges):

    graph = defaultdict(list)

    for u, v in edges:
        graph[u].append(v)
        graph[v].append(u)

    #make sure no islands before starting the random walk
    for v in vertices:
        if len(graph[v]) == 0:
            return []

    #start the tree from one random root
    root = random.choice(vertices)
    intree = {root}
    nextnode = {}
    #loop erased random walk from every other vertex until it hits the tree
    #then add that path to the tree
    for start in vertices:
        if start in intree:
            continue

        u = start
        path = {}

        #random walk until reaching the existing tree
        while u not in intree:
            v = random.choice(graph[u])
            path[u] = v
            u = v

        #add the path into the tree
        u = start
        while u not in intree:
            intree.add(u)
            nextnode[u] = path[u]
            u = path[u]
    #return the edges of the tree
    return [(u, nextnode[u]) for u in nextnode]

#cut an edge and return the resulting components
def getcomponents(treeedges, cutedge):
    graph = defaultdict(list)
    nodes = set()

    # Build graph without the cut edge
    for u, v in treeedges:
        nodes.add(u)
        nodes.add(v)
        if (u, v) != cutedge and (v, u) != cutedge:
            graph[u].append(v)
            graph[v].append(u)

    visited = set()
    components = []

    #iterate over ALL nodes
    for node in nodes:
        if node not in visited:
            comp = []
            queue = deque([node])
            visited.add(node)

            while queue:
                n = queue.popleft()
                comp.append(n)

                for nbr in graph[n]:
                    if nbr not in visited:
                        visited.add(nbr)
                        queue.append(nbr)

            components.append(comp)

    return components

#check if compnents are balanced relative to ideal district population
def balancedsplit(comp1, comp2):

    pop1 = sum(data[v]["population"] for v in comp1)
    pop2 = sum(data[v]["population"] for v in comp2)
    #allowable population range is ideal +/- tolerance relative to ideal population
    tolerance = POPTOLERANCE * idealDistrictPopulation

    return (
        abs(pop1 - idealDistrictPopulation) <= tolerance and
        abs(pop2 - idealDistrictPopulation) <= tolerance
    )

#runs until global TARGETACCEPTS is reached or we hit max attempts
for sim in range(1, 1500):
    #for tracking
    accepted = 0
    attempts = 0
    skippeddisconnected = 0
    skippednocut = 0

    #load json from RookBuild.py for map 0. then run on prev map for subsequent sims
    with open(f"Simulation/map_{sim-1}.json", "r") as f:
        data = json.load(f)

    print("Loaded graph JSON")

    #initial district stats
    #makes sure we have the right number of districts and computes ideal district population for balance checks
    districtlabels = sorted(set(data[node]["district"] for node in data))
    numDistricts = len(districtlabels)
    totalPopulation = sum(data[node]["population"] for node in data)
    idealDistrictPopulation = totalPopulation / numDistricts
    #5000 protects aginst infinite loop if acceptance rate is very low
    while accepted < TARGETACCEPTS and attempts < 5000:
        print(f"\nAttempt {attempts} | Accepted so far: {accepted}")
        attempts += 1

        #build edge list of the current map
        E = getedges(data)

        #find district boundary edges
        borderedges = findbordercrossings(data, E)

        #determine which district labels touch
        adjdists = findadjacentdistricts(borderedges)

        #choose one touching pair uniformly at random
        a, b = random.choice(list(adjdists))
        #print("Trying districts:", a, b)

        #merge the two districts into one subgraph
        vX, eX = com(a, b)

        if len(vX) == 0:
            continue

        #if the merged district is disconnected skip
        if not isconnected(vX, eX):
            #print("Merged districts not connected, skipping")
            skippeddisconnected += 1
            continue

        #build a Wilson spanning tree on the merged district
        treeedges = randomspanningtree(vX, eX)

        if len(treeedges) == 0:
            continue

        #Check every possible tree cut and keep the valid ones
        validcuts = []

        for edge in treeedges:
            components = getcomponents(treeedges, edge)
         
            comp1, comp2 = components
            #check if this cut is population balanced within our tolerance
            #if yes add to valid cuts
            if balancedsplit(comp1, comp2):
                validcuts.append((edge, comp1, comp2))

        if not validcuts:
            #print("No balanced cuts found, skipping")
            skippednocut += 1
            continue

        #Uniformly choose one valid cut
        cutedge, comp1, comp2 = random.choice(validcuts)

        #reassign the two components back to districts a and b
        for v in comp1:
            data[v]["district"] = a

        for v in comp2:
            data[v]["district"] = b

        accepted += 1
        #print(f"Accepted recom move #{accepted}")

    #save
    with open(f"Simulation/map_{sim}.json", "w") as f:
        json.dump(data, f, indent=2)

    print(f"Saved Simulation/map_{sim}.json")





from draw_wards import draw_wards_map
draw_wards_map(
    shp_filename="WI.shp",
    json_filename=f"Simulation/map_1501.json",
    title="Wisconsin ReCom District Map"
)