Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

typos #1

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ datasets
Once the dataset is ready, new models can be trained with the following commands. For example,

```
python train.py --config configs/youtube-vos.jon --model sttn
python train.py --config configs/youtube-vos.json --model sttn
```

<!-- ---------------------------------------------- -->
## Testing

Testing is similar to [Completing Videos Using Pretrained Model](https://github.com/researchmm/STTN#completing_videos_using_rpetrained_model).
Testing is similar to [Completing Videos Using Pretrained Model](https://github.com/researchmm/STTN#completing-videos-using-pretrained-model).

```
python test.py --video examples/schoolgirls_orig.mp4 --mask examples/schoolgirls --ckpt checkpoints/sttn.pth
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ dependencies:
- torch==1.1.0
- torchvision==0.3.0
- tornado==6.0.4
- tqdm==4.49.0
- traitlets==4.3.3
- typing-extensions==3.7.4.2
- urllib3==1.25.9
Expand Down
16 changes: 8 additions & 8 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
ToTorchFormatTensor()])


# sample reference frames from the whole video
# sample reference frames from the whole video
def get_ref_index(neighbor_ids, length):
ref_index = []
for i in range(0, length, ref_length):
Expand All @@ -56,12 +56,12 @@ def get_ref_index(neighbor_ids, length):
return ref_index


# read frame-wise masks
# read frame-wise masks
def read_mask(mpath):
masks = []
mnames = os.listdir(mpath)
mnames.sort()
for m in mnames:
for m in mnames:
m = Image.open(os.path.join(mpath, m))
m = m.resize((w, h), Image.NEAREST)
m = np.array(m.convert('L'))
Expand All @@ -72,7 +72,7 @@ def read_mask(mpath):
return masks


# read frames from video
# read frames from video
def read_frame_from_videos(vname):
frames = []
vidcap = cv2.VideoCapture(vname)
Expand All @@ -83,12 +83,12 @@ def read_frame_from_videos(vname):
frames.append(image.resize((w,h)))
success, image = vidcap.read()
count += 1
return frames
return frames


def main_worker():
# set up models
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
# set up models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = importlib.import_module('model.' + args.model)
model = net.InpaintGenerator().to(device)
model_path = args.ckpt
Expand All @@ -97,7 +97,7 @@ def main_worker():
print('loading from: {}'.format(args.ckpt))
model.eval()

# prepare datset, encode all frames into deep space
# prepare dataset, encode all frames into deep space
frames = read_frame_from_videos(args.video)
video_length = len(frames)
feats = _to_tensors(frames).unsqueeze(0)*2-1
Expand Down