Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

render rollout merge #87

Merged
merged 26 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
18f3b7c
render rollout merge
Naveen-Raj-M Sep 9, 2024
e22406a
Update config.yaml
Naveen-Raj-M Sep 24, 2024
23bfe0c
Merge branch 'v2' of https://github.com/geoelements/gns into v2
Naveen-Raj-M Oct 1, 2024
8d5d7b4
Added unit test for render-rollout merge
Naveen-Raj-M Oct 1, 2024
2f9ab01
Merge branch 'v2' of https://github.com/Naveen-Raj-M/gns into v2
Naveen-Raj-M Oct 1, 2024
5146e57
deleted debugging fixtures
Naveen-Raj-M Oct 2, 2024
a115cb0
update config
Naveen-Raj-M Oct 4, 2024
2bcf615
add test for VTK rendering
Naveen-Raj-M Oct 4, 2024
cdf6f45
bug fix for NoneType material property
Naveen-Raj-M Oct 4, 2024
f9a2a9e
remove test_rendering and temp directory
Naveen-Raj-M Oct 4, 2024
57fad31
add test for vtk rendering
Naveen-Raj-M Oct 5, 2024
b387599
modify config for render_rollout merge
Naveen-Raj-M Oct 5, 2024
db3aec3
update to merge render-rollout
Naveen-Raj-M Oct 5, 2024
74b1e43
set default mode to gif
Naveen-Raj-M Oct 5, 2024
e08fcf5
rewrite 'rendering' function in an extensible way
Naveen-Raj-M Oct 10, 2024
3b88f65
improve readability and consistency
Naveen-Raj-M Oct 10, 2024
9fbafbb
update rendering options
Naveen-Raj-M Oct 10, 2024
81bc3f0
run black
Oct 11, 2024
f82c674
minor fix on viewpoint_rotation type
Oct 12, 2024
a861ec4
improve logging and reformat with black
Oct 12, 2024
7110352
refactor: move n_files function to a separate count_n_files.py in uti…
Oct 12, 2024
a86e58f
rename count_n_files.py to file_utils.py
Naveen-Raj-M Oct 13, 2024
b72770a
minor fix on module import
Naveen-Raj-M Oct 14, 2024
94c8519
run black
Naveen-Raj-M Oct 20, 2024
037b020
minor fix on raising error
Naveen-Raj-M Oct 22, 2024
ef41a87
add package for reading vtk files
Naveen-Raj-M Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,18 @@ python3 -m gns.train mode="train" training.resume=True
python3 -m meshnet.train mode="train" training.resume=True
```

> Rollout prediction
> Rollout prediction and render
```shell
# For particulate domain,
python3 -m gns.train mode="rollout"
# For mesh-based domain,
python3 -m meshnet.train mode="rollout"
```
To choose not to render after rollout prediction, add option `rendering.mode=null`.

By default the renderer writes `.gif` file.

In particulate domain, the renderer also writes `.vtu` files to visualize in ParaView.
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved

> Render
```shell
Expand All @@ -50,8 +55,6 @@ python3 -m gns.render_rollout --output_mode="gif" --rollout_dir="<path-containin
python3 -m gns.render --rollout_dir="<path-containing-rollout-file>" --rollout_name="<name-of-rollout-file>"
```

In particulate domain, the renderer also writes `.vtu` files to visualize in ParaView.

![Sand rollout](docs/img/rollout_0.gif)
> GNS prediction of Sand rollout after training for 2 million steps.

Expand Down Expand Up @@ -118,8 +121,16 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/
```

# Rendering configuration
rendering:
mode: gif

gif:
step_stride: 3
vertical_camera_angle: 20
viewpoint_rotation: 0.3
change_yz: False
</details>


Expand Down Expand Up @@ -180,7 +191,6 @@ The total number of training steps to execute before stopping.
**nsave_steps (Integer)**

Interval at which the model and training state are saved.

</details>

## Datasets
Expand Down
10 changes: 10 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ hardware:
# Logging configuration
logging:
tensorboard_dir: logs/

# Rendering configuration
rendering:
mode: gif
kks32 marked this conversation as resolved.
Show resolved Hide resolved

gif:
step_stride: 3
vertical_camera_angle: 20
viewpoint_rotation: 0.3
change_yz: False
14 changes: 13 additions & 1 deletion gns/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ class HardwareConfig:
class LoggingConfig:
tensorboard_dir: str = "logs/"

@dataclass
class GifConfig:
step_stride: int = 3
vertical_camera_angle: int = 20
viewpoint_rotation: int = 0.3
change_yz: bool = False

