diff --git a/nerfies/configs.py b/nerfies/configs.py index f349804..188c90e 100644 --- a/nerfies/configs.py +++ b/nerfies/configs.py @@ -16,7 +16,7 @@ from typing import Any, Mapping, Optional, Tuple import dataclasses -from flax import nn +from flax import linen as nn import gin import immutabledict diff --git a/nerfies/evaluation.py b/nerfies/evaluation.py index eed8a0c..a86ab68 100644 --- a/nerfies/evaluation.py +++ b/nerfies/evaluation.py @@ -92,7 +92,7 @@ def render_image( ret_map = jax_utils.unreplicate(model_out[ret_key]) ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map) ret_maps.append(ret_map) - ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps) + ret_map = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *ret_maps) logging.info('Rendering took %.04s', time.time() - start_time) out = {} for key, value in ret_map.items(): diff --git a/nerfies/model_utils.py b/nerfies/model_utils.py index 48a3b64..780eb0f 100644 --- a/nerfies/model_utils.py +++ b/nerfies/model_utils.py @@ -104,7 +104,7 @@ def volumetric_rendering(rgb, last_sample_z = 1e10 if sample_at_infinity else 1e-19 dists = jnp.concatenate([ z_vals[..., 1:] - z_vals[..., :-1], - jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape) + jnp.broadcast_to(jnp.array([last_sample_z]), z_vals[..., :1].shape) ], -1) dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1) alpha = 1.0 - jnp.exp(-sigma * dists) diff --git a/nerfies/schedules.py b/nerfies/schedules.py index e73da25..1e5bc79 100644 --- a/nerfies/schedules.py +++ b/nerfies/schedules.py @@ -15,6 +15,7 @@ """Annealing Schedules.""" import abc import collections +from collections.abc import Mapping import copy import math from typing import Any, Iterable, List, Tuple, Union @@ -38,7 +39,7 @@ def from_config(schedule): return schedule if isinstance(schedule, Tuple) or isinstance(schedule, List): return from_tuple(schedule) - if isinstance(schedule, collections.Mapping): + if isinstance(schedule, Mapping): return from_dict(schedule) raise ValueError(f'Unknown type {type(schedule)}.') diff --git a/nerfies/utils.py b/nerfies/utils.py index b4543f5..d4beefe 100644 --- a/nerfies/utils.py +++ b/nerfies/utils.py @@ -377,7 +377,7 @@ def strided_subset(sequence, count): def tree_collate(list_of_pytrees): """Collates a list of pytrees with the same structure.""" - return tree_util.tree_multimap(lambda *x: np.stack(x), *list_of_pytrees) + return tree_util.tree_map(lambda *x: np.stack(x), *list_of_pytrees) @contextlib.contextmanager