import numpy as np
import sympy as sp
import matplotlib.pyplot as plt
import pandas as pd
import scipy.optimize as so
import re

from matplotlib.lines import Line2D
from mpl_toolkits.axes_grid1.inset_locator import inset_axes, InsetPosition, mark_inset

import matplotlib.ticker as ticker
from pyfluids import Fluid, FluidsList, Input

from pathlib import Path

from TemperatureField import TemperatureField

plt.rcParams.update({'font.size': 22})
plt.rcParams.update({'mathtext.fontset':'cm'})
plt.rcParams.update({'font.family':'Times New Roman'})

def make_dict(paths, Dh, Pr, density, nu, Cp, k, H, A, h_area, Twall, new_lambda, lambda_coeffs, method, solver):

    def find_nearest(array, value):
        array = np.asarray(array)
        idx = (np.abs(array - value)).argmin()
        return array[idx]

    plot_dict = {}

    for i, path in enumerate(paths):
        case = path.stem
        plot_dict[case] = {'periodic':{}, 'analytical':{}}
        # analytical = TemperatureField(1.0, 0.5, 0.7, float(case[2::]), nu, True, 8)    
        df = pd.read_csv(path)
        x_coords = df['Points_0'].unique()
        min_x = np.min(x_coords)
        max_x = np.max(x_coords)
        plot_dict[case]['periodic']['L'] = max_x-min_x
        mid_x = find_nearest(x_coords, plot_dict[case]['periodic']['L']/2)
        
        inlet_df = df[df['Points_0']==min_x]
        outlet_df = df[df['Points_0']==max_x]
        midway_df = df[df['Points_0']==mid_x]

        plot_dict[case]['periodic']['x'] = df['Points_0'].values-min_x
        plot_dict[case]['periodic']['y'] = df['Points_1'].values   
        plot_dict[case]['periodic']['yI'] = inlet_df['Points_1'].values
        plot_dict[case]['periodic']['yO'] = outlet_df['Points_1'].values
        plot_dict[case]['periodic']['yM'] = midway_df['Points_1'].values

        plot_dict[case]['periodic']['Ux'] = df['UMean_0'].values
        plot_dict[case]['periodic']['UI'] = inlet_df['UMean_0'].values
        plot_dict[case]['periodic']['UImean'] = np.mean(plot_dict[case]['periodic']['UI'])
        plot_dict[case]['periodic']['UO'] = outlet_df['UMean_0'].values
        plot_dict[case]['periodic']['UM'] = midway_df['UMean_0'].values
        
        plot_dict[case]['periodic']['m_dot'] = density*A*plot_dict[case]['periodic']['UImean']
        
        plot_dict[case]['periodic']['T'] = df['TMean'].values
        plot_dict[case]['periodic']['TI'] = inlet_df['TMean'].values
        plot_dict[case]['periodic']['TO'] = outlet_df['TMean'].values
        plot_dict[case]['periodic']['TM'] = midway_df['TMean'].values

        #---------------------- h for periodic ------------------------------------------------
        periodic_TI = (np.sum(np.multiply(plot_dict[case]['periodic']['TI'], plot_dict[case]['periodic']['UI']))/
                       np.sum(plot_dict[case]['periodic']['UI']))
        periodic_TO = (np.sum(np.multiply(plot_dict[case]['periodic']['TO'], plot_dict[case]['periodic']['UO']))/
                       np.sum(plot_dict[case]['periodic']['UO']))

        temp_lm = ((Twall - periodic_TO) - (Twall - periodic_TI))/(np.log((Twall - periodic_TO)/(Twall-periodic_TI)))
        
        plot_dict[case]['periodic']['h'] = np.abs(plot_dict[case]['periodic']['m_dot']*Cp*(periodic_TO-periodic_TI)/(h_area*temp_lm))
        plot_dict[case]['periodic']['Nu'] = plot_dict[case]['periodic']['h']*Dh/k

        plot_dict[case]['periodic']['Pdrop'] = np.abs(solver[case]['pg'][-1000:-1])*density

        plot_dict[case]['periodic']['f'] = ((plot_dict[case]['periodic']['Pdrop']*Dh)/
                                            (plot_dict[case]['periodic']['L']*2*density*plot_dict[case]['periodic']['UImean']**2))

    #---------------------- Analytical Dictionary-------------------------------------------
        analytical = TemperatureField(plot_dict[case]['periodic']['L'], H, Dh, Pr, plot_dict[case]['periodic']['UImean'], nu, 
                                  new_lambda, lambda_coeffs, method=method)  
        
        mid_index = int(len(analytical.xv)/2)
        plot_dict[case]['analytical']['x'] = analytical.xv
        plot_dict[case]['analytical']['y'] = analytical.yv        
        plot_dict[case]['analytical']['yI'] = analytical.yv[:,0]
        plot_dict[case]['analytical']['yO'] = analytical.yv[:,-1]       
        plot_dict[case]['analytical']['yM'] = analytical.yv[:,mid_index]

        plot_dict[case]['analytical']['Ux'] = analytical.Ux
        plot_dict[case]['analytical']['UI'] = analytical.Ux[:,0]
        plot_dict[case]['analytical']['UImean'] = np.mean(plot_dict[case]['analytical']['UI'])
        plot_dict[case]['analytical']['UO'] = analytical.Ux[:,-1]
        plot_dict[case]['analytical']['UM'] = analytical.Ux[:,mid_index]
        
        plot_dict[case]['analytical']['m_dot'] = density*A*plot_dict[case]['analytical']['UImean']
        
        plot_dict[case]['analytical']['Yfield'] = analytical.Y_field
        plot_dict[case]['analytical']['T'] = (analytical.T_field*(np.max(plot_dict[case]['periodic']['T'])-Twall)) + Twall
        # plot_dict[case]['analytical']['T'] = (analytical.T_field*np.max(plot_dict[case]['periodic']['T']))
        plot_dict[case]['analytical']['TI'] = plot_dict[case]['analytical']['T'][:,0]
        plot_dict[case]['analytical']['TO'] = plot_dict[case]['analytical']['T'][:,-1]
        plot_dict[case]['analytical']['TM'] = plot_dict[case]['analytical']['T'][:,mid_index]

        #---------------------- h for analytical ------------------------------------------------
        analytical_TI = (np.sum(np.multiply(plot_dict[case]['analytical']['TI'], plot_dict[case]['analytical']['UI']))/
                         np.sum(plot_dict[case]['analytical']['UI']))
        analytical_TO = (np.sum(np.multiply(plot_dict[case]['analytical']['TO'], plot_dict[case]['analytical']['UO']))/
                         np.sum(plot_dict[case]['analytical']['UO']))

        temp_lm = ((Twall - analytical_TO) - (Twall - analytical_TI))/(np.log((Twall - analytical_TO)/(Twall-analytical_TI)))
        
        plot_dict[case]['analytical']['h'] = plot_dict[case]['analytical']['m_dot']*Cp*(analytical_TO-analytical_TI)/(h_area*temp_lm)
        plot_dict[case]['analytical']['Nu'] = plot_dict[case]['analytical']['h']*Dh/k

        plot_dict[case]['periodic']['error'] = np.divide(np.abs(plot_dict[case]['analytical']['h'] - plot_dict[case]['periodic']['h']),
                                                         plot_dict[case]['analytical']['h'])

        Re = (Dh*plot_dict[case]['periodic']['UImean'])/nu

    return plot_dict

