import geopandas as gpd
import pandas as pd
import json
import os
import matplotlib.pyplot as plt

# Folder containing all simulated plans:
# map_0.json, map_1.json, ..., map_5000.json
MAPS_FOLDER = "/Users/edieneuville/Desktop/graph decomp/SEPFILE/Simulation"

# Output files
OUTPUTCSV = "/Users/edieneuville/Desktop/graph decomp/SEPFILE/ensemble_results.csv"
SEATDISTCSV = "/Users/edieneuville/Desktop/graph decomp/SEPFILE/seat_distributions.csv"
MMDISTCSV = "/Users/edieneuville/Desktop/graph decomp/SEPFILE/majority_minority_districts.csv"
MMTCSV = "/Users/edieneuville/Desktop/graph decomp/SEPFILE/majority_minority_counts_per_map.csv"

#elections to analyze
ELECTIONS = [
    ("PRES12D", "PRES12R", "PRES12"),
    ("GOV12D", "GOV12R", "GOV12"),
    ("SEN12D", "SEN12R", "SEN12"),
    ("USH12D", "USH12R", "USH12"),
    ("GOV14D", "GOV14R", "GOV14"),
    ("USH14D", "USH14R", "USH14"),
    ("PRES16D", "PRES16R", "PRES16"),
    ("SEN16D", "SEN16R", "SEN16"),
    ("USH16D", "USH16R", "USH16"),
    ("GOV18D", "GOV18R", "GOV18"),
    ("SEN18D", "SEN18R", "SEN18"),
    ("USH18D", "USH18R", "USH18"),
]

# The JSON maps only contain the district assignment for each ward
wards = gpd.read_file("WI.shp")

# Reset index so row numbers match the node IDs used in JSON maps
wards = wards.reset_index(drop=True)

print("Loaded shapefile")
print("Rows:", len(wards))

#Create a node_id column that matches the JSON keys
wards["node_id"] = wards.index.astype(str)


#Majority Minority districts counts
#(TOTPOP - NH_WHITE) > NH_WHITE
wards["NONWHITE_POP"] = wards["TOTPOP"] - wards["NH_WHITE"]
wards["MAJORITY_MINORITY_WARD"] = wards["NONWHITE_POP"] > wards["NH_WHITE"]

#count statewide majority-minority wards
num_mm_wards = int(wards["MAJORITY_MINORITY_WARD"].sum())




def count_district_wins(gdf, dem_col, rep_col, district_col="district"):

    #Aggregate ward-level votes up to districts, then count how many
    #districts Democrats and Republicans win
    #dem_wins = number of districts Democrats win
    #rep_wins = number of districts Republicans win
    #ties = number of tied districts
    #grouped = district-level vote totals

    #Sum vote totals within each district
    grouped = gdf.groupby(district_col)[[dem_col, rep_col]].sum()

    #Count district wins
    dem_wins = int((grouped[dem_col] > grouped[rep_col]).sum())
    rep_wins = int((grouped[rep_col] > grouped[dem_col]).sum())
    ties = int((grouped[dem_col] == grouped[rep_col]).sum())

    return dem_wins, rep_wins, ties, grouped


map_files = sorted(
    [f for f in os.listdir(MAPS_FOLDER) if f.endswith(".json")]
)
#These lists will collect all output rows
results = []
majority_minority_rows = []



#load the JSON
#assign district labels to the shapefile rows
#compute election outcomes by district
#compute majority-minority districts
#save one summary row for that map
for idx, filename in enumerate(map_files):
    if idx % 100 == 0:
        print(f"Processing map {idx} / {len(map_files)}")

    path = os.path.join(MAPS_FOLDER, filename)

    # Load one simulated map
    with open(path, "r") as f:
        map_data = json.load(f)

    #copy the shapefile data so we can assign this map's districts
    gdf = wards.copy()

    # Replace district labels using the JSON
    gdf["district"] = gdf["node_id"].map(lambda x: map_data[x]["district"])

    #start one summary row for this map
    row_result = {
        "map_file": filename,
        "majority_minority_wards_statewide": num_mm_wards
    }
    for dem_col, rep_col, label in ELECTIONS:
        total_dem_votes = int(wards[dem_col].sum())
        total_rep_votes = int(wards[rep_col].sum())

        row_result[f"{label}_D_total_votes"] = total_dem_votes
        row_result[f"{label}_R_total_votes"] = total_rep_votes
    # Aggregate TOTPOP and NH_WHITE to district level
    district_demo = (
        gdf.groupby("district")[["TOTPOP", "NH_WHITE"]]
        .sum()
        .reset_index()
    )

    district_demo["NONWHITE_POP"] = district_demo["TOTPOP"] - district_demo["NH_WHITE"]
    district_demo["MAJORITY_MINORITY_DISTRICT"] = district_demo["NONWHITE_POP"] > district_demo["NH_WHITE"]

    # Save each majority-minority district for this map
    for _, row in district_demo.iterrows():
        if row["MAJORITY_MINORITY_DISTRICT"]:
            majority_minority_rows.append({
                "map_file": filename,
                "district": int(row["district"]),
                "TOTPOP": int(row["TOTPOP"]),
                "NH_WHITE": int(row["NH_WHITE"]),
                "NONWHITE_POP": int(row["NONWHITE_POP"])
            })

    # Also save how many majority-minority districts this map has
    row_result["num_majority_minority_districts"] = int(
        district_demo["MAJORITY_MINORITY_DISTRICT"].sum()
    )
    # For each election, count Democratic and Republican district wins
    for dem_col, rep_col, label in ELECTIONS:
        dem_wins, rep_wins, ties, grouped = count_district_wins(gdf, dem_col, rep_col)

        # Save seat counts into the results row
        row_result[f"{label}_D_wins"] = dem_wins
        row_result[f"{label}_R_wins"] = rep_wins
        row_result[f"{label}_ties"] = ties

        # Optional: average Democratic vote share across districts
        vote_share = grouped[dem_col] / (grouped[dem_col] + grouped[rep_col])
        row_result[f"{label}_D_vote_share_mean"] = vote_share.mean()

    # Store this map's full summary row
    results.append(row_result)


