Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fengwu ghr RES module #118

Merged
merged 19 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Models"""

from .fengwu_ghr.layers import ImageMetaModel, MetaModel
from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
204 changes: 173 additions & 31 deletions graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,67 +88,155 @@ def forward(self, x):


class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim):
def __init__(
self, dim, depth, heads, dim_head, mlp_dim, res=False, image_size=None, scale_factor=None
):
super().__init__()
self.depth = depth
self.res = res
self.norm = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.res_layers = nn.ModuleList([])
for _ in range(self.depth):
self.layers.append(
nn.ModuleList(
[Attention(dim, heads=heads, dim_head=dim_head), FeedForward(dim, mlp_dim)]
)
)
if self.res:
assert (
image_size is not None and scale_factor is not None
), "If res=True, you must provide h, w and scale_factor"
h, w = pair(image_size)
s_h, s_w = pair(scale_factor)
self.res_layers.append(
nn.ModuleList(
[ # reshape to original shape window partition operation
# (b s_h s_w) (h w) d -> b (s_h h) (s_w w) d -> (b h w) (s_h s_w) d
Rearrange(
"(b s_h s_w) (h w) d -> (b h w) (s_h s_w) d",
h=h,
w=w,
s_h=s_h,
s_w=s_w,
),
# TODO ?????
Attention(dim, heads=heads, dim_head=dim_head),
# restore shape
Rearrange(
"(b h w) (s_h s_w) d -> (b s_h s_w) (h w) d",
h=h,
w=w,
s_h=s_h,
s_w=s_w,
),
]
)
)

def forward(self, x):
for attn, ff in self.layers:
for i in range(self.depth):
attn, ff = self.layers[i]
x = attn(x) + x
x = ff(x) + x
if self.res:
reshape, loc_attn, restore = self.res_layers[i]
x = reshape(x)
x = loc_attn(x) + x
x = restore(x)
return self.norm(x)


class ImageMetaModel(nn.Module):
def __init__(self, *, image_size, patch_size, depth, heads, mlp_dim, channels=3, dim_head=64):
def __init__(
self,
*,
image_size,
patch_size,
depth,
heads,
mlp_dim,
channels,
dim_head,
res=False,
scale_factor=None,
**kwargs
):
super().__init__()
image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)

# TODO this can probably be done better
self.image_size = image_size
self.patch_size = patch_size
self.depth = depth
self.heads = heads
self.mlp_dim = mlp_dim
self.channels = channels
self.dim_head = dim_head
self.res = res
self.scale_factor = scale_factor

self.image_height, self.image_width = pair(image_size)
self.patch_height, self.patch_width = pair(patch_size)
s_h, s_w = pair(scale_factor)

if res:
assert scale_factor is not None, "If res=True, you must provide scale_factor"
assert (
image_height % patch_height == 0 and image_width % patch_width == 0
self.image_height % self.patch_height == 0 and self.image_width % self.patch_width == 0
), "Image dimensions must be divisible by the patch size."

patch_dim = channels * patch_height * patch_width
patch_dim = channels * self.patch_height * self.patch_width
dim = patch_dim
self.to_patch_embedding = nn.Sequential(
Rearrange(
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)", p_h=patch_height, p_w=patch_width
"b c (h p_h) (w p_w) -> b (h w) (p_h p_w c)",
p_h=self.patch_height,
p_w=self.patch_width,
),
nn.LayerNorm(patch_dim), # TODO Do we need this?
nn.Linear(patch_dim, dim), # TODO Do we need this?
nn.LayerNorm(dim), # TODO Do we need this?
)

self.pos_embedding = posemb_sincos_2d(
h=image_height // patch_height,
w=image_width // patch_width,
h=self.image_height // self.patch_height,
w=self.image_width // self.patch_width,
dim=dim,
)

self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)
self.transformer = Transformer(
dim,
depth,
heads,
dim_head,
mlp_dim,
res=res,
image_size=(
self.image_height // self.patch_height,
self.image_width // self.patch_width,
),
scale_factor=(s_h, s_w),
)

self.reshaper = nn.Sequential(
Rearrange(
"b (h w) (p_h p_w c) -> b c (h p_h) (w p_w)",
h=image_height // patch_height,
w=image_width // patch_width,
p_h=patch_height,
p_w=patch_width,
h=self.image_height // self.patch_height,
w=self.image_width // self.patch_width,
p_h=self.patch_height,
p_w=self.patch_width,
)
)

def forward(self, x):
device = x.device
dtype = x.dtype

def forward(self, x):
device = x.device
dtype = x.dtype

x = self.to_patch_embedding(x)
x += self.pos_embedding.to(device, dtype=dtype)
x = self.to_patch_embedding(x)
x += self.pos_embedding.to(device, dtype=dtype)

Expand All @@ -158,33 +246,49 @@ def forward(self, x):
return x


class WrapperImageModel(nn.Module):
def __init__(self, image_meta_model: ImageMetaModel, scale_factor):
super().__init__()
s_h, s_w = pair(scale_factor)
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w)

imm_args = vars(image_meta_model)
imm_args.update({"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load_state_dict(image_meta_model.state_dict(), strict=False)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)

def forward(self, x):
x = self.batcher(x)
x = self.image_meta_model(x)
x = self.debatcher(x)
return x


class MetaModel(nn.Module):
def __init__(
self,
lat_lons: list,
*,
image_size,
patch_size,
depth,
heads,
mlp_dim,
image_size=(721, 1440),
channels=3,
channels,
dim_head=64
):
super().__init__()
self.image_size = pair(image_size)
self.i_h, self.i_w = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(
torch.arange(-self.image_size[0] / 2, self.image_size[0] / 2, 1)
/ self.image_size[0]
* 180
).to(torch.long),
(torch.arange(0, self.image_size[1], 1) / self.image_size[1] * 360).to(torch.long),
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)

self.image_model = ImageMetaModel(
self.image_meta_model = ImageMetaModel(
image_size=image_size,
patch_size=patch_size,
depth=depth,
Expand All @@ -199,12 +303,50 @@ def forward(self, x):

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(
x, "(w h) (b c) -> b c w h", b=b, c=c, w=self.image_size[0], h=self.image_size[1]
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)
x = self.image_meta_model(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)
return x


class WrapperMetaModel(nn.Module):
def __init__(self, lat_lons: list, meta_model: MetaModel, scale_factor):
super().__init__()
s_h, s_w = pair(scale_factor)
self.i_h, self.i_w = meta_model.i_h * s_h, meta_model.i_w * s_w
self.pos_x = torch.tensor(lat_lons)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
)
x = self.image_model(x)

x = rearrange(x, "b c w h -> (w h) (b c)")
self.batcher = Rearrange("b c (h s_h) (w s_w) -> (b s_h s_w) c h w", s_h=s_h, s_w=s_w)

imm_args = vars(meta_model.image_meta_model)
imm_args.update({"res": True, "scale_factor": scale_factor})
self.image_meta_model = ImageMetaModel(**imm_args)
self.image_meta_model.load_state_dict(
meta_model.image_meta_model.state_dict(), strict=False
)

self.debatcher = Rearrange("(b s_h s_w) c h w -> b c (h s_h) (w s_w)", s_h=s_h, s_w=s_w)

def forward(self, x):
b, n, c = x.shape

x = rearrange(x, "b n c -> n (b c)")
x = knn_interpolate(x, self.pos_x, self.pos_y)
x = rearrange(x, "(h w) (b c) -> b c h w", b=b, c=c, h=self.i_h, w=self.i_w)

x = self.batcher(x)
x = self.image_meta_model(x)
x = self.debatcher(x)

x = rearrange(x, "b c h w -> (h w) (b c)")
x = knn_interpolate(x, self.pos_y, self.pos_x)
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x
Loading
Loading