def loop_plots_T(plots, H, colormap, Twall):
    fig_x, fig_y = 12, 10
    figs = {}
    axs = {}
    for idx, case in enumerate(plots):
        figs[case], axs[case] = plt.subplots(figsize=(fig_x, fig_y), sharex=True)
        plt.subplots_adjust(hspace=0.1)
        figs[case].tight_layout()
        figs[case].subplots_adjust(top=0.94)
        figs[case].suptitle(case)
        
        H2 = H/2

        analytical_T = (plots[case]['analytical']['T']-Twall)/np.mean(plots[case]['analytical']['TI']-Twall)
        periodic_T = (plots[case]['periodic']['T']-Twall)/np.mean(plots[case]['periodic']['TI']-Twall)
        
        cf1 = axs[case].contour(plots[case]['analytical']['x'], plots[case]['analytical']['y']/H2, analytical_T, levels=20, cmap=colormap)

        index = np.where(plots[case]['periodic']['y'] <= 0, True, False)
        cf3 = axs[case].tricontour(plots[case]['periodic']['x'][index]/plots[case]['periodic']['L'], 
                                   plots[case]['periodic']['y'][index]/H2, periodic_T[index], levels=20, cmap=colormap)

        axs[case].plot([0, plots[case]['periodic']['L']], [0, 0], 'k', linewidth=3)
        
        axs[case].clabel(cf1, inline=True, fontsize=20)
        axs[case].clabel(cf3, inline=True, fontsize=20)

        axs[case].set_xlabel('x/L')
        axs[case].set_ylabel('y/H')
    
    return figs, axs

