Skip to content

3DGS Functional API with MiLO #29

@Jaykumaran

Description

@Jaykumaran

Hello,

I'm trying to integrate with a trained checkpoint from a 3dgs trained checkpoint of 30k or 10k iters , however the results are not that good compared to what if i trained with MiLO directly. Though i was able to get a rgb mesh it didnt well preseve the structures even if splat file is extremely good as i captured it from a high res cam setting. Can you please check whether following script is correct

import torch
import os
import sys
import numpy as np
import trimesh
from argparse import ArgumentParser

# Add paths to submodules
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'submodules'))

from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer.radegs import render_radegs
from regularization.sdf.depth_fusion import evaluate_mesh_colors_all_vertices

# Import the MILo Functional API
from functional import (
    sample_gaussians_on_surface,
    extract_gaussian_pivots,
    compute_initial_sdf_values,
    compute_delaunay_triangulation,
    extract_mesh,
)

def sanitize_gaussians(gaussians):
    """
    Ensures Gaussian parameters are valid to prevent CUDA crashes.
    """
    print("[INFO] Sanitizing Gaussian parameters...")
    
    # 1. NaN/Inf checks and fixes
    with torch.no_grad():
        # XYZ
        if torch.isnan(gaussians._xyz).any() or torch.isinf(gaussians._xyz).any():
            print("[WARN] Found NaNs/Infs in XYZ. Replacing with 0.")
            gaussians._xyz.data = torch.nan_to_num(gaussians._xyz.data, nan=0.0, posinf=0.0, neginf=0.0)
            
        # Scaling
        # Clamp huge scales that cause bounding box errors in rasterizer
        # Scaling is stored as log-space in the model
        max_scale = 100.0 
        scales = gaussians.get_scaling
        if (scales > max_scale).any():
            print(f"[WARN] Found scales > {max_scale}. Clamping.")
            max_log_scale = np.log(max_scale)
            gaussians._scaling.data = torch.clamp(gaussians._scaling.data, max=max_log_scale)
            
        # 2. Normalize Rotations
        # Invalid quaternions cause NaN rotation matrices -> CUDA crash
        gaussians._rotation.data = torch.nn.functional.normalize(gaussians._rotation.data, dim=-1)
        
        # 3. Opacity
        if torch.isnan(gaussians._opacity).any():
             gaussians._opacity.data = torch.nan_to_num(gaussians._opacity.data, 0.0)

    print("[INFO] Sanitization complete.")

