-
Notifications
You must be signed in to change notification settings - Fork 30
Open
Description
Hello,
I'm trying to integrate with a trained checkpoint from a 3dgs trained checkpoint of 30k or 10k iters , however the results are not that good compared to what if i trained with MiLO directly. Though i was able to get a rgb mesh it didnt well preseve the structures even if splat file is extremely good as i captured it from a high res cam setting. Can you please check whether following script is correct
import torch
import os
import sys
import numpy as np
import trimesh
from argparse import ArgumentParser
# Add paths to submodules
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'submodules'))
from scene import Scene, GaussianModel
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer.radegs import render_radegs
from regularization.sdf.depth_fusion import evaluate_mesh_colors_all_vertices
# Import the MILo Functional API
from functional import (
sample_gaussians_on_surface,
extract_gaussian_pivots,
compute_initial_sdf_values,
compute_delaunay_triangulation,
extract_mesh,
)
def sanitize_gaussians(gaussians):
"""
Ensures Gaussian parameters are valid to prevent CUDA crashes.
"""
print("[INFO] Sanitizing Gaussian parameters...")
# 1. NaN/Inf checks and fixes
with torch.no_grad():
# XYZ
if torch.isnan(gaussians._xyz).any() or torch.isinf(gaussians._xyz).any():
print("[WARN] Found NaNs/Infs in XYZ. Replacing with 0.")
gaussians._xyz.data = torch.nan_to_num(gaussians._xyz.data, nan=0.0, posinf=0.0, neginf=0.0)
# Scaling
# Clamp huge scales that cause bounding box errors in rasterizer
# Scaling is stored as log-space in the model
max_scale = 100.0
scales = gaussians.get_scaling
if (scales > max_scale).any():
print(f"[WARN] Found scales > {max_scale}. Clamping.")
max_log_scale = np.log(max_scale)
gaussians._scaling.data = torch.clamp(gaussians._scaling.data, max=max_log_scale)
# 2. Normalize Rotations
# Invalid quaternions cause NaN rotation matrices -> CUDA crash
gaussians._rotation.data = torch.nn.functional.normalize(gaussians._rotation.data, dim=-1)
# 3. Opacity
if torch.isnan(gaussians._opacity).any():
gaussians._opacity.data = torch.nan_to_num(gaussians._opacity.data, 0.0)
print("[INFO] Sanitization complete.")
def export_standard_3dgs_to_mesh(dataset, iteration, pipeline, args):
with torch.no_grad():
# 1. Load the Gaussian Model (Standard 3DGS)
print(f"[INFO] Loading Gaussian Splatting checkpoint...")
# Force mip_filter=False and learn_occupancy=False for standard checkpoints
gaussians = GaussianModel(dataset.sh_degree, use_mip_filter=False, learn_occupancy=False)
# Initialize Scene (Loads Training Cameras)
# This will load the PLY from iteration automatically
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
print(f"[INFO] Loaded {gaussians.get_xyz.shape[0]} Gaussians.")
# --- FIX 1: Sanitize Data ---
sanitize_gaussians(gaussians)
# ----------------------------
# Setup basics
train_cameras = scene.getTrainCameras()
bg_color = torch.tensor([0., 0., 0.], device="cuda")
# 2. Define the Render Function Wrapper
def render_func(view):
render_pkg = render_radegs(
viewpoint_camera=view,
pc=gaussians,
pipe=pipeline,
bg_color=bg_color,
kernel_size=0.0,
scaling_modifier=1.0,
require_coord=False,
require_depth=True
)
return {
"render": render_pkg["render"],
"depth": render_pkg["median_depth"],
}
# 3. Get Parameters (Force contiguous memory)
# --- FIX 2: Ensure contiguous memory layout for CUDA ---
means = gaussians.get_xyz.contiguous()
scales = gaussians.get_scaling.contiguous()
rotations = gaussians.get_rotation.contiguous()
opacities = gaussians.get_opacity.contiguous()
# 4. Sample Gaussians on Surface
print("[INFO] Selecting Gaussians for triangulation...")
if means.shape[0] == 0:
print("[ERROR] No Gaussians loaded!")
return
# --- FIX 3: Skip Complex Sampling if requested ---
if args.simple_sampling:
print("[INFO] Using SIMPLE sampling (Opacity based) to avoid CUDA crashes.")
# Sort by opacity descending
# We assume high opacity gaussians define the surface better than low opacity "fog"
indices = torch.argsort(opacities.squeeze(), descending=True)
# Take top N
n_samples = min(args.n_max_samples, indices.shape[0])
surface_gaussians_idx = indices[:n_samples]
else:
# Try the complex sampling
try:
print("[INFO] Attempting Importance Sampling (may crash on cleaned PLYs)...")
surface_gaussians_idx = sample_gaussians_on_surface(
views=train_cameras,
means=means,
scales=scales,
rotations=rotations,
opacities=opacities,
n_max_samples=args.n_max_samples,
scene_type=args.imp_metric,
)
except RuntimeError as e:
print(f"\n[FATAL ERROR] CUDA Crash during sampling: {e}")
print("Please restart this script with the flag: --simple_sampling")
sys.exit(1)
print(f"[INFO] Selected {len(surface_gaussians_idx)} Gaussians.")
# 5. Compute Delaunay Triangulation
print("[INFO] Computing Delaunay Triangulation...")
delaunay_tets = compute_delaunay_triangulation(
means=means,
scales=scales,
rotations=rotations,
gaussian_idx=surface_gaussians_idx,
)
# 6. Compute SDF via Depth Fusion
print("[INFO] Fusing depth maps to compute SDF...")
# Note: If render_radegs crashes here, we have deeper issues with the SH data.
initial_pivots_sdf = compute_initial_sdf_values(
views=train_cameras,
render_func=render_func,
means=means,
scales=scales,
rotations=rotations,
method='depth_fusion',
gaussian_idx=surface_gaussians_idx,
)
# 7. Extract the Mesh
print("[INFO] Extracting mesh via Marching Tetrahedra...")
mesh_object = extract_mesh(
delaunay_tets=delaunay_tets,
pivots_sdf=initial_pivots_sdf,
means=means,
scales=scales,
rotations=rotations,
gaussian_idx=surface_gaussians_idx,
filter_large_edges=True
)
# 8. Colorize the Mesh
print("[INFO] Computing vertex colors...")
vert_colors = evaluate_mesh_colors_all_vertices(
views=train_cameras,
mesh=mesh_object,
masks=None,
use_scalable_renderer=True,
)
# 9. Save
output_ply = os.path.join(dataset.model_path, f"mesh_extracted_{iteration}.ply")
print(f"[INFO] Saving mesh to {output_ply}...")
trimesh_object = trimesh.Trimesh(
vertices=mesh_object.verts.cpu().numpy(),
faces=mesh_object.faces.cpu().numpy(),
vertex_colors=vert_colors.cpu().numpy(),
process=False
)
trimesh_object.export(output_ply)
print("[INFO] Done.")
if __name__ == "__main__":
parser = ArgumentParser(description="Export Standard 3DGS to Mesh")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--imp_metric", default='outdoor', type=str, choices=['indoor', 'outdoor'])
parser.add_argument("--n_max_samples", default=600_000, type=int)
parser.add_argument("--simple_sampling", action="store_true", help="Skip importance sampling and use simple opacity sorting to avoid CUDA crashes.")
args = get_combined_args(parser)
# Initialize system
torch.cuda.set_device(torch.device("cuda:0"))
export_standard_3dgs_to_mesh(
model.extract(args),
args.iteration,
pipeline.extract(args),
args
)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels