import numpy as np, os, matplotlib as plt
from hotwheels_core.utils import *
from hotwheels_core.wrap import *
from hotwheels_core.soas import *
from hotwheels_PM import *
from hotwheels_integrate import *
from hotwheels_io import *

#
# Step 1: Configure components
# This stage configures components without allocating resources.
# Configurations are passed to constructors to compile the underlying C libraries.
#

mpi = MPI().init() # Initialize MPI
mym = MyMalloc(alloc_bytes=int(2e9)) # Configure memory allocator with 2GB
p = SoA(maxpart=int(1e5), mem=mym) # Configure P to hold 1e5 particles
soas = SoAs(p) # Add P to a multi-type SoA container
# Set up a fixed time-step integrator from 0 to 1 Gyr
# Conversion factor for Gyr to internal units
gyr_to_cu = 3.086e+16 / (1e9 * 3600 * 24 * 365)
ts = FixedTimeStep(
    soas,
    G=43007.1,  # Gravitational constant in specific units
    t_from=0.,
    t_to=1. * gyr_to_cu,
    MPI=mpi
)
# Initialize a NFW profile with scale radius `rs=100` and density `rho0=1e-6`
ic = NFWIC(r_s=100., rho_0=1e-6, r_max_f=10.)
# Configure a refined PM grid with 7 stacked high-resolution regions
pm = SuperHiResPM( #wrapper to the PM C library
    soas=soas,
    mem=mym,
    TS=ts, #will use it to attach gravkick callback
    MPI=mpi,
    pmgrid=128,
    grids=8, # number of grids to instantiate
    dt_displacement_factor=0.25 #factor for DtDisplacement
)
build = make.Build('./', mpi, pm, ts, mym, *soas.values()) # Compile all modules in the current directory
headers = OnTheFly(build.build_name, *build.components, generate_user_c=True) # Generate SoA headers

if mpi.rank == 0:  # Master rank handles compilation
    headers.write()
    build.compile()

#
# Step 2: Allocate resources
#

with (
    Panic(Build=build) as panic,  # Attach panic handler
    Timer(Build=build) as timer,  # Attach timer handler
    build.enter(debug=mpi.rank == 0),  # Parse compiled objects
    mpi.enter(pm),  # Initialize MPI in the PM module
    mym.enter(*build.components),  # Allocate 2GB memory
    p,  # Allocate particle data structure in MyMalloc
    ic.enter(p, mpi.ranks, p.get_maxpart(), ts.G),  # Sample NFW profile
    pm,  # Initialize PM and compute first accelerations
    ts  # Compute DriftTables if needed
):

    #
    # Step 3: Main simulation loop
    #
    while ts.time < ts.time_end:
        ts.find_timesteps()  # Determine timesteps
        ts.do_first_halfstep_kick()  # First kick (includes drift/kick callbacks)
        ts.drift()  # Update particle positions
        pm.compute_accelerations()  # Recompute accelerations
        ts.do_second_halfstep_kick()  # Second kick

        # Occasionally, generate plots on the master rank
        if mpi.rank == 0 and ts.steps % 10 == 0:
            fig, ax = plt.subplots(1)
            ax.hist2d(p['pos'][:, 0], p['pos'][:, 1], bins=128)
            ax.set_aspect('equal')
            fig.savefig(f'snap{ts.steps}_rank{mpi.rank}.png', bbox_inches='tight', dpi=200)
            plt.close(fig)

print('Simulation finished')
