【Django】DRF源码分析之三大认证

纸上得来终觉浅,绝知此事要躬行。

前言

之前在【Django】DRF源码分析之五大模块文章中没有讲到认证模块,本章就主要来谈谈认证模块中的三大认证,首先我们先回顾一下DRF请求的流程:

  1. 前台发送请求,后台接受,进行urls.py中的url匹配,执行对应类视图调用as_view()方法
from django.conf.urls import url
from . import views

urlpatterns = [
    url(r‘^v1/users/$‘, views.User.as_view())
]
  1. 之后在APIView中调用父类as_view(),并且在闭包中调用了dispatch()方法,该方法调用的APIView类中的(该类重写了父类)
def dispatch(self, request, *args, **kwargs):
        
        ......
        # 请求模块和解析模块
        request = self.initialize_request(request, *args, **kwargs)
        
        ......

        try:
            # 三大认证模块
            self.initial(request, *args, **kwargs)

            ......

            # 响应模块
            response = handler(request, *args, **kwargs)

        except Exception as exc:
            # 异常模块
            response = self.handle_exception(exc)
        # 渲染模块
        self.response = self.finalize_response(request, response, *args, **kwargs)
        return self.response
  1. 先执行self.initialize_request(request, *args, **kwargs)请求模块,此步骤是rest_framework对request进行了扩展封装和兼容
def initialize_request(self, request, *args, **kwargs):
        """
        Returns the initial request object.
        """
        parser_context = self.get_parser_context(request)

        return Request(
            request,
            parsers=self.get_parsers(),
            authenticators=self.get_authenticators(),
            negotiator=self.get_content_negotiator(),
            parser_context=parser_context
        )
  1. 请求模块走完之后,接下来就是认证模块,执行self.initial(request, *args, **kwargs),点击源码进入。
def initial(self, request, *args, **kwargs):
        ......

        # 认证组件
        self.perform_authentication(request)

        # 权限组件
        self.check_permissions(request)

        # 限流组件
        self.check_throttles(request)

目前请求走到这里,就是我们今天需要讨论的认证模块,分为三个部分,以此来从源码进行剖析。

认证组件

源码分析

首先执行的就是self.perform_authentication(request)(认证组件),点击源码查看:

def perform_authentication(self, request):
        
        request.user

我们发现该方法只有一行代码,没有返回值,也没有赋值,也不能继续点击进入(可能会出现一堆的东西),但是我们的目的就是找认证组件认证方法,所以猜想这句话就是调用方法,是不是很可能被@property装饰了,还是通过request对象调用的,所以我们就rest_framework/request.py下面找request类(因为之前的请求模块对原生request进行了扩展就是使用的该类)中的user方法,发现源码如下:

@property
    def user(self):
        """
        Returns the user associated with the current request, as authenticated
        by the authentication classes provided to the request.
        """
        if not hasattr(self, ‘_user‘):
            with wrap_attributeerrors():
                # 没用户,认证用户
                self._authenticate()
        # 有用户,直接返回
        return self._user

发现对于认证的函数只调用了self._authenticate(),我们继续点击进入,源码分析如下图:

【Django】DRF源码分析之三大认证

结合上图我们需要分析出下面几个问题:

  1. self.authenticators是啥?
  2. authenticate(self)方法执行了什么玩意?
  3. self._not_authenticated()方法做了啥?

先解决第一个问题(self.authenticators是啥?)我们直接点击进去发现跑到了request类的__init__方法肯定不对,我们往回找,他是request的属性,记得之前请求模块对request进行了扩展,回去发现在APIView类的下面有self.initialize_request(request, *args, **kwargs)方法中有下面的代码:

Request(
    request,
    parsers=self.get_parsers(),
    authenticators=self.get_authenticators(),
    negotiator=self.get_content_negotiator(),
    parser_context=parser_context
)

发现传入了authenticators,他等于self.get_authenticators()的返回值,所以我们去查找self.get_authenticators()的源码:

def get_authenticators(self):
        """
        Instantiates and returns the list of authenticators that this view can use.
        """
        return [auth() for auth in self.authentication_classes]

发现结果是一个列表推导式,所以上图中可以进行遍历,而且列表中装的也都是对象,我们就去看看到底是什么类的对象,点击查看authentication_classes,发现是通过authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES配置的,所以我们可以去api_settings查找,结果代码如下:

‘DEFAULT_AUTHENTICATION_CLASSES‘: [
        ‘rest_framework.authentication.SessionAuthentication‘,
        ‘rest_framework.authentication.BasicAuthentication‘
    ],

默认写了两个类,ok,目前我们已经知道第一个问题(self.authenticators是啥?),他默认就是这两个类的对象列表,下面就是解决第二个问题(authenticate(self)方法执行了什么玩意?),该方法是这两个类的方法,所以我们去这两个默认的类去查看,源码查看文件rest_framework/authentication.py,下面是UML图以及类继承图:

【Django】DRF源码分析之三大认证

【Django】DRF源码分析之三大认证

由此我们发现BaseAuthentication是其他类的父类,而且每个类都有authenticate方法,我们首先查看一下BasicAuthentication类中的authenticate方法实现:

def get_authorization_header(request):
    """
    Return request‘s ‘Authorization:‘ header, as a bytestring.

    Hide some test client ickyness where the header can be unicode.
    """
    auth = request.META.get(‘HTTP_AUTHORIZATION‘, b‘‘)
    if isinstance(auth, str):
        # Work around django test client oddness
        auth = auth.encode(HTTP_HEADER_ENCODING)
    return auth

class BasicAuthentication(BaseAuthentication):
    """
    HTTP Basic authentication against username/password.
    """
    www_authenticate_realm = ‘api‘

    def authenticate(self, request):
        """
        Returns a `User` if a correct username and password have been supplied
        using HTTP Basic authentication.  Otherwise returns `None`.
        """
        auth = get_authorization_header(request).split() # 第一步:从请求头获取token信息按照空格分割

        if not auth or auth[0].lower() != b‘basic‘: # 第二步:判断我们的值格式:“basic xxxxxxxx”,就是有两段,中间空格隔开
            return None
     
        # 校验分割长度是不是等于2
        if len(auth) == 1: 
            msg = _(‘Invalid basic header. No credentials provided.‘)
            raise exceptions.AuthenticationFailed(msg)
        elif len(auth) > 2:
            msg = _(‘Invalid basic header. Credentials string should not contain spaces.‘)
            raise exceptions.AuthenticationFailed(msg)

        # 把token值按照一定规则解密
        try:
            auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(‘:‘)
        except (TypeError, UnicodeDecodeError, binascii.Error):
            msg = _(‘Invalid basic header. Credentials not correctly base64 encoded.‘)
            raise exceptions.AuthenticationFailed(msg)
        
        userid, password = auth_parts[0], auth_parts[2]
        return self.authenticate_credentials(userid, password, request)

    def authenticate_credentials(self, userid, password, request=None):
        """
        Authenticate the userid and password against username and password
        with optional request for context.
        """
        credentials = {
            get_user_model().USERNAME_FIELD: userid,
            ‘password‘: password
        }
        user = authenticate(request=request, **credentials)

        if user is None:
            raise exceptions.AuthenticationFailed(_(‘Invalid username/password.‘))

        if not user.is_active:
            raise exceptions.AuthenticationFailed(_(‘User inactive or deleted.‘))

        return (user, None)

    def authenticate_header(self, request):
        return ‘Basic realm="%s"‘ % self.www_authenticate_realm

分析过程:
1. 调用get_authorization_header从请求头获取,Authorization 的值,一般就是token信息,并且按照空格分割
2. 分割完成,判断我们的第一部分是不是basic
3. 校验分割长度是不是等于2
4. 把token值按照一定规则解密,分配
5. 调用self.authenticate_credentials(userid, password, request),可以看成通过解密的信息查询用户,最终返回元祖类型的数据

目第二个问题也已经解决,得知返回的结果是一个(user,None)的元祖,然后把元祖信息拆分给request.userrequest.auth。如果其中任意一个地方发生异常都会调用self._not_authenticated(),下面我们就来看看第三个问题(self._not_authenticated()方法做了啥?)

def _not_authenticated(self):
        """
        Set authenticator, user & authtoken representing an unauthenticated request.

        Defaults are None, AnonymousUser & None.
        """
        self._authenticator = None

        if api_settings.UNAUTHENTICATED_USER:
            self.user = api_settings.UNAUTHENTICATED_USER()
        else:
            self.user = None

        if api_settings.UNAUTHENTICATED_TOKEN:
            self.auth = api_settings.UNAUTHENTICATED_TOKEN()
        else:
            self.auth = None

源码其实很简单,就是给self.userself.auth赋值,其实就相当于给request.userrequest.auth赋值。其中api_settings.UNAUTHENTICATED_USER()表示的是一个匿名用户也可以理解为游客,而api_settings.UNAUTHENTICATED_TOKEN(),默认值为None。可以在api_settings中查看‘UNAUTHENTICATED_USER‘: ‘django.contrib.auth.models.AnonymousUser‘‘UNAUTHENTICATED_TOKEN‘: None,

整个认证的过程分析完成,我们可以知道大致流程就是:

  • 未携带认证信息的用户访问(游客或匿名用户),返回None赋值给经过_not_authenticated方法,赋值为request.user和request.auth
  • 携带认证信息的用户访问,返回(user,None)赋值给request.user和request.auth
  • 携带错误认证信息或者认证信息失效的用户,抛出异常调用_not_authenticated方法,赋值给request.user和request.auth

也即是说我们可以通过在类视图的request对象直接获取当前访问的用户,判断他是登录用户还是游客。

自定义认证类

通过源码的分析,我们可以知道实现自定义认证类必要条件,继承BaseAuthentication,然后实现authenticate方法,至于验证的逻辑可以结合业务编写,最终返回(user,auth)的元祖

#继承BaseAuthentication
class MyAuthentication(BaseAuthentication):
    def authenticate(self, request):   #重写authenticate方法
        # 1. 从请求的META获取token信息
        # 2. 判断信息是否合法或这不存在
          # 2.1 不存在:表示游客,返回None
          # 2.2 存在但是错误:非法用户,抛出异常
          # 2.3 存在且正确:返回 (用户, 认证信息)

        return (user,None)
  • 全局配置(settings.py)
REST_FRAMEWORK = {
    # 认证类配置
    ‘DEFAULT_AUTHENTICATION_CLASSES‘: [
        ‘rest_framework.authentication.SessionAuthentication‘,
        ‘rest_framework.authentication.BasicAuthentication‘,
        ‘xxxx.xxxxxx.MyAuthentication‘
        # eg:‘utils.authentications.MyAuthentication‘
    ]
}
  • 局部配置(CBV)
def xxxx(APIView):
    authentication_classes = (MyAuthentication,SessionAuthentication,BasicAuthentication)
    ......
    def get():
        ......

权限组件

源码分析

经过认证组件之后我们知道request对象中保存这当前请求的用户,下面执行self.check_permissions(request)方法,点击进入源码

def check_permissions(self, request):
        """
        Check if the request should be permitted.
        Raises an appropriate exception if the request is not permitted.
        """
        for permission in self.get_permissions():
            if not permission.has_permission(request, self):
                self.permission_denied(
                    request, message=getattr(permission, ‘message‘, None)
                )

看到这个过程简直是似曾相识,对,他和认证组件一个设计模式,通过self.get_permissions()获取权限类的对象列表,然后遍历,源码:

def get_permissions(self):
        """
        Instantiates and returns the list of permissions that this view requires.
        """
        return [permission() for permission in self.permission_classes]

发现同样是一个列表推导式,查看源码发现他和认证组件就是放在一起,接着点击self.permission_classes,同样发现permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES也是由api_settings配置得到,从rest_framework/settings.py文件找到得到默认配置。

‘DEFAULT_PERMISSION_CLASSES‘: [
        ‘rest_framework.permissions.AllowAny‘,
    ],

默认配置了一个AllowAny,此时程序开始遍历权限类的对象,执行has_permission方法的到返回值,如果为False表示他没有权限,继续执行self.permission_denied()方法直接抛出异常,为True则遍历下一个,直到全部为True,遍历结束,什么也不做就表示拥有配置的所有权限。

所以首先,让我们去了解has_permission到底做了什么?以及系统默认包含了那些认证类,通过rest_framework/permissions.py可以查看所有的权限类

【Django】DRF源码分析之三大认证

【Django】DRF源码分析之三大认证

发现大致分为以下几个类,BasePermission类是其他类的父类,而且每个类都实现了has_permission方法。

