-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodels.py
46 lines (37 loc) · 1.38 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torchvision.models as models
import numpy, h5py
from inpainting.datasets import *
from inpainting.models import *
class NIMA_vgg16(nn.Module):
"""Neural IMage Assessment model by Google"""
def __init__(self, num_classes=10):
super(NIMA_vgg16, self).__init__()
base_model = models.vgg16(pretrained=True)
self.features = base_model.features
self.classifier = nn.Sequential(
nn.Dropout(p=0.75),
nn.Linear(in_features=25088, out_features=num_classes),
nn.Softmax())
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
class inpainting_D_AVA(nn.Module):
def __init__(self, num_classes=10):
super(inpainting_D_AVA, self).__init__()
discriminator = Discriminator(channels=3)
discriminator.load_state_dict(torch.load("inpainting-pretrained-weights/inpainting-FPP-discriminator.pkl"))
self.features = discriminator
self.classifier = nn.Sequential(
nn.Dropout(p=0.75),
nn.Linear(in_features=784, out_features=num_classes),
nn.Softmax())
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out