forked from coleifer/peewee
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpwiz.py
executable file
·184 lines (150 loc) · 6.18 KB
/
pwiz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
#!/usr/bin/env python
import datetime
from optparse import OptionParser
import sys
from peewee import *
from peewee import print_
from peewee import __version__ as peewee_version
from playhouse.reflection import *
TEMPLATE = """from peewee import *
database = %s('%s', **%s)
class UnknownField(object):
pass
class BaseModel(Model):
class Meta:
database = database
"""
DATABASE_ALIASES = {
MySQLDatabase: ['mysql', 'mysqldb'],
PostgresqlDatabase: ['postgres', 'postgresql'],
SqliteDatabase: ['sqlite', 'sqlite3'],
}
DATABASE_MAP = dict((value, key)
for key in DATABASE_ALIASES
for value in DATABASE_ALIASES[key])
def make_introspector(database_type, database_name, **kwargs):
if database_type not in DATABASE_MAP:
err('Unrecognized database, must be one of: %s' %
', '.join(DATABASE_MAP.keys()))
sys.exit(1)
schema = kwargs.pop('schema', None)
DatabaseClass = DATABASE_MAP[database_type]
db = DatabaseClass(database_name, **kwargs)
return Introspector.from_database(db, schema=schema)
def print_models(introspector, tables=None, preserve_order=False):
database = introspector.introspect()
print_(TEMPLATE % (
introspector.get_database_class().__name__,
introspector.get_database_name(),
repr(introspector.get_database_kwargs())))
def _print_table(table, seen, accum=None):
accum = accum or []
foreign_keys = database.foreign_keys[table]
for foreign_key in foreign_keys:
dest = foreign_key.dest_table
# In the event the destination table has already been pushed
# for printing, then we have a reference cycle.
if dest in accum and table not in accum:
print_('# Possible reference cycle: %s' % dest)
# If this is not a self-referential foreign key, and we have
# not already processed the destination table, do so now.
if dest not in seen and dest not in accum:
seen.add(dest)
if dest != table:
_print_table(dest, seen, accum + [table])
print_('class %s(BaseModel):' % database.model_names[table])
columns = database.columns[table].items()
if not preserve_order:
columns = sorted(columns)
primary_keys = database.primary_keys[table]
for name, column in columns:
skip = all([
name in primary_keys,
name == 'id',
len(primary_keys) == 1,
column.field_class in introspector.pk_classes])
if skip:
continue
if column.primary_key and len(primary_keys) > 1:
# If we have a CompositeKey, then we do not want to explicitly
# mark the columns as being primary keys.
column.primary_key = False
print_(' %s' % column.get_field())
print_('')
print_(' class Meta:')
print_(' db_table = \'%s\'' % table)
if introspector.schema:
print_(' schema = \'%s\'' % introspector.schema)
if len(primary_keys) > 1:
pk_field_names = sorted([
field.name for col, field in columns
if col in primary_keys])
pk_list = ', '.join("'%s'" % pk for pk in pk_field_names)
print_(' primary_key = CompositeKey(%s)' % pk_list)
print_('')
seen.add(table)
seen = set()
for table in sorted(database.model_names.keys()):
if table not in seen:
if not tables or table in tables:
_print_table(table, seen)
def print_header(cmd_line, introspector):
timestamp = datetime.datetime.now()
print_('# Code generated by:')
print_('# python -m pwiz %s' % cmd_line)
print_('# Date: %s' % timestamp.strftime('%B %d, %Y %I:%M%p'))
print_('# Database: %s' % introspector.get_database_name())
print_('# Peewee version: %s' % peewee_version)
print_('')
def err(msg):
sys.stderr.write('\033[91m%s\033[0m\n' % msg)
sys.stderr.flush()
def get_option_parser():
parser = OptionParser(usage='usage: %prog [options] database_name')
ao = parser.add_option
ao('-H', '--host', dest='host')
ao('-p', '--port', dest='port', type='int')
ao('-u', '--user', dest='user')
ao('-P', '--password', dest='password')
engines = sorted(DATABASE_MAP)
ao('-e', '--engine', dest='engine', default='postgresql', choices=engines,
help=('Database type, e.g. sqlite, mysql or postgresql. Default '
'is "postgresql".'))
ao('-s', '--schema', dest='schema')
ao('-t', '--tables', dest='tables',
help=('Only generate the specified tables. Multiple table names should '
'be separated by commas.'))
ao('-i', '--info', dest='info', action='store_true',
help=('Add database information and other metadata to top of the '
'generated file.'))
ao('-o', '--preserve-order', action='store_true', dest='preserve_order',
help='Model definition column ordering matches source table.')
return parser
def get_connect_kwargs(options):
ops = ('host', 'port', 'user', 'password', 'schema')
return dict((o, getattr(options, o)) for o in ops if getattr(options, o))
if __name__ == '__main__':
raw_argv = sys.argv
parser = get_option_parser()
options, args = parser.parse_args()
if options.preserve_order:
try:
from collections import OrderedDict
except ImportError:
err('Preserve order requires Python >= 2.7.')
sys.exit(1)
if len(args) < 1:
err('Missing required parameter "database"')
parser.print_help()
sys.exit(1)
connect = get_connect_kwargs(options)
database = args[-1]
tables = None
if options.tables:
tables = [table.strip() for table in options.tables.split(',')
if table.strip()]
introspector = make_introspector(options.engine, database, **connect)
if options.info:
cmd_line = ' '.join(raw_argv[1:])
print_header(cmd_line, introspector)
print_models(introspector, tables, preserve_order=options.preserve_order)