def loop_plots_Ux(plots, H, colormap):
    fig_x, fig_y = 12, 10
    figs = {}
    axs = {}
    for idx, case in enumerate(plots):
        figs[case], axs[case] = plt.subplots(figsize=(fig_x, fig_y), sharex=True)
        plt.subplots_adjust(hspace=0.1)
        figs[case].tight_layout()
        figs[case].subplots_adjust(top=0.94)
        figs[case].suptitle(case)

        H2 = H/2

        cf1 = axs[case].contour(plots[case]['analytical']['x']/plots[case]['periodic']['L'], plots[case]['analytical']['y']/H2, 
                                    plots[case]['analytical']['Ux']/plots[case]['analytical']['UImean'], levels=20, cmap=colormap)

        index = np.where(plots[case]['periodic']['y'] <= 0, True, False)
        cf3 = axs[case].tricontour(plots[case]['periodic']['x'][index]/plots[case]['periodic']['L'], plots[case]['periodic']['y'][index]/H2, 
                                    plots[case]['periodic']['Ux'][index]/plots[case]['periodic']['UImean'], levels=20, cmap=colormap)

        axs[case].plot([0, plots[case]['periodic']['L']], [0, 0], 'k', linewidth=3)
        
        axs[case].clabel(cf1, inline=True, fontsize=20)
        axs[case].clabel(cf3, inline=True, fontsize=20)

        axs[case].set_xlabel('x/L')
        axs[case].set_ylabel('y/H')
    
    return figs, axs

