Skip to content

Commit

Permalink
Add checkpoint converter for our own format
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 3, 2024
1 parent ef22911 commit c93a07f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
3 changes: 2 additions & 1 deletion darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
run_native_sentinel2_pipeline,
run_native_sentinel2_pipeline_fast,
)
from darts.training import preprocess_s2_train_data, train_smp
from darts.training import convert_lightning_checkpoint, preprocess_s2_train_data, train_smp
from darts.utils.config import ConfigParser
from darts.utils.logging import add_logging_handlers, setup_logging

Expand Down Expand Up @@ -71,6 +71,7 @@ def env_info():

app.command(group=train_group)(preprocess_s2_train_data)
app.command(group=train_group)(train_smp)
app.command(group=train_group)(convert_lightning_checkpoint)


# Custom wrapper for the create_arcticdem_vrt function, which dodges the loading of all the heavy modules
Expand Down
53 changes: 53 additions & 0 deletions darts/src/darts/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import multiprocessing as mp
from datetime import datetime
from math import ceil, sqrt
from pathlib import Path
from typing import Literal
Expand Down Expand Up @@ -280,3 +281,55 @@ def train_smp(
check_val_every_n_epoch=3,
)
trainer.fit(model, datamodule)

# TODO: save with own config etc.
# Add timestamp


def convert_lightning_checkpoint(
*,
lightning_checkpoint: Path,
out_directory: Path,
checkpoint_name: str,
framework: str = "smp",
):
"""Convert a lightning checkpoint to our own format.
The final checkpoint will contain the model configuration and the state dict.
It will be saved to:
```python
out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
```
Args:
lightning_checkpoint (Path): Path to the lightning checkpoint.
out_directory (Path): Output directory for the converted checkpoint.
checkpoint_name (str): A unique name of the new checkpoint.
framework (str, optional): The framework used for the model. Defaults to "smp".
"""
import torch

logger.debug(f"Loading checkpoint from {lightning_checkpoint.resolve()}")
lckpt = torch.load(lightning_checkpoint, weights_only=False)

now = datetime.now()
formatted_date = now.strftime("%Y-%m-%d")
config = lckpt["hyper_parameters"]
config["time"] = formatted_date
config["name"] = checkpoint_name
config["model_framework"] = framework

own_ckpt = {
"config": config,
"statedict": lckpt["state_dict"],
}

out_directory.mkdir(exist_ok=True, parents=True)

out_checkpoint = out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

torch.save(own_ckpt, out_checkpoint)

logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}")

0 comments on commit c93a07f

Please sign in to comment.