import numpy as np
import matplotlib.pyplot as plt

x, y = [], []

def process_file(file_path, title, time, quantity, savepath):
    x = []
    y = []
    line_counter = 0

    with open(file_path, 'r') as f:
        for line in f:
            if line_counter >= 18:  # Skip the first 20 lines
                cols = line.split()

                if len(cols) == 2:
                    x.append(float(cols[0]))
                    y.append(float(cols[1]))

            line_counter += 1
    fig = plt.figure(figsize=(15,10))
    ax1 = fig.add_subplot(111)
    ax1.set_title(title)    
    ax1.set_xlabel(time)
    ax1.set_ylabel(quantity)
    ax1.plot(x,y, c='r')
    plt.grid(True)
    plt.savefig(savepath)
    leg = ax1.legend()
    

process_file("data/rmsd.xvg", "RMSD", "Time (ns)", "RMSD (nm)", "plots/RMSD.png")
process_file("data/rmsd_xtal.xvg", "RMSD_xtal", "Time (ns)", "RMSD_xtal (nm)", "plots/RMSD_xtal.png")
process_file("data/density.xvg", "Density", "Time (ps)", "Density (kg/m^3)", "plots/density.png")
process_file("data/potential.xvg", "Potential", "Time (ps)", "Potential (kJ/mol)", "plots/potential.png")
process_file("data/temperature.xvg", "Temperature", "Time (ps)", "Temperature (K)", "plots/temperature.png")
#plt.show()