Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utility plot fixes #65

Merged
merged 1 commit into from
Jun 11, 2024
Merged

Utility plot fixes #65

merged 1 commit into from
Jun 11, 2024

Conversation

ajati
Copy link
Collaborator

@ajati ajati commented Jun 10, 2024

This PR fixes a bug in the plotting utility function plot_preds in the file tsfm_public/models/tinytimemixer/utils/ttm_utils.py related to truncation of context_length for better visualization. It also makes the truncation optional (default: True). The PR also re-runs all the notebooks which use the plot_preds function, and deletes some unused imports.

@vijaye12
Copy link
Collaborator

Looks good @ajati Lets wait for @wgifford to review and merge it.

@wgifford wgifford requested review from wgifford June 11, 2024 16:14
Copy link
Collaborator

@wgifford wgifford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving, but we need to relocate to make these more general purpose (eventually).

@@ -176,7 +176,7 @@ def int_to_bool(value):


# Utitlity: plot
def plot_preds(trainer, dset, plot_dir, num_plots=10, plot_prefix="valid", channel=-1):
def plot_preds(trainer, dset, plot_dir, num_plots=10, plot_prefix="valid", channel=-1, truncate_history=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually these utilities should be moved out of model-specific folders.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have these utilities been moved out to some other location? The below code in ttm_tutorial.py is failing

from tsfm_public.models.tinytimemixer.utils import (
count_parameters,
plot_preds,
)

Here are the errors.
ImportError: cannot import name 'count_parameters' from 'tsfm_public.models.tinytimemixer.utils' (/content/tsfm/tsfm_public/models/tinytimemixer/utils/init.py)

ImportError: cannot import name 'plot_preds' from 'tsfm_public.models.tinytimemixer.utils' (/content/tsfm/tsfm_public/models/tinytimemixer/utils/init.py)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SatishGune They were moved -- let me double check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plot_preds is now plot_predictions. It can be imported with from tsfm_public.toolkit.visualization import plot_predictions. The signature of plot_predictions is now:

def plot_predictions(
    model: torch.nn.Module,
    dset: torch.utils.data.Dataset,
    plot_dir: str = None,
    num_plots: int = 10,
    plot_prefix: str = "valid",
    channel: int = -1,
    truncate_history: bool = True,
)

It takes a model as its first argument instead of a trainer. The ttm_tutorial.ipynb has been updated to reflect these changes.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wgifford . Much appreciated.

@wgifford wgifford merged commit b5bc3ba into ibm-granite:main Jun 11, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants