import pandas as pd
import numpy as np
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.discrete.discrete_model import MNLogit
from statsmodels.tools.tools import add_constant
import random
import math
from collections import Counter
from sklearn.preprocessing import OneHotEncoder
from sklearn.linear_model import LogisticRegression

######################
#(1) import clean datasets
#work to get these datasets in save_pitcherdata.R

#clean pitcher data - all pitchers
pitcher_data = pd.read_csv('clean_pitcherdata.csv')

#list of pitchers users can choose from - pitcher, throw, id number
pitcher_list = pd.read_csv('pitcher_list.csv')


######################
#(2) store user selected variables

#moved to app.py
#selected_pitcher = 434378
#selected_batter_stance = "R"

#filter data pitcher_data with selections

#call run_sim with user selected inputs in app.py - then references all of these models in full.py to give list of results
def run_simulation(selected_pitcher, selected_batter_stance, selected_pitch, balls, strikes, sd_x, sd_z, mu_x, mu_z):
    #only grab data for the current pitcher and batter stance selected
    p_data = pitcher_data[(pitcher_data['pitcher'] == selected_pitcher) & (pitcher_data['stand'] == selected_batter_stance)]


    ######################
    #(3) swing and outcome parameters
    #
    #add filter to pick pitches with more than 50 observations

    #filter data to only rows where the pitch equals the selected pitch
    sub_data = p_data[p_data["pitch_name"] == selected_pitch]

    #create list to store coefficients 
    swing_dec_params = {}
    #only run for pitches with at least 50 observations (i already made sure there were enough observations when creating the list, but just to make sure nothing passes though!)
    if len(sub_data) > 50:
            
        #fit logistic regression on whether batter will swing or not
        model = smf.logit(
            "is_swing ~ dist_x + dist_z + balls + strikes", #horizontal and vertical distance from center, # balls and strikes
            data=sub_data
        ).fit_regularized(alpha=0.01, disp=0)

        coefs = model.params
        
        #save coefficients from pitch type to the swing_dec_params
        swing_dec_params = {
            "intercept": coefs["Intercept"],
            "b_x": coefs["dist_x"],
            "b_z": coefs["dist_z"],
            "b_b": coefs["balls"],
            "b_s": coefs["strikes"]
        }
    
    #create new empty list for contact param coefficients
    contact_params = {}

    #setting the order for multinomial regression
    #whiff/swining strike is the baseline
    result_map = {
        'swinging_strike': 0,
        'foul': 1,
        'hit_into_play': 2
    }
    
    #only want data where the batter swung at that pitch type - we know he swung, now need to predict what the result was on that swing
    sub_data = p_data[(p_data['pitch_name'] == selected_pitch) & (p_data['is_swing'] == 1)].copy()
        
    #map the results to integers
    sub_data['result_idx'] = sub_data['simple_result'].map(result_map)
    
    sub_data = sub_data.dropna(subset=['result_idx'])

    #for each result, make sure there are at least 25 observations of a swing 
    #(this was already filtered for ahead of time, but a second pass to make sure doesn't hurt!)
    if len(sub_data) > 25 and len(sub_data['result_idx'].unique()) > 1:
        try:
            formula = "result_idx ~ dist_x + dist_z + balls + strikes" #result is modeled using horizontal and vertical distance from center, # balls and strikes
            model = smf.mnlogit(formula=formula, data=sub_data).fit_regularized(method='l1', alpha=0.01, disp=0)
            coeffs = model.params 

            #set the order of results - whiff is baseline, foul = 1, hit into play = 2
            col_mapping = {
                'foul': 1.0,
                'hit_into_play': 2.0
            }

            #encountered many errors before this chunk if there wasn't enough data or it could only result in coefficients for one result
            #ensures that it returns 0 as a default coefficient if it cannot model
            def safe_get(outcome, var, default_val=0.0):
                    col_idx = col_mapping.get(outcome) #translate outcome to index (foul --> 1)
                    if col_idx not in coeffs.columns: #check if model produced a column for the outcome
                        return default_val
                    try: 
                        val = coeffs.loc[var, col_idx] #get coefficient for variable
                        return default_val if pd.isna(val) else val #get the default if nothing exists 
                    except KeyError:
                        return default_val

            #store coefficient in contact params list
            contact_params = {
                    'foul_int': safe_get('foul', 'Intercept', -0.5),
                    'foul_b_x': safe_get('foul', 'dist_x'),
                    'foul_b_z': safe_get('foul', 'dist_z'),
                    'foul_b_b': safe_get('foul', 'balls'),
                    'foul_b_s': safe_get('foul', 'strikes'),
                    
                    'inplay_int': safe_get('hit_into_play', 'Intercept', -1),
                    'inplay_b_x': safe_get('hit_into_play', 'dist_x'),
                    'inplay_b_z': safe_get('hit_into_play', 'dist_z'),
                    'inplay_b_b': safe_get('hit_into_play', 'balls'),
                    'inplay_b_s': safe_get('hit_into_play', 'strikes')
                }
        
        #in case the modeling really doesn't work - assign everything as 0 so it doesn't fail
        #moreso an issue before i changed filters in data and removed categorical tag from balls and strikes
        #haven't had an issue since but good to keep in case of future errors!
        except Exception as e:
            print(f"Error fitting {selected_pitch}: {e}") #for debugging purposes
            contact_params = {
                'foul_int': 0, 'foul_b_x': 0, 'foul_b_z': 0,
                'foul_b_b': 0,
                'foul_b_s': 0,
                'inplay_int': 0, 'inplay_b_x': 0, 'inplay_b_z': 0,
                'inplay_b_b': 0,
                'inplay_b_s': 0
            }
            
    ######################
    #(4) simulation functions

    #1 - need to know where a pitch is thrown
    def AB_find_location(mu_x, sigma_x, mu_z, sigma_z): #mu_x, mu_z are location chosen by user, sigma_x and z are from pitchnames_and_eststd.csv
        while True: #continues running until it picks a simulated location where it is within 1 sd 
            pitch_x = random.gauss(mu_x, sigma_x) #from a normal distribution, give random x coord
            pitch_z = random.gauss(mu_z, sigma_z) #from a normal distribution, give random z coord

            #calculate elliptical distance squared - similar to how circle is drawn around pitch showing sd
            #this value squared = how many standard deviations from chosen location
            dist_sq = ((pitch_x - mu_x)**2 / sigma_x**2) + ((pitch_z - mu_z)**2 / sigma_z**2)

            #i only want simulated locations within sd listed
            if dist_sq <= 1:
                return {'x': pitch_x, 'z': pitch_z}

    #2 - will the batter swing at the pitch
    def AB_batter(L_x, L_z, B, S, swing_dec_params): #L_x and L_z are simulated location from ab_find_location; b and s are current simulation's number of balls and strikes 
        params = swing_dec_params #get swing decision params and reference as params
        
        #predict whether they will swing using coefficients from swing_dec model and current simulation location and count
        logit = (params['intercept'] + (params['b_x'] * L_x) + 
                (params['b_z'] * (L_z - 2.5)) + (params['b_b'] * B) + (params['b_s'] * S))
        
        #turn the value from logit above into a probability between 0-1 where 0 is no swing and 1 is swing
        prob_swing = 1 / (1 + math.exp(-logit))

        #with the prob_swing, need to decide if batter actuall swung
        #chooses a random number between 0-1, if the number is less than the probability to swing, then the decision = 1 (swing)
        decision = 1 if random.random() < prob_swing else 0
        
        #store whether it was a swing or not (decision) and the actual probability that determined it (prob_swing)
        return {"decision": decision, "prob": prob_swing}
    

    #3 - the batter decided to swing! what kind of contact was it?
    def AB_swing(L_x, L_z, B, S, contact_params):
        
        #reference contact paramaters for pitch type
        params = contact_params
        
        #multinomial logistic regression
        #first store the predictors 
        #whiff = 0 because it is the baseline category, foull and inplay are interpreted relative to the whiff category
        u_whiff  = 0
        u_foul   = (params['foul_int'] + (params['foul_b_x'] * L_x) + (params['foul_b_z'] * (L_z - 2.5)) 
                    + (params['foul_b_b'] * B) + (params['foul_b_s'] * S))
        u_inplay = (params['inplay_int'] + (params['inplay_b_x'] * L_x) + (params['inplay_b_z'] * (L_z - 2.5)) 
                    + (params['inplay_b_b'] * B) + (params['inplay_b_s'] * S))
        
        #get the exponential sum of values above 
        exp_sum = np.exp(u_whiff) + np.exp(u_foul) + np.exp(u_inplay)
        
        #this step results in a probability for each result where all three add up to equal 1
        p_whiff  = np.exp(u_whiff) / exp_sum
        p_foul   = np.exp(u_foul) / exp_sum
        p_inplay = np.exp(u_inplay) / exp_sum

        #assign lists of probabilities and associated outcomes
        probs = [p_whiff, p_foul, p_inplay]
        outcomes = ["Miss", "Foul", "In Play"]

        #need to actually decide what happens with the probabilities of each result
        #utilizes a random weighted probability to choose final result of swing
        outcome = np.random.choice(outcomes, p=probs)
        
        #return the actual result (miss, foul, inplay (outcome)) and the probability of each happening on that swing
        return {"outcome": outcome, "probs": probs}


    #4- batter didn't want to swing, what does the umpire call?
    def AB_umpire(L_x, L_z, sz_top=3.5, sz_bot=1.5):
        #treat as lists/arrays
        L_x = np.atleast_1d(L_x)
        L_z = np.atleast_1d(L_z)
        
        #set the horizontal boundary of the strike zone
        plate_limit = 0.7083

        # calculate how far the pitch is inside of the strike zone
        # negative values signal that the pitch is outside of the set strike zone
        dist_x = plate_limit - abs(L_x) #how far from edge of zone

        half_height = (sz_top - sz_bot) / 2
        dist_z = half_height - abs(L_z - 2.5) #how far from top/bottom edge

        #determine if the x coord or z coord of the pitch is closer to the edge of the zone
        min_inside_dist = min(dist_x, dist_z)

        #k = steep curve value, steep logic - controls how accurate the umpire is
        #30 was chosen to create a sharp boundary of the strike zone - probability of strike decreases significantly once pitch is outside
        #lower value resulted in a very unpredictable umpire. umpires are generally very accurate so this was the best value from testing
        k = 30

        #if dist_x or dist_z is negative (outside the zone), the prob_strike will be close to 0 (ball)
        logit = k * min_inside_dist

        #turn value from logit into a probability between 0-1
        prob_strike = 1 / (1 + np.exp(-logit))

        #to determine the actual call: pick a random number, if its less than prob_strike, then result = 1 (called strike)
        calls = (np.random.rand(len(L_x)) < prob_strike).astype(int)

        return calls[0] if len(calls) == 1 else calls


    #how all of the functions work and run together
    #stores all of the data/probabilities and outcomes that i want to display
    #this all uses the user selections/current situation to run the functions
    def run_sim(n_pitches, mu_x, mu_z, sigma_x, sigma_z, B, S, swing_dec_params, contact_params):
        results = []
        
        #run for however many iterations (1,000)
        for _ in range(n_pitches):
            #pitch location - get the simulated location
            L = AB_find_location(mu_x, sigma_x, mu_z, sigma_z)
            L_x = L['x']
            L_z = L['z']

            #initialize variables
            p_whiff = np.nan
            p_foul = np.nan
            p_inplay = np.nan

            #get true call of umpire regardless of swing
            true_call_val = AB_umpire(L_x, L_z, sz_top=3.5, sz_bot=1.5)
            true_call = "Called Strike" if true_call_val == 1 else "Ball"
            
            #does batter decide to swing 
            swing_decision = AB_batter(L_x, L_z, B, S, swing_dec_params)
            
            if swing_decision['decision'] == 1: #run if the batter decides to swing at the pitch
                swing = AB_swing(L_x, L_z, B, S, contact_params)
                final_outcome = swing['outcome'] #options/probs that will determine outcome displayed to the user
                p_whiff, p_foul, p_inplay = swing['probs']
            else:
                #umpire if batter does not swing
                call = AB_umpire(L_x, L_z, sz_top=3.5, sz_bot=1.5)
                final_outcome = "Called Strike" if call == 1 else "Ball" #options/probs that will determine outcome displayed to the user
                
            #append all of the data saved from the functions into a dataset
            results.append({
                "x": L_x,
                "z": L_z,
                "outcome": final_outcome,
                "true_call": true_call,
                "p_swing": swing_decision['prob'],
                "p_whiff": p_whiff,
                "p_foul": p_foul,
                "p_inplay": p_inplay
            })
        #dataframe of results and probabilities 
        final_df = pd.DataFrame(results)
        return final_df
    
    #run sim function (that contains all other functions) with the users current selections and situation in the at bat
    #return the output - will get used in app.py
    sim_output = run_sim(
            1000, mu_x=mu_x, mu_z=mu_z, sigma_x=sd_x, sigma_z=sd_z, 
            B=balls, S=strikes,
            swing_dec_params=swing_dec_params, contact_params=contact_params
        )
    return sim_output