def export_standard_3dgs_to_mesh(dataset, iteration, pipeline, args):
    with torch.no_grad():
        # 1. Load the Gaussian Model (Standard 3DGS)
        print(f"[INFO] Loading Gaussian Splatting checkpoint...")
        # Force mip_filter=False and learn_occupancy=False for standard checkpoints
        gaussians = GaussianModel(dataset.sh_degree, use_mip_filter=False, learn_occupancy=False)
        
        # Initialize Scene (Loads Training Cameras)
        # This will load the PLY from iteration automatically
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
        
        print(f"[INFO] Loaded {gaussians.get_xyz.shape[0]} Gaussians.")
        
        # --- FIX 1: Sanitize Data ---
        sanitize_gaussians(gaussians)
        # ----------------------------

        # Setup basics
        train_cameras = scene.getTrainCameras()
        bg_color = torch.tensor([0., 0., 0.], device="cuda")
        
        # 2. Define the Render Function Wrapper
        def render_func(view):
            render_pkg = render_radegs(
                viewpoint_camera=view, 
                pc=gaussians, 
                pipe=pipeline, 
                bg_color=bg_color, 
                kernel_size=0.0, 
                scaling_modifier=1.0, 
                require_coord=False, 
                require_depth=True
            )
            return {
                "render": render_pkg["render"],
                "depth": render_pkg["median_depth"],
            }

        # 3. Get Parameters (Force contiguous memory)
        # --- FIX 2: Ensure contiguous memory layout for CUDA ---
        means = gaussians.get_xyz.contiguous()
        scales = gaussians.get_scaling.contiguous()
        rotations = gaussians.get_rotation.contiguous()
        opacities = gaussians.get_opacity.contiguous()

        # 4. Sample Gaussians on Surface
        print("[INFO] Selecting Gaussians for triangulation...")
        
        if means.shape[0] == 0:
            print("[ERROR] No Gaussians loaded!")
            return

        # --- FIX 3: Skip Complex Sampling if requested ---
        if args.simple_sampling:
            print("[INFO] Using SIMPLE sampling (Opacity based) to avoid CUDA crashes.")
            # Sort by opacity descending
            # We assume high opacity gaussians define the surface better than low opacity "fog"
            indices = torch.argsort(opacities.squeeze(), descending=True)
            # Take top N
            n_samples = min(args.n_max_samples, indices.shape[0])
            surface_gaussians_idx = indices[:n_samples]
        else:
            # Try the complex sampling
            try:
                print("[INFO] Attempting Importance Sampling (may crash on cleaned PLYs)...")
                surface_gaussians_idx = sample_gaussians_on_surface(
                    views=train_cameras,
                    means=means,
                    scales=scales,
                    rotations=rotations,
                    opacities=opacities,
                    n_max_samples=args.n_max_samples,
                    scene_type=args.imp_metric,
                )
            except RuntimeError as e:
                print(f"\n[FATAL ERROR] CUDA Crash during sampling: {e}")
                print("Please restart this script with the flag: --simple_sampling")
                sys.exit(1)

        print(f"[INFO] Selected {len(surface_gaussians_idx)} Gaussians.")

        # 5. Compute Delaunay Triangulation
        print("[INFO] Computing Delaunay Triangulation...")
        delaunay_tets = compute_delaunay_triangulation(
            means=means,
            scales=scales,
            rotations=rotations,
            gaussian_idx=surface_gaussians_idx,
        )

        # 6. Compute SDF via Depth Fusion
        print("[INFO] Fusing depth maps to compute SDF...")
        # Note: If render_radegs crashes here, we have deeper issues with the SH data.
        initial_pivots_sdf = compute_initial_sdf_values(
            views=train_cameras,
            render_func=render_func,
            means=means,
            scales=scales,
            rotations=rotations,
            method='depth_fusion', 
            gaussian_idx=surface_gaussians_idx,
        )

        # 7. Extract the Mesh
        print("[INFO] Extracting mesh via Marching Tetrahedra...")
        mesh_object = extract_mesh(
            delaunay_tets=delaunay_tets,
            pivots_sdf=initial_pivots_sdf,
            means=means,
            scales=scales,
            rotations=rotations,
            gaussian_idx=surface_gaussians_idx,
            filter_large_edges=True
        )

        # 8. Colorize the Mesh
        print("[INFO] Computing vertex colors...")
        vert_colors = evaluate_mesh_colors_all_vertices(
            views=train_cameras, 
            mesh=mesh_object,
            masks=None,
            use_scalable_renderer=True, 
        )

        # 9. Save
        output_ply = os.path.join(dataset.model_path, f"mesh_extracted_{iteration}.ply")
        print(f"[INFO] Saving mesh to {output_ply}...")
        
        trimesh_object = trimesh.Trimesh(
            vertices=mesh_object.verts.cpu().numpy(),
            faces=mesh_object.faces.cpu().numpy(),
            vertex_colors=vert_colors.cpu().numpy(),
            process=False
        )
        trimesh_object.export(output_ply)
        print("[INFO] Done.")

if __name__ == "__main__":
    parser = ArgumentParser(description="Export Standard 3DGS to Mesh")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--imp_metric", default='outdoor', type=str, choices=['indoor', 'outdoor'])
    parser.add_argument("--n_max_samples", default=600_000, type=int)
    parser.add_argument("--simple_sampling", action="store_true", help="Skip importance sampling and use simple opacity sorting to avoid CUDA crashes.")
args = get_combined_args(parser)

# Initialize system
torch.cuda.set_device(torch.device("cuda:0"))

export_standard_3dgs_to_mesh(
    model.extract(args), 
    args.iteration, 
    pipeline.extract(args), 
    args
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions