diff --git a/www/orm.py b/www/orm.py index 93ea5b4..3ec939a 100644 --- a/www/orm.py +++ b/www/orm.py @@ -7,8 +7,8 @@ import aiomysql -def log(sql, args=()): - logging.info('SQL: %s' % sql) +def log(sql, args=[]): + logging.info('SQL: [%s] args: %s' % (sql, args)) async def create_pool(loop, **kw): logging.info('create database connection pool...') @@ -40,7 +40,7 @@ async def select(sql, args, size=None): return rs async def execute(sql, args, autocommit=True): - log(sql) + log(sql, args) async with __pool.get() as conn: if not autocommit: await conn.begin() @@ -56,12 +56,6 @@ async def execute(sql, args, autocommit=True): raise return affected -def create_args_string(num): - L = [] - for n in range(num): - L.append('?') - return ', '.join(L) - class Field(object): def __init__(self, name, column_type, primary_key, default): @@ -103,34 +97,32 @@ class ModelMetaclass(type): def __new__(cls, name, bases, attrs): if name=='Model': return type.__new__(cls, name, bases, attrs) - tableName = attrs.get('__table__', None) or name + tableName = attrs.get('__table__', name) logging.info('found model: %s (table: %s)' % (name, tableName)) mappings = dict() - fields = [] + escaped_fields = [] primaryKey = None - for k, v in attrs.items(): + for k, v in attrs.copy().items(): if isinstance(v, Field): logging.info(' found mapping: %s ==> %s' % (k, v)) - mappings[k] = v + mappings[k] = attrs.pop(k) if v.primary_key: # 找到主键: if primaryKey: raise StandardError('Duplicate primary key for field: %s' % k) primaryKey = k else: - fields.append(k) + escaped_fields.append(k) if not primaryKey: raise StandardError('Primary key not found.') - for k in mappings.keys(): - attrs.pop(k) - escaped_fields = list(map(lambda f: '`%s`' % f, fields)) + attrs['__mappings__'] = mappings # 保存属性和列的映射关系 attrs['__table__'] = tableName attrs['__primary_key__'] = primaryKey # 主键属性名 - attrs['__fields__'] = fields # 除主键外的属性名 - attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName) - attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1)) - attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey) + attrs['__fields__'] = escaped_fields + [primaryKey] # 全部属性名,主键一定在是最后 + attrs['__select__'] = 'select * from `%s`' % (tableName) + attrs['__insert__'] = 'insert into `%s` (%s) values (%s)' % (tableName, ', '.join('`%s`' % f for f in attrs['__fields__']), ', '.join('?' * len(mappings))) + attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join('`%s`=?' % f for f in escaped_fields), primaryKey) attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey) return type.__new__(cls, name, bases, attrs) @@ -148,9 +140,6 @@ def __getattr__(self, key): def __setattr__(self, key, value): self[key] = value - def getValue(self, key): - return getattr(self, key, None) - def getValueOrDefault(self, key): value = getattr(self, key, None) if value is None: @@ -210,20 +199,18 @@ async def find(cls, pk): async def save(self): args = list(map(self.getValueOrDefault, self.__fields__)) - args.append(self.getValueOrDefault(self.__primary_key__)) rows = await execute(self.__insert__, args) if rows != 1: logging.warn('failed to insert record: affected rows: %s' % rows) async def update(self): - args = list(map(self.getValue, self.__fields__)) - args.append(self.getValue(self.__primary_key__)) + args = list(map(self.get, self.__fields__)) rows = await execute(self.__update__, args) if rows != 1: logging.warn('failed to update by primary key: affected rows: %s' % rows) async def remove(self): - args = [self.getValue(self.__primary_key__)] + args = [self.get(self.__primary_key__)] rows = await execute(self.__delete__, args) if rows != 1: logging.warn('failed to remove by primary key: affected rows: %s' % rows)