import numpy as np
import sympy as sp
import scipy.optimize as so


class TemperatureField:
    def __init__(self,
                 L = 1.0,
                 H = 0.5,
                 Dh = 1.0,
                 Pr = 0.7,
                 Um = 1e-03,
                 viscosity=7.0e-5,
                 new_lambda = True,
                 lambda_coeffs=20,
                 method='Wang'):
        
        self.Um = Um
        self.x = np.linspace(0.0, L, 100)
        self.y = np.linspace(0.0, H/2, 100)
        self.xv, self.yv = np.meshgrid(self.x, self.y)

        self.Ux = (3/2)*self.Um*(1 - (2*self.yv/H)**2)
        self.Pr = Pr
        self.Re_a = (2*self.Um*H)/viscosity
        
        if method == 'Wang': 
            if new_lambda:
                x_sym = sp.Symbol('x')
                
                self.expr = self.coeffs_Wang(x_sym, int(lambda_coeffs))
                f = sp.lambdify(x_sym, self.expr, 'numpy')
                self.e_values = so.newton(f, 1.6, maxiter=1000)
                self.lam = np.min(self.e_values)
            else:
                self.lam=lambda_coeffs
            
            self.Y = self.YField_Wang(H, N=50)

        elif method == 'Li':
            if new_lambda:
                x_sym = sp.Symbol('x')
                
                self.expr = self.coeffs_Li(x_sym, int(lambda_coeffs))
                f = sp.lambdify(x_sym, self.expr, 'numpy')
                self.e_values = so.newton(f, 1.6, maxiter=1000)
                self.lam = np.min(self.e_values)
            else:
                self.lam=lambda_coeffs
            
            self.Y = self.YField_Li(H, N=50)

        _, self.Y_field = np.meshgrid(self.x, self.Y)

        self.T_field = np.multiply(self.Y_field, np.exp((-8/3)*self.lam**2*((2*self.xv)/(H*self.Pr*self.Re_a))))

    def coeffs_Wang(self, x, N):
        m = -(1+((8*x)/(3*self.Pr*self.Re_a))**2)*x**2
        n = x**2
        
        b0 = 1
        b1 = m*b0/2

        B = [b0, b1]
        for k in range(N):
            if k > 1:
                B.append(((m*B[k-1]) + (n*B[k-2]))/((2*k)*((2*k)-1)))

        expr = 0.0
        for i in B:
            expr += i
        
        return expr
    
    # @staticmethod
    # @jit
    def YField_Wang(self, H, N=50):
        m = -(1+((8*self.lam)/(3*self.Pr*self.Re_a))**2)*self.lam**2
        n = self.lam**2
        
        b0 = 1
        b1 = m*b0/2

        B = np.zeros(N)
        B[0], B[1] = b0, b1

        for k in np.arange(2,N,1):
            B[k] = (((m*B[k-1]) + (n*B[k-2]))/((2*k)*((2*k)-1)))

        Y = np.zeros(self.y.shape[0])
        for i in range(len(Y)):
            for j in range(len(B)):
                Y[i] += B[j]*((2*self.y[i])/H)**(2*j)
    
        return Y
    def coeffs_Li(self, x, N):
        m = -(1+((8*x)/(3*self.Pr*self.Re_a))**2)*x**2
        n = x**2
        
        b0 = 1
        b1 = 0
        b2 = (m+n)/2
        b3 = n/3

        B = [b0, b1, b2, b3]
        for k in range(N):
            if k > 3:
                B.append((((m+n)*B[k-2]) + (2*n*B[k-3]) + (n*B[k-4]))/((k+1)*(k+2)))

        expr = 0.0
        for i in B:
            expr += i
        
        return expr
    
    # @staticmethod
    # @jit
    def YField_Li(self, H, N=50):
        m = -(1+((8*self.lam)/(3*self.Pr*self.Re_a))**2)*self.lam**2
        n = self.lam**2
        
        b0 = 1
        b1 = 0
        b2 = (m+n)/2
        b3 = n/3

        B = np.zeros(N)
        B[0], B[1], B[2], B[3] = b0, b1, b2, b3

        for k in np.arange(4,N,1):
            B[k] = (((m+n)*B[k-2]) + (2*n*B[k-3]) + (n*B[k-4]))/((k+1)*(k+2))

        Y = np.zeros(self.y.shape[0])
        for i in range(len(Y)):
            for j in range(len(B)):
                Y[i] += B[j]*(((2*self.y[i])/H)+1)**j
        
        return Y
