from pyomo.environ import *
from pyomo.dae import *

def preprocessing(V_Tray, eps_Tray, rho, M, V_in):
    V = eps_Tray * V_Tray
    n = V / M * rho 
    N_in = V_in / M * rho
    return V, n, N_in

def adsorption(V_ad,X_REE_in,N_in,n_S,V_S,k_ad,B_inf,M_S,rho_S):
    
    m = ConcreteModel()

    m.t = ContinuousSet(bounds=(0.0, 100.0)) # Time
    
    ### Parameters
    ## Input
    m.N_in = Param(initialize=N_in) # Flow Rate of Solvent into Tray
    m.X_REE_in = Param(initialize=X_REE_in) # Incoming Load of REE on Solvent
    m.n_S = Param(initialize=n_S) # Amount of Solvent in Tray

    ## Geometry
    m.V_S = Param(initialize=V_S) # Volume of Liquid in Tray

    ## Kinetics
    m.k_ad = Param(initialize=k_ad) # Kinetic Parameter of Adsorption
    m.B_inf = Param(initialize=B_inf) # Maximum Binding Capacity

    ## Solvent Attributes
    m.M_S = Param(initialize=M_S) # Molecular Weight of Solution
    m.rho_S = Param(initialize=rho_S) # Density of Solution

    ### Variables
    m.n_REE = Var(m.t) # Amount of REE in Tra
    m.N_out = Var(m.t) # Flow Rate of Solvent out of Tray
    m.X_REE = Var(m.t) # Load of REE on Solvent

    m.n_BS = Var(m.t, bounds=(0,B_inf)) # Number of Unoccupied Binding Sites on Mycelium
    m.n_B = Var(m.t) # Number of Occupied Binding Sites on Mycelium

    m.V_prod = Var(m.t)
    m.n_prod = Var(m.t)

    ### Differentials
    m.dn_REEdt = DerivativeVar(m.n_REE, wrt=m.t)
    m.dn_BSdt = DerivativeVar(m.n_BS, wrt=m.t)
    m.dn_Bdt = DerivativeVar(m.n_B, wrt=m.t)
    m.dV_proddt = DerivativeVar(m.V_prod, wrt=m.t)
    m.dn_proddt = DerivativeVar(m.n_prod, wrt=m.t)


    ### Initial Conditions
    m.n_REE[0].fix(0)
    m.n_BS[0].fix(m.B_inf)
    m.n_B[0].fix(0)
    m.V_prod[0].fix(0)
    m.n_prod[0].fix(0)

    ### Differential Equations
    def _diffeq1(m, t):
        return m.dn_REEdt[t] == m.N_in * m.X_REE_in - m.N_out[t] * m.X_REE[t] - m.k_ad * m.n_REE[t] / m.V_S * m.n_BS[t]
    m.diffeq1 = Constraint(m.t, rule=_diffeq1)

    def _diffeq2(m, t):
        return m.dn_Bdt[t] == m.k_ad * m.n_REE[t] / m.V_S * m.n_BS[t]
    m.diffeq2 = Constraint(m.t, rule=_diffeq2)

    def _diffeq3(m, t):
        return m.dV_proddt[t] == m.N_out[t] * m.M_S / m.rho_S
    m.diffeq3 = Constraint(m.t, rule=_diffeq3)

    def _diffeq4(m, t):
        return m.dn_proddt[t] == m.N_out[t] * m.X_REE[t]
    m.diffeq4 = Constraint(m.t, rule=_diffeq4)

    ### Algebraic Equations
    def _eq1(m, t):
        return 0 == m.N_in - m.N_out[t]
    m.eq1 = Constraint(m.t, rule=_eq1)

    def _eq2(m, t):
        return m.X_REE[t] == m.n_REE[t] / m.n_S
    m.eq2 = Constraint(m.t, rule=_eq2)

    def _eq3(m, t):
        return m.n_BS[t] == m.B_inf - m.n_B[t]
    m.eq3 = Constraint(m.t, rule=_eq3)

    ### Simulator and Output
    sim = Simulator(m, package='casadi') 

    tsim, profiles = sim.simulate(numpoints=100000, integrator='collocation') 

    varlist = [str(v)[:-5] for v in sim.get_variable_order()]
    varlist += [str(v)[:-5] for v in sim.get_variable_order('time-varying')]

    vartable = dict.fromkeys(varlist)
    for v in varlist:
        vartable[v] = profiles[:,varlist.index(v)]

    vol_index = 0
    for v in vartable["V_prod"]:
        if v <= V_ad:
            vol_index += 1

    for var in vartable.keys():
        vartable[var] = vartable[var][1:vol_index-1]
    tsim = tsim[1:vol_index-1]

    return tsim, vartable


