通过DRF的throttle设置API的访问速率

throttling功能为DRF内置功能我们无需安装第三方包可以直接使用。

查看官方文档:http://www.django-rest-framework.org/#api-guide

image-20180627230648107

配置

按照文档 我们需要在settings.py文件进行如下配置:

1
2
3
4
5
6
7
8
9
10
REST_FRAMEWORK = {
'DEFAULT_THROTTLE_CLASSES': (
'rest_framework.throttling.AnonRateThrottle',
'rest_framework.throttling.UserRateThrottle'
),
'DEFAULT_THROTTLE_RATES': {
'anon': '100/day', # 代表匿名用户每天访问100次
'user': '1000/day'
}
}

DEFAULT_THROTTLE_CLASSES是配置限速类,AnonRateThrottle是指匿名用户(未登录用户),UserRateThrottle是指登录用户。

查看源码得知,匿名用户的判断依据可知是通过IP地址,登录用户的判断依据是Token。两者的判断依据是不一致的。

DEFAULT_THROTTLE_RATES是配置限速规则,具体可以是时分秒等。

我们看下官方文档的描述:

The rate descriptions used in DEFAULT_THROTTLE_RATES may include second, minute, hour or day as the throttle period.

使用
1
2
3
4
5
6
7
8
9
10
11
12
13
from rest_framework.response import Response
from rest_framework.throttling import UserRateThrottle
from rest_framework.views import APIView

class ExampleView(APIView):
# 配置限速类 这是只配置了登录用户
throttle_classes = (UserRateThrottle,)

def get(self, request, format=None):
content = {
'status': 'request was permitted'
}
return Response(content)

除了APIView外,我们当然可以配置到ViewSet中,同样只需要配置throttle_classes在类中即可。

原理

我们查看下AnonRateThrottle的源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class AnonRateThrottle(SimpleRateThrottle):
"""
Limits the rate of API calls that may be made by a anonymous users.

The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'

def get_cache_key(self, request, view):
if request.user.is_authenticated():
return None # Only throttle unauthenticated requests.

return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
}
# 我们看到类文档描述中说的,IP地址将会是唯一键作为判断依据

接着看继承的SimpleRateThrottle部分源码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class SimpleRateThrottle(BaseThrottle):
"""
A simple cache implementation, that only requires `.get_cache_key()`
to be overridden.

The rate (requests / seconds) is set by a `throttle` attribute on the View
class. The attribute is a string of the form 'number_of_requests/period'.

Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')

Previous request information used for throttling is stored in the cache.
"""
# 从上面的话中我们得知是将请求信息保存在缓存中
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES # 这个是获取到我们在配置文件的中的配置

def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
raise ImproperlyConfigured(msg)

try:
return self.THROTTLE_RATES[self.scope]
except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope
raise ImproperlyConfigured(msg)
# 上面这个函数主要是获得配置

def parse_rate(self, rate):
"""
Given the request rate string, return a two tuple of:
<allowed number of requests>, <period of time in seconds>
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration)

# 上面函数解析配置

def allow_request(self, request, view):
"""
Implement the check to see if the request should be throttled.

On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True

self.key = self.get_cache_key(request, view)
if self.key is None:
return True

self.history = self.cache.get(self.key, [])
self.now = self.timer()

# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
# 这个函数是真正记录每个IP的请求次数

具体如何获得请求IP呢?我们看下BaseThrottle的源码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BaseThrottle(object):
"""
Rate throttling of requests.
"""
def allow_request(self, request, view):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise NotImplementedError('.allow_request() must be overridden')

def get_ident(self, request):
"""
Identify the machine making the request by parsing HTTP_X_FORWARDED_FOR
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
# 上面这句话就是获得请求IP的
num_proxies = api_settings.NUM_PROXIES

if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()

return ''.join(xff.split()) if xff else remote_addr
知识就是财富
如果您觉得文章对您有帮助, 欢迎请我喝杯水!