def loop_plots_all(plots, H, Tw, colormap):
    fig_x, fig_y = 20, 12
    figs = {}
    axs = {}
    for idx, case in enumerate(plots):
        figs[case], axs[case] = plt.subplots(1,3, figsize=(fig_x, fig_y), width_ratios=[2,1,1], sharey=True)
        plt.subplots_adjust(hspace=0.1)
        figs[case].tight_layout()
        figs[case].subplots_adjust(top=0.95)
        figs[case].suptitle(case)
        
        periodic_T = plots[case]['periodic']['T']-Tw
        analytical_T = plots[case]['analytical']['T']-Tw
        periodic_TI = plots[case]['periodic']['TI']-Tw
        analytical_TI = plots[case]['analytical']['TI']-Tw
        periodic_TO = plots[case]['periodic']['TO']-Tw
        analytical_TO = plots[case]['analytical']['TO']-Tw
        periodic_TM = plots[case]['periodic']['TM']-Tw
        analytical_TM = plots[case]['analytical']['TM']-Tw

        H2 = H/2

        index = np.where(plots[case]['periodic']['y'] <= 0, True, False)

        cf1 = axs[case][0].contour(plots[case]['analytical']['x']/plots[case]['periodic']['L'], plots[case]['analytical']['y']/H2, 
                                     (analytical_T/np.mean(analytical_TI)), levels=20, cmap=colormap)
        cf2 = axs[case][0].tricontour(plots[case]['periodic']['x'][index]/plots[case]['periodic']['L'], plots[case]['periodic']['y'][index]/H2, 
                                     (periodic_T[index]/np.mean(periodic_TI)), levels=20, cmap=colormap)
        
        axs[case][0].plot([0, plots[case]['periodic']['L']], [0, 0], 'k', linewidth=3)

        periodic_UI = plots[case]['periodic']['UI']
        analytical_UI = plots[case]['analytical']['UI']
        periodic_UO = plots[case]['periodic']['UO']
        analytical_UO = plots[case]['analytical']['UO']
        periodic_UM = plots[case]['periodic']['UM']
        analytical_UM = plots[case]['analytical']['UM']

        step = 3
        #---------------- Velocity Profile ------------------------------------
        axs[case][1].plot((periodic_UI[::step]/plots[case]['periodic']['UImean']), plots[case]['periodic']['yI'][::step]/H2, 'o', color='C0', label = 'x=0')
        axs[case][1].plot(analytical_UI/plots[case]['analytical']['UImean'], plots[case]['analytical']['yI']/H2, '--', color='C0')
        axs[case][1].plot(analytical_UI/plots[case]['analytical']['UImean'], -plots[case]['analytical']['yI']/H2, '--', color='C0')
        
        axs[case][1].plot((periodic_UM[::step]/plots[case]['periodic']['UImean']), plots[case]['periodic']['yI'][::step]/H2, '^', color='C1', label='x=L/2')
        axs[case][1].plot(analytical_UM/plots[case]['analytical']['UImean'], plots[case]['analytical']['yI']/H2, '--', color='C1')
        axs[case][1].plot(analytical_UM/plots[case]['analytical']['UImean'], -plots[case]['analytical']['yI']/H2, '--', color='C1')

        axs[case][1].plot((periodic_UO[::step]/plots[case]['periodic']['UImean']), plots[case]['periodic']['yI'][::step]/H2, 'd', color='C2', label='x=L')
        axs[case][1].plot(analytical_UO/plots[case]['analytical']['UImean'], plots[case]['analytical']['yI']/H2, '--', color='C2')
        axs[case][1].plot(analytical_UO/plots[case]['analytical']['UImean'], -plots[case]['analytical']['yI']/H2, '--', color='C2')

        #------------------ Temperature Profile -----------------------------------------
        axs[case][2].plot((periodic_TI[::step]/np.mean(periodic_TI)), plots[case]['periodic']['yI'][::step]/H2, 'o', color='C0', label='x=0')
        axs[case][2].plot(analytical_TI/np.mean(analytical_TI), plots[case]['analytical']['yI']/H2, '--', color='C0')
        axs[case][2].plot(analytical_TI/np.mean(analytical_TI), -plots[case]['analytical']['yI']/H2, '--', color='C0')

        axs[case][2].plot((periodic_TM[::step]/np.mean(periodic_TI)), plots[case]['periodic']['yI'][::step]/H2, '^', color='C1', label='x=L/2')
        axs[case][2].plot(analytical_TM/np.mean(analytical_TI), plots[case]['analytical']['yI']/H2, '--', color='C1')
        axs[case][2].plot(analytical_TM/np.mean(analytical_TI), -plots[case]['analytical']['yI']/H2, '--', color='C1')

        axs[case][2].plot((periodic_TO[::step]/np.mean(periodic_TI)), plots[case]['periodic']['yI'][::step]/H2, 'd', color='C2', label='x=L')
        axs[case][2].plot(analytical_TO/np.mean(analytical_TI), plots[case]['analytical']['yI']/H2, '--', color='C2')
        axs[case][2].plot(analytical_TO/np.mean(analytical_TI), -plots[case]['analytical']['yI']/H2, '--', color='C2')

        axs[case][0].clabel(cf1, inline=True, fontsize=20)
        axs[case][0].clabel(cf2, inline=True, fontsize=20)

        axs[case][0].set_xlabel('x/L')
        axs[case][0].set_ylabel('y/H')

        axs[case][1].set_xlabel('$U/U_0$')

        x_ticks = np.arange(0, 2.5, 0.5)
        # Rewrite the y labels
        axs[case][1].set_xticks(x_ticks)
        axs[case][1].xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))

        axs[case][2].set_xlabel('$T/T_0$')
        # Rewrite the y labels
        axs[case][2].set_xticks(x_ticks)
        axs[case][2].xaxis.set_major_formatter(ticker.FormatStrFormatter('%.1f'))

        axs[case][1].legend()
        axs[case][2].legend()
    
    return figs, axs

