import geopandas as gpd
import pandas as pd
import json
from collections import defaultdict, deque


# Column that uniquely identifies each ward
# We will create one if needed, but this is the preferred source column
UNIQUE_ID_COL = "node_id"

wards = gpd.read_file("WI.shp")
#Fix invalid geometries
wards["geometry"] = wards.buffer(0)
#wi crs 
wards = wards.to_crs(epsg=3071)

#unique node IDs as strings
#ReCom code expects string node IDs in the JSON.
#stable ID column from the row index.
wards = wards.reset_index(drop=True)
wards[UNIQUE_ID_COL] = wards.index.astype(str)
# Rook adjacency=two wards share more than a single point.


#makes a dictionary of ward ids to list of adj ward ids
def calculate_rook_adjacency(gdf, unique_col):
    #Confirm the unique column really is unique
    if not (gdf[unique_col].value_counts(dropna=False) == 1).all():
        raise ValueError(f"{unique_col} is not unique")

    print("Running self-overlay for adjacency...")

    #Intersect the GeoDataFrame with itself
    all_intersections = gpd.overlay(
        gdf[[unique_col, "geometry"]],
        gdf[[unique_col, "geometry"]],
        how="intersection",
        keep_geom_type=False
    )


    #remove self-intersections
    filtered = all_intersections[
        all_intersections[f"{unique_col}_1"] != all_intersections[f"{unique_col}_2"]
    ].copy()

    #rook adjacency: keep only non-point intersections
    rook_intersections = filtered[
        ~filtered.geom_type.isin(["Point", "MultiPoint"])
    ].copy()

    #build adjacency dictionary
    rook_dict = defaultdict(list)

    for u, v in zip(rook_intersections[f"{unique_col}_1"], rook_intersections[f"{unique_col}_2"]):
        rook_dict[u].append(v)

    # Make sure every ward appears, even if isolated
    for val in gdf[unique_col]:
        if val not in rook_dict:
            rook_dict[val] = []

    # Remove duplicates and sort
    for key in rook_dict:
        rook_dict[key] = sorted(list(set(rook_dict[key])))

    return rook_dict


adj = calculate_rook_adjacency(wards, UNIQUE_ID_COL)
print("Built rook adjacency")

#Build json file in format expected by ReCom code
#CD = district label
#TOTPOP = population
data = {}

for _, row in wards.iterrows():
    node = row[UNIQUE_ID_COL]

    data[node] = {
        "district": int(row["CD"]),
        "population": int(row["TOTPOP"]),
        "adjacencies": adj[node]
    }

#helper funcs for checking connectivity of initial districts before ReCom loop
def district_vertices(d):
    return [n for n in data if data[n]["district"] == d]


def district_edges(d):
    verts = district_vertices(d)
    vert_set = set(verts)
    edges = []
#build edge list of edges with both endpoints in this district
    for n in verts:
        for nbr in data[n]["adjacencies"]:
            if nbr in vert_set:
                edges.append((n, nbr))

    return verts, edges

from collections import defaultdict, deque

def connected_components(vertices, edges):
    # Build adjacency list representation of the graph
    # Each node maps to a list of its neighbors
    graph = defaultdict(list)

    for u, v in edges:
        graph[u].append(v)   
        graph[v].append(u)   

    visited = set()  # keeps track of nodes we've already explored
    comps = []       

    for node in vertices:
        #if node hasn't been visited, it starts a new component
        if node not in visited:
            comp = []               
            queue = deque([node]) 
            visited.add(node)     

            #bfs to explore all nodes in this component
            while queue:
                cur = queue.popleft()  
                comp.append(cur)      

                #visit all neighbors of current node
                for nbr in graph[cur]:
                    if nbr not in visited:
                        visited.add(nbr)  
                        queue.append(nbr)

            #Finished exploring one full connected component
            comps.append(comp)

    #return list of all connected components
    return comps

#if a district has multiple connected components, we keep the largest as the mainland
# onnect each smaller component to the nearest ward in the mainland using one bidirectional edge
def fix_islands_by_connecting_to_mainland():
    labels = sorted(set(data[n]["district"] for n in data))

    print("\nConnecting island components to their own mainland...\n")
    #loop through districts
    for d in labels:
        verts, edges = district_edges(d)
        comps = connected_components(verts, edges)
        #no chnages needed go to next dist
        if len(comps) <= 1:
            continue
        #there are islands so get components by size to see which is mainland
        comps = sorted(comps, key=len, reverse=True)
        mainland = comps[0]
        islands = comps[1:]

        print(f"District {d} has {len(islands)} island component(s).")

        #loop through islands and connect to mainland
        for comp in islands:
            bestIsland = None
            bestMainland = None
            connectDist = float("inf")
            #loop through all nodes
            for island_node in comp:
                #get geometries of island and mainland nodes to calculate distance
                geom1 = wards.loc[int(island_node)].geometry
                #loop through mainland nodes to find closest pair
                for mainland_node in mainland:
                    geom2 = wards.loc[int(mainland_node)].geometry
                    dist = geom1.distance(geom2)
                    #if this pair is closer than the best pair so far, update best pair and distance
                    if dist < connectDist:
                        connectDist = dist
                        bestIsland = island_node
                        bestMainland = mainland_node
            #if we found a pair to connect
            if bestIsland is None or bestMainland is None:
                continue
            #add bidirectional edge in data
            if bestMainland not in data[bestIsland]["adjacencies"]:
                data[bestIsland]["adjacencies"].append(bestMainland)
            if bestIsland not in data[bestMainland]["adjacencies"]:
                data[bestMainland]["adjacencies"].append(bestIsland)

fix_islands_by_connecting_to_mainland()

#Save
with open("Map_0.json", "w") as f:
    json.dump(data, f, indent=2)


from draw_wards import draw_wards_map
draw_wards_map(
    shp_filename="WI.shp",
    json_filename="Map_0.json",
    title="Wisconsin ReCom District Map"
)