@dataclass
class RenderingConfig:
mode: Optional[str] = field(default='gif')
gif: GifConfig = field(default_factory=GifConfig)

@dataclass
class Config:
Expand All @@ -62,8 +73,9 @@ class Config:
training: TrainingConfig = field(default_factory=TrainingConfig)
hardware: HardwareConfig = field(default_factory=HardwareConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
rendering: RenderingConfig = field(default_factory=RenderingConfig)


# Hydra configuration
cs = ConfigStore.instance()
cs.store(name="base_config", node=Config)
cs.store(name="base_config", node=Config)
kks32 marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion gns/render_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def write_vtk(self):
}

# Check if material property exists and add it to data if it does
if "material_property" in self.rollout_data:
if "material_property" in self.rollout_data and self.rollout_data['material_property'] is not None:
material_property = self.rollout_data["material_property"]
data["material_property"] = material_property

Expand Down
21 changes: 19 additions & 2 deletions gns/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from gns import reading_utils
from gns import particle_data_loader as pdl
from gns import distribute
from gns import render_rollout
from gns.args import Config

Stats = collections.namedtuple("Stats", ["mean", "std"])
Expand Down Expand Up @@ -200,14 +201,30 @@ def predict(device: str, cfg: DictConfig):
example_rollout["metadata"] = metadata
example_rollout["loss"] = loss.mean()
filename = f"{cfg.output.filename}_ex{example_i}.pkl"
filename = os.path.join(cfg.output.path, filename)
filename_render = f"{cfg.output.filename}_ex{example_i}"
filename = os.path.join(cfg.output.path, filename_render)
with open(filename, "wb") as f:
pickle.dump(example_rollout, f)
if cfg.rendering.mode:
rendering(cfg.output.path, filename_render, cfg)

print(
"Mean loss on rollout prediction: {}".format(torch.mean(torch.cat(eval_loss)))
)

def rendering(input_dir, input_name, cfg: DictConfig):
kks32 marked this conversation as resolved.
Show resolved Hide resolved
render = render_rollout.Render(input_dir, input_name)
kks32 marked this conversation as resolved.
Show resolved Hide resolved

if cfg.rendering.mode == "gif":
render.render_gif_animation(
point_size=1,
timestep_stride=cfg.rendering.gif.step_stride,
vertical_camera_angle=cfg.rendering.gif.vertical_camera_angle,
viewpoint_rotation=cfg.rendering.gif.viewpoint_rotation,
change_yz=cfg.rendering.gif.change_yz,
)
elif cfg.rendering.mode == "vtk":
render.write_vtk()

def optimizer_to(optim, device):
for param in optim.state.values():
Expand Down Expand Up @@ -850,4 +867,4 @@ def main(cfg: Config):


if __name__ == "__main__":
main()
main()
190 changes: 190 additions & 0 deletions test/test_vtk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
import pytest
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import pickle
import os
import tempfile
import shutil
import pyvista as pv
from gns.train import rendering
from omegaconf import DictConfig

@pytest.fixture
def cfg_vtk():
"""
Fixture for VTK configuration.

Returns:
DictConfig: Configuration dictionary for VTK rendering mode.
"""
return DictConfig({
'rendering': {
'mode': 'vtk'
}
})

@pytest.fixture
def temp_dir():
"""
Fixture for creating and cleaning up a temporary directory.

Yields:
str: Path to the temporary directory.
"""
directory = tempfile.mkdtemp()
yield directory
shutil.rmtree(directory)

@pytest.fixture
def dummy_pkl_data(temp_dir):
"""
kks32 marked this conversation as resolved.
Show resolved Hide resolved
Fixture for generating dummy pickle data for testing.

Args:
temp_dir (str): Path to the temporary directory.

Returns:
tuple: Path to the temporary directory and the pickle file name.
"""
# Define parameters for simulation
n_timesteps = 2
n_particles = 3
dim = 2
n_init_pos = 2

# Generate random predictions and ground truth positions
predictions = np.random.rand(n_timesteps, n_particles, dim)
ground_truth_positions = np.random.randn(n_timesteps, n_particles, dim)
loss = (predictions - ground_truth_positions)**2

# Rollout dictionary to store all relevant information
dummy_rollout = {
"initial_positions": np.random.rand(n_init_pos, n_particles, dim),
"predicted_rollout": predictions,
"ground_truth_rollout": ground_truth_positions,
"particle_types": np.full(n_particles, 5)
}

# Metadata for the simulation
metadata = {
"bounds": [[0.0, 1.0], [0.0, 1.0]]
}

