基于flask的restful-api后端笔记

最近使用flask写了一个前后端分离的小项目,用到的新东西不多。在这里总结下几个有些意思的小点。包括restful api的用户登录、速度限制、还有数据库的ORM使用。

基于token的用户认证

restful api是无状态的,服务端不保存任何访问的状态信息。很多提供api服务的接口会给用户一个永久的token,每次访问带上token即可,而这里的token就相于是密码了。但是作为web服务的后端,每次访问携带用户名密码显然是不可取的。
一个比较通用的方法是首次访问提供用户名密码,换取一个具有有效期的token,然后后续的访问带上这个token。token放在请求头中,可以是自定义的一个请求头,比如TOKEN;也可以是已有的,比如基础认证头Authorization。Miguel的REST-auth就是用现有的基础认证模块实现的token认证,以下实现参照REST-auth中的方法。

首先是定义数据库模型的User类,实现加密和校验方法,passlib和werkzeug.security库都可以实现。这里用的是passlib的custom_app_context实现的基于sha256算法的加密方式。

from passlib.apps import custom_app_context as pwd_context

class User(db.Document):
    ...
    username = db.StringField(required=True,  unique=True)
    password_hash = db.StringField(required=True)

    def hash_password(self, password):
        self.password_hash = pwd_context.encrypt(password)

    def verify_password(self, password):
        return pwd_context.verify(password, self.password_hash)

至于token,需要由一个用户的唯一标识和当前时间经过加密生成,这正是itsdangerous库中实现的功能:

from itsdangerous import (TimedJSONWebSignatureSerializer as Serializer, BadSignature, SignatureExpired)

class User(db.Document):
    ...
    def generate_auth_token(self, expiration=3600):
        s = Serializer(current_app.config['SECRET_KEY'], expires_in=expiration)
        return s.dumps({'id': self.id})

    @staticmethod
    def verify_auth_token(token):
        s = Serializer(current_app.config['SECRET_KEY'])
        try:
            data = s.loads(token)
        except SignatureExpired:
            return None  # valid token, but expired
        except BadSignature:
            return None  # invalid token
        user = User.objects(id=data['id']).first()
        return user

接下来要做的就是拿用户名、密码换取token,并在后续访问中验证token。因为这里选择将token放在Authorization中,所以直接使用flask_httpauth库提供的基础验证模块。

from flask_httpauth import HTTPBasicAuth
auth = HTTPBasicAuth()

在需要认证的视图函数前加上@auth.login_required装饰器就可以实现认证了。不过还需要我们自己实现一个验证回调函数:

@auth.verify_password
def verify_password(username_or_token, password):
    # first try to authenticate by token
    user = User.verify_auth_token(username_or_token)
    if not user:
        # try to authenticate with username/password
        user = User.objects(username=username_or_token).first()
        if not user or not user.verify_password(password):
            return False
    g.user = user
    return True

登录时,将用户名密码放入Authorization中,以后只需要把token放到用户名的位置,密码留空。校验时,先试试校验token,不行再校验用户名、密码。如果通过校验,将返回的用户信息放入全局变量g中。我们再定义一个登录的视图函数:

from flask import jsonify
from datetime import datetime

@app.route('/login') 
@auth.login_required
def login():
    duration = 60*60
    return jsonify({
        "username": g.user.username,
        "email": g.user.email,
        "token": g.user.generate_auth_token(duration).decode('ascii'),
        "duration": duration,
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    })

还有一个问题,使用flask_httpauth,校验失败会返回401,浏览器会直接弹出基础认证的输入框。作为api,这当然是我们不想要的。所以重载错误处理函数,返回403.

@auth.error_handler
def unauthorized():
    return make_response(jsonify({'error': 'Unauthorized access'}), 403)

最后,每次得到的token会在设置的duration到期后失效,那时又会强制让用户输入用户名、密码。那么如何动态地延长token的有效期,或者是动态地更换token,我还没想好如何实现。。。

速度限制

一些api资源不能无限制地提供给用户,需要在一定时间范围内限制用户的使用次数。flask的作者Armin Ronacher写了一段关于限速的snippets.

import time
from functools import update_wrapper
from flask import request, g
import redis

redis = redis.Redis(host='localhost', port=6379, db=0)