#Save results to CSV
results_df = pd.DataFrame(results)
results_df.to_csv(OUTPUTCSV, index=False)



#Saving majority-minority district details and counts
#which map had a majority-minority district
#which district it was
#its district-level population totals
majority_minority_df = pd.DataFrame(majority_minority_rows)
majority_minority_df.to_csv(MMDISTCSV  , index=False)

#also save a simpler file: how many majority-minority districts each map had
if not majority_minority_df.empty:
    mm_count_per_map = (
        majority_minority_df.groupby("map_file")
        .size()
        .reset_index(name="num_majority_minority_districts")
    )
else:
    # If none exist, still write an empty summary using the map list
    mm_count_per_map = pd.DataFrame({
        "map_file": map_files,
        "num_majority_minority_districts": [0] * len(map_files)
    })

mm_count_per_map.to_csv(MMTCSV, index=False)

print("Saved majority-minority district counts per map to:")
print(MMTCSV  )


# For each election:
#count how many maps gave Democrats k seats
#count how many maps gave Republicans k seats
#save those distributions
#plot histograms
seat_distribution_rows = []

# Number of districts is the max district label count in any map.
# Since you are working with congressional plans, this should be 8.
numDistricts = int(results_df[[f"{label}_D_wins" for _, _, label in ELECTIONS]].max().max()
                   + results_df[[f"{label}_R_wins" for _, _, label in ELECTIONS]].max().max())

# The previous line overcounts because D wins + R wins = seats, so instead
# just use the maximum D wins or R wins seen, and compare against known 8 if needed.
numDistricts = max(
    int(results_df[f"{label}_D_wins"].max()) for _, _, label in ELECTIONS
)
numDistricts = max(
    numDistricts,
    max(int(results_df[f"{label}_R_wins"].max()) for _, _, label in ELECTIONS)
)

# If you know you always have 8 districts, you can just do:
# numDistricts = 8

for _, _, label in ELECTIONS:
    print(f"{label} seat distribution")

    # Count how many maps gave Democrats k seats
    dem_counts = (
        results_df[f"{label}_D_wins"]
        .value_counts()
        .sort_index()
    )

    # Count how many maps gave Republicans k seats
    rep_counts = (
        results_df[f"{label}_R_wins"]
        .value_counts()
        .sort_index()
    )

    # Convert to full arrays from 0 to numDistricts
    dem_array = [dem_counts.get(k, 0) for k in range(numDistricts + 1)]
    rep_array = [rep_counts.get(k, 0) for k in range(numDistricts + 1)]

    print("Democratic seat count array:")
    print(dem_array)

    print("Republican seat count array:")
    print(rep_array)

    # Save one long-format row for every possible seat count
    for seats in range(numDistricts + 1):
        seat_distribution_rows.append({
            "election": label,
            "party": "D",
            "seats_won": seats,
            "num_maps": dem_array[seats]
        })
        seat_distribution_rows.append({
            "election": label,
            "party": "R",
            "seats_won": seats,
            "num_maps": rep_array[seats]
        })

    

    # Democratic seat histogram
    plt.figure(figsize=(8, 5))
    plt.hist(
        results_df[f"{label}_D_wins"],
        bins=range(0, numDistricts + 2),
        align="left",
        rwidth=0.8
    )
    plt.xticks(range(0, numDistricts + 1))
    plt.xlabel("Democratic districts won")
    plt.ylabel("Number of maps")
    plt.title(f"{label} Democratic Seat Distribution")
    plt.show()

    # Republican seat histogram
    plt.figure(figsize=(8, 5))
    plt.hist(
        results_df[f"{label}_R_wins"],
        bins=range(0, numDistricts + 2),
        align="left",
        rwidth=0.8
    )
    plt.xticks(range(0, numDistricts + 1))
    plt.xlabel("Republican districts won")
    plt.ylabel("Number of maps")
    plt.title(f"{label} Republican Seat Distribution")
    plt.show()


#seat dists
seat_dist_df = pd.DataFrame(seat_distribution_rows)
seat_dist_df.to_csv(SEATDISTCSV  , index=False)

print("\nSaved seat distributions to:")
print(SEATDISTCSV  )
