From 029bcf1324836219da39155f11eb205f750627c5 Mon Sep 17 00:00:00 2001 From: xlei user Date: Fri, 16 Jan 2026 06:42:04 +0000 Subject: [PATCH] add pixio as external encoder --- bash_scripts/benchmark/dense_2_view/pixio.sh | 28 ++ configs/model/encoder/pixio.yaml | 10 + configs/model/pixio.yaml | 18 ++ mapanything/models/__init__.py | 3 + .../models/external/{ => pixio}/__init__.py | 0 .../models/external/pixio/layers/__init__.py | 0 .../models/external/pixio/layers/attention.py | 177 ++++++++++++ .../models/external/pixio/layers/drop_path.py | 50 ++++ .../models/external/pixio/layers/helpers.py | 25 ++ .../external/pixio/layers/layerscale.py | 41 +++ .../models/external/pixio/layers/mlp.py | 58 ++++ .../external/pixio/layers/patch_embed.py | 53 ++++ mapanything/models/external/pixio/pixio.py | 272 ++++++++++++++++++ mapanything/models/mapanything/model.py | 16 +- scripts/demo_local_weight.py | 9 +- 15 files changed, 758 insertions(+), 2 deletions(-) create mode 100644 bash_scripts/benchmark/dense_2_view/pixio.sh create mode 100644 configs/model/encoder/pixio.yaml create mode 100644 configs/model/pixio.yaml rename mapanything/models/external/{ => pixio}/__init__.py (100%) create mode 100644 mapanything/models/external/pixio/layers/__init__.py create mode 100644 mapanything/models/external/pixio/layers/attention.py create mode 100644 mapanything/models/external/pixio/layers/drop_path.py create mode 100644 mapanything/models/external/pixio/layers/helpers.py create mode 100644 mapanything/models/external/pixio/layers/layerscale.py create mode 100644 mapanything/models/external/pixio/layers/mlp.py create mode 100644 mapanything/models/external/pixio/layers/patch_embed.py create mode 100644 mapanything/models/external/pixio/pixio.py diff --git a/bash_scripts/benchmark/dense_2_view/pixio.sh b/bash_scripts/benchmark/dense_2_view/pixio.sh new file mode 100644 index 00000000..7afb05d3 --- /dev/null +++ b/bash_scripts/benchmark/dense_2_view/pixio.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +PRETRAINED_CHECKPOINT_PATH=$1 +RESULT_DIR=$2 + +export HYDRA_FULL_ERROR=1 + +echo "Running with task=images_only" + +python3 \ + benchmarking/dense_n_view/benchmark.py \ + machine=aws \ + dataset=benchmark_512_eth3d_snpp_tav2 \ + dataset.num_workers=12 \ + dataset.num_views=2 \ + batch_size=10 \ + model=pixio \ + model/task=images_only \ + model.pretrained=${PRETRAINED_CHECKPOINT_PATH} \ + hydra.run.dir="${RESULT_DIR}/mapa_24v_images_only" + +echo "Finished running with task=$task" + diff --git a/configs/model/encoder/pixio.yaml b/configs/model/encoder/pixio.yaml new file mode 100644 index 00000000..87bddf4f --- /dev/null +++ b/configs/model/encoder/pixio.yaml @@ -0,0 +1,10 @@ +# External encoder module +module: "mapanything.models.external.pixio.pixio" +# Class name +class_name: "PixioEncoder" +# huggingface model name +hf_model_name: "facebook/pixio-vith16" +# Flag to indicate whether to use gradient checkpointing for encoder +gradient_checkpointing: True +# Flag to indicate whether model class uses torch hub +uses_torch_hub: False diff --git a/configs/model/pixio.yaml b/configs/model/pixio.yaml new file mode 100644 index 00000000..1ce7fe6c --- /dev/null +++ b/configs/model/pixio.yaml @@ -0,0 +1,18 @@ +defaults: + - default + - encoder: pixio + - info_sharing: aat_ifr_24_layers + - pred_head: dpt_pose_scale + - task: images_only + +# String for model factory +model_str: "pixio" +# Model config +model_config: + name: "pixio" + encoder_config: ${model.encoder} + info_sharing_config: ${model.info_sharing} + pred_head_config: ${model.pred_head} + geometric_input_config: ${model.task} +# Image Normalization Type +data_norm_type: "dinov2" diff --git a/mapanything/models/__init__.py b/mapanything/models/__init__.py index 527707ca..8d6dc014 100644 --- a/mapanything/models/__init__.py +++ b/mapanything/models/__init__.py @@ -198,6 +198,9 @@ def init_model_from_config( "class_name": "VGGTWrapper", }, # Add other model classes here + "pixio": { + "class": MapAnything, + }, } diff --git a/mapanything/models/external/__init__.py b/mapanything/models/external/pixio/__init__.py similarity index 100% rename from mapanything/models/external/__init__.py rename to mapanything/models/external/pixio/__init__.py diff --git a/mapanything/models/external/pixio/layers/__init__.py b/mapanything/models/external/pixio/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mapanything/models/external/pixio/layers/attention.py b/mapanything/models/external/pixio/layers/attention.py new file mode 100644 index 00000000..f9023703 --- /dev/null +++ b/mapanything/models/external/pixio/layers/attention.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models +# -------------------------------------------------------- + +from typing import Optional, Type + +import torch +import torch.nn.functional as F +from torch import nn + +from .drop_path import DropPath +from .layerscale import LayerScale +from .mlp import Mlp + + +class SelfAttention(nn.Module): + """Standard Multi-head Self Attention module with QKV projection. + + This module implements the standard multi-head attention mechanism used in transformers. + It supports both the fused attention implementation (scaled_dot_product_attention) for + efficiency when available, and a manual implementation otherwise. The module includes + options for QK normalization, attention dropout, and projection dropout. + """ + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + scale_norm: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + norm_layer: Optional[Type[nn.Module]] = None, + ) -> None: + """Initialize the Attention module. + + Args: + dim: Input dimension of the token embeddings + num_heads: Number of attention heads + qkv_bias: Whether to use bias in the query, key, value projections + qk_norm: Whether to apply normalization to query and key vectors + proj_bias: Whether to use bias in the output projection + attn_drop: Dropout rate applied to the attention weights + proj_drop: Dropout rate applied after the output projection + norm_layer: Normalization layer constructor for QK normalization if enabled + """ + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + if qk_norm or scale_norm: + assert norm_layer is not None, ( + "norm_layer must be provided if qk_norm or scale_norm is True" + ) + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.norm = norm_layer(dim) if scale_norm else nn.Identity() + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, L, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, L, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + x = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + + x = x.transpose(1, 2).reshape(B, L, C) + x = self.norm(x) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SelfAttentionBlock(nn.Module): + """Transformer block with pre-normalization.""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + qk_norm: bool = False, + scale_attn_norm: bool = False, + scale_mlp_norm: bool = False, + proj_bias: bool = True, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + init_values: Optional[float] = None, + drop_path: float = 0.0, + act_layer: Type[nn.Module] = nn.GELU, + norm_layer: Type[nn.Module] = nn.LayerNorm, + mlp_layer: Type[nn.Module] = Mlp, + ) -> None: + """Initialize Block. + + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + qk_norm: If True, apply normalization to query and key. + proj_bias: If True, add bias to output projection. + proj_drop: Projection dropout rate. + attn_drop: Attention dropout rate. + init_values: Initial values for layer scale. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + mlp_layer: MLP layer. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = SelfAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + scale_norm=scale_attn_norm, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + norm_layer=norm_layer if scale_mlp_norm else None, + bias=proj_bias, + drop=proj_drop, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward( + self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x diff --git a/mapanything/models/external/pixio/layers/drop_path.py b/mapanything/models/external/pixio/layers/drop_path.py new file mode 100644 index 00000000..509e5b9b --- /dev/null +++ b/mapanything/models/external/pixio/layers/drop_path.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models +# -------------------------------------------------------- + +from torch import nn + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob, 3):0.3f}" diff --git a/mapanything/models/external/pixio/layers/helpers.py b/mapanything/models/external/pixio/layers/helpers.py new file mode 100644 index 00000000..ef6feeeb --- /dev/null +++ b/mapanything/models/external/pixio/layers/helpers.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models +# -------------------------------------------------------- + +import collections.abc +from itertools import repeat + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + + +to_2tuple = _ntuple(2) diff --git a/mapanything/models/external/pixio/layers/layerscale.py b/mapanything/models/external/pixio/layers/layerscale.py new file mode 100644 index 00000000..887887de --- /dev/null +++ b/mapanything/models/external/pixio/layers/layerscale.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models +# -------------------------------------------------------- + +import torch +from torch import nn + + +class LayerScale(nn.Module): + """Layer scale module. + + References: + - https://arxiv.org/abs/2103.17239 + """ + + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + """Initialize LayerScale module. + + Args: + dim: Dimension. + init_values: Initial value for scaling. + inplace: If True, perform inplace operations. + """ + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply layer scaling.""" + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/mapanything/models/external/pixio/layers/mlp.py b/mapanything/models/external/pixio/layers/mlp.py new file mode 100644 index 00000000..4a4dc513 --- /dev/null +++ b/mapanything/models/external/pixio/layers/mlp.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models +# -------------------------------------------------------- + +from functools import partial + +from torch import nn + +from .helpers import to_2tuple + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks + + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = ( + norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + ) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/mapanything/models/external/pixio/layers/patch_embed.py b/mapanything/models/external/pixio/layers/patch_embed.py new file mode 100644 index 00000000..e4939904 --- /dev/null +++ b/mapanything/models/external/pixio/layers/patch_embed.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models +# -------------------------------------------------------- + +from typing import Callable, Optional, Tuple, Union + +from torch import nn + +from .helpers import to_2tuple + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + bias: bool = True, + ): + super().__init__() + self.patch_size = to_2tuple(patch_size) + self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size) + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def _init_img_size(self, img_size: Union[int, Tuple[int, int]]): + assert self.patch_size + if img_size is None: + return None, None, None + img_size = to_2tuple(img_size) + grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)]) + num_patches = grid_size[0] * grid_size[1] + return img_size, grid_size, num_patches + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + + x = self.norm(x) + return x diff --git a/mapanything/models/external/pixio/pixio.py b/mapanything/models/external/pixio/pixio.py new file mode 100644 index 00000000..06cfecee --- /dev/null +++ b/mapanything/models/external/pixio/pixio.py @@ -0,0 +1,272 @@ +from collections import namedtuple +from functools import partial +from typing import Callable, Type, Union + +import torch +import torch.nn as nn +from huggingface_hub import hf_hub_download, list_repo_files +from torch.utils.checkpoint import checkpoint + +from .layers.attention import SelfAttentionBlock +from .layers.mlp import Mlp +from .layers.patch_embed import PatchEmbed + +torch.backends.cuda.enable_flash_sdp(True) +torch.backends.cuda.enable_mem_efficient_sdp(True) +torch.backends.cuda.enable_math_sdp(False) + + +class PixioEncoder(nn.Module): + def __init__( + self, + img_size: int = 256, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 1280, + depth: int = 32, + num_heads: int = 16, + mlp_ratio: float = 4.0, + n_cls_tokens: int = 8, + norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial( + nn.LayerNorm, eps=1e-6 + ), + hf_model_name: str = "facebook/pixio-vith16", + gradient_checkpointing: bool = True, + ): + """ + Pixio ViT Encoder. + """ + super().__init__() + + self.n_cls_tokens = n_cls_tokens + + self.patch_size = patch_size + + self.enc_embed_dim = embed_dim + + self.patch_embed = PatchEmbed( + img_size, patch_size, in_chans, self.enc_embed_dim + ) + + self.cls_token = nn.Parameter(torch.zeros(1, n_cls_tokens, self.enc_embed_dim)) + + self.pos_embed = nn.Parameter( + torch.zeros( + 1, self.patch_embed.num_patches + n_cls_tokens, self.enc_embed_dim + ) + ) + + self.blocks = nn.ModuleList( + [ + SelfAttentionBlock( + self.enc_embed_dim, + num_heads, + mlp_ratio, + qkv_bias=True, + norm_layer=norm_layer, + mlp_layer=Mlp, + ) + for _ in range(depth) + ] + ) + + self.norm = norm_layer(self.enc_embed_dim) + + ckpt_path = self.get_pth_file(repo_id=hf_model_name) + print(f"Loading pretrained Pixio Encoder from {ckpt_path} ...") + ckpt = torch.load(ckpt_path, weights_only=False) + print(self.load_state_dict(ckpt, strict=False)) + + if gradient_checkpointing: + for i in range(len(self.blocks)): + self.blocks[i] = self.wrap_module_with_gradient_checkpointing( + self.blocks[i] + ) + + def wrap_module_with_gradient_checkpointing(self, module: nn.Module): + class _CheckpointingWrapper(nn.Module): + def __init__(self, inner): + super().__init__() + self.inner = inner + + def forward(self, *args, **kwargs): + return checkpoint( + self.inner.forward, *args, use_reentrant=False, **kwargs + ) + + return _CheckpointingWrapper(module) + + def _interpolate_pos_emb(self, x): + """ + Interpolate the positional embeddings to match the input x. + """ + assert x.shape[-2] % self.patch_embed.patch_size[0] == 0, ( + f"height {x.shape[-2]} must be divisible by patch size {self.patch_embed.patch_size[0]}" + ) + assert x.shape[-1] % self.patch_embed.patch_size[1] == 0, ( + f"width {x.shape[-1]} must be divisible by patch size {self.patch_embed.patch_size[1]}" + ) + + H = x.shape[-2] // self.patch_embed.patch_size[0] + W = x.shape[-1] // self.patch_embed.patch_size[1] + + cls_pos_embed = self.pos_embed[:, : self.n_cls_tokens] + patch_pos_embed = self.pos_embed[:, self.n_cls_tokens :] + + pt_size = int(patch_pos_embed.shape[1] ** 0.5) + + if pt_size == H == W: + return self.pos_embed + + patch_pos_embed = patch_pos_embed.reshape(1, pt_size, pt_size, -1).permute( + 0, 3, 1, 2 + ) + patch_pos_embed = torch.nn.functional.interpolate( + patch_pos_embed, size=(H, W), mode="bicubic", align_corners=False + ) + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, H * W, -1) + + new_pos_embed = torch.cat((cls_pos_embed, patch_pos_embed), dim=1) + + return new_pos_embed + + def forward(self, encoder_input): + assert isinstance(encoder_input.image, torch.Tensor), ( + "Input must be a torch.Tensor" + ) + assert encoder_input.image.ndim == 4, "Input must be of shape (B, C, H, W)" + _, channels, height, width = encoder_input.image.shape + assert channels == 3, "Input must have 3 channels" + assert height % self.patch_size == 0 and width % self.patch_size == 0, ( + f"Input shape must be divisible by patch size: {self.patch_size}" + ) + + pos_embed = self._interpolate_pos_emb(encoder_input.image) + + x = self.patch_embed(encoder_input.image) + + x = x + pos_embed[:, self.n_cls_tokens :, :] + + cls_token = self.cls_token + pos_embed[:, : self.n_cls_tokens, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + layers = list(range(len(self.blocks))) + for i, blk in enumerate(self.blocks): + x = blk(x) + + if i in layers: + x_norm = self.norm(x) + features = x_norm[:, self.n_cls_tokens :] + + features = features.permute(0, 2, 1) + + features = features.reshape( + -1, self.enc_embed_dim, height // self.patch_size, width // self.patch_size + ).contiguous() + + return namedtuple("res", ["features"])(features=features) + + def get_pth_file(self, repo_id: str) -> str: + files = list_repo_files(repo_id) + pth_files = [f for f in files if f.endswith(".pth")] + if not pth_files: + raise FileNotFoundError(f"No .pth file found in {repo_id}") + if len(pth_files) > 1: + raise ValueError(f"Multiple .pth files found: {pth_files}") + return hf_hub_download(repo_id=repo_id, filename=pth_files[0]) + + +def pixio_vitb16(pretrained=None): + model = PixioEncoder( + img_size=256, + patch_size=16, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + n_cls_tokens=8, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) + + if pretrained: + state_dict = torch.load(pretrained, map_location="cpu", weights_only=False) + model.load_state_dict(state_dict) + + return model + + +def pixio_vitl16(pretrained=None): + model = PixioEncoder( + img_size=256, + patch_size=16, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + n_cls_tokens=8, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) + + if pretrained: + state_dict = torch.load(pretrained, map_location="cpu", weights_only=False) + model.load_state_dict(state_dict) + + return model + + +def pixio_vith16(pretrained=None): + model = PixioEncoder( + img_size=256, + patch_size=16, + embed_dim=1280, + depth=32, + num_heads=16, + mlp_ratio=4, + n_cls_tokens=8, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) + + if pretrained: + state_dict = torch.load(pretrained, map_location="cpu", weights_only=False) + model.load_state_dict(state_dict) + + return model + + +def pixio_vit1b16(pretrained=None): + model = PixioEncoder( + img_size=256, + patch_size=16, + embed_dim=1536, + depth=48, + num_heads=24, + mlp_ratio=4, + n_cls_tokens=8, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) + + if pretrained: + state_dict = torch.load(pretrained, map_location="cpu", weights_only=False) + model.load_state_dict(state_dict) + + return model + + +def pixio_vit5b16(pretrained=None): + model = PixioEncoder( + img_size=256, + patch_size=16, + embed_dim=3072, + depth=48, + num_heads=32, + mlp_ratio=4, + n_cls_tokens=8, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + ) + + if pretrained: + state_dict = torch.load(pretrained, map_location="cpu", weights_only=False) + model.load_state_dict(state_dict) + + return model diff --git a/mapanything/models/mapanything/model.py b/mapanything/models/mapanything/model.py index 644f02b7..b941da6a 100644 --- a/mapanything/models/mapanything/model.py +++ b/mapanything/models/mapanything/model.py @@ -7,6 +7,7 @@ MapAnything model class defined using UniCeption modules. """ +import importlib import warnings from functools import partial from typing import Any, Callable, Dict, List, Tuple, Type, Union @@ -163,7 +164,7 @@ def __init__( # Create a copy of the config before deleting the key to preserve it for serialization encoder_config_copy = self.encoder_config.copy() del encoder_config_copy["uses_torch_hub"] - self.encoder = encoder_factory(**encoder_config_copy) + self.encoder = self._initialize_image_encoder(encoder_config_copy) # Initialize the encoder for ray directions ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"] @@ -240,6 +241,19 @@ def device(self) -> torch.device: def dtype(self) -> torch.dtype: return next(self.parameters()).dtype + def _initialize_image_encoder(self, encoder_config): + """Initialize image encoder using importlib if module/class_name provided, else use uniception.""" + if "module" in encoder_config and "class_name" in encoder_config: + module = importlib.import_module(encoder_config["module"]) + encoder_class = getattr(module, encoder_config["class_name"]) + config_copy = { + k: v + for k, v in encoder_config.items() + if k not in ("module", "class_name") + } + return encoder_class(**config_copy) + return encoder_factory(**encoder_config) + def _initialize_info_sharing(self, info_sharing_config): """ Initialize the information sharing module based on the configuration. diff --git a/scripts/demo_local_weight.py b/scripts/demo_local_weight.py index f774586f..fb96cfb6 100644 --- a/scripts/demo_local_weight.py +++ b/scripts/demo_local_weight.py @@ -97,7 +97,14 @@ def main() -> None: print("Successfully loaded pretrained weights") print(f"Loading images from: {args.image_folder}") - views = load_images(args.image_folder) + if "patch_size" in args.local_config and "resolution" in args.local_config: + views = load_images( + args.image_folder, + patch_size=args.local_config["patch_size"], + resolution_set=args.local_config["resolution"], + ) + else: + views = load_images(args.image_folder) if len(views) == 0: raise ValueError(f"No images found in {args.image_folder}") print(f"Loaded {len(views)} views")