def desorption(V_de,n_REE_bound,N_in,n_D,V_D,k_de,B_inf,M_D,rho_D):

    m = ConcreteModel()

    m.t = ContinuousSet(bounds=(0.0, 100.0)) # Time
    
    ### Parameters
    ## Input
    m.N_in = Param(initialize=N_in) # Flow Rate of Solvent into Tray
    m.n_D = Param(initialize=n_D) # Amount of Solvent in Tray

    ## Geometry
    m.V_D = Param(initialize=V_D) # Volume of Liquid in Tray

    ## Kinetics
    m.k_de = Param(initialize=k_de)
    m.B_inf = Param(initialize=B_inf)

    ## Solvent Attributes
    m.M_D = Param(initialize=M_D) # Molecular Weight of the Dissolution Fluid
    m.rho_D = Param(initialize=rho_D) # Density of the Dissolution Fluid


    ### Variables
    m.n_REE = Var(m.t) # Amount of REE in Tray
    m.N_out = Var(m.t) # Flow Rate of Solvent out of Tray
    m.X_REE = Var(m.t) # Load of REE on Solvent

    m.n_BS = Var(m.t) # Number of Unoccupied Binding Sites on Mycelium
    m.n_B = Var(m.t) # Number of Occupied Binding Sites on Mycelium

    m.V_prod = Var(m.t)
    m.n_prod = Var(m.t)

    ### Differentials
    m.dn_REEdt = DerivativeVar(m.n_REE, wrt=m.t)
    m.dn_BSdt = DerivativeVar(m.n_BS, wrt=m.t)
    m.dn_Bdt = DerivativeVar(m.n_B, wrt=m.t)
    m.dV_proddt = DerivativeVar(m.V_prod, wrt=m.t)
    m.dn_proddt = DerivativeVar(m.n_prod, wrt=m.t)


    ### Initial Conditions
    m.n_REE[0].fix(0)
    m.n_BS[0].fix(m.B_inf - n_REE_bound)
    m.n_B[0].fix(n_REE_bound)
    m.V_prod[0].fix(0)
    m.n_prod[0].fix(0)

    ### Differential Equations
    def _diffeq1(m, t):
        return m.dn_REEdt[t] == - m.N_out[t] * m.X_REE[t] + m.k_de * m.n_B[t]
    m.diffeq1 = Constraint(m.t, rule=_diffeq1)

    def _diffeq2(m, t):
        return m.dn_Bdt[t] == - m.k_de * m.n_B[t]
    m.diffeq2 = Constraint(m.t, rule=_diffeq2)

    def _diffeq3(m, t):
        return m.dV_proddt[t] == m.N_out[t] * m.M_D / m.rho_D
    m.diffeq3 = Constraint(m.t, rule=_diffeq3)

    def _diffeq4(m, t):
        return m.dn_proddt[t] == m.N_out[t] * m.X_REE[t]
    m.diffeq4 = Constraint(m.t, rule=_diffeq4)

    ### Algebraic Equations
    def _eq1(m, t):
        return 0 == m.N_in - m.N_out[t]
    m.eq1 = Constraint(m.t, rule=_eq1)

    def _eq2(m, t):
        return m.X_REE[t] == m.n_REE[t] / m.n_D
    m.eq2 = Constraint(m.t, rule=_eq2)

    def _eq3(m, t):
        return m.n_BS[t] == m.B_inf - m.n_B[t]
    m.eq3 = Constraint(m.t, rule=_eq3)

    ### Simulator and Output
    sim = Simulator(m, package='casadi') 

    tsim, profiles = sim.simulate(numpoints=100000, integrator='collocation') 

    varlist = [str(v)[:-5] for v in sim.get_variable_order()]
    varlist += [str(v)[:-5] for v in sim.get_variable_order('time-varying')]

    vartable = dict.fromkeys(varlist)
    for v in varlist:
        vartable[v] = profiles[:,varlist.index(v)]

    vol_index = 0
    for v in vartable["V_prod"]:
        if v <= V_de:
            vol_index += 1

    for var in vartable.keys():
        vartable[var] = vartable[var][1:vol_index-1]
    tsim = tsim[1:vol_index-1]

    return tsim, vartable

def growth(duration,X_bio_0,V_Tray,eps_Tray,mu_max,K,rho_Sub,Y_XS):

    m = ConcreteModel()

    m.t = ContinuousSet(bounds=(0.0, duration)) # Time
    
    ### Parameters
    ## Geometry
    m.V_Tray = Param(initialize=V_Tray) # Volume of Tray
    m.eps_Tray = Param(initialize=eps_Tray) # Average Porosity of Tray

    ## Kinetics
    m.mu_max = Param(initialize=mu_max) # Maximum Growth Rate of Fungus
    m.K = Param(initialize=K) # Monod Constant of Fungus
    m.Y_XS = Param(initialize=Y_XS) # Substrate to Biomass Conversion Rate

    ## Substrate Attributes
    m.rho_Sub = Param(initialize=rho_Sub) # Density of Substrate


    ### Variables
    m.X_bio = Var(m.t) # Amount of Biomass in Tray
    m.S = Var(m.t, bounds=(0,m.V_Tray*rho_Sub)) # Amount of Substrate in Tray
    m.mu = Var(m.t) # Current Growth Rate

    ### Differentials
    m.dX_biodt = DerivativeVar(m.X_bio, wrt=m.t)
    m.dSdt = DerivativeVar(m.S, wrt=m.t)

    ### Initial Conditions
    m.X_bio[0].fix(X_bio_0)
    m.S[0].fix((1-m.eps_Tray)*m.V_Tray*m.rho_Sub)

    ### Differential Equations
    def _diffeq1(m, t):
        return m.dX_biodt[t] == m.mu[t] * m.X_bio[t]
    m.diffeq1 = Constraint(m.t, rule=_diffeq1)

    def _diffeq2(m, t):
        return m.dSdt[t] == - 1/m.Y_XS * m.mu[t] * m.X_bio[t]
    m.diffeq2 = Constraint(m.t, rule=_diffeq2)


    ### Algebraic Equations
    def _eq1(m, t):
        return m.mu[t] == m.mu_max * m.S[t] / (m.S[t] + (1-m.eps_Tray)*m.V_Tray*m.K)
    m.eq1 = Constraint(m.t, rule=_eq1)

    def _ineq1(m, t):
        return m.S[t] >= 0
    m.ineq1 = Constraint(m.t, rule=_ineq1)

    ### Simulator and Output
    sim = Simulator(m, package='casadi') 

    tsim, profiles = sim.simulate(numpoints=10000, integrator='collocation') 

    varlist = [str(v)[:-5] for v in sim.get_variable_order()]
    varlist += [str(v)[:-5] for v in sim.get_variable_order('time-varying')]

    vartable = dict.fromkeys(varlist)
    for v in varlist:
        vartable[v] = profiles[:,varlist.index(v)]

    return tsim, vartable