Skip to content

Commit

Permalink
Cache models locally.
Browse files Browse the repository at this point in the history
Load models from GitHub (as fallback).
  • Loading branch information
hendriks73 committed Oct 12, 2020
1 parent 4a2efd8 commit 07b04d2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 43 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Changes
- Officially support Python 3.7.
- Enabled GitHub actions for packaging and testing.
- Added Pypi workflow.
- Cache models locally.
- Load models from GitHub.

0.0.4:
- Added support for DeepTemp, DeepSquare, and ShallowTemp models.
Expand Down
68 changes: 35 additions & 33 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,39 +19,41 @@
scripts = glob.glob('bin/*')

# define the models to be included in the PyPI package
package_data = ['models/cnn.h5',
'models/fcn.h5',
'models/ismir2018.h5',
'models/fma2018.h5',
'models/fma2018-meter.h5',
'models/dt_maz_m_fold0.h5',
'models/dt_maz_m_fold1.h5',
'models/dt_maz_m_fold2.h5',
'models/dt_maz_m_fold3.h5',
'models/dt_maz_m_fold4.h5',
'models/dt_maz_v_fold0.h5',
'models/dt_maz_v_fold1.h5',
'models/dt_maz_v_fold2.h5',
'models/dt_maz_v_fold3.h5',
'models/dt_maz_v_fold4.h5',
'models/deepsquare_k1.h5',
'models/deepsquare_k2.h5',
'models/deepsquare_k4.h5',
'models/deepsquare_k8.h5',
'models/deepsquare_k16.h5',
'models/deepsquare_k24.h5',
'models/deeptemp_k2.h5',
'models/deeptemp_k4.h5',
'models/deeptemp_k8.h5',
'models/deeptemp_k16.h5',
'models/deeptemp_k24.h5',
'models/shallowtemp_k1.h5',
'models/shallowtemp_k2.h5',
'models/shallowtemp_k4.h5',
'models/shallowtemp_k6.h5',
'models/shallowtemp_k8.h5',
'models/shallowtemp_k12.h5',
]
# do not package some large models, to stay below PyPI 100mb threshold
package_data = [
'models/cnn.h5',
'models/fcn.h5',
'models/ismir2018.h5',
# 'models/fma2018.h5',
# 'models/fma2018-meter.h5',
'models/dt_maz_m_fold0.h5',
'models/dt_maz_m_fold1.h5',
'models/dt_maz_m_fold2.h5',
'models/dt_maz_m_fold3.h5',
'models/dt_maz_m_fold4.h5',
'models/dt_maz_v_fold0.h5',
'models/dt_maz_v_fold1.h5',
'models/dt_maz_v_fold2.h5',
'models/dt_maz_v_fold3.h5',
'models/dt_maz_v_fold4.h5',
'models/deepsquare_k1.h5',
'models/deepsquare_k2.h5',
'models/deepsquare_k4.h5',
'models/deepsquare_k8.h5',
# 'models/deepsquare_k16.h5',
# 'models/deepsquare_k24.h5',
'models/deeptemp_k2.h5',
'models/deeptemp_k4.h5',
'models/deeptemp_k8.h5',
# 'models/deeptemp_k16.h5',
# 'models/deeptemp_k24.h5',
'models/shallowtemp_k1.h5',
'models/shallowtemp_k2.h5',
'models/shallowtemp_k4.h5',
'models/shallowtemp_k6.h5',
# 'models/shallowtemp_k8.h5',
# 'models/shallowtemp_k12.h5',
]

# requirements
with open('requirements.txt', 'r') as fh:
Expand Down
48 changes: 39 additions & 9 deletions tempocnn/classifier.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# encoding: utf-8

import logging
import os
import pkgutil
import sys
import tempfile
import urllib.request
from pathlib import Path
from urllib.error import HTTPError

import numpy as np
from tensorflow.python.keras.models import load_model
Expand Down Expand Up @@ -83,10 +85,7 @@ def __init__(self, model_name='fcn'):
print('Failed to find a model named \'{}\'. Please check the model name.'.format(model_name),
file=sys.stderr)
raise e
try:
self.model = load_model(file)
finally:
os.remove(file)
self.model = load_model(file)

def estimate(self, data):
"""
Expand Down Expand Up @@ -273,8 +272,39 @@ def _to_model_resource(model_name):


def _extract_from_package(resource):
# check local cache
cache_path = Path(Path.home(), '.tempocnn', resource)
if cache_path.exists():
return str(cache_path)

# ensure cache path exists
cache_path.parent.mkdir(parents=True, exist_ok=True)

data = pkgutil.get_data('tempocnn', resource)
with tempfile.NamedTemporaryFile(prefix='model', suffix='.h5', delete=False) as f:
if not data:
data = _load_model_from_github(resource)

# write to cache
with open(cache_path, 'wb') as f:
f.write(data)
name = f.name
return name

return str(cache_path)


def _load_model_from_github(resource):
url = f"https://raw.githubusercontent.com/hendriks73/tempo-cnn/main/tempocnn/{resource}"
logging.info(f"Attempting to download model file from main branch {url}")
try:
response = urllib.request.urlopen(url)
return response.read()
except HTTPError as e:
# fall back to dev branch
try:
url = f"https://raw.githubusercontent.com/hendriks73/tempo-cnn/dev/tempocnn/{resource}"
logging.info(f"Attempting to download model file from dev branch {url}")
response = urllib.request.urlopen(url)
return response.read()
except Exception:
pass

raise FileNotFoundError(f"Failed to download model from {url}: {type(e).__name__}: {e}")
2 changes: 1 addition & 1 deletion tempocnn/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.4"
__version__ = "0.0.5.dev0"

0 comments on commit 07b04d2

Please sign in to comment.