-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utility plot fixes #65
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approving, but we need to relocate to make these more general purpose (eventually).
@@ -176,7 +176,7 @@ def int_to_bool(value): | |||
|
|||
|
|||
# Utitlity: plot | |||
def plot_preds(trainer, dset, plot_dir, num_plots=10, plot_prefix="valid", channel=-1): | |||
def plot_preds(trainer, dset, plot_dir, num_plots=10, plot_prefix="valid", channel=-1, truncate_history=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually these utilities should be moved out of model-specific folders.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have these utilities been moved out to some other location? The below code in ttm_tutorial.py is failing
from tsfm_public.models.tinytimemixer.utils import (
count_parameters,
plot_preds,
)
Here are the errors.
ImportError: cannot import name 'count_parameters' from 'tsfm_public.models.tinytimemixer.utils' (/content/tsfm/tsfm_public/models/tinytimemixer/utils/init.py)
ImportError: cannot import name 'plot_preds' from 'tsfm_public.models.tinytimemixer.utils' (/content/tsfm/tsfm_public/models/tinytimemixer/utils/init.py)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SatishGune They were moved -- let me double check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
plot_preds
is now plot_predictions
. It can be imported with from tsfm_public.toolkit.visualization import plot_predictions
. The signature of plot_predictions
is now:
def plot_predictions(
model: torch.nn.Module,
dset: torch.utils.data.Dataset,
plot_dir: str = None,
num_plots: int = 10,
plot_prefix: str = "valid",
channel: int = -1,
truncate_history: bool = True,
)
It takes a model as its first argument instead of a trainer. The ttm_tutorial.ipynb
has been updated to reflect these changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @wgifford . Much appreciated.
This PR fixes a bug in the plotting utility function
plot_preds
in the filetsfm_public/models/tinytimemixer/utils/ttm_utils.py
related to truncation ofcontext_length
for better visualization. It also makes the truncation optional (default: True). The PR also re-runs all the notebooks which use theplot_preds
function, and deletes some unused imports.