基于flask的restful-api后端笔记
最近使用flask写了一个前后端分离的小项目,用到的新东西不多。在这里总结下几个有些意思的小点。包括restful api的用户登录、速度限制、还有数据库的ORM使用。
基于token的用户认证
restful api是无状态的,服务端不保存任何访问的状态信息。很多提供api服务的接口会给用户一个永久的token,每次访问带上token即可,而这里的token就相于是密码了。但是作为web服务的后端,每次访问携带用户名密码显然是不可取的。
一个比较通用的方法是首次访问提供用户名密码,换取一个具有有效期的token,然后后续的访问带上这个token。token放在请求头中,可以是自定义的一个请求头,比如TOKEN;也可以是已有的,比如基础认证头Authorization。Miguel的REST-auth就是用现有的基础认证模块实现的token认证,以下实现参照REST-auth中的方法。
首先是定义数据库模型的User类,实现加密和校验方法,passlib和werkzeug.security库都可以实现。这里用的是passlib的custom_app_context实现的基于sha256算法的加密方式。
from passlib.apps import custom_app_context as pwd_context
class User(db.Document):
...
username = db.StringField(required=True, unique=True)
password_hash = db.StringField(required=True)
def hash_password(self, password):
self.password_hash = pwd_context.encrypt(password)
def verify_password(self, password):
return pwd_context.verify(password, self.password_hash)
至于token,需要由一个用户的唯一标识和当前时间经过加密生成,这正是itsdangerous库中实现的功能:
from itsdangerous import (TimedJSONWebSignatureSerializer as Serializer, BadSignature, SignatureExpired)
class User(db.Document):
...
def generate_auth_token(self, expiration=3600):
s = Serializer(current_app.config['SECRET_KEY'], expires_in=expiration)
return s.dumps({'id': self.id})
@staticmethod
def verify_auth_token(token):
s = Serializer(current_app.config['SECRET_KEY'])
try:
data = s.loads(token)
except SignatureExpired:
return None # valid token, but expired
except BadSignature:
return None # invalid token
user = User.objects(id=data['id']).first()
return user
接下来要做的就是拿用户名、密码换取token,并在后续访问中验证token。因为这里选择将token放在Authorization中,所以直接使用flask_httpauth库提供的基础验证模块。
from flask_httpauth import HTTPBasicAuth
auth = HTTPBasicAuth()
在需要认证的视图函数前加上@auth.login_required
装饰器就可以实现认证了。不过还需要我们自己实现一个验证回调函数:
@auth.verify_password
def verify_password(username_or_token, password):
# first try to authenticate by token
user = User.verify_auth_token(username_or_token)
if not user:
# try to authenticate with username/password
user = User.objects(username=username_or_token).first()
if not user or not user.verify_password(password):
return False
g.user = user
return True
登录时,将用户名密码放入Authorization中,以后只需要把token放到用户名的位置,密码留空。校验时,先试试校验token,不行再校验用户名、密码。如果通过校验,将返回的用户信息放入全局变量g
中。我们再定义一个登录的视图函数:
from flask import jsonify
from datetime import datetime
@app.route('/login')
@auth.login_required
def login():
duration = 60*60
return jsonify({
"username": g.user.username,
"email": g.user.email,
"token": g.user.generate_auth_token(duration).decode('ascii'),
"duration": duration,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
})
还有一个问题,使用flask_httpauth,校验失败会返回401,浏览器会直接弹出基础认证的输入框。作为api,这当然是我们不想要的。所以重载错误处理函数,返回403.
@auth.error_handler
def unauthorized():
return make_response(jsonify({'error': 'Unauthorized access'}), 403)
最后,每次得到的token会在设置的duration到期后失效,那时又会强制让用户输入用户名、密码。那么如何动态地延长token的有效期,或者是动态地更换token,我还没想好如何实现。。。
速度限制
一些api资源不能无限制地提供给用户,需要在一定时间范围内限制用户的使用次数。flask的作者Armin Ronacher写了一段关于限速的snippets.
import time
from functools import update_wrapper
from flask import request, g
import redis
redis = redis.Redis(host='localhost', port=6379, db=0)
class RateLimit(object):
expiration_window = 10 # 额外的延时,抵消提交到redis的网络延迟
def __init__(self, key_prefix, limit, per, send_x_headers):
self.reset = (int(time.time()) // per) * per + per
self.key = key_prefix + str(self.reset)
self.limit = limit
self.per = per
self.send_x_headers = send_x_headers
"""
使用pipeline提交操作
about redis.pipeline: https://redis.io/topics/pipelining
"""
p = redis.pipeline()
p.incr(self.key)
p.expireat(self.key, self.reset + self.expiration_window)
self.current = min(p.execute()[0], limit)
remaining = property(lambda x: x.limit - x.current) # 剩余查询次数
over_limit = property(lambda x: x.current >= x.limit) # 是否超过限制
def get_view_rate_limit():
return getattr(g, '_view_rate_limit', None)
def on_over_limit(limit):
return 'You hit the rate limit', 400
def ratelimit(limit, per=300, send_x_headers=True, over_limit=on_over_limit,
scope_func=lambda: g.user.username, key_func=lambda: request.endpoint):
def decorator(f):
def rate_limited(*args, **kwargs):
key = 'rate-limit/%s/%s/' % (key_func(), scope_func()) # redis中的键名 默认使用路由端点和用户名
rlimit = RateLimit(key, limit, per, send_x_headers)
g._view_rate_limit = rlimit
if over_limit is not None and rlimit.over_limit:
return over_limit(rlimit)
return f(*args, **kwargs)
return update_wrapper(rate_limited, f)
return decorator
这里的关键是使用redis记录用户的访问次数,将路由端点和用户名作为redis的key值。在reset过期时间之内访问到路由端点就会把对应的value+1,然后再记录剩余次数remaining和是否超限over_limit。
加入限速后,可以用如下几个返回头提示用户:
X-RateLimit-Remaining
: 当前时间段剩余次数X-RateLimit-Limit
: 当前时间段允许的请求书数X-RateLimit-Reset
: 下次重置次数的时间
使用after_request钩子为响应添加返回头:
@app.after_request
def inject_x_rate_headers(response):
limit = get_view_rate_limit()
if limit and limit.send_x_headers:
h = response.headers
h.add('X-RateLimit-Remaining', str(limit.remaining))
h.add('X-RateLimit-Limit', str(limit.limit))
h.add('X-RateLimit-Reset', str(limit.reset))
return response
注意如果要使用g中的user变量,先用login_required装饰器将user导入g,再使用ratelimit装饰器
@app.route('/rate-limited')
@auth.login_required
@ratelimit(limit=300, per=60 * 15)
def index():
return '<h1>This is a rate limited response</h1>'
ORM的使用
之前用ORM时,只是在类中定义好各个数据库的字段,而逻辑方法,都是放到试图函数中去实现。这次需要根据用户的输入,去指定的表(collection)中查询内容,还有把采集的数据存储在不同的表中。如果再一个个判断就会显得太麻烦了。
这里用到的办法是Mixin类,在Mixin类中定义一些每个数据表都共有的方法和字段。让每个ORM类都去多重继承这个Mixin类。
import json
from datetime import datetime
from .. import db
class DataMixin(object):
created_at = db.DateTimeField(default=datetime.now())
@classmethod
def create_or_update(cls, dataset):
...
def to_dict(self):
_dict = json.loads(self.to_json())
_dict.pop('_id')
return _dict
在DataMixin类中定义好所有表都有的created_at字段,再添加两个共有的方法:保存或更新和to_dict。
class DataMixin(object):
...
@classmethod
def _serach_by_id(cls, keyword):
return = cls.objects(id=keyword).first()
@classmethod
def _search_by_textfield(cls, _type, keyword):
"""
需要子类实现
根据字段模糊匹配
"""
return
@classmethod
def search(cls, _type, keyword):
if _type == 'id':
return cls._serach_by_id(keyword)
else:
return cls._search_by_textfield(_type, keyword)
搜索时,需要指定搜索的数据表、字段、关键字,其中id字段是所有表共有的,所以把它放到Mixin类中实现,而其他字段每个表各不相同,放到自己的类中去实现:
class A(db.Document, DataMixin):
...
@classmethod
def _search_by_textfield(cls, _type, keyword):
...
再用一个str2cls函数将传入的数据表关键字转换成对应的ORM类对象。
def str2cls(str_name):
if str_name == 'A':
return A
elif str_name == 'B':
return B
这时多表的搜索就变得简单了:
@web_bp.route('/data', methods=['POST'])
def search_data():
table, _type, keyword, per_page, page = get_post_params()
Document = str2cls(table)
data = Document.search(_type, keyword)
paginate_data = Document.search(_type, keyword).paginate(per_page=per_page, page=page)
data_list = [item.to_dict() for item in paginate_data.items]
return jsonify({
"data": data_list,
"pages": paginate_data.pages,
"page": paginate_data.page,
"per_page": paginate_data.per_page,
"total": paginate_data.total,
"prev_num": paginate_data.prev_num,
"next_num": paginate_data.next_num
})
最后,还需要一个接口,根据参数将采集的数据放到各个表中保存:
@worker_bp.route('/data', methods=['POST'])
def save_data():
table, params = get_post_params()
Document = str2cls(table)
dataset = Document()
data_dict = {k.strip('data_'): v for k, v in params.iteritems() if k.startswith('data_')}
for key, value in data_dict.iteritems():
setattr(dataset, key, value)
Document.create_or_update(dataset)
将需要保存的数据参数以data_前缀命名,就可以通过上面的几行代码实现多表存储。
目前这个项目并发量能到多少呢?我之前的项目也做过后端API开发
@422886557 使用的gunicorn+nginx,因为是个小项目,所以没有太大的并发压力。数据库的操作也都很简单,所以可以参考gunicorn+nginx的并发能力,这里有篇文章可供参考 https://www.jianshu.com/p/2e713b54df0c?utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendation