import geopandas as gpd
import json
import matplotlib.pyplot as plt
import os


def draw_wards_map(
    shp_filename="Wards_2021_Johnson_v_WEC.shp",
    json_filename="wisconsin_recom.json",
    title="ReCom District Map"
):
    #find files relative to this script
    script_dir = os.path.dirname(os.path.abspath(__file__))

    shp_path = os.path.join(script_dir, shp_filename)
    json_path = os.path.join(script_dir, json_filename)

    #load shapefile
    wards = gpd.read_file(shp_path)

    #load recom output
    with open(json_path, "r") as f:
        recom_data = json.load(f)

    #match district labels from json back onto wards
    wards["district"] = wards.index.astype(str).map(
        lambda x: recom_data[x]["district"] if x in recom_data else None
    )

    # draw map
    fig, ax = plt.subplots(figsize=(10, 10))

    wards.plot(
        column="district",
        cmap="tab20",
        linewidth=0.3,
        edgecolor="black",
        ax=ax,
        legend=False
    )

    ax.set_title(title)
    ax.set_axis_off()
    plt.show()