def search_str(file_path):
    time = []
    pg = []
    with open(file_path, 'r') as f:
        for line in f:
            if re.match(r'Time = ', line):
                time.append(float(line.split()[2]))
            elif re.match(r'Pressure gradient', line):
                pg.append(float(line.split()[-1]))
    
    return np.array(time), np.array(pg)


solver_type = ['steady', 'unsteady']
fluid_types = ['air','water']

flow_type = 'HP'
for subcase in solver_type:
    for fluid in fluid_types:

        data_directory = Path(r'./').joinpath('Re_runs_'+ subcase, fluid)
        periodic_paths = list(data_directory.glob('*.csv'))
        solver_files = sorted(list(data_directory.rglob('log.cyclic*Foam')))

        save_path = data_directory

        # grab the pressure gradient
        solver = {}
        for file in solver_files:
            case = file.parent.stem
            solver[case] = {}
            solver[case]['time'], solver[case]['pg'] = search_str(file)

        
        if fluid == 'air':
            working_fluid = Fluid(FluidsList.Air).with_state(
                Input.pressure(101325), Input.temperature(35)
            )
        else:
            working_fluid = Fluid(FluidsList.Water).with_state(
                Input.pressure(101325), Input.temperature(35)
            )

        Twall = 293.15

        plot_dict = {}

        L = 1.0
        H = 0.5
        W = 0.02

        A = H*W
        h_area = 2*L*W
        Dh = 2*H

        Pr = working_fluid.prandtl
        density = working_fluid.density
        mu = working_fluid.dynamic_viscosity
        nu = working_fluid.dynamic_viscosity/working_fluid.density 
        Cp = working_fluid.specific_heat
        alpha = nu/Pr
        k = Cp*mu/Pr

        plot_dict = make_dict(periodic_paths, Dh, Pr, density, nu, Cp, k, H, A, h_area, Twall, True, 15 , 'Wang', solver)

        # plot the htc
        fig1, ax1 = plt.subplots(figsize=(12, 10), facecolor='w')

        percent_diff = []

        axins = plt.axes([0,0,1,1])
        ip = InsetPosition(ax1, [0.4,0.2,0.5,0.5])
        axins.set_axes_locator(ip)

        for case in plot_dict:
            ax1.scatter(float(case[2:]), np.abs(plot_dict[case]['periodic']['h']), marker='o', color='C0')  
            ax1.scatter(float(case[2:]), np.abs(plot_dict[case]['analytical']['h']), marker='v', color='C1')
            
            axins.scatter(float(case[2:]), np.abs(plot_dict[case]['periodic']['h']), marker='o', color='C0')
            axins.scatter(float(case[2:]), np.abs(plot_dict[case]['analytical']['h']), marker='v', color='C1')
            
            percent_diff.append(np.divide(np.abs(plot_dict[case]['analytical']['h'] - plot_dict[case]['periodic']['h']),plot_dict[case]['analytical']['h']))

        # k = 1.717
        h_nu = 7.54*k/(2*H)
        ax1.scatter(0.0, h_nu, marker='d', linewidth=3, color='c', label='Nusselt Coorelation')
        axins.scatter(0.0, h_nu, marker='d', linewidth=3, color='c')

        legend_elements = [Line2D([0], [0], marker='o', color='w', label='periodic',
                                markerfacecolor='C0', markersize=12),
                        Line2D([0], [0], marker='v', color='w', label='analytical',
                                markerfacecolor='C1', markersize=12),
                        Line2D([0], [0], marker='d', color='w', label='From Coorelation',
                                markerfacecolor='c', markersize=12)]

        # plt.xticks(rotation=45)
        ax1.set_xlabel('Reynolds Number')
        ax1.set_ylabel('Local Heat Transfer Coefficient [W/mK]')
        if fluid == 'air':
            ax1.set_ylim([0.0, 0.25])
        else:
            ax1.set_ylim([0.0, 5.5])
        ax1.legend(handles=legend_elements, ncol=3, bbox_to_anchor=(1.02, 1.09))
        fig1.savefig(save_path.joinpath(flow_type + '_' + subcase + '_' + fluid + '_ht_coefficient_steady.pdf'), facecolor='w', dpi=300, bbox_inches='tight')
        plt.close()

        # plot the Nusselt Number
        fig2, ax2 = plt.subplots(figsize=(12, 10), facecolor='w')

        percent_diff = []


        axins = plt.axes([0,0,1,1])
        ip = InsetPosition(ax2, [0.4,0.2,0.5,0.5])
        axins.set_axes_locator(ip)

        for case in plot_dict:
            ax2.scatter(float(case[2:]), np.abs(plot_dict[case]['periodic']['Nu']), marker='o', color='C0')  
            ax2.scatter(float(case[2:]), np.abs(plot_dict[case]['analytical']['Nu']), marker='v', color='C1')
            
            axins.scatter(float(case[2:]), np.abs(plot_dict[case]['periodic']['Nu']), marker='o', color='C0')
            axins.scatter(float(case[2:]), np.abs(plot_dict[case]['analytical']['Nu']), marker='v', color='C1')
            
            percent_diff.append(np.divide(np.abs(plot_dict[case]['analytical']['Nu'] - plot_dict[case]['periodic']['Nu']),plot_dict[case]['analytical']['Nu']))

        # k = 1.717
        Nu = 7.54
        ax2.scatter(0.0, Nu, marker='d', linewidth=3, color='c', label='Nusselt Coorelation')
        axins.scatter(0.0, Nu, marker='d', linewidth=3, color='c')

        legend_elements = [Line2D([0], [0], marker='o', color='w', label='periodic',
                                markerfacecolor='C0', markersize=12),
                        Line2D([0], [0], marker='v', color='w', label='analytical',
                                markerfacecolor='C1', markersize=12),
                        Line2D([0], [0], marker='d', color='w', label='From Coorelation',
                                markerfacecolor='c', markersize=12)]

        # plt.xticks(rotation=45)
        ax2.set_xlabel('Reynolds Number')
        ax2.set_ylabel('Nusselt Number')
        if fluid == 'air':
            ax2.set_ylim([0.0, 8.5])
        else:
            ax2.set_ylim([0.0, 9.0])
        ax2.legend(handles=legend_elements, ncol=3, bbox_to_anchor=(1.02, 1.09))
        fig2.savefig(save_path.joinpath(flow_type + '_' + subcase + '_' + fluid + '_Nusselt.pdf'), facecolor='w', dpi=300, bbox_inches='tight')
        plt.close()

        # plot the temperature fields
        figs, axs = loop_plots_T(plot_dict, H, 'twilight', Twall)
        for case in figs:
            figs[case].savefig(save_path.joinpath(flow_type + '_' + case + '_' + subcase + '_' + fluid + '_T.pdf'), facecolor='w', dpi=300, bbox_inches='tight')
            plt.close()

        # plot the velocity fields
        figs, axs = loop_plots_Ux(plot_dict, H, 'twilight')
        for case in figs:
            figs[case].savefig(save_path.joinpath(flow_type + '_' + case + '_' + subcase + '_' + fluid + '_Ux.pdf'), facecolor='w', dpi=300, bbox_inches='tight')
            plt.close()

        # plot all the fields combined
        figs, axs = loop_plots_all(plot_dict, H, Twall,'twilight')
        for case in figs:
            figs[case].savefig(save_path.joinpath(flow_type + '_' + case + '_' + subcase + '_' + fluid + '_combined.pdf'), facecolor='w', dpi=300, bbox_inches='tight')
            plt.close()
