import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.backends.backend_tkagg import NavigationToolbar2Tk
from matplotlib.figure import Figure
import numpy as np
import threading
import queue
import os

import time
from datetime import datetime, timedelta

import TDS_Material
import TDS_Sim
from ModelEnsemble import ModelEnsemble
from Model_Parameters import Model_Parameters
from ExpDataProcessing import ExpDataProcessing
from UnitConverter import UnitConverter

def run_thermal_desorption_analysis(params, result_queue):
    """
    This function runs the ML analysis in a background thread and puts results in a queue.
    """
    try:
        # Extract stop flag if it exists
        stop_flag = params.get('stop_flag')
        
        material_params = params['material']
        test_params = params['test']
        numerical_params = params['numerical']
        ExpName = params['ExpName']
        trap_model = params['trap_model']
        exp_file = params['exp_file']
        training_parameters = params['training_parameters']
        HD_Trap = params['HD_Trap']

        # Check for stop before each major step
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        # Check if an experimental file was provided
        if not exp_file:
            result_queue.put({'status': 'error', 'message': 'No experimental data file provided.', 'progress': 100})
            return

        # Step 1: Initialize objects
        result_queue.put({'status': 'progress', 'progress': 10, 'message': "Initialising parameters...", 'use_advanced': True})
        
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return
            
        Material = TDS_Material.TDS_Material(ExpName, 
                               material_param=material_params, 
                               test_param=test_params, 
                               numerical_param=numerical_params,
                               HD_Trap_param=HD_Trap, 
                               trap_model=trap_model)
        HyperParameters = Model_Parameters(ParameterSet=training_parameters['ParameterSet'])
        
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return
        
        # Step 2: Set additional parameters
        Traps = training_parameters['Traps']
        Concentrations = training_parameters['Concentrations']
        MaxTraps = training_parameters['MaxTraps']
        Regenerate_Training = training_parameters['Regenerate_Training'] == "True"
        Regenerate_Data = training_parameters['Regenerate_Data'] == "True"
        
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return
        
        # Step 3: Create and train the Model Ensemble
        result_queue.put({'status': 'progress', 'progress': 20, 'message': "Generating data and training models...", 'use_advanced': True})
        
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        try:
            Model = ModelEnsemble(Material, Traps, MaxTraps, Concentrations, HyperParameters,
                                  numerical_params['NumTraining'], Regenerate_Data,
                                  Regenerate_Training, numerical_params['n_cpu_cores'])
        except Exception as model_error:
            if stop_flag and stop_flag.is_set():
                result_queue.put({'status': 'stopped'})
                return
            else:
                raise model_error

        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        # Step 4: Run verification
        result_queue.put({'status': 'progress', 'progress': 60, 'message': "Running model verification...", 'use_advanced': True})
        
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return
        
        TDS_Curves_Ver, Actual_Traps, Actual_Concentrations, Actual_Energies, TDS_Temp_Ver = TDS_Sim.SimDataSet(
            Material, numerical_params['NumVerification'], MaxTraps, Traps, Concentrations, numerical_params['n_cpu_cores'])
        
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return
        
        Predicted_Traps_Ver, Predicted_Concentrations_Ver, Predicted_Energies_Ver = Model.predict(TDS_Curves_Ver)

        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        # Step 5: Process experimental data
        result_queue.put({'status': 'progress', 'progress': 80, 'message': f"Processing experimental data...", 'use_advanced': True})

        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        # Get experimental units (this needs to be passed from GUI)
        exp_units = params.get('exp_units', {
            'temp_units': 'K', 
            'y_type': 'flux', 
            'y_units': 'mol/m\u00b2s'
        })

        # Updated ExpDataProcessing call
        Exp_Processed_Data = ExpDataProcessing(
            file_name=exp_file,
            temp_units=exp_units['temp_units'],
            y_units=exp_units['y_units'], 
            y_type=exp_units['y_type'],
            material=Material, 
            hyperparameters=HyperParameters)

        Exp_TDS_Curve = Exp_Processed_Data.TDS_Curve
        Exp_Temp = Exp_Processed_Data.Temperature
        Exp_Flux = Exp_Processed_Data.Flux

        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        result_queue.put({'status': 'progress', 'progress': 90, 'message': f"Making predictions...", 'use_advanced': True})
        Exp_Predicted_Traps, Exp_Predicted_Concentrations, Exp_Predicted_Energies = Model.predict(Exp_TDS_Curve)

        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
            return

        # Return all data needed for plotting in the main thread
        result_data = {
            'status': 'success',
            'progress': 100,
            'message': 'ML Analysis completed successfully!',
            'Material': Material,
            'Model': Model,
            'Predicted_Traps_Ver': Predicted_Traps_Ver,
            'Predicted_Concentrations_Ver': Predicted_Concentrations_Ver,
            'Predicted_Energies_Ver': Predicted_Energies_Ver,
            'TDS_Curves_Ver': TDS_Curves_Ver,
            'TDS_Temp_Ver': TDS_Temp_Ver,
            'Actual_Traps': Actual_Traps,
            'Actual_Concentrations': Actual_Concentrations,
            'Actual_Energies': Actual_Energies,
            'Exp_Temp': Exp_Temp,
            'Exp_Flux': Exp_Flux,
            'Predicted_Traps': Exp_Predicted_Traps,
            'Predicted_Concentrations': Exp_Predicted_Concentrations,
            'Predicted_Energies': Exp_Predicted_Energies
        }
        
        result_queue.put(result_data)

    except Exception as e:
        # Check if this was due to stopping
        if stop_flag and stop_flag.is_set():
            result_queue.put({'status': 'stopped'})
        else:
            result_queue.put({'status': 'error', 'message': f"An error occurred during ML analysis: {str(e)}", 'progress': 100})

class ScrolledFrame(tk.Frame):
    def __init__(self, parent, *args, **kwargs):
        super().__init__(parent, *args, **kwargs)
        
        self.canvas = tk.Canvas(self, borderwidth=0, highlightthickness=0)
        self.scrollbar = tk.Scrollbar(self, orient="vertical", command=self.canvas.yview)
        self.scrollable_frame = tk.Frame(self.canvas)

        self.scrollable_frame.bind(
            "<Configure>",
            lambda e: self.canvas.configure(scrollregion=self.canvas.bbox("all"))
        )

        self.canvas_window = self.canvas.create_window((0, 0), window=self.scrollable_frame, anchor="nw")
        self.canvas.configure(yscrollcommand=self.scrollbar.set)

        self.canvas.bind('<Configure>', self._configure_canvas)

        self.canvas.pack(side="left", fill="both", expand=True)
        self.scrollbar.pack(side="right", fill="y")
    
    def _configure_canvas(self, event):
        canvas_width = event.width
        self.canvas.itemconfig(self.canvas_window, width=canvas_width)
        
class SimulationGUI:
    def __init__(self, root):
        self.root = root
        self.root.title("TDS Simulator")
        self.root.geometry("1200x750")
        self.root.minsize(1200, 750)
        
        # Configure styles for bold frame labels
        style = ttk.Style()
        style.configure("Bold.TLabelframe.Label", font=("TkDefaultFont", 10, "bold"))
        style.configure("Bold.TLabelframe", labelanchor="n")
        
        # Configure grid weights for responsive layout
        self.root.grid_rowconfigure(0, weight=1)
        self.root.grid_columnconfigure(0, weight=1)
        self.root.grid_columnconfigure(1, weight=2)
        
        # Initialize trap frames list
        self.trap_frames = []
        
        # Storage for entry widgets to access values easily
        self.entry_widgets = {}
        
        # Queue for thread communication
        self.result_queue = queue.Queue()
        self.analysis_thread = None
        
        # Add stop flag for ML analysis control
        self.stop_ml_flag = threading.Event()
        
        # Store experimental data for immediate plotting
        self.exp_data = None
        
        self.setup_ui()
        
        # Start checking for thread results
        self.check_thread_results()

        # Initialize unit converter with default values
        try:
            self.unit_converter = UnitConverter(
                atw=1.008,           # Atomic weight of hydrogen
                mass_density=7.8474, # Default mass density
                thickness=0.0063     # Default thickness
            )
        except NameError:
            print("Warning: UnitConverter not available")
            self.unit_converter = None
            
        self.ml_start_time = None
        self.progress_animation_id = None
        
    def setup_ui(self):
        # Main container
        main_frame = ttk.Frame(self.root, padding="10")
        main_frame.grid(row=0, column=0, columnspan=2, sticky="nsew")
        main_frame.grid_rowconfigure(0, weight=1)
        main_frame.grid_columnconfigure(0, weight=1)
        main_frame.grid_columnconfigure(1, weight=2)
        
        # Left panel with tabs
        self.create_left_panel(main_frame)
        
        # Right panel with plot and output
        self.create_right_panel(main_frame)
        
    def create_left_panel(self, parent):
        left_frame = ttk.Frame(parent)
        left_frame.grid(row=0, column=0, sticky="nsew", padx=(0, 10))
        left_frame.grid_rowconfigure(0, weight=1)
        left_frame.grid_columnconfigure(0, weight=1)
        
        # Create notebook for tabs
        self.notebook = ttk.Notebook(left_frame)
        self.notebook.grid(row=0, column=0, sticky="nsew")
        
        # Tab 1: Simulation
        simulation_frame = ttk.Frame(self.notebook, padding="10")
        self.notebook.add(simulation_frame, text="Simulation")
        
        # Simulation models section
        models_frame = ttk.LabelFrame(simulation_frame, text="Simulation Model", padding="10", style="Bold.TLabelframe")
        models_frame.pack(fill="x", pady=(0, 15))
        
        # Trap model selection (radio buttons)
        self.trap_model_var = tk.StringVar(value="Oriani")
        
        oriani_radio = ttk.Radiobutton(models_frame, text="Oriani", 
                                      variable=self.trap_model_var, value="Oriani",
                                      command=self.update_traps_tab)
        oriani_radio.pack(anchor="w", pady=2)
        
        mcnabb_radio = ttk.Radiobutton(models_frame, text="McNabb Foster", 
                                      variable=self.trap_model_var, value="McNabb",
                                      command=self.update_traps_tab)
        mcnabb_radio.pack(anchor="w", pady=2)
        
        no_traps_radio = ttk.Radiobutton(models_frame, text="No traps", 
                                        variable=self.trap_model_var, value="None",
                                        command=self.update_traps_tab)
        no_traps_radio.pack(anchor="w", pady=(2, 10))
        
        # Individual trap contributions checkbox
        self.show_individual_var = tk.BooleanVar()
        individual_checkbox = ttk.Checkbutton(models_frame, text="Show individual trap contributions", 
                                            variable=self.show_individual_var)
        individual_checkbox.pack(anchor="w", pady=(0, 0))
        
        # Graphical output section
        output_frame = ttk.LabelFrame(simulation_frame, text="Output", padding="10", style="Bold.TLabelframe")
        output_frame.pack(fill="x", pady=(0, 15))
        
        # Graphical output dropdown
        ttk.Label(output_frame, text="Graphical output:").pack(anchor="w")
        self.graphical_output_combo = ttk.Combobox(output_frame, values=["Delta C", "Flux"], 
                                            state="readonly", width=18)
        self.graphical_output_combo.set("Flux")
        self.graphical_output_combo.pack(anchor="w", pady=(2, 10))
        
        # X-axis dropdown
        ttk.Label(output_frame, text="x-axis:").pack(anchor="w")
        self.x_axis_combo = ttk.Combobox(output_frame, values=["Temperature", "Time (s)"], 
                                  state="readonly", width=18)
        self.x_axis_combo.set("Temperature")
        self.x_axis_combo.pack(anchor="w", pady=(2, 10))
        
        # Flux unit dropdown
        ttk.Label(output_frame, text="Flux unit:").pack(anchor="w")
        self.flux_unit_combo = ttk.Combobox(output_frame, values=["mol/m\u00b2s", "mol/cm\u00b2s", "wppm m/s"], 
                                     state="readonly", width=18)
        self.flux_unit_combo.set("mol/m\u00b2s")
        self.flux_unit_combo.pack(anchor="w", pady=(2, 10))
        
        # Delta C unit dropdown
        ttk.Label(output_frame, text="Delta C unit:").pack(anchor="w")
        self.delta_c_unit_combo = ttk.Combobox(output_frame, values=["mol/m\u00b3s", "mol/cm\u00b3s", "wppm/s"], 
                                state="readonly", width=18)
        self.delta_c_unit_combo.set("mol/m\u00b3s")
        self.delta_c_unit_combo.pack(anchor="w", pady=(2, 10))
        
        # Temperature unit dropdown
        ttk.Label(output_frame, text="Temperature unit:").pack(anchor="w")
        self.temp_unit_combo = ttk.Combobox(output_frame, values=["K", "\u00b0C"], 
                                     state="readonly", width=18)
        self.temp_unit_combo.set("K")
        self.temp_unit_combo.pack(anchor="w", pady=(2, 10))
        
        # Grid checkbox
        self.grid_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(output_frame, text="Grid", 
                       variable=self.grid_var).pack(anchor="w", pady=(5, 0))
        
        # Tab 2: Model parameters
        self.create_model_parameters_tab()
        
        # Tab 3: Hydrogen traps
        self.create_traps_tab()
        
        # Tab 4: ML fitting
        self.create_ml_fitting_tab()
        
        # Initialize traps display
        self.update_traps_tab()

    def create_model_parameters_tab(self):
        """Create the model parameters tab"""
        model_frame = ttk.Frame(self.notebook, padding="10")
        self.notebook.add(model_frame, text="Parameters")
        
        # Create scrollable content for model parameters
        model_scroll = ScrolledFrame(model_frame)
        model_scroll.pack(fill="both", expand=True)
        model_content = model_scroll.scrollable_frame
        
        # Test inputs section
        test_frame = ttk.LabelFrame(model_content, text="Test Inputs", padding="10", style="Bold.TLabelframe")
        test_frame.pack(fill="x", pady=(0, 15))

        self.create_parameter_entry(test_frame, "Specimen thickness [m]:", "thickness", "0.0063")
        self.create_parameter_entry(test_frame, "Heating rate \u03d5 [K/s]:", "heating_rate", "0.055")
        self.create_parameter_entry(test_frame, "Resting time [s]:", "resting_time", "2700")
        self.create_parameter_entry(test_frame, "Minimum temperature [K]:", "min_temp", "293.15")
        self.create_parameter_entry(test_frame, "Maximum temperature [K]:", "max_temp", "873.15")

        # Numerical inputs section
        numerical_frame = ttk.LabelFrame(model_content, text="Numerical Inputs", padding="10", style="Bold.TLabelframe")
        numerical_frame.pack(fill="x", pady=(0, 15))

        self.create_parameter_entry(numerical_frame, "Number of temperature evaluations:", "ntp", "64")
        self.create_parameter_entry(numerical_frame, "Sample frequency:", "sample_freq", "10")

        # Material inputs section
        material_frame = ttk.LabelFrame(model_content, text="Material Inputs", padding="10", style="Bold.TLabelframe")
        material_frame.pack(fill="x", pady=(0, 15))

        self.create_parameter_entry(material_frame, "Activation energy for lattice diffusion E\u2081 [J/mol]:", "activation_energy", "5690")
        self.create_parameter_entry(material_frame, "Pre-exponential diffusion factor D\u2080 [m\u00b2/s]:", "diffusion_factor", "7.23e-8")
        self.create_parameter_entry(material_frame, "Molar mass [g/mol]:", "molar_mass", "55.847")
        self.create_parameter_entry(material_frame, "Mass density of material [g/cm\u00b3]:", "mass_density", "7.8474")
        self.create_parameter_entry(material_frame, "Density of lattice sites N\u2097 [mol/m\u00b3]:", "lattice_density", "8.47e5")
        self.create_parameter_entry(material_frame, "Initial H concentration in the lattice C\u2097\u2070 [mol/m\u00b3]:", "initial_conc", "0.06")

    def create_parameter_entry(self, parent, label_text, entry_name, default_value):
        """Helper method to create parameter entry fields"""
        ttk.Label(parent, text=label_text).pack(anchor="w")
        self.entry_widgets[entry_name] = ttk.Entry(parent, width=20)
        self.entry_widgets[entry_name].insert(0, default_value)
        self.entry_widgets[entry_name].pack(anchor="w", pady=(0, 5))

    def create_traps_tab(self):
        """Create the hydrogen traps tab"""
        self.traps_frame = ttk.Frame(self.notebook, padding="10")
        self.notebook.add(self.traps_frame, text="Hydrogen Traps")
        
        # Create scrollable frame for traps
        self.create_traps_content()

    def create_ml_fitting_tab(self):
        """Create the ML fitting tab"""
        ml_frame = ttk.Frame(self.notebook, padding="10")
        self.notebook.add(ml_frame, text="ML Data Fitting")
        
        # Create scrollable content for ML parameters
        ml_scroll = ScrolledFrame(ml_frame)
        ml_scroll.pack(fill="both", expand=True)
        ml_content = ml_scroll.scrollable_frame
        
        # Experimental Data Input Panel
        self.create_ml_exp_data_panel(ml_content)
        
        # Additional Numerical Parameters Panel
        self.create_ml_numerical_panel(ml_content)
        
        # ML Parameters Panel
        self.create_ml_parameters_panel(ml_content)

        # ML Control Panel
        self.create_ml_control_panel(ml_frame)

    def create_ml_parameters_panel(self, parent):
        """Create ML parameters panel"""
        ml_frame = ttk.LabelFrame(parent, text="ML Model Training Parameters", padding="10", style="Bold.TLabelframe")
        ml_frame.pack(fill="x", pady=(0, 15))

        fields = [
            ("CPU cores:", "ml_n_cpu_cores", "16"),
            ("Number training datapoints:", "ml_num_training", "50000"),
            ("Number verification datapoints:", "ml_num_verification", "500"),
            ("Hyperparameter set:", "ml_hp_set", "optimised"),
            ("Maximum traps:", "ml_max_traps", "4"),
            ("Traps:", "ml_traps", "Random"),
            ("Concentrations:", "ml_concentrations", "Random"),
            ("Regenerate data:", "ml_regenerate_data", "False"),
            ("Regenerate training:", "ml_regenerate_training", "False")
        ]
        
        for i, (label, attr, default) in enumerate(fields):
            ttk.Label(ml_frame, text=label).pack(anchor="w", pady=(5 if i > 0 else 0, 0))
            self.entry_widgets[attr] = ttk.Entry(ml_frame, width=20)
            self.entry_widgets[attr].insert(0, default)
            self.entry_widgets[attr].pack(anchor="w", pady=(2, 0))

    def create_ml_exp_data_panel(self, parent):
        """Create experimental data input panel for ML fitting"""
        exp_data_frame = ttk.LabelFrame(parent, text="Experimental Data Input", padding="10", style="Bold.TLabelframe")
        exp_data_frame.pack(fill="x", pady=(0, 15))
        
        ttk.Label(exp_data_frame, text="Excel file:").pack(anchor="w")
        
        file_frame = ttk.Frame(exp_data_frame)
        file_frame.pack(fill="x", pady=(5, 10))
        
        self.ml_exp_file_entry = ttk.Entry(file_frame, width=35)
        self.ml_exp_file_entry.pack(side="left", fill="x", expand=True)
        
        self.ml_browse_button = ttk.Button(file_frame, text="Browse", command=self.browse_ml_file)
        self.ml_browse_button.pack(side="right", padx=(5, 0))
        
        self.ml_load_plot_button = ttk.Button(file_frame, text="Plot", command=self.load_and_plot_ml_data)
        self.ml_load_plot_button.pack(side="right", padx=(5, 0))
        
        # Experimental data units section
        units_label_frame = ttk.Frame(exp_data_frame)
        units_label_frame.pack(fill="x", pady=(0, 5))
        ttk.Label(units_label_frame, text="Experimental data units:").pack(anchor="w")
        
        # Temperature and flux unit selection
        units_frame = ttk.Frame(exp_data_frame)
        units_frame.pack(fill="x", pady=(0, 10))
        
        # First column - Temperature
        temp_frame = ttk.Frame(units_frame)
        temp_frame.pack(side="left", fill="x", expand=True)

        ttk.Label(temp_frame, text="Temperature").pack(anchor="w")
        self.exp_temp_unit_combo = ttk.Combobox(temp_frame, values=["K", "\u00b0C"], 
                                            state="readonly", width=10)
        self.exp_temp_unit_combo.set("K")
        self.exp_temp_unit_combo.pack(anchor="w", pady=(2, 0))

        # Second column - Y-axis type selection
        y_type_frame = ttk.Frame(units_frame)
        y_type_frame.pack(side="left", fill="x", expand=True, padx=(20, 0))

        ttk.Label(y_type_frame, text="Y-axis data type").pack(anchor="w")
        self.exp_y_type_combo = ttk.Combobox(y_type_frame, values=["Flux", "Delta C"], 
                                        state="readonly", width=10)
        self.exp_y_type_combo.set("Flux")
        self.exp_y_type_combo.pack(anchor="w", pady=(2, 0))

        # Third column - Units (dynamic based on type)
        y_units_frame = ttk.Frame(units_frame)
        y_units_frame.pack(side="left", fill="x", expand=True, padx=(20, 0))

        ttk.Label(y_units_frame, text="Units").pack(anchor="w")
        self.exp_y_unit_combo = ttk.Combobox(y_units_frame, values=["mol/m\u00b2s", "mol/cm\u00b2s", "wppm m/s"], 
                                        state="readonly", width=15)
        self.exp_y_unit_combo.set("mol/m\u00b2s")
        self.exp_y_unit_combo.pack(anchor="w", pady=(2, 0))

        # Bind event to update units when type changes
        self.exp_y_type_combo.bind('<<ComboboxSelected>>', self._on_y_type_changed)

        ttk.Label(exp_data_frame, text="Test case:").pack(anchor="w", pady=(5, 0))
        self.entry_widgets['ml_exp_name'] = ttk.Entry(exp_data_frame, width=20)
        self.entry_widgets['ml_exp_name'].insert(0, "Novak_200")
        self.entry_widgets['ml_exp_name'].pack(anchor="w", pady=(5, 5))
        
    def _on_y_type_changed(self, event=None):
        """Update available units when y-axis type changes"""
        y_type = self.exp_y_type_combo.get()
        
        if y_type == "Flux":
            flux_units = ["mol/m\u00b2s", "mol/cm\u00b2s", "wppm m/s"]
            self.exp_y_unit_combo['values'] = flux_units
            self.exp_y_unit_combo.set("mol/m\u00b2s")
        elif y_type == "Delta C":
            delta_c_units = ["mol/m\u00b3s", "mol/cm\u00b3s", "wppm/s"]
            self.exp_y_unit_combo['values'] = delta_c_units
            self.exp_y_unit_combo.set("mol/m\u00b3s")

    def get_experimental_units(self):
        """Get the selected experimental data units and type"""
        return {
            'temp_units': self.exp_temp_unit_combo.get(),
            'y_type': self.exp_y_type_combo.get().lower().replace(' ', '_'),
            'y_units': self.exp_y_unit_combo.get()
        }

    def create_ml_numerical_panel(self, parent):
        """Create additional numerical parameters panel for ML fitting"""
        numerical_frame = ttk.LabelFrame(parent, text="Additional Numerical Parameters", padding="10", style="Bold.TLabelframe")
        numerical_frame.pack(fill="x", pady=(0, 15))
        
        fields = [
            ("Minimum difference in trap binding energies [J/mol]:", "ml_de_min", "10e3"),
            ("Density lower bound [mol/m\u00b3]:", "ml_n_range_min", "1e-1"),
            ("Density upper bound [mol/m\u00b3]:", "ml_n_range_max", "1e1"),
            ("Binding energy lower bound [J/mol]:", "ml_e_range_min", "50e3"),
            ("Binding energy upper bound [J/mol]:", "ml_e_range_max", "150e3"),
            ("High density trap:", "high_density_trap", "False"),
            ("(HDT) Density lower bound [mol/m\u00b3]:", "HDT_ml_n_range_min", "0"),
            ("(HDT) Density upper bound [mol/m\u00b3]:", "HDT_ml_n_range_max", "0"),
            ("(HDT) Binding energy lower bound [J/mol]:", "HDT_ml_e_range_min", "0"),
            ("(HDT) Binding energy upper bound [J/mol]:", "HDT_ml_e_range_max", "0")
        ]
        
        for i, (label, attr, default) in enumerate(fields):
            ttk.Label(numerical_frame, text=label).pack(anchor="w", pady=(5 if i > 0 else 0, 0))
            self.entry_widgets[attr] = ttk.Entry(numerical_frame, width=20)
            self.entry_widgets[attr].insert(0, default)
            self.entry_widgets[attr].pack(anchor="w", pady=(2, 0))

    def create_ml_control_panel(self, parent):
        """Create ML control panel"""
        buttons_frame = ttk.Frame(parent, padding="10")
        buttons_frame.pack(fill="x", pady=(0, 15))
        
        # Control buttons
        self.ml_button = ttk.Button(buttons_frame, text="Start ML Data Fitting", 
                                command=self.run_ml_analysis)
        self.ml_button.pack(side="left", padx=(0, 5))
        
        self.stop_ml_button = ttk.Button(buttons_frame, text="Stop", 
                                    command=self.stop_ml_analysis, state=tk.DISABLED)
        self.stop_ml_button.pack(side="left", padx=(0, 5))

    def browse_ml_file(self):
        """Open file dialog to select Excel file for ML fitting"""
        file_path = filedialog.askopenfilename(
            title="Select Experimental Data File for ML Fitting",
            filetypes=[("Excel files", "*.xlsx *.xls")]
        )
        if file_path:
            self.ml_exp_file_entry.delete(0, tk.END)
            self.ml_exp_file_entry.insert(0, file_path)

    def load_and_plot_ml_data(self):
        """Load and plot experimental data for ML fitting"""
        file_path = self.ml_exp_file_entry.get().strip()
        if not file_path:
            messagebox.showwarning("No File Selected", "Please select an experimental data file first.")
            return
        
        if not os.path.exists(file_path):
            messagebox.showerror("File Not Found", f"The file '{file_path}' does not exist.")
            return
        
        # Change clear_plot to False to preserve existing simulation data
        self.load_and_plot_experimental_data(file_path, is_ml=True, clear_plot=False)

    def create_traps_content(self):
        # Create a container frame that fills the entire traps_frame
        container = ttk.Frame(self.traps_frame)
        container.pack(fill="both", expand=True)
        
        # Create trap selector at the top (non-scrolling) - STORE REFERENCE
        self.selector_frame = ttk.Frame(container)
        self.selector_frame.pack(fill="x", pady=(0, 10))
        
        ttk.Label(self.selector_frame, text="Number of traps:").pack(side="left")
        self.traps_count_var = tk.StringVar(value="6")
        self.traps_combo = ttk.Combobox(self.selector_frame, 
                                    textvariable=self.traps_count_var,
                                    values=["1", "2", "3", "4", "5", "6"], 
                                    state="readonly", 
                                    width=5)
        self.traps_combo.pack(side="left", padx=(10, 0))
        self.traps_combo.bind("<<ComboboxSelected>>", self.on_traps_count_changed)
        
        # Add vibration frequency input for McNabb model (before scrollable area)
        self.vib_freq_frame = ttk.Frame(container)
        self.vib_freq_frame.pack(fill="x", pady=(0, 10))
        
        ttk.Label(self.vib_freq_frame, text="Vibration frequency [Hz]:").pack(side="left")
        self.vib_freq_entry = ttk.Entry(self.vib_freq_frame, width=15)
        self.vib_freq_entry.insert(0, "1e13")
        self.vib_freq_entry.pack(side="right")
        
        # Initially hide it
        self.vib_freq_frame.pack_forget()
        
        # Create scrollable container with grid layout
        scroll_container = ttk.Frame(container)
        scroll_container.pack(fill="both", expand=True)
        scroll_container.grid_rowconfigure(0, weight=1)
        scroll_container.grid_columnconfigure(0, weight=1)
        
        # Create canvas and scrollbar for scrolling
        self.canvas = tk.Canvas(scroll_container, highlightthickness=0)
        self.v_scrollbar = ttk.Scrollbar(scroll_container, orient="vertical", command=self.canvas.yview)
        self.scrollable_frame = ttk.Frame(self.canvas)
        
        # Configure scrolling
        self.scrollable_frame.bind(
            "<Configure>",
            lambda e: self.canvas.configure(scrollregion=self.canvas.bbox("all"))
        )
        
        self.canvas_window = self.canvas.create_window((0, 0), window=self.scrollable_frame, anchor="nw")
        self.canvas.configure(yscrollcommand=self.v_scrollbar.set)
        
        # Bind canvas configure event to adjust scrollable frame width
        self.canvas.bind('<Configure>', self._configure_canvas)
        
        # Bind mousewheel to canvas
        self.canvas.bind("<MouseWheel>", self._on_mousewheel)
        self.canvas.bind("<Button-4>", self._on_mousewheel)
        self.canvas.bind("<Button-5>", self._on_mousewheel)
        
        # Use grid instead of pack for better control
        self.canvas.grid(row=0, column=0, sticky="nsew")
        self.v_scrollbar.grid(row=0, column=1, sticky="ns")
        
        # Make canvas focusable for mousewheel events
        self.canvas.bind("<Enter>", lambda e: self.canvas.focus_set())

    def _configure_canvas(self, event):
        # Configure the scrollable frame to fill the canvas width
        canvas_width = event.width
        self.canvas.itemconfig(self.canvas_window, width=canvas_width)

    def _on_mousewheel(self, event):
        # Handle mouse wheel scrolling
        if event.num == 4 or event.delta > 0:
            self.canvas.yview_scroll(-1, "units")
        elif event.num == 5 or event.delta < 0:
            self.canvas.yview_scroll(1, "units")
    
    def create_trap_section(self, parent, trap_num):
        """Create a trap section based on the selected trap model"""
        trap_frame = ttk.LabelFrame(parent, text=f"Trap {trap_num}", padding="10")
        trap_frame.pack(fill="x", pady=(0, 10), padx=5)
        
        # Store reference to trap frame for enabling/disabling
        self.trap_frames.append(trap_frame)
        
        # Store trap entry widgets
        trap_entries = {}
        
        trap_model = self.trap_model_var.get()
        
        if trap_model == "Oriani":
            # Simplified trap parameters (Oriani only)
            # Binding Energy
            row1_frame = ttk.Frame(trap_frame)
            row1_frame.pack(fill="x", pady=(0, 10))
            
            ttk.Label(row1_frame, text="Binding energy ΔH [J/mol]").pack(side="left")
            trap_entries['binding_energy'] = ttk.Entry(row1_frame, width=15)
            trap_entries['binding_energy'].insert(0, f"{20000 + trap_num * 5000}")  # Default values
            trap_entries['binding_energy'].pack(side="right")
            
            # Trap density
            row2_frame = ttk.Frame(trap_frame)
            row2_frame.pack(fill="x")
            
            ttk.Label(row2_frame, text="Density of trapping sites Nt [sites/m³]").pack(side="left")
            trap_entries['nt'] = ttk.Entry(row2_frame, width=15)
            trap_entries['nt'].insert(0, "1.5e25")
            trap_entries['nt'].pack(side="right")

        elif trap_model == "McNabb":
            # Full trap parameters (McNabb Foster)
            # Trapping Energy - ALWAYS READONLY
            row1_frame = ttk.Frame(trap_frame)
            row1_frame.pack(fill="x", pady=(0, 5))
            
            ttk.Label(row1_frame, text="Act. E. for trapping Et [J/mol]").pack(side="left")
            
            # Create as readonly from the start and keep it that way
            trap_entries['et'] = ttk.Entry(row1_frame, width=12, state="readonly")
            
            # Get the current lattice activation energy and set it
            try:
                lattice_energy = self.entry_widgets['activation_energy'].get().strip()
                if lattice_energy:
                    # Temporarily enable to set value, then disable again
                    trap_entries['et'].config(state="normal")
                    trap_entries['et'].delete(0, tk.END)
                    trap_entries['et'].insert(0, lattice_energy)
                    trap_entries['et'].config(state="readonly")
                else:
                    # Set default value
                    trap_entries['et'].config(state="normal")
                    trap_entries['et'].delete(0, tk.END)
                    trap_entries['et'].insert(0, "5690")
                    trap_entries['et'].config(state="readonly")
            except:
                # Fallback to default
                trap_entries['et'].config(state="normal")
                trap_entries['et'].delete(0, tk.END)
                trap_entries['et'].insert(0, "5690")
                trap_entries['et'].config(state="readonly")
                
            trap_entries['et'].pack(side="right")
            
            # Detrapping Energy
            row2_frame = ttk.Frame(trap_frame)
            row2_frame.pack(fill="x", pady=(0, 5))
            
            ttk.Label(row2_frame, text="Act. E. for detrapping Ed [J/mol]").pack(side="left")
            trap_entries['ed'] = ttk.Entry(row2_frame, width=12)
            trap_entries['ed'].insert(0, f"{20000 + trap_num * 5000}")  # Default values
            trap_entries['ed'].pack(side="right")
            
            # Trap density
            row3_frame = ttk.Frame(trap_frame)
            row3_frame.pack(fill="x", pady=(0, 5))
            
            ttk.Label(row3_frame, text="Density of trapping sites Nt [sites/m³]").pack(side="left")
            trap_entries['nt'] = ttk.Entry(row3_frame, width=15)
            trap_entries['nt'].insert(0, "1.5e25")
            trap_entries['nt'].pack(side="right")
        
        if not hasattr(self, 'trap_entries'):
            self.trap_entries = {}
        self.trap_entries[trap_num] = trap_entries
    
    
    def update_traps_tab(self):
        """Update the hydrogen traps tab based on selected trap model"""
        # Clear existing content in scrollable frame
        for widget in self.scrollable_frame.winfo_children():
            widget.destroy()
        
        # Clear trap frames list and entries
        self.trap_frames = []
        if hasattr(self, 'trap_entries'):
            self.trap_entries = {}
        
        trap_model = self.trap_model_var.get()
        
        if trap_model == "None":
            # Hide vibration frequency frame
            if hasattr(self, 'vib_freq_frame'):
                self.vib_freq_frame.pack_forget()
            
            # Hide trap selector when "No traps" is selected
            if hasattr(self, 'selector_frame'):
                self.selector_frame.pack_forget()
                
            # Show a message indicating no traps are used
            no_traps_label = ttk.Label(self.scrollable_frame, 
                                    text="Simulation will run without hydrogen traps.",
                                    font=("TkDefaultFont", 10, "italic"))
            no_traps_label.pack(pady=20)
            
            return
            
        elif trap_model == "McNabb":
            # Show vibration frequency frame
            if hasattr(self, 'vib_freq_frame'):
                self.vib_freq_frame.pack(fill="x", pady=(0, 10))
            
            # Show trap selector
            if hasattr(self, 'selector_frame'):
                self.selector_frame.pack(fill="x", pady=(0, 10))
                
        else:  # Oriani
            # Hide vibration frequency frame
            if hasattr(self, 'vib_freq_frame'):
                self.vib_freq_frame.pack_forget()
                
            # Show trap selector
            if hasattr(self, 'selector_frame'):
                self.selector_frame.pack(fill="x", pady=(0, 10))
                
            # Enable trap selector
            if hasattr(self, 'traps_combo'):
                self.traps_combo.config(state="readonly")
        
        # Only create trap sections if traps are enabled
        if trap_model != "None":
            # Create all 6 trap sections
            for i in range(1, 7):
                self.create_trap_section(self.scrollable_frame, i)
            
            # Update enabled/disabled state based on trap count
            self.update_trap_states()
            
            # Add binding to update McNabb trap energies when lattice energy changes
            if trap_model == "McNabb":
                self.bind_lattice_energy_update()
        
        # Update scroll region after adding all widgets
        self.scrollable_frame.update_idletasks()
        self.canvas.configure(scrollregion=self.canvas.bbox("all"))
        
    def bind_lattice_energy_update(self):
        """Bind lattice energy changes to update McNabb trap energies"""
        if 'activation_energy' in self.entry_widgets:
            # Remove any existing bindings
            self.entry_widgets['activation_energy'].unbind('<FocusOut>')
            self.entry_widgets['activation_energy'].unbind('<KeyRelease>')
            
            # Bind to update trap energies when lattice energy changes
            self.entry_widgets['activation_energy'].bind('<FocusOut>', self.update_mcnabb_trap_energies)
            self.entry_widgets['activation_energy'].bind('<KeyRelease>', self.update_mcnabb_trap_energies)

    def update_mcnabb_trap_energies(self, event=None):
        """Update all McNabb trap energies to match lattice activation energy"""
        if self.trap_model_var.get() != "McNabb":
            return
            
        if not hasattr(self, 'trap_entries'):
            return
            
        try:
            # Get the current lattice activation energy
            lattice_energy = self.entry_widgets['activation_energy'].get().strip()
            if not lattice_energy:
                lattice_energy = "5690"  # Default value
                
            # Update all trap Et values
            for trap_num, trap_data in self.trap_entries.items():
                if 'et' in trap_data:
                    et_entry = trap_data['et']
                    # Temporarily enable to update, then immediately set back to readonly
                    et_entry.config(state="normal")
                    et_entry.delete(0, tk.END)
                    et_entry.insert(0, lattice_energy)
                    et_entry.config(state="readonly")  # ALWAYS set back to readonly
                    
        except Exception as e:
            pass


    def on_traps_count_changed(self, event=None):
        """Handle trap count change from combobox"""
        self.update_trap_states()

    def update_trap_states(self):
        """Enable/disable trap frames based on selected count"""
        if not hasattr(self, 'traps_count_var') or not self.trap_frames:
            return
        
        try:
            active_traps = int(self.traps_count_var.get())
        except ValueError:
            active_traps = 6
        
        # Enable/disable trap frames
        for i, trap_frame in enumerate(self.trap_frames):
            if i < active_traps:
                # Enable trap frame
                self.set_widget_state(trap_frame, "normal")
                trap_frame.configure(style="TLabelframe")
            else:
                # Disable trap frame
                self.set_widget_state(trap_frame, "disabled")
                try:
                    style = ttk.Style()
                    style.configure("Disabled.TLabelframe", foreground="gray")
                    trap_frame.configure(style="Disabled.TLabelframe")
                except:
                    pass

    def set_widget_state(self, widget, state):
        """Recursively set state for widget and all its children, but ALWAYS preserve readonly entries"""
        try:
            # Check if this is an Entry widget for trapping energy (Et)
            if (hasattr(widget, 'winfo_class') and 
                widget.winfo_class() == 'Entry'):
                try:
                    current_state = widget.cget('state')
                    # If it's already readonly, NEVER change it regardless of requested state
                    if current_state == 'readonly':
                        return  # Always keep Et entries readonly
                except:
                    pass
            
            widget.configure(state=state)
        except tk.TclError:
            pass
        
        for child in widget.winfo_children():
            self.set_widget_state(child, state)
            
    def create_right_panel(self, parent):
        """Updated right panel with consistent figure size"""
        right_frame = ttk.Frame(parent)
        right_frame.grid(row=0, column=1, sticky="nsew")
        right_frame.grid_rowconfigure(1, weight=2)
        right_frame.grid_rowconfigure(3, weight=1)
        right_frame.grid_columnconfigure(0, weight=1)
        
        # Top buttons frame
        buttons_frame = ttk.Frame(right_frame)
        buttons_frame.grid(row=0, column=0, sticky="ew", pady=(0, 10))
        buttons_frame.grid_columnconfigure(0, weight=1)
        
        # Button container aligned to the right
        button_container = ttk.Frame(buttons_frame)
        button_container.pack(side="right")
        
        # Control buttons
        self.run_button = ttk.Button(button_container, text="Run Simulation", 
                                command=self.run_simulation)
        self.run_button.pack(side="left", padx=(0, 5))
        
        self.clear_plot_button = ttk.Button(button_container, text="Clear Plot", 
                                   command=self.clear_plot_only)
        
        self.clear_plot_button.pack(side="left", padx=(0, 5))
        
        self.clear_button = ttk.Button(button_container, text="Clear All", 
                                    command=self.clear_inputs)
        self.clear_button.pack(side="left", padx=(0, 5))
        
        self.reset_button = ttk.Button(button_container, text="Reset Default", 
                                    command=self.reset_defaults)
        self.reset_button.pack(side="left", padx=(0, 5))
        
        self.export_button = ttk.Button(button_container, text="Export Data", 
                                        command=self.export_plot_data)
        self.export_button.pack(side="right", padx=(5, 0))
        
        # Plot container frame
        plot_container = ttk.Frame(right_frame, padding="10")
        plot_container.grid(row=1, column=0, sticky="nsew", pady=(0, 10))
        
        self.fig = Figure(figsize=(6, 4), dpi=100)
        self.ax = self.fig.add_subplot(111)
        self.ax.set_xlabel("Temperature (K)")
        self.ax.set_ylabel("Flux (mol/m²s)")
        self.ax.grid(True, alpha=0.3)
        
        self.fig.tight_layout()
        
        # Create canvas with fixed size
        self.plot_canvas = FigureCanvasTkAgg(self.fig, plot_container)
        self.plot_canvas.draw()
        
        # Configure canvas to maintain aspect ratio
        canvas_widget = self.plot_canvas.get_tk_widget()
        canvas_widget.pack(side="top", fill="both", expand=True)
        
        # Set minimum size to prevent shrinking below readable size
        canvas_widget.configure(width=500, height=300)  # Minimum size in pixels
        
        # Add navigation toolbar
        self.toolbar = NavigationToolbar2Tk(self.plot_canvas, plot_container)
        self.toolbar.update()
        self.toolbar.pack(side="bottom", fill="x")
        
        # ML Progress panel
        ml_progress_frame = ttk.LabelFrame(right_frame, text="ML Data Fitting Progress", padding="5")
        ml_progress_frame.grid(row=2, column=0, sticky="ew", pady=(0, 10))

        # Overall progress only
        ttk.Label(ml_progress_frame, text="Progress:").pack(anchor="w")
        self.ml_progress_var = tk.DoubleVar()
        self.ml_progress_bar = ttk.Progressbar(ml_progress_frame, variable=self.ml_progress_var, 
                                            maximum=100, length=400)
        self.ml_progress_bar.pack(fill=tk.X, padx=5, pady=(2, 5))

        # Status and timing info
        status_frame = ttk.Frame(ml_progress_frame)
        status_frame.pack(fill="x", pady=(5, 0))

        self.ml_progress_status = ttk.Label(status_frame, text="Ready for ML data fitting")
        self.ml_progress_status.pack(side="left")

        self.ml_time_label = ttk.Label(status_frame, text="")
        self.ml_time_label.pack(side="right")
        
        # Output frame
        output_frame = ttk.LabelFrame(right_frame, text="Output", padding="10")
        output_frame.grid(row=3, column=0, sticky="nsew")
        output_frame.grid_rowconfigure(0, weight=1)
        output_frame.grid_columnconfigure(0, weight=1)
        
        # Text widget with scrollbar
        self.output_text = tk.Text(output_frame, wrap=tk.WORD, height=8)
        scrollbar = ttk.Scrollbar(output_frame, orient="vertical", 
                                command=self.output_text.yview)
        self.output_text.configure(yscrollcommand=scrollbar.set)
        
        self.output_text.grid(row=0, column=0, sticky="nsew")
        scrollbar.grid(row=0, column=1, sticky="ns")
        
        # Add initial message
        self.output_text.insert("end", "Welcome to TDS Simulator with integrated ML data fitting!\n")
        self.output_text.insert("end", "• For standard simulation: Configure parameters in the Simulation, Parameters and Hydrogen Traps tabs and click 'Run Simulation'\n")
        self.output_text.insert("end", "• For ML analysis: Go to the ML Data Fitting tab, upload experimental data, configure additional parameters, and click 'Start ML Data Fitting'\n")
        self.output_text.config(state=tk.DISABLED)
        
    def update_timer_display(self):
        """Update elapsed time in real-time during ML analysis"""
        if hasattr(self, 'ml_start_time') and self.ml_start_time is not None:
            # Calculate elapsed time
            elapsed_seconds = time.time() - self.ml_start_time
            elapsed_str = str(timedelta(seconds=int(elapsed_seconds)))
            
            # Update the time label with elapsed time only
            if hasattr(self, 'ml_time_label'):
                self.ml_time_label.config(text=f"Elapsed: {elapsed_str}")
        
        # Schedule next update if ML analysis is still running
        if (hasattr(self, 'analysis_thread') and 
            self.analysis_thread is not None and 
            self.analysis_thread.is_alive()):
            self.root.after(1000, self.update_timer_display)
                
        
    def update_ml_progress_advanced(self, overall_progress, status_text, stage_progress=None, stage_name=None):
        """Simplified progress update with elapsed time only"""
        if not hasattr(self, 'ml_start_time') or self.ml_start_time is None:
            self.ml_start_time = time.time()
            # Start the real-time timer updates
            self.update_timer_display()
        
        # Update progress bar
        self.ml_progress_var.set(overall_progress)
        
        # Update status
        self.ml_progress_status.config(text=status_text)
        
        # Update elapsed time immediately
        if hasattr(self, 'ml_start_time') and self.ml_start_time is not None:
            elapsed_seconds = time.time() - self.ml_start_time
            elapsed_str = str(timedelta(seconds=int(elapsed_seconds)))
            
            if hasattr(self, 'ml_time_label'):
                self.ml_time_label.config(text=f"Elapsed: {elapsed_str}")
        

    def check_thread_results(self):
        """Check for results from background thread"""
        try:
            while True:
                result = self.result_queue.get_nowait()
                
                if result['status'] == 'progress':
                    # Use the advanced progress update if requested
                    if result.get('use_advanced', False):
                        self.update_ml_progress_advanced(
                            result.get('progress', 0), 
                            result['message']
                        )
                    else:
                        # Fallback to basic progress update
                        self.update_ml_progress(result.get('progress', 0), result['message'])
                    self.log_message(result['message'])
                elif result['status'] == 'success':
                    self.handle_ml_success(result)
                elif result['status'] == 'error':
                    self.handle_ml_error(result['message'])
                elif result['status'] == 'stopped':
                    self.handle_ml_stopped()
                        
        except queue.Empty:
            pass

        # Keep the checking loop running
        self.root.after(100, self.check_thread_results)

    def export_plot_data(self):
        """Export plot data"""
        try:
            # Check if there's data to export
            lines = self.ax.get_lines()
            scatter_collections = self.ax.collections
            
            if not lines and not scatter_collections:
                messagebox.showwarning("No Data", "No plot data available to export.")
                return
            
            # Open file dialog
            filename = filedialog.asksaveasfilename(
                defaultextension=".csv",
                filetypes=[("CSV files", "*.csv"), ("All files", "*.*")],
                title="Save plot data as..."
            )
            
            if filename:
                import csv
                from datetime import datetime
                
                # Collect all data series
                all_series = []
                
                # Get data from line plots
                for i, line in enumerate(lines):
                    x_data = line.get_xdata()
                    y_data = line.get_ydata()
                    label = line.get_label() or f"Series_{i+1}"
                    color = line.get_color()
                    linestyle = line.get_linestyle()
                    all_series.append({
                        'x_data': x_data,
                        'y_data': y_data,
                        'label': label,
                        'type': 'line',
                        'style': f"{color} {linestyle}",
                        'points': len(x_data)
                    })
                
                # Get data from scatter plots
                for i, collection in enumerate(scatter_collections):
                    offsets = collection.get_offsets()
                    if len(offsets) > 0:
                        x_data = offsets[:, 0]
                        y_data = offsets[:, 1]
                        label = collection.get_label() or f"Scatter_{i+1}"
                        all_series.append({
                            'x_data': x_data,
                            'y_data': y_data,
                            'label': label,
                            'type': 'scatter',
                            'style': 'scatter',
                            'points': len(x_data)
                        })
                
                if not all_series:
                    messagebox.showwarning("No Data", "No exportable data found in the plot.")
                    return
                
                # Get axis labels for metadata
                x_label = self.ax.get_xlabel()
                y_label = self.ax.get_ylabel()
                plot_title = self.ax.get_title()
                
                with open(filename, 'w', newline='') as csvfile:
                    writer = csv.writer(csvfile)
                    
                    # Write metadata header
                    writer.writerow([f"# TDS Simulator Data Export - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"])
                    writer.writerow([f"# Plot Title: {plot_title}"])
                    writer.writerow([f"# X-axis: {x_label}"])
                    writer.writerow([f"# Y-axis: {y_label}"])
                    writer.writerow([f"# Number of series: {len(all_series)}"])
                    writer.writerow([])  # Empty row
                    
                    # Write series information
                    writer.writerow(["# Series Information:"])
                    for i, series in enumerate(all_series):
                        writer.writerow([f"# {i+1}. {series['label']} ({series['points']} points)"])
                    writer.writerow([])  # Empty row
                    
                    # Find maximum length
                    max_length = max(len(series['x_data']) for series in all_series)
                    
                    # Create and write headers
                    headers = []
                    for series in all_series:
                        clean_label = series['label'].replace(',', '_').replace('\n', '_')
                        headers.extend([f"{clean_label}_X", f"{clean_label}_Y"])
                    writer.writerow(headers)
                    
                    # Write data rows
                    for row_idx in range(max_length):
                        row_data = []
                        for series in all_series:
                            if row_idx < len(series['x_data']):
                                x_val = series['x_data'][row_idx]
                                y_val = series['y_data'][row_idx]
                                row_data.extend([x_val, y_val])
                            else:
                                row_data.extend(['', ''])
                        writer.writerow(row_data)
                
                # Log detailed information
                self.log_message(f"Plot data exported to: {filename}")
                self.log_message(f"Export includes {len(all_series)} series with metadata")
                for series in all_series:
                    self.log_message(f"  - {series['label']}: {series['points']} points ({series['type']})")
                
                messagebox.showinfo("Export Complete", 
                                f"Data exported successfully"
                                f"Series: {len(all_series)}\n")
                
        except Exception as e:
            self.log_message(f"Error exporting data: {str(e)}")
            messagebox.showerror("Export Error", f"Failed to export data:\n{str(e)}")

    def update_unit_converter_with_material_props(self):
        """Update unit converter with current material properties"""
        if not self.unit_converter:
            try:
                self.unit_converter = UnitConverter()
            except:
                self.unit_converter = None
                return
        
        try:
            # Get values from GUI with fallbacks
            mass_density = 7.8474  # default
            thickness = 0.0063     # default
            
            if 'mass_density' in self.entry_widgets and self.entry_widgets['mass_density'].get().strip():
                mass_density = float(self.entry_widgets['mass_density'].get())
                
            if 'thickness' in self.entry_widgets and self.entry_widgets['thickness'].get().strip():
                thickness = float(self.entry_widgets['thickness'].get())
            
            # Update the unit converter properties
            self.unit_converter.mass_density = mass_density
            self.unit_converter.thickness = thickness
            
        except (ValueError, AttributeError) as e:
            self.log_message(f"Warning: Could not update unit converter: {e}")

    def get_display_units_and_labels(self):
        """Get current display units and generate appropriate labels - NEW METHOD"""
        if not self.unit_converter:
            return {
                'temp_unit': 'K',
                'y_unit': 'mol/m\u00b2s',
                'y_type': 'flux',
                'temp_label': 'Temperature (K)',
                'y_label': 'Flux (mol/m\u00b2s)'
            }
        
        # Get units from simulation tab
        temp_unit = self.temp_unit_combo.get()
        output_type = self.graphical_output_combo.get()
        
        if output_type == "Flux":
            y_unit = self.flux_unit_combo.get()
            y_type = 'flux'
        else:  # ΔC
            y_unit = self.delta_c_unit_combo.get()
            y_type = 'delta_c'
        
        # Get labels from unit converter
        labels = self.unit_converter.get_labels(temp_unit=temp_unit, y_unit=y_unit, y_type=y_type)
        
        return {
            'temp_unit': temp_unit,
            'y_unit': y_unit,
            'y_type': y_type,
            'temp_label': labels.get('temperature', f'Temperature ({temp_unit})'),
            'y_label': labels.get('y_axis', f'{y_type.title()} ({y_unit})')
        }

    def convert_data_for_plotting(self, temp_data, flux_data):
        """Vectorized unit conversion for better performance"""
        if not self.unit_converter:
            return temp_data, flux_data
        
        try:
            display_info = self.get_display_units_and_labels()
            
            # Use numpy operations for faster conversion
            import numpy as np
            temp_array = np.asarray(temp_data)
            flux_array = np.asarray(flux_data)
            
            # Convert temperature (vectorized)
            temp_converted = self.unit_converter.temperature_from_standard(
                temp_array, display_info['temp_unit']
            )
            
            # Convert y-axis data (vectorized)
            if display_info['y_type'] == 'flux':
                y_converted = self.unit_converter.flux_from_standard(
                    flux_array, display_info['y_unit']
                )
            else:  # delta_c
                y_converted = self.unit_converter.flux_to_delta_c(flux_array)
                y_converted = self.unit_converter.delta_c_from_standard(
                    y_converted, display_info['y_unit']
                )
            
            return temp_converted, y_converted
            
        except Exception as e:
            self.log_message(f"Warning: Unit conversion failed: {e}")
            return temp_data, flux_data

    def plot_experimental_data_only(self, Temperature, TDS_Curve, clear_plot=True):
        """Optimized version - Plot experimental data with batched operations"""
        
        # Turn off interactive plotting for faster updates
        was_interactive = plt.isinteractive()
        plt.ioff()
        
        try:
            if clear_plot:
                # Check if there are simulation plots to preserve
                has_simulation_plots = any(
                    any(keyword in line.get_label().lower() for keyword in 
                        ['simulation', 'trap', 'no traps', 'mcnabb', 'oriani', 'prediction'])
                    for line in self.ax.get_lines()
                )
                
                if not has_simulation_plots:
                    self.ax.clear()
            
            # Update unit converter once
            self.update_unit_converter_with_material_props()
            
            # Batch unit conversions
            if self.unit_converter:
                try:
                    display_info = self.get_display_units_and_labels()
                    x_data = self.unit_converter.temperature_from_standard(Temperature, display_info['temp_unit'])
                    
                    if display_info['y_type'] == "flux":
                        y_data = self.unit_converter.flux_from_standard(TDS_Curve, display_info['y_unit'])
                    elif display_info['y_type'] == "delta_c":
                        delta_c_std = self.unit_converter.flux_to_delta_c(TDS_Curve)
                        y_data = self.unit_converter.delta_c_from_standard(delta_c_std, display_info['y_unit'])
                    
                    x_label = display_info['temp_label']
                    y_label = display_info['y_label']
                    
                except Exception as e:
                    self.log_message(f"Warning: Unit conversion failed: {e}")
                    x_data, y_data = Temperature, TDS_Curve
                    x_label, y_label = "Temperature (K)", "Flux (mol/m²/s)"
            else:
                x_data, y_data = Temperature, TDS_Curve
                x_label, y_label = "Temperature (K)", "Flux (mol/m²/s)"
            
            # Efficiently remove experimental data
            self._remove_plots_by_label("Experimental")
            
            # Add new plot
            self.ax.scatter(x_data, y_data, label="Experimental Data", color="black", 
                        s=20, alpha=0.7, zorder=5)
            
            # Set labels only if needed
            if clear_plot or not self.ax.get_xlabel():
                self.ax.set_xlabel(x_label)
            if clear_plot or not self.ax.get_ylabel():
                self.ax.set_ylabel(y_label)
            
            # Apply all visual changes at once
            self.ax.grid(self.grid_var.get(), alpha=0.3)
            
            # Update legend only if there are multiple items
            lines = self.ax.get_lines()
            scatter_collections = self.ax.collections
            if len(lines) + len(scatter_collections) > 1:
                self.ax.legend(loc='best')
            
            self.fig.tight_layout()
            
        finally:
            # Restore interactive state and draw once
            if was_interactive:
                plt.ion()
            self.plot_canvas.draw()

    def run_ml_analysis(self):
        """Start ML analysis in background thread"""
        self.log_message("--- Running ML Fitting ---", clear=True)
        self.reset_ml_progress()
        
        # Reset stop flag and manage button states
        self.stop_ml_flag.clear()
        self.ml_button.config(state=tk.DISABLED)
        self.stop_ml_button.config(state=tk.NORMAL)

        try:
            self.update_unit_converter_with_material_props()
            
            params = self.collect_ml_parameters()
            # Add stop flag to parameters
            params['stop_flag'] = self.stop_ml_flag
            self.display_ml_parameters(params)

            # Start analysis in background thread
            self.analysis_thread = threading.Thread(
                target=run_thermal_desorption_analysis, 
                args=(params, self.result_queue),
                daemon=True
            )
            self.analysis_thread.start()
            
            # Start the real-time timer updates
            self.root.after(1000, self.update_timer_display)

        except ValueError as e:
            self.handle_ml_error(f"Invalid input value: {str(e)}")
        except Exception as e:
            self.handle_ml_error(f"An unexpected error occurred: {str(e)}")
            
    def stop_ml_analysis(self):
        """Stop the running ML analysis"""
        self.stop_ml_flag.set()
        self.ml_progress_status.config(text="Stop requested - waiting for safe stopping point...")
        self.log_message("Stop requested. Waiting for current operation to complete safely...")
        
        # Disable stop button to prevent multiple clicks
        self.stop_ml_button.config(state=tk.DISABLED)


    def reset_ml_progress(self):
        """Reset ML progress with timing"""
        self.ml_progress_var.set(0)
        self.ml_progress_status.config(text="Ready to start ML data fitting")
        if hasattr(self, 'ml_time_label'):
            self.ml_time_label.config(text="")
        self.ml_start_time = None
        self.analysis_thread = None

    def update_ml_progress(self, progress_percent, status_text):
        """Update ML progress bar and status"""
        self.ml_progress_var.set(progress_percent)
        self.ml_progress_status.config(text=status_text)

    def collect_ml_parameters(self):
        """Collect all parameters from GUI inputs for ML analysis"""
        # Handle trap model selection safely
        try:
            trap_model_value = self.trap_model_var.get()
            
            # Convert GUI names to backend names
            if trap_model_value == "McNabb":
                trap_model_backend = "McNabb"
            elif trap_model_value == "Oriani":
                trap_model_backend = "Oriani"
            else:
                trap_model_backend = "McNabb"  # Default
        except:
            trap_model_backend = "McNabb"
            
        # Get vibration frequency for McNabb model
        vib_freq = 1e13  # Default
        if trap_model_backend == "McNabb" and hasattr(self, 'vib_freq_entry'):
            try:
                vib_freq = float(self.vib_freq_entry.get().strip())
            except ValueError:
                pass 

        # Use the basic material parameters from the simulation inputs
        material = {
            'NL': float(self.entry_widgets['lattice_density'].get()) if self.entry_widgets['lattice_density'].get() else 8.47e5,
            'E_Diff': float(self.entry_widgets['activation_energy'].get()) if self.entry_widgets['activation_energy'].get() else 5690,
            'D0': float(self.entry_widgets['diffusion_factor'].get()) if self.entry_widgets['diffusion_factor'].get() else 7.23e-8,
            'C0': float(self.entry_widgets['initial_conc'].get()) if self.entry_widgets['initial_conc'].get() else 0.06,
            'TrapRate': vib_freq,  # Set vibration frequency here
            'MolMass': float(self.entry_widgets['molar_mass'].get()) if self.entry_widgets['molar_mass'].get() else 55.847,
            'MassDensity': float(self.entry_widgets['mass_density'].get()) if self.entry_widgets['mass_density'].get() else 7.8474
        }
        
        # Use test parameters from simulation inputs
        test = {
            'tRest': float(self.entry_widgets['resting_time'].get()) if self.entry_widgets['resting_time'].get() else 2700,
            'HeatingRate': float(self.entry_widgets['heating_rate'].get()) if self.entry_widgets['heating_rate'].get() else 0.055,
            'Thickness': float(self.entry_widgets['thickness'].get()) if self.entry_widgets['thickness'].get() else 0.0063,
            'TMax': float(self.entry_widgets['max_temp'].get()) if self.entry_widgets['max_temp'].get() else 873.15,
            'TMin': float(self.entry_widgets['min_temp'].get()) if self.entry_widgets['min_temp'].get() else 293.15
        }
        
        # Combine basic numerical parameters with ML-specific ones
        numerical = {
            'dEMin': float(self.entry_widgets['ml_de_min'].get()) if self.entry_widgets['ml_de_min'].get() else 10e3,
            'ntp': int(self.entry_widgets['ntp'].get()) if self.entry_widgets['ntp'].get() else 64,
            'SampleFreq': int(self.entry_widgets['sample_freq'].get()) if self.entry_widgets['sample_freq'].get() else 10,
            'NRange': [float(self.entry_widgets['ml_n_range_min'].get()) if self.entry_widgets['ml_n_range_min'].get() else 1e-1, 
                    float(self.entry_widgets['ml_n_range_max'].get()) if self.entry_widgets['ml_n_range_max'].get() else 1e1],
            'ERange': [float(self.entry_widgets['ml_e_range_min'].get()) if self.entry_widgets['ml_e_range_min'].get() else 50e3, 
                    float(self.entry_widgets['ml_e_range_max'].get()) if self.entry_widgets['ml_e_range_max'].get() else 150e3],
            'NumTraining': int(self.entry_widgets['ml_num_training'].get()) if self.entry_widgets['ml_num_training'].get() else 50000,
            'NumVerification': int(self.entry_widgets['ml_num_verification'].get()) if self.entry_widgets['ml_num_verification'].get() else 500,
            'n_cpu_cores': int(self.entry_widgets['ml_n_cpu_cores'].get()) if self.entry_widgets['ml_n_cpu_cores'].get() else 16
        }

        training_parameters = {
            'Traps': self.entry_widgets['ml_traps'].get() if self.entry_widgets['ml_traps'].get() else "Random",
            'Concentrations': self.entry_widgets['ml_concentrations'].get() if self.entry_widgets['ml_concentrations'].get() else "Random",
            'MaxTraps': int(self.entry_widgets['ml_max_traps'].get()) if self.entry_widgets['ml_max_traps'].get() else 4,
            'ParameterSet': self.entry_widgets['ml_hp_set'].get() if self.entry_widgets['ml_hp_set'].get() else "optimised",
            'Regenerate_Data': self.entry_widgets['ml_regenerate_data'].get() if self.entry_widgets['ml_regenerate_data'].get() else "False",
            'Regenerate_Training': self.entry_widgets['ml_regenerate_training'].get() if self.entry_widgets['ml_regenerate_training'].get() else "False"
        }

        # Handle HDT parameters
        hdt_value = self.entry_widgets['high_density_trap'].get().strip()
        if hdt_value == 'True':
            HDT_Flag = True
        else:
            HDT_Flag = False

        HD_Trap = {
            "HDT": HDT_Flag,
            'HDT_NRange': [float(self.entry_widgets['HDT_ml_n_range_min'].get()) if self.entry_widgets['HDT_ml_n_range_min'].get() else 0, 
                    float(self.entry_widgets['HDT_ml_n_range_max'].get()) if self.entry_widgets['HDT_ml_n_range_max'].get() else 0],
            'HDT_ERange': [float(self.entry_widgets['HDT_ml_e_range_min'].get()) if self.entry_widgets['HDT_ml_e_range_min'].get() else 0, 
                    float(self.entry_widgets['HDT_ml_e_range_max'].get()) if self.entry_widgets['HDT_ml_e_range_max'].get() else 0],
        }

        exp_units = self.get_experimental_units()

        return {
            'material': material, 
            'test': test, 
            'numerical': numerical, 
            'trap_model': trap_model_backend,
            'exp_file': self.ml_exp_file_entry.get(),
            'ExpName': self.entry_widgets['ml_exp_name'].get(),
            'training_parameters': training_parameters,
            'HD_Trap': HD_Trap,
            'exp_units': exp_units
        }

    def display_ml_parameters(self, params):
        """Display collected parameters for ML analysis in output text"""
        self.log_message(f"ML Test case: {params['ExpName']}")
        self.log_message("--- ML Fitting Inputs ---")
        self.log_message(f"Experimental Data File: {params['exp_file']}")
        
        self.log_message("\nMaterial parameters:")
        for key, value in params['material'].items():
            if isinstance(value, float) and value >= 1e4:
                self.log_message(f"  {key}: {value:.2e}")
            else:
                self.log_message(f"  {key}: {value}")

        self.log_message("\nTest parameters:")
        for key, value in params['test'].items():
            if isinstance(value, float) and value >= 1e4:
                self.log_message(f"  {key}: {value:.2e}")
            else:
                self.log_message(f"  {key}: {value}")

        self.log_message("\nML Training parameters:")
        for key, value in params['training_parameters'].items():
            self.log_message(f"  {key}: {value}")

        self.log_message(f"\nTrap model: {params['trap_model']}")
        self.log_message("ML parameters collected! Starting fitting...")

    def load_and_plot_experimental_data(self, file_path, is_ml=False, clear_plot=True):
        """Load experimental data and plot it - with option to preserve existing plots"""
        try:
            # Create temporary parameters for processing
            temp_params = self.collect_ml_parameters() if is_ml else self.get_simulation_parameters()
            
            # Ensure ExpDataProcessing is available
            if 'ExpDataProcessing' not in globals():
                messagebox.showerror("Module Not Available", "ExpDataProcessing module is not available.")
                return
                
            temp_material = TDS_Material.TDS_Material(
                temp_params['ExpName'] if is_ml else 'Simulation', 
                material_param=temp_params['material'], 
                test_param=temp_params['test'], 
                numerical_param=temp_params['numerical'] if is_ml else {
                    'dEMin': 10e3, 'ntp': 64, 'SampleFreq': 10, 
                    'NRange': [1e-1, 1e3], 'ERange': [20e3, 120e3]
                },
                HD_Trap_param=temp_params['HD_Trap'] if is_ml else None,
                trap_model=temp_params['trap_model'] if is_ml else self.get_trap_model_for_simulation()
            )
            
            temp_hyperparams = Model_Parameters(
                ParameterSet=temp_params['training_parameters']['ParameterSet'] if is_ml else "optimised"
            )

            # Get experimental units from GUI
            exp_units = self.get_experimental_units()
            temp_units = exp_units['temp_units']
            y_units = exp_units['y_units']
            y_type = exp_units['y_type']
            
            # Updated ExpDataProcessing instantiation
            self.exp_data = ExpDataProcessing(
                file_name=file_path, 
                temp_units=temp_units, 
                y_units=y_units, 
                y_type=y_type, 
                material=temp_material, 
                hyperparameters=temp_hyperparams
            )
            
            # Get processed data
            processed_temp = self.exp_data.Temperature
            processed_flux = self.exp_data.Flux
            
            # Plot with option to preserve existing plots
            self.plot_experimental_data_only(processed_temp, processed_flux, clear_plot=clear_plot)
            
            prefix = "ML " if is_ml else ""
            self.log_message(f"{prefix}Experimental data loaded and plotted from: {file_path}")
            self.log_message(f"Temperature range: {min(processed_temp):.1f} - {max(processed_temp):.1f} K")
            self.log_message(f"Max flux: {max(processed_flux):.2e} mol/m²s")
            
            if is_ml:
                self.log_message(f"Original units: Temperature ({temp_units}), Y-data type ({y_type}), Units ({y_units})")
                
        except Exception as e:
            messagebox.showerror("Error Loading Data", f"Could not load experimental data:\n{str(e)}")
            self.log_message(f"Error loading experimental data: {str(e)}")

    def get_trap_model_for_simulation(self):
        """Get trap model for simulation (convert from GUI format)"""
        trap_model = self.trap_model_var.get()
        if trap_model == "McNabb":
            return "McNabb"
        elif trap_model == "Oriani":
            return "Oriani"
        else:
            return None

    def handle_ml_success(self, result):
        """Handle successful ML analysis completion - Progress 100% AFTER plots"""
        
        # Reset button states first
        self.ml_button.config(state=tk.NORMAL)
        self.stop_ml_button.config(state=tk.DISABLED)
        
        # Clear the analysis thread reference to stop timer updates
        self.analysis_thread = None
        
        # Update progress to 95% while generating plots
        self.update_ml_progress_advanced(95, "Generating plots...")
        self.log_message("ML Analysis completed! Generating plots...")
        
        # Generate plots in main thread
        try:
            self.update_unit_converter_with_material_props()
            display_info = self.get_display_units_and_labels()
            
            # Clear existing plot
            self.ax.clear()
            
            if self.unit_converter:
                temp_plot = self.unit_converter.temperature_from_standard(result['Exp_Temp'], display_info['temp_unit'])
                if display_info['y_type'] == 'flux':
                    y_data_plot = self.unit_converter.flux_from_standard(result['Exp_Flux'], display_info['y_unit'])
                else:
                    delta_c_std = self.unit_converter.flux_to_delta_c(result['Exp_Flux'])
                    y_data_plot = self.unit_converter.delta_c_from_standard(delta_c_std, display_info['y_unit'])
                x_label = display_info['temp_label']
                y_label = display_info['y_label']
            else:
                temp_plot = result['Exp_Temp']
                y_data_plot = result['Exp_Flux']
                x_label = 'Temperature (K)'
                y_label = 'Flux (mol/m²s)'

            self.ax.scatter(temp_plot, y_data_plot, color='black', label="Experimental Data", 
                        s=20, alpha=0.7, zorder=5)
            
            # Update progress while generating prediction curves
            self.update_ml_progress_advanced(97, "Generating prediction curves...")
            
            # Generate predicted curves with coherent styling
            num_predicted_traps = result['Predicted_Traps'][0] if 'Predicted_Traps' in result else 0
            predicted_concentrations = result['Predicted_Concentrations'] if 'Predicted_Concentrations' in result else []
            predicted_energies = result['Predicted_Energies'] if 'Predicted_Energies' in result else []
            
            N_traps = [lst.tolist() for array in predicted_concentrations for lst in array]
            E_traps = [lst.tolist() for array in predicted_energies for lst in array]
            
            if num_predicted_traps > 0 and len(predicted_concentrations) > 0 and len(predicted_energies) > 0:
                Material = result['Material']
                
                # Prepare energy arrays based on trap model
                E_pred = []
                N_pred = N_traps
                
                for i in range(num_predicted_traps):
                    if Material.TrapModel == TDS_Material.TRAPMODELS.MCNABB:
                        E_pred.append([Material.E_Diff, E_traps[i]])
                    elif Material.TrapModel == TDS_Material.TRAPMODELS.ORIANI:
                        E_pred.append([Material.E_Diff, E_traps[i] + Material.E_Diff])
                
                # Generate total predicted curve
                try:
                    Sample_total = TDS_Sim.TDS_Sample(Material, N_pred, E_pred, False)
                    Sample_total.Charge()
                    Sample_total.Rest()
                    T_pred, J_pred = Sample_total.TDS()
                    
                    # Convert predicted curve to display units
                    if self.unit_converter:
                        temp_pred_plot = self.unit_converter.temperature_from_standard(T_pred, display_info['temp_unit'])
                        if display_info['y_type'] == 'flux':
                            y_pred_plot = self.unit_converter.flux_from_standard(J_pred, display_info['y_unit'])
                        else:
                            delta_c_pred_std = self.unit_converter.flux_to_delta_c(J_pred)
                            y_pred_plot = self.unit_converter.delta_c_from_standard(delta_c_pred_std, display_info['y_unit'])
                    else:
                        temp_pred_plot = T_pred
                        y_pred_plot = J_pred
                    
                    self.ax.plot(temp_pred_plot, y_pred_plot, '-', color='red', 
                            label=f"ML Prediction ({num_predicted_traps} traps)", 
                            linewidth=2, zorder=4)
                    
                    # Plot individual trap contributions if requested
                    if self.show_individual_var.get() and num_predicted_traps > 1:
                        colors = ['blue', 'green', 'orange', 'purple', 'brown', 'pink', 'gray', 'olive']
                        
                        for i in range(num_predicted_traps):
                            try:
                                N_single = [N_pred[i]]
                                E_single = [E_pred[i]]
                                
                                Sample_single = TDS_Sim.TDS_Sample(Material, N_single, E_single, False)
                                Sample_single.Charge()
                                Sample_single.Rest()
                                T_single, J_single = Sample_single.TDS()
                                
                                # Convert single trap curve to display units
                                if self.unit_converter:
                                    temp_single_plot = self.unit_converter.temperature_from_standard(T_single, display_info['temp_unit'])
                                    if display_info['y_type'] == 'flux':
                                        y_single_plot = self.unit_converter.flux_from_standard(J_single, display_info['y_unit'])
                                    else:
                                        delta_c_single_std = self.unit_converter.flux_to_delta_c(J_single)
                                        y_single_plot = self.unit_converter.delta_c_from_standard(delta_c_single_std, display_info['y_unit'])
                                else:
                                    temp_single_plot = T_single
                                    y_single_plot = J_single
                                
                                # COHERENT STYLING - same individual trap style as simulation
                                color = colors[i % len(colors)]
                                self.ax.plot(temp_single_plot, y_single_plot, '--', color=color, 
                                        label=f"Trap {i+1}", linewidth=1.5, alpha=0.8, zorder=3)
                                
                            except Exception as trap_error:
                                self.log_message(f"Warning: Could not plot trap {i+1} contribution: {trap_error}")
                                continue
                    
                except Exception as pred_error:
                    self.log_message(f"Warning: Could not generate predicted TDS curve: {pred_error}")
            
            # Update progress while finalizing plot
            self.update_ml_progress_advanced(99, "Finalizing plot...")

            self.ax.set_xlabel(x_label)
            self.ax.set_ylabel(y_label)
            self.ax.grid(self.grid_var.get(), alpha=0.3)
            
            lines = self.ax.get_lines()
            scatter_collections = self.ax.collections
            if len(lines) > 1 or len(scatter_collections) > 0:
                self.ax.legend(loc='best')
            
            self.fig.tight_layout()
            self.plot_canvas.draw()
            
            # Display results
            self.display_ml_trap_results(result)

            self.update_ml_progress_advanced(100, "Complete!")
            self.log_message(f"\n{result['message']}")
            
        except Exception as plot_error:
            self.update_ml_progress_advanced(100, "Analysis complete (plotting error)")
            self.log_message(f"Plotting error: {plot_error}")
            try:
                self.plot_experimental_data_only(result['Exp_Temp'], result['Exp_Flux'])
            except:
                pass
            
    def display_ml_trap_results(self, result):
        """Display ML trap analysis results in text"""
        try:
            if 'Predicted_Traps' in result and result['Predicted_Traps']:
                num_traps = result['Predicted_Traps'][0]
                self.log_message(f"\n=== ML PREDICTION RESULTS ===")
                self.log_message(f"Number of predicted traps: {num_traps}")
                
                # Handle trap model selection for energy label
                try:
                    trap_model = self.trap_model_var.get()
                    if trap_model == "Oriani":
                        energy_label = "Binding energy"
                        energy_explanation = "(relative to lattice site)"
                    else:  # McNabb
                        energy_label = "De-trapping energy" 
                        energy_explanation = "(activation energy for detrapping)"
                except:
                    energy_label = "De-trapping energy"
                    energy_explanation = ""
                
                if 'Predicted_Energies' in result and 'Predicted_Concentrations' in result:
                    predicted_energies = result['Predicted_Energies'][0]
                    predicted_concentrations = result['Predicted_Concentrations'][0]
        

                    for i in range(num_traps):
                        if i < len(predicted_energies) and i < len(predicted_concentrations):
                            energy = predicted_energies[i]
                            concentration = predicted_concentrations[i]
                            
                            self.log_message(f"\nTrap {i+1}:")
                            self.log_message(f"  {energy_label}: {energy:.0f} J/mol {energy_explanation}")
                            self.log_message(f"  Trap density: {concentration:.2e} mol/m³")
                            
                            # Convert to sites/m³ for comparison with GUI input
                            sites_per_m3 = concentration * 6.022e23
                            self.log_message(f"  Trap density: {sites_per_m3:.2e} sites/m³")
                            
                try:
                    exp_flux = result['Exp_Flux']
                    self.log_message(f"\nExperimental data range:")
                    self.log_message(f"  Temperature: {min(result['Exp_Temp']):.1f} - {max(result['Exp_Temp']):.1f} K")
                    self.log_message(f"  Max flux: {max(exp_flux):.2e} mol/m²s")
                except:
                    pass
                    
        except Exception as e:
            self.log_message(f"Warning: Could not display ML trap parameters: {e}")
            
    def handle_ml_error(self, message):
        """Handle ML analysis errors"""
        self.update_ml_progress_advanced(100, f"ML Error occurred")
        self.log_message(f"ML Error: {message}")
        
        # Reset button states
        self.ml_button.config(state=tk.NORMAL)
        self.stop_ml_button.config(state=tk.DISABLED)
        
        # Clear the analysis thread reference to stop timer updates
        self.analysis_thread = None

    def handle_ml_stopped(self):
        """Handle ML analysis stop"""
        self.update_ml_progress_advanced(100, "ML Analysis stopped by user")
        self.log_message("\n=== ML ANALYSIS STOPPED ===")
        self.log_message("Note: Some operations (like model training) may continue briefly in the background")
        self.log_message("before fully stopping.")
        
        # Reset button states
        self.ml_button.config(state=tk.NORMAL)
        self.stop_ml_button.config(state=tk.DISABLED)
        
        # Clear the analysis thread reference to stop timer updates
        self.analysis_thread = None

    def run_simulation(self, clear_plot=False):
        """Run the TDS simulation with current parameters"""
        try:
            self.log_message("Starting TDS simulation...")

            self.update_unit_converter_with_material_props()
            
            # Disable run button during simulation
            self.run_button.config(state=tk.DISABLED)
            
            # Get parameters from GUI
            params = self.get_simulation_parameters()
            if params is None:
                return
            
            # Check if required modules are available
            if 'TDS_Material' not in globals() or 'TDS_Sim' not in globals():
                messagebox.showerror("Modules Not Available", "Required simulation modules (TDS_Material, TDS_Sim) are not available.")
                return
            
            # Create Material object
            trap_model = self.trap_model_var.get()
            if trap_model == "None":
                trap_model_enum = None
            elif trap_model == "McNabb":
                trap_model_enum = "McNabb"
            else:  # Oriani
                trap_model_enum = "Oriani"
            
            Material = TDS_Material.TDS_Material(
                params['exp_name'], 
                params['material_param'], 
                params['test_param'], 
                params['numerical_param'], 
                params['hd_trap_param'], 
                trap_model_enum
            )
            
            self.log_message(f"Trap model: {trap_model}")
            
            # Turn off interactive plotting for faster updates
            was_interactive = plt.isinteractive()
            plt.ioff()
            
            try:
                if clear_plot:
                    self.ax.clear()
                else:
                    self._remove_simulation_plots()
            
                self.update_unit_converter_with_material_props()
                display_info = self.get_display_units_and_labels()
                
                # Get trap parameters
                result = self.get_trap_lists()
                if len(result) == 3:
                    N, E_traps, vib_freq = result
                else:
                    N, E_traps = result[:2]
                    vib_freq = None
                
                if trap_model_enum is None:
                    # No traps simulation
                    self.log_message("Running simulation without traps...")
                    Sample = TDS_Sim.TDS_Sample(Material, [], [], False)
                    Sample.Charge()
                    Sample.Rest()
                    T, J = Sample.TDS()
                    
                    temp_plot, y_plot = self.convert_data_for_plotting(T, J)
                    self.ax.plot(temp_plot, y_plot, '-', color='blue', 
                                label="No Traps", linewidth=2, zorder=4)
                    
                else:
                    # With traps simulation
                    if len(N) == 0 or len(E_traps) == 0:
                        self.log_message("Warning: No active traps found. Running without traps.")
                        Sample = TDS_Sim.TDS_Sample(Material, [], [], False)
                        Sample.Charge()
                        Sample.Rest()
                        T, J = Sample.TDS()
                        temp_plot, y_plot = self.convert_data_for_plotting(T, J)
                        self.ax.plot(temp_plot, y_plot, '-', color='blue', 
                                    label=f"{trap_model} (No Active Traps)", linewidth=2, zorder=4)
                    else:
                        # Prepare energy arrays based on trap model
                        E = []
                        for i in range(len(E_traps)):
                            if Material.TrapModel == TDS_Material.TRAPMODELS.MCNABB:
                                E.append([Material.E_Diff, E_traps[i]])
                            elif Material.TrapModel == TDS_Material.TRAPMODELS.ORIANI:
                                E.append([Material.E_Diff, E_traps[i] + Material.E_Diff])
                        
                        # Update Material object with vibration frequency if McNabb model
                        if trap_model == "McNabb" and vib_freq is not None:
                            try:
                                Material.TrapRate = vib_freq
                                self.log_message(f"Set vibration frequency to {vib_freq:.2e} Hz")
                            except AttributeError:
                                self.log_message("Warning: Could not set vibration frequency in Material object")
                        
                        self.log_message(f"Running simulation with {len(N)} traps...")
                        
                        # Run total simulation
                        Sample = TDS_Sim.TDS_Sample(Material, N, E, False)
                        Sample.Charge()
                        Sample.Rest()
                        T, J = Sample.TDS()
                        
                        temp_plot, y_plot = self.convert_data_for_plotting(T, J)
                        self.ax.plot(temp_plot, y_plot, '-', color='blue', 
                                    label=f"TDS Simulation ({trap_model})", linewidth=2, zorder=4)
                        
                        # Plot individual trap contributions if requested
                        if self.show_individual_var.get() and len(N) > 1:
                            self.log_message("Computing individual trap contributions...")
                            colors = ['red', 'green', 'orange', 'purple', 'brown', 'pink']
                            
                            for i in range(len(N)):
                                trap_num = i + 1
                                N_single = [N[i]]
                                
                                if Material.TrapModel == TDS_Material.TRAPMODELS.MCNABB:
                                    E_single = [[Material.E_Diff, E_traps[i]]]
                                elif Material.TrapModel == TDS_Material.TRAPMODELS.ORIANI:
                                    E_single = [[Material.E_Diff, E_traps[i] + Material.E_Diff]]
                                
                                Sample_single = TDS_Sim.TDS_Sample(Material, N_single, E_single, False)
                                Sample_single.Charge()
                                Sample_single.Rest()
                                T_single, J_single = Sample_single.TDS()
                                
                                temp_single_plot, y_single_plot = self.convert_data_for_plotting(T_single, J_single)
                                color = colors[i % len(colors)]
                                self.ax.plot(temp_single_plot, y_single_plot, '--', color=color, 
                                            label=f"Trap {trap_num}", linewidth=1.5, alpha=0.8, zorder=3)
                
                # Finalize plot inside the try block
                self.ax.set_xlabel(display_info['temp_label'])
                self.ax.set_ylabel(display_info['y_label'])
                self.ax.grid(self.grid_var.get(), alpha=0.3)
                    
                lines = self.ax.get_lines()
                scatter_collections = self.ax.collections
                if len(lines) > 0 or len(scatter_collections) > 0:
                    self.ax.legend(loc='best')
                    
                self.fig.tight_layout()
                
                self.log_message("Simulation completed successfully!")
                    
            finally:
                # Restore interactive state and draw once
                if was_interactive:
                    plt.ion()
                self.plot_canvas.draw()
                
        except Exception as e:
            self.log_message(f"Error during simulation: {str(e)}")
            import traceback
            self.log_message(f"Traceback: {traceback.format_exc()}")
        finally:
            # Re-enable run button
            self.run_button.config(state=tk.NORMAL)
    
    def _remove_plots_by_label(self, label_keyword):
        """remove plots containing specific label keywords"""
        # Remove lines
        [line.remove() for line in self.ax.get_lines() 
        if label_keyword in line.get_label()]
        
        # Remove collections (scatter plots)
        [col.remove() for col in self.ax.collections 
        if label_keyword in col.get_label()]

    def _remove_simulation_plots(self):
        """remove simulation plots while keeping experimental data"""
        simulation_keywords = ['simulation', 'trap', 'no traps', 'mcnabb', 'oriani']
        
        [line.remove() for line in self.ax.get_lines()
        if any(keyword in line.get_label().lower() for keyword in simulation_keywords)]

    def get_simulation_parameters(self):
        """Extract all parameters from the GUI and format them for TDS_Material"""
        try:
            # Extract basic values with error handling
            def safe_float(widget_name, default_value):
                try:
                    return float(self.entry_widgets[widget_name].get().strip())
                except (ValueError, KeyError):
                    self.log_message(f"Warning: Using default value for {widget_name}")
                    return default_value
            
            def safe_int(widget_name, default_value):
                try:
                    return int(self.entry_widgets[widget_name].get().strip())
                except (ValueError, KeyError):
                    self.log_message(f"Warning: Using default value for {widget_name}")
                    return default_value
            
            # Material parameters
            material_param = {
                'NL': safe_float('lattice_density', 8.47e5),
                'E_Diff': safe_float('activation_energy', 5690),
                'D0': safe_float('diffusion_factor', 7.23e-8),
                'C0': safe_float('initial_conc', 0.06),
                'TrapRate': 1e13,  # Fixed as per requirements
                'MolMass': safe_float('molar_mass', 55.847),
                'MassDensity': safe_float('mass_density', 7.8474)
            }
            
            # Test parameters
            test_param = {
                'tRest': safe_float('resting_time', 2700),
                'HeatingRate': safe_float('heating_rate', 0.055),
                'Thickness': safe_float('thickness', 0.0063),
                'TMax': safe_float('max_temp', 873.15),
                'TMin': safe_float('min_temp', 293.15)
            }
            
            # Numerical parameters
            numerical_param = {
                'dEMin': 10e3,  # Fixed as per requirements
                'ntp': safe_int('ntp', 64),
                'SampleFreq': safe_int('sample_freq', 10),
                'NRange': [1e-1, 1e3],  # Fixed as per requirements
                'ERange': [20e3, 120e3]  # Fixed as per requirements
            }
            
            # HD_Trap_param is kept as None
            hd_trap_param = None
            
            params = {
                'exp_name': 'Simulation',  # Fixed as per requirements
                'material_param': material_param,
                'test_param': test_param,
                'numerical_param': numerical_param,
                'hd_trap_param': hd_trap_param
            }
            
            self.log_message("Parameters extracted successfully")
            return params
            
        except Exception as e:
            self.log_message(f"Error extracting parameters: {str(e)}")
            return None

    def get_trap_lists(self):
        """Extract trap concentrations and energies from GUI"""
        N = []  # trap concentrations
        E_traps = []  # trap energies
        
        trap_model = self.trap_model_var.get()
        
        if trap_model == "None" or not hasattr(self, 'trap_entries'):
            return N, E_traps
        
        try:
            active_traps = int(self.traps_count_var.get())
        except ValueError:
            active_traps = 0
            
        # Get vibration frequency for McNabb model
        vib_freq = 1e13  # Default
        if trap_model == "McNabb" and hasattr(self, 'vib_freq_entry'):
            try:
                vib_freq = float(self.vib_freq_entry.get().strip())
            except ValueError:
                self.log_message("Warning: Invalid vibration frequency, using default 1e13")
                vib_freq = 1e13
        
        for trap_num in range(1, active_traps + 1):
            if trap_num in self.trap_entries:
                trap_data = self.trap_entries[trap_num]
                
                try:
                    # Get trap concentration
                    if 'nt' in trap_data:
                        nt_value = float(trap_data['nt'].get().strip())
                        N.append(nt_value / 6.022e23)
                    
                    # Get trap energy based on model
                    if trap_model == "Oriani":
                        if 'binding_energy' in trap_data:
                            binding_energy = float(trap_data['binding_energy'].get().strip())
                            E_traps.append(binding_energy)
                    elif trap_model == "McNabb":
                        if 'ed' in trap_data:
                            ed_value = float(trap_data['ed'].get().strip())
                            E_traps.append(ed_value)
                
                except ValueError as e:
                    self.log_message(f"Warning: Invalid value in trap {trap_num}, skipping: {e}")
                    continue
        
        self.log_message(f"Extracted {len(N)} trap concentrations and {len(E_traps)} trap energies")
        if trap_model == "McNabb":
            self.log_message(f"Using vibration frequency: {vib_freq:.2e} Hz")
        return N, E_traps, vib_freq if trap_model == "McNabb" else None

    def finalize_plot(self):
        """Finalize the plot with consistent styling"""
        # Add legend if there are multiple plots
        lines = self.ax.get_lines()
        scatter_collections = self.ax.collections
        
        if len(lines) > 1 or len(scatter_collections) > 0:
            # Position legend consistently
            if len(lines) <= 3:
                self.ax.legend(loc='best')
            else:
                self.ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Add grid consistently
        self.ax.grid(self.grid_var.get(), alpha=0.3)
        
        # Adjust layout consistently
        self.fig.tight_layout()
        self.plot_canvas.draw()

    def log_message(self, message, clear=False):
        """Optimized logging with reduced UI updates"""
        self.output_text.config(state=tk.NORMAL)
        if clear:
            self.output_text.delete(1.0, tk.END)
            self._log_counter = 0  # Reset counter on clear
        self.output_text.insert("end", f"{message}\n")
        self.output_text.see("end")
        self.output_text.config(state=tk.DISABLED)
        
        # Only update every few messages to reduce overhead
        if not hasattr(self, '_log_counter'):
            self._log_counter = 0
        self._log_counter += 1
        
        # Update UI every 5 messages or on clear
        if self._log_counter % 5 == 0 or clear:
            self.output_text.update_idletasks()
        
    def clear_plot_only(self):
        """Clear only the plot without affecting input fields"""
        self.ax.clear()
        self.ax.set_xlabel("Temperature (K)")
        self.ax.set_ylabel("Flux (mol/m²s)")
        self.ax.grid(True, alpha=0.3)
        self.fig.tight_layout()
        self.plot_canvas.draw()
        self.log_message("Plot cleared.")

    def clear_inputs(self):
        """Clear all input fields"""
        # Clear main parameter entries
        for widget_name, widget in self.entry_widgets.items():
            try:
                widget.delete(0, tk.END)
            except:
                pass
        
        # Clear experimental file entries
        self.ml_exp_file_entry.delete(0, tk.END)

        # Reset experimental data unit combos
        if hasattr(self, 'exp_temp_unit_combo'):
            self.exp_temp_unit_combo.set("K")
        if hasattr(self, 'exp_y_unit_combo'):
            self.exp_y_unit_combo.set("mol/m\u00b2s")
        if hasattr(self, 'exp_y_type_combo'):
            self.exp_y_type_combo.set("Flux")
        
        # Clear trap parameters
        if hasattr(self, 'trap_entries'):
            for trap_num, trap_data in self.trap_entries.items():
                for entry_name, entry in trap_data.items():
                    try:
                        entry.delete(0, tk.END)
                    except:
                        pass
                    
        if hasattr(self, 'vib_freq_entry'):
            self.vib_freq_entry.delete(0, tk.END)
        
        # Reset selections
        self.trap_model_var.set("Oriani")
        self.show_individual_var.set(False)
        self.grid_var.set(True)
        
        # Reset combo boxes
        self.graphical_output_combo.set("Flux")
        self.x_axis_combo.set("Temperature")
        self.flux_unit_combo.set("mol/m\u00b2s")
        self.delta_c_unit_combo.set("mol/m\u00b3s")
        self.temp_unit_combo.set("K")
        
        # Reset trap count
        self.traps_count_var.set("6")
        
        # Reset button states properly
        self.ml_button.config(state=tk.NORMAL)
        self.stop_ml_button.config(state=tk.DISABLED)
        
        # Clear the stop flag
        self.stop_ml_flag.clear()
        
        # Clear output and reset plot
        self.log_message("Inputs cleared. Ready for new simulation...", clear=True)
        
        # Reset ML progress
        self.reset_ml_progress()
        
        # Clear plot completely
        self.clear_plot_only()
        
        # Update traps display
        self.update_traps_tab()

    def reset_defaults(self):
        """Reset all inputs to default values"""
        # Reset main parameters to default values
        defaults = {
            'thickness': "0.0063",
            'heating_rate': "0.055", 
            'resting_time': "2700",
            'min_temp': "293.15",
            'max_temp': "873.15",
            'ntp': "64",
            'sample_freq': "10",
            'activation_energy': "5690",
            'diffusion_factor': "7.23e-8",
            'molar_mass': "55.847",
            'mass_density': "7.8474",
            'lattice_density': "8.47e5",
            'initial_conc': "0.06",
            # ML defaults
            'ml_exp_name': "Novak_200",
            'ml_de_min': "10e3",
            'ml_n_range_min': "1e-1",
            'ml_n_range_max': "1e1",
            'ml_e_range_min': "50e3",
            'ml_e_range_max': "150e3",
            'high_density_trap': "False",
            'HDT_ml_n_range_min': "0",
            'HDT_ml_n_range_max': "0",
            'HDT_ml_e_range_min': "0",
            'HDT_ml_e_range_max': "0",
            'ml_n_cpu_cores': "16",
            'ml_num_training': "50000",
            'ml_num_verification': "500",
            'ml_hp_set': "optimised",
            'ml_max_traps': "4",
            'ml_traps': "Random",
            'ml_concentrations': "Random",
            'ml_regenerate_data': "False",
            'ml_regenerate_training': "False"
        }
        
        # Clear and set defaults
        for widget_name, default_value in defaults.items():
            if widget_name in self.entry_widgets:
                try:
                    self.entry_widgets[widget_name].delete(0, tk.END)
                    self.entry_widgets[widget_name].insert(0, default_value)
                except:
                    pass
        
        # Clear experimental file entries
        self.ml_exp_file_entry.delete(0, tk.END)

        # Reset experimental data unit combos to defaults
        if hasattr(self, 'exp_temp_unit_combo'):
            self.exp_temp_unit_combo.set("K")
        if hasattr(self, 'exp_y_unit_combo'):
            self.exp_y_unit_combo.set("mol/m\u00b2s")
        if hasattr(self, 'exp_y_type_combo'):
            self.exp_y_type_combo.set("Flux")
        
        # Reset selections to defaults
        self.trap_model_var.set("Oriani")
        self.show_individual_var.set(False)
        self.grid_var.set(True)
        
        # Reset combo boxes to defaults
        self.graphical_output_combo.set("Flux")
        self.x_axis_combo.set("Temperature")
        self.flux_unit_combo.set("mol/m\u00b2s")
        self.delta_c_unit_combo.set("mol/m\u00b3s")
        self.temp_unit_combo.set("K")
        
        # Reset trap count
        self.traps_count_var.set("6")
        
        # Update traps display
        self.update_traps_tab()
        
        if hasattr(self, 'vib_freq_entry'):
            self.vib_freq_entry.delete(0, tk.END)
            self.vib_freq_entry.insert(0, "1e13")
            
        if hasattr(self, 'trap_entries'):
            for trap_num, trap_data in self.trap_entries.items():
                if 'binding_energy' in trap_data:
                    try:
                        trap_data['binding_energy'].delete(0, tk.END)
                        trap_data['binding_energy'].insert(0, f"{20000 + trap_num * 5000}")
                    except:
                        pass
                if 'nt' in trap_data:
                    try:
                        trap_data['nt'].delete(0, tk.END)
                        trap_data['nt'].insert(0, "1.5e25")
                    except:
                        pass
                if 'ed' in trap_data:
                    try:
                        trap_data['ed'].delete(0, tk.END)
                        trap_data['ed'].insert(0, f"{20000 + trap_num * 5000}")
                    except:
                        pass
                if 'et' in trap_data:
                    try:
                        # For Et entries, temporarily enable, set value, then set back to readonly
                        trap_data['et'].config(state="normal")
                        trap_data['et'].delete(0, tk.END)
                        trap_data['et'].insert(0, "5690")
                        trap_data['et'].config(state="readonly")  # ALWAYS readonly
                    except:
                        pass
        
        # Reset button states properly
        self.ml_button.config(state=tk.NORMAL)
        self.stop_ml_button.config(state=tk.DISABLED)
        
        # Clear the stop flag
        self.stop_ml_flag.clear()
        
        # Reset ML progress
        self.reset_ml_progress()
        
        self.log_message("Reset to default values.")

def main():
    root = tk.Tk()
    app = SimulationGUI(root)
    root.mainloop()

if __name__ == "__main__":
    main()