dummy_rollout['metadata'] = metadata
dummy_rollout['loss'] = loss.mean()
pkl_file_name = "test_input_file.pkl"
pkl_file_path = os.path.join(temp_dir, pkl_file_name)
with open(pkl_file_path, "wb") as f:
pickle.dump(dummy_rollout, f)
temp_dir = temp_dir + '/'
pkl_file_name = "test_input_file"

return temp_dir, pkl_file_name
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved

def n_files(dir, extension):
Naveen-Raj-M marked this conversation as resolved.
Show resolved Hide resolved
"""
Count the number of files with a specific extension in a directory.

Args:
dir (str): Directory path.
extension (str): File extension to count.

Returns:
int: Number of files with the specified extension.
"""
files = os.listdir(dir)
each_file = []

for file in files:
if file.endswith(extension):
each_file.append(file)

return len(each_file)

def test_rendering_vtk(dummy_pkl_data, cfg_vtk):
"""
Test the VTK rendering function.

Args:
dummy_pkl_data (tuple): Tuple containing the path to the temporary directory and the pickle file name.
cfg_vtk (DictConfig): Configuration dictionary for VTK rendering mode.
"""
input_dir, input_file = dummy_pkl_data

rendering(input_dir, input_file, cfg_vtk)

# Define paths for the generated VTK files
vtk_path_gns = os.path.join(input_dir, f"{input_file}_vtk-GNS")
vtk_path_reality = os.path.join(input_dir, f"{input_file}_vtk-Reality")

with open(f"{input_dir}{input_file}.pkl", "rb") as file:
rollout = pickle.load(file)

# Concatenate initial positions and rollout positions
positions_gns = np.concatenate(
[rollout["initial_positions"], rollout["predicted_rollout"]],
axis=0,
)
positions_reality = np.concatenate(
[rollout["initial_positions"], rollout["ground_truth_rollout"]],
axis=0,
)

# Count the number of .vtu and .vtr files in the VTK directories
n_vtu_files_gns = n_files(vtk_path_gns, 'vtu')
n_vtu_files_reality = n_files(vtk_path_reality, 'vtu')
n_vtr_files_gns = n_files(vtk_path_gns, 'vtr')
n_vtr_files_reality = n_files(vtk_path_reality, 'vtr')

# Assert that the number of .vtu and .vtr files matches the expected count
assert n_vtu_files_gns == (positions_gns.shape[0])
assert n_vtu_files_reality == (positions_reality.shape[0])
assert n_vtr_files_gns == (positions_gns.shape[0])
assert n_vtr_files_reality == (positions_reality.shape[0])

# Verify the contents of the generated VTK files at each time step for GNS
for time_step in range(positions_gns.shape[0]):
vtu = os.path.join(vtk_path_gns, f"points{time_step}.vtu")
vtu_object = pv.read(vtu)
displacement = vtu_object['displacement']
particle_type = vtu_object['particle_type']
color_map = vtu_object['color']

assert np.all(displacement == np.linalg.norm(positions_gns[0] - positions_gns[time_step], axis=1))
assert np.all(particle_type == rollout['particle_types'])
assert np.all(color_map == rollout['particle_types'])

vtr = os.path.join(vtk_path_gns, f"boundary{time_step}.vtr")
vtr_object = pv.read(vtr)

bounds = vtr_object.bounds

xmin, xmax, ymin, ymax, zmin, zmax = bounds

assert xmin == rollout['metadata']['bounds'][0][0]
assert xmax == rollout['metadata']['bounds'][0][1]
assert ymin == rollout['metadata']['bounds'][1][0]
assert ymax == rollout['metadata']['bounds'][1][1]

# Verify the contents of the generated VTK files at each time step for reality
for time_step in range(positions_reality.shape[0]):
vtu = os.path.join(vtk_path_reality, f"points{time_step}.vtu")
vtu_object = pv.read(vtu)
displacement = vtu_object['displacement']
particle_type = vtu_object['particle_type']
color_map = vtu_object['color']

assert np.all(displacement == np.linalg.norm(positions_reality[0] - positions_reality[time_step], axis=1))
assert np.all(particle_type == rollout['particle_types'])
assert np.all(color_map == rollout['particle_types'])

vtr = os.path.join(vtk_path_reality, f"boundary{time_step}.vtr")
vtr_object = pv.read(vtr)

bounds = vtr_object.bounds

xmin, xmax, ymin, ymax, zmin, zmax = bounds

assert xmin == rollout['metadata']['bounds'][0][0]
assert xmax == rollout['metadata']['bounds'][0][1]
assert ymin == rollout['metadata']['bounds'][1][0]
assert ymax == rollout['metadata']['bounds'][1][1]