From 461ac85c25ee2292f3ac2724209aec23b1b11211 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 16 Aug 2022 22:04:47 +0800 Subject: [PATCH 1/9] Update model_utils.py --- nerfies/model_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/model_utils.py b/nerfies/model_utils.py index 48a3b64..780eb0f 100644 --- a/nerfies/model_utils.py +++ b/nerfies/model_utils.py @@ -104,7 +104,7 @@ def volumetric_rendering(rgb, last_sample_z = 1e10 if sample_at_infinity else 1e-19 dists = jnp.concatenate([ z_vals[..., 1:] - z_vals[..., :-1], - jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape) + jnp.broadcast_to(jnp.array([last_sample_z]), z_vals[..., :1].shape) ], -1) dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1) alpha = 1.0 - jnp.exp(-sigma * dists) From a419d330685d257c54b8e4c019312b551a362ff9 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 16 Aug 2022 22:05:57 +0800 Subject: [PATCH 2/9] Update configs.py --- nerfies/configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/configs.py b/nerfies/configs.py index f349804..4aacabb 100644 --- a/nerfies/configs.py +++ b/nerfies/configs.py @@ -16,7 +16,7 @@ from typing import Any, Mapping, Optional, Tuple import dataclasses -from flax import nn +from flax import line as nn import gin import immutabledict From d83d5a7fb9f2654e1629c6bfee6b3ab1e019ddfa Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 16 Aug 2022 22:07:15 +0800 Subject: [PATCH 3/9] Update configs.py --- nerfies/configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/configs.py b/nerfies/configs.py index 4aacabb..188c90e 100644 --- a/nerfies/configs.py +++ b/nerfies/configs.py @@ -16,7 +16,7 @@ from typing import Any, Mapping, Optional, Tuple import dataclasses -from flax import line as nn +from flax import linen as nn import gin import immutabledict From 58295dc8cf4b2e8161a83467ddbcc35792a0a7e4 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Sun, 21 Aug 2022 04:35:13 +0800 Subject: [PATCH 4/9] Update utils.py --- nerfies/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/utils.py b/nerfies/utils.py index b4543f5..d4beefe 100644 --- a/nerfies/utils.py +++ b/nerfies/utils.py @@ -377,7 +377,7 @@ def strided_subset(sequence, count): def tree_collate(list_of_pytrees): """Collates a list of pytrees with the same structure.""" - return tree_util.tree_multimap(lambda *x: np.stack(x), *list_of_pytrees) + return tree_util.tree_map(lambda *x: np.stack(x), *list_of_pytrees) @contextlib.contextmanager From d7fea4542dbc2ccddb509c8febb61cd820ba5460 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 5 Dec 2023 00:01:05 +0800 Subject: [PATCH 5/9] Update schedules.py --- nerfies/schedules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nerfies/schedules.py b/nerfies/schedules.py index e73da25..a51c43d 100644 --- a/nerfies/schedules.py +++ b/nerfies/schedules.py @@ -15,6 +15,7 @@ """Annealing Schedules.""" import abc import collections +from collections.abc import MutableMapping import copy import math from typing import Any, Iterable, List, Tuple, Union From 7e38ed5efdb2f40c4825737fbc0f05fce0c721bb Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 5 Dec 2023 00:03:06 +0800 Subject: [PATCH 6/9] Update schedules.py --- nerfies/schedules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/schedules.py b/nerfies/schedules.py index a51c43d..449fbaa 100644 --- a/nerfies/schedules.py +++ b/nerfies/schedules.py @@ -15,7 +15,7 @@ """Annealing Schedules.""" import abc import collections -from collections.abc import MutableMapping +from collections.abc import Mapping import copy import math from typing import Any, Iterable, List, Tuple, Union From b7ed6e4bb7f59e9b33ac28611c336efa391266a9 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 5 Dec 2023 11:55:55 +0800 Subject: [PATCH 7/9] Update schedules.py --- nerfies/schedules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/schedules.py b/nerfies/schedules.py index 449fbaa..3db7233 100644 --- a/nerfies/schedules.py +++ b/nerfies/schedules.py @@ -39,7 +39,7 @@ def from_config(schedule): return schedule if isinstance(schedule, Tuple) or isinstance(schedule, List): return from_tuple(schedule) - if isinstance(schedule, collections.Mapping): + if isinstance(schedule, collections.abc.Mapping): return from_dict(schedule) raise ValueError(f'Unknown type {type(schedule)}.') From 108b3e335ee83f5d695801647f26dd955ca3ae00 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:04:44 +0800 Subject: [PATCH 8/9] Update schedules.py --- nerfies/schedules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/schedules.py b/nerfies/schedules.py index 3db7233..1e5bc79 100644 --- a/nerfies/schedules.py +++ b/nerfies/schedules.py @@ -39,7 +39,7 @@ def from_config(schedule): return schedule if isinstance(schedule, Tuple) or isinstance(schedule, List): return from_tuple(schedule) - if isinstance(schedule, collections.abc.Mapping): + if isinstance(schedule, Mapping): return from_dict(schedule) raise ValueError(f'Unknown type {type(schedule)}.') From c0a5f04aff0826bfc57eeec2cda4230a8dbb7de4 Mon Sep 17 00:00:00 2001 From: CurtinComputing <44148295+CurtinComputing@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:32:50 +0800 Subject: [PATCH 9/9] Update evaluation.py --- nerfies/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nerfies/evaluation.py b/nerfies/evaluation.py index eed8a0c..a86ab68 100644 --- a/nerfies/evaluation.py +++ b/nerfies/evaluation.py @@ -92,7 +92,7 @@ def render_image( ret_map = jax_utils.unreplicate(model_out[ret_key]) ret_map = jax.tree_map(lambda x: utils.unshard(x, padding), ret_map) ret_maps.append(ret_map) - ret_map = jax.tree_multimap(lambda *x: jnp.concatenate(x, axis=0), *ret_maps) + ret_map = jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *ret_maps) logging.info('Rendering took %.04s', time.time() - start_time) out = {} for key, value in ret_map.items():