Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions bash_scripts/benchmark/dense_2_view/pixio.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/bin/bash

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

PRETRAINED_CHECKPOINT_PATH=$1
RESULT_DIR=$2

export HYDRA_FULL_ERROR=1

echo "Running with task=images_only"

python3 \
benchmarking/dense_n_view/benchmark.py \
machine=aws \
dataset=benchmark_512_eth3d_snpp_tav2 \
dataset.num_workers=12 \
dataset.num_views=2 \
batch_size=10 \
model=pixio \
model/task=images_only \
model.pretrained=${PRETRAINED_CHECKPOINT_PATH} \
hydra.run.dir="${RESULT_DIR}/mapa_24v_images_only"

echo "Finished running with task=$task"

10 changes: 10 additions & 0 deletions configs/model/encoder/pixio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# External encoder module
module: "mapanything.models.external.pixio.pixio"
# Class name
class_name: "PixioEncoder"
# huggingface model name
hf_model_name: "facebook/pixio-vith16"
# Flag to indicate whether to use gradient checkpointing for encoder
gradient_checkpointing: True
# Flag to indicate whether model class uses torch hub
uses_torch_hub: False
18 changes: 18 additions & 0 deletions configs/model/pixio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
defaults:
- default
- encoder: pixio
- info_sharing: aat_ifr_24_layers
- pred_head: dpt_pose_scale
- task: images_only

# String for model factory
model_str: "pixio"
# Model config
model_config:
name: "pixio"
encoder_config: ${model.encoder}
info_sharing_config: ${model.info_sharing}
pred_head_config: ${model.pred_head}
geometric_input_config: ${model.task}
# Image Normalization Type
data_norm_type: "dinov2"
3 changes: 3 additions & 0 deletions mapanything/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,9 @@ def init_model_from_config(
"class_name": "VGGTWrapper",
},
# Add other model classes here
"pixio": {
"class": MapAnything,
},
}


Expand Down
Empty file.
177 changes: 177 additions & 0 deletions mapanything/models/external/pixio/layers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models
# --------------------------------------------------------

from typing import Optional, Type

import torch
import torch.nn.functional as F
from torch import nn

from .drop_path import DropPath
from .layerscale import LayerScale
from .mlp import Mlp


class SelfAttention(nn.Module):
"""Standard Multi-head Self Attention module with QKV projection.

This module implements the standard multi-head attention mechanism used in transformers.
It supports both the fused attention implementation (scaled_dot_product_attention) for
efficiency when available, and a manual implementation otherwise. The module includes
options for QK normalization, attention dropout, and projection dropout.
"""

def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_norm: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
norm_layer: Optional[Type[nn.Module]] = None,
) -> None:
"""Initialize the Attention module.

Args:
dim: Input dimension of the token embeddings
num_heads: Number of attention heads
qkv_bias: Whether to use bias in the query, key, value projections
qk_norm: Whether to apply normalization to query and key vectors
proj_bias: Whether to use bias in the output projection
attn_drop: Dropout rate applied to the attention weights
proj_drop: Dropout rate applied after the output projection
norm_layer: Normalization layer constructor for QK normalization if enabled
"""
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
if qk_norm or scale_norm:
assert norm_layer is not None, (
"norm_layer must be provided if qk_norm or scale_norm is True"
)
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.norm = norm_layer(dim) if scale_norm else nn.Identity()
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)

def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, L, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, L, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)

x = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p if self.training else 0.0,
)

x = x.transpose(1, 2).reshape(B, L, C)
x = self.norm(x)
x = self.proj(x)
x = self.proj_drop(x)
return x


class SelfAttentionBlock(nn.Module):
"""Transformer block with pre-normalization."""

def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
qk_norm: bool = False,
scale_attn_norm: bool = False,
scale_mlp_norm: bool = False,
proj_bias: bool = True,
proj_drop: float = 0.0,
attn_drop: float = 0.0,
init_values: Optional[float] = None,
drop_path: float = 0.0,
act_layer: Type[nn.Module] = nn.GELU,
norm_layer: Type[nn.Module] = nn.LayerNorm,
mlp_layer: Type[nn.Module] = Mlp,
) -> None:
"""Initialize Block.

Args:
dim: Number of input channels.
num_heads: Number of attention heads.
mlp_ratio: Ratio of mlp hidden dim to embedding dim.
qkv_bias: If True, add a learnable bias to query, key, value.
qk_norm: If True, apply normalization to query and key.
proj_bias: If True, add bias to output projection.
proj_drop: Projection dropout rate.
attn_drop: Attention dropout rate.
init_values: Initial values for layer scale.
drop_path: Stochastic depth rate.
act_layer: Activation layer.
norm_layer: Normalization layer.
mlp_layer: MLP layer.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SelfAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
scale_norm=scale_attn_norm,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

self.norm2 = norm_layer(dim)
self.mlp = mlp_layer(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
norm_layer=norm_layer if scale_mlp_norm else None,
bias=proj_bias,
drop=proj_drop,
)
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

def forward(
self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), attn_mask=attn_mask)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
50 changes: 50 additions & 0 deletions mapanything/models/external/pixio/layers/drop_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models
# --------------------------------------------------------

from torch import nn


def drop_path(
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.

"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor


class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""

def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep

def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

def extra_repr(self):
return f"drop_prob={round(self.drop_prob, 3):0.3f}"
25 changes: 25 additions & 0 deletions mapanything/models/external/pixio/layers/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models
# --------------------------------------------------------

import collections.abc
from itertools import repeat


# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))

return parse


to_2tuple = _ntuple(2)
41 changes: 41 additions & 0 deletions mapanything/models/external/pixio/layers/layerscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models
# --------------------------------------------------------

import torch
from torch import nn


class LayerScale(nn.Module):
"""Layer scale module.

References:
- https://arxiv.org/abs/2103.17239
"""

def __init__(
self,
dim: int,
init_values: float = 1e-5,
inplace: bool = False,
) -> None:
"""Initialize LayerScale module.

Args:
dim: Dimension.
init_values: Initial value for scaling.
inplace: If True, perform inplace operations.
"""
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply layer scaling."""
return x.mul_(self.gamma) if self.inplace else x * self.gamma
Loading