Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

render rollout merge #87

Merged
merged 26 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
18f3b7c
render rollout merge
Naveen-Raj-M Sep 9, 2024
e22406a
Update config.yaml
Naveen-Raj-M Sep 24, 2024
23bfe0c
Merge branch 'v2' of https://github.com/geoelements/gns into v2
Naveen-Raj-M Oct 1, 2024
8d5d7b4
Added unit test for render-rollout merge
Naveen-Raj-M Oct 1, 2024
2f9ab01
Merge branch 'v2' of https://github.com/Naveen-Raj-M/gns into v2
Naveen-Raj-M Oct 1, 2024
5146e57
deleted debugging fixtures
Naveen-Raj-M Oct 2, 2024
a115cb0
update config
Naveen-Raj-M Oct 4, 2024
2bcf615
add test for VTK rendering
Naveen-Raj-M Oct 4, 2024
cdf6f45
bug fix for NoneType material property
Naveen-Raj-M Oct 4, 2024
f9a2a9e
remove test_rendering and temp directory
Naveen-Raj-M Oct 4, 2024
57fad31
add test for vtk rendering
Naveen-Raj-M Oct 5, 2024
b387599
modify config for render_rollout merge
Naveen-Raj-M Oct 5, 2024
db3aec3
update to merge render-rollout
Naveen-Raj-M Oct 5, 2024
74b1e43
set default mode to gif
Naveen-Raj-M Oct 5, 2024
e08fcf5
rewrite 'rendering' function in an extensible way
Naveen-Raj-M Oct 10, 2024
3b88f65
improve readability and consistency
Naveen-Raj-M Oct 10, 2024
9fbafbb
update rendering options
Naveen-Raj-M Oct 10, 2024
81bc3f0
run black
Oct 11, 2024
f82c674
minor fix on viewpoint_rotation type
Oct 12, 2024
a861ec4
improve logging and reformat with black
Oct 12, 2024
7110352
refactor: move n_files function to a separate count_n_files.py in uti…
Oct 12, 2024
a86e58f
rename count_n_files.py to file_utils.py
Naveen-Raj-M Oct 13, 2024
b72770a
minor fix on module import
Naveen-Raj-M Oct 14, 2024
94c8519
run black
Naveen-Raj-M Oct 20, 2024
037b020
minor fix on raising error
Naveen-Raj-M Oct 22, 2024
ef41a87
add package for reading vtk files
Naveen-Raj-M Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,10 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/

# Rendering configuration
rendering:
render: True
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved
mode: gif
kks32 marked this conversation as resolved.
Show resolved Hide resolved
step_stride: 3
change_yz: False
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 16 additions & 0 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gns import reading_utils
from gns import particle_data_loader as pdl
from gns import distribute
from gns import render_rollout
from gns.args import Config

Stats = collections.namedtuple("Stats", ["mean", "std"])
Expand Down Expand Up @@ -200,9 +201,24 @@ def predict(device: str, cfg: DictConfig):
example_rollout["metadata"] = metadata
example_rollout["loss"] = loss.mean()
filename = f"{cfg.output.filename}_ex{example_i}.pkl"
filename_render = f"{cfg.output.filename}_ex{example_i}"
filename = os.path.join(cfg.output.path, filename)
with open(filename, "wb") as f:
pickle.dump(example_rollout, f)
if cfg.rendering.render:
render = render_rollout.Render(
input_dir=cfg.output.path, input_name=filename_render
)
if cfg.rendering.mode == "gif":
render.render_gif_animation(
point_size=1,
timestep_stride=cfg.rendering.step_stride,
vertical_camera_angle=20,
viewpoint_rotation=0.3,
change_yz=cfg.rendering.change_yz,
)
elif cfg.rendering.mode == "vtk":
render.write_vtk()

print(
"Mean loss on rollout prediction: {}".format(torch.mean(torch.cat(eval_loss)))
Expand Down