class RateLimit(object):
    expiration_window = 10  # 额外的延时,抵消提交到redis的网络延迟

    def __init__(self, key_prefix, limit, per, send_x_headers):
        self.reset = (int(time.time()) // per) * per + per
        self.key = key_prefix + str(self.reset)
        self.limit = limit
        self.per = per
        self.send_x_headers = send_x_headers

        """
        使用pipeline提交操作
        about redis.pipeline: https://redis.io/topics/pipelining
        """
        p = redis.pipeline()
        p.incr(self.key)
        p.expireat(self.key, self.reset + self.expiration_window)
        self.current = min(p.execute()[0], limit)

    remaining = property(lambda x: x.limit - x.current)   # 剩余查询次数
    over_limit = property(lambda x: x.current >= x.limit)  # 是否超过限制

def get_view_rate_limit():
    return getattr(g, '_view_rate_limit', None)

def on_over_limit(limit):
    return 'You hit the rate limit', 400

def ratelimit(limit, per=300, send_x_headers=True, over_limit=on_over_limit,
              scope_func=lambda: g.user.username, key_func=lambda: request.endpoint):
    def decorator(f):
        def rate_limited(*args, **kwargs):
            key = 'rate-limit/%s/%s/' % (key_func(), scope_func())   # redis中的键名 默认使用路由端点和用户名
            rlimit = RateLimit(key, limit, per, send_x_headers)
            g._view_rate_limit = rlimit
            if over_limit is not None and rlimit.over_limit:
                return over_limit(rlimit)
            return f(*args, **kwargs)
        return update_wrapper(rate_limited, f)
    return decorator

这里的关键是使用redis记录用户的访问次数,将路由端点和用户名作为redis的key值。在reset过期时间之内访问到路由端点就会把对应的value+1,然后再记录剩余次数remaining和是否超限over_limit。
加入限速后,可以用如下几个返回头提示用户:

  • X-RateLimit-Remaining: 当前时间段剩余次数
  • X-RateLimit-Limit: 当前时间段允许的请求书数
  • X-RateLimit-Reset: 下次重置次数的时间

使用after_request钩子为响应添加返回头:

@app.after_request
def inject_x_rate_headers(response):
    limit = get_view_rate_limit()
    if limit and limit.send_x_headers:
        h = response.headers
        h.add('X-RateLimit-Remaining', str(limit.remaining))
        h.add('X-RateLimit-Limit', str(limit.limit))
        h.add('X-RateLimit-Reset', str(limit.reset))
    return response

注意如果要使用g中的user变量,先用login_required装饰器将user导入g,再使用ratelimit装饰器

@app.route('/rate-limited')
@auth.login_required
@ratelimit(limit=300, per=60 * 15)
def index():
    return '<h1>This is a rate limited response</h1>'

ORM的使用

之前用ORM时,只是在类中定义好各个数据库的字段,而逻辑方法,都是放到试图函数中去实现。这次需要根据用户的输入,去指定的表(collection)中查询内容,还有把采集的数据存储在不同的表中。如果再一个个判断就会显得太麻烦了。
这里用到的办法是Mixin类,在Mixin类中定义一些每个数据表都共有的方法和字段。让每个ORM类都去多重继承这个Mixin类。

import json
from datetime import datetime
from .. import db

class DataMixin(object):
    created_at = db.DateTimeField(default=datetime.now())

    @classmethod
    def create_or_update(cls, dataset):
        ...
    
    def to_dict(self):
        _dict = json.loads(self.to_json())
        _dict.pop('_id')
        return _dict        

在DataMixin类中定义好所有表都有的created_at字段,再添加两个共有的方法:保存或更新和to_dict。

class DataMixin(object):
    ...
    
    @classmethod
    def _serach_by_id(cls, keyword):
        return = cls.objects(id=keyword).first()

    @classmethod
    def _search_by_textfield(cls, _type, keyword):
        """
        需要子类实现
        根据字段模糊匹配
        """
        return

    @classmethod
    def search(cls, _type, keyword):
        if _type == 'id':
            return cls._serach_by_id(keyword)
        else:
            return cls._search_by_textfield(_type, keyword)

搜索时,需要指定搜索的数据表、字段、关键字,其中id字段是所有表共有的,所以把它放到Mixin类中实现,而其他字段每个表各不相同,放到自己的类中去实现:

class A(db.Document, DataMixin):
    ...

    @classmethod
    def _search_by_textfield(cls, _type, keyword):
        ...

再用一个str2cls函数将传入的数据表关键字转换成对应的ORM类对象。

def str2cls(str_name):
    if str_name == 'A':
        return A
    elif str_name == 'B':
        return B

这时多表的搜索就变得简单了:

@web_bp.route('/data', methods=['POST'])
def search_data():
    table, _type, keyword, per_page, page = get_post_params()
    Document = str2cls(table)
        data = Document.search(_type, keyword)
        paginate_data = Document.search(_type, keyword).paginate(per_page=per_page, page=page)
        data_list = [item.to_dict() for item in paginate_data.items]
        return jsonify({
            "data": data_list,
            "pages": paginate_data.pages,
            "page": paginate_data.page,
            "per_page": paginate_data.per_page,
            "total": paginate_data.total,
            "prev_num": paginate_data.prev_num,
            "next_num": paginate_data.next_num
        })

最后,还需要一个接口,根据参数将采集的数据放到各个表中保存:

@worker_bp.route('/data', methods=['POST'])
def save_data():
    table, params = get_post_params()
    Document = str2cls(table)
    dataset = Document()
    data_dict = {k.strip('data_'): v for k, v in params.iteritems() if k.startswith('data_')}
    for key, value in data_dict.iteritems():
        setattr(dataset, key, value)
    Document.create_or_update(dataset)

将需要保存的数据参数以data_前缀命名,就可以通过上面的几行代码实现多表存储。

Comments
Write a Comment
  • 422886557 reply

    目前这个项目并发量能到多少呢?我之前的项目也做过后端API开发

    • Melw00d reply

      @422886557 使用的gunicorn+nginx,因为是个小项目,所以没有太大的并发压力。数据库的操作也都很简单,所以可以参考gunicorn+nginx的并发能力,这里有篇文章可供参考 https://www.jianshu.com/p/2e713b54df0c?utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendation