Skip to content

Commit

Permalink
condition and supervise
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 13, 2025
1 parent d65c3e9 commit 0a32e0b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
17 changes: 16 additions & 1 deletion ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,12 +922,27 @@ def plot_reconstructions(
plt.axis("off")
plt.tight_layout()
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
figure_path = os.path.join(prefix, f'diffusion_image_generations_{now_string}{IMAGE_EXT}')
figure_path = os.path.join(prefix, f'diffusion_image_reconstructions_{now_string}{IMAGE_EXT}')
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()
plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0), dpi=300)
for row in range(num_rows):
for col in range(num_cols):
index = row * num_cols + col
plt.subplot(num_rows, num_cols, index + 1)
plt.imshow(images[index], cmap='gray')
plt.axis("off")
plt.tight_layout()
now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
figure_path = os.path.join(prefix, f'input_images_{now_string}{IMAGE_EXT}')
if not os.path.exists(os.path.dirname(figure_path)):
os.makedirs(os.path.dirname(figure_path))
plt.savefig(figure_path, bbox_inches="tight")
plt.close()


def control_plot_images(
self, control_batch, epoch=None, logs=None, num_rows=2, num_cols=8, reseed=None,
renoise=None,
Expand Down
2 changes: 0 additions & 2 deletions ml4h/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,6 @@ def train_diffusion_control_model(args, supervised=False):
data, labels, paths = big_batch_from_minibatch_generator(generate_test, 1)

model.plot_reconstructions((data, labels), prefix=f'{args.output_folder}/{args.id}/')
images = data[args.tensor_maps_in[0].input_name()]
predictions_to_pngs(images, args.tensor_maps_in, args.tensor_maps_in, data, labels, paths, '{args.output_folder}/{args.id}/')
interpolate_controlled_generations(model, args.tensor_maps_out, args.tensor_maps_out[0], args.batch_size,
f'{args.output_folder}/{args.id}/')
if model.input_map.axes() == 2:
Expand Down

0 comments on commit 0a32e0b

Please sign in to comment.