class BasePermission(metaclass=BasePermissionMetaclass):
    """
    A base class from which all permission classes should inherit.
    """

    def has_permission(self, request, view):
        """
        Return `True` if permission is granted, `False` otherwise.
        """
        return True

    def has_object_permission(self, request, view, obj):
        """
        Return `True` if permission is granted, `False` otherwise.
        """
        return True


class AllowAny(BasePermission):
    """
    Allow any access.
    This isn‘t strictly required, since you could use an empty
    permission_classes list, but it‘s useful because it makes the intention
    more explicit.
    """

    def has_permission(self, request, view):
        return True


class IsAuthenticated(BasePermission):
    """
    Allows access only to authenticated users.
    """

    def has_permission(self, request, view):
        return bool(request.user and request.user.is_authenticated)


class IsAdminUser(BasePermission):
    """
    Allows access only to admin users.
    """

    def has_permission(self, request, view):
        return bool(request.user and request.user.is_staff)


class IsAuthenticatedOrReadOnly(BasePermission):
    """
    The request is authenticated as a user, or is a read-only request.
    """

    def has_permission(self, request, view):
        return bool(
            request.method in SAFE_METHODS or
            request.user and
            request.user.is_authenticated
        )

接下来分别解释一个每个类:

  • AllowAny:直接返回True,任何用户拥有权限
  • IsAuthenticated:必须是认证信息通过的用户
  • IsAdminUser:必须是认证信息通过的用户且is_staff为True的用户,数据库保存的结果为1
  • IsAuthenticatedOrReadOnly:表示通过认证用户拥有权限或者游客以及认证失败的用户只能有SAFE_METHODS属性内定义的请求方法,默认为SAFE_METHODS = (‘GET‘, ‘HEAD‘, ‘OPTIONS‘)

权限组件相对过程比较简单,因为他是建立在认证组件基础之上,下面就让我们自定义权限组件。

自定义权限组件

通过源码的分析,我们可以同样也知道实现自定义权限类必要条件,继承BasePermission,然后实现has_permission方法,最终通过判断返回True或False。

from rest_framework.permissions import BasePermission

class MyPermission(BasePermission):
    def has_permission(self, request, view):
        # 判断逻辑xxxxxxx
        # 返回True或False
        return True or Flase
  • 全局配置(settings.py)
REST_FRAMEWORK = {
    #  权限类配置
    ‘DEFAULT_PERMISSION_CLASSES‘: [
        ‘utils.permissions.MyPermission‘,
    ],
}
  • 局部配置(CBV)
def xxxx(APIView):
    permission_classes = (MyPermission,)
    .....
    def get():
        ....

限流组件

源码分析

前面的认证和权限组件处理完成之后接下来就是限流组件,代码运行到self.check_throttles(request),点击查看源码:

def check_throttles(self, request):
        """
        Check if request should be throttled.
        Raises an appropriate exception if the request is throttled.
        """
        throttle_durations = []
        for throttle in self.get_throttles():
            if not throttle.allow_request(request, self):
                throttle_durations.append(throttle.wait())

        if throttle_durations:
            # Filter out `None` values which may happen in case of config / rate
            # changes, see #1438
            durations = [
                duration for duration in throttle_durations
                if duration is not None
            ]

            duration = max(durations, default=None)
            self.throttled(request, duration)

首先定义了一个throttle_durations空列表,之后又是循环遍历self.get_throttles(),可以想象他和认证组件、权限组件应该是一个样子,返回限流类对象列表,源码如下:

def get_throttles(self):
        """
        Instantiates and returns the list of throttles that this view uses.
        """
        return [throttle() for throttle in self.throttle_classes]

果然是一个列表推导式,保存的是限流类对象,同样我们也会想到它应该也是通过api_settings配置,点击self.throttle_classes查看throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES,继续查看默认配置信息。

‘DEFAULT_THROTTLE_CLASSES‘: [],

发现结果是一个空列表,意思也就是,默认并没有采用任何一个类来限制用户请求频率。通过认证的类定义在rest_frameworks/authentication.py和权限类定义在rest_framework/permission.py,我们应该在rest_framework下面查找类似throttle类,即rest_frameworks/throttling.py,并且通过源码应该不难发现他们应该都实现了allow_request方法。

【Django】DRF源码分析之三大认证

【Django】DRF源码分析之三大认证

通过类的继承关系我们发现

相关推荐