RestFul风格代码库

以drf代码为参考基线, 路由命名参考flask文档的说明: https://flask.palletsprojects.com/en/2.3.x/views/#method-dispatching-and-apis

技术栈:

完整版: 在bilibili-api各框架下examples案例的project目录下

项目工程结构

project/app ; go-zero创建的api服务
        model ; gorm模型
        scripts/ddl ; sql脚本
                gencode/configuration.go ; 生成代码脚本
        tests  ; 接口测试
        go.mod

模型

from django.db import models

# 从自动生成的模型文件导入模型
from .gen import Author


# 在代码声明模型
# =========================================
class Story(models.Model):
    title = models.CharField(max_length=200)
    content = models.TextField()
    created_at = models.DateTimeField(auto_now_add=True)
    updated_at = models.DateTimeField(auto_now=True)
    author = models.ForeignKey(Author, on_delete=models.CASCADE)

    def __str__(self):
        return self.title

视图

application/views.py

"""ModelViewSet的源码解析"""
from rest_framework.response import Response
from rest_framework import status
from rest_framework.settings import api_settings
from rest_framework.viewsets import GenericViewSet, ModelViewSet

from .models import Story
from .serializers import StorySerializer


# 全自动CRUD视图
# =========================================================
# 由框架生成CRUD视图, 应用开发者只需指定好模型和序列化器
# ModelViewSet: https://www.django-rest-framework.org/api-guide/viewsets/#modelviewset
class StoryViewSet(ModelViewSet):
    queryset = Story.objects.all()
    serializer_class = StorySerializer
    # todo 有时间可以补充用户认证和权限认证的内容


# 半自动CRUD视图
# =========================================================
# 此阶段CRUD代码仍然由框架生成
# 应用开发手工组装CRUD的各个混入视图
from rest_framework.mixins import CreateModelMixin, UpdateModelMixin, DestroyModelMixin, RetrieveModelMixin, ListModelMixin


# GenericViewSet: https://www.django-rest-framework.org/api-guide/viewsets/#genericviewset
class StoryV2ViewSet(CreateModelMixin,
                     RetrieveModelMixin,
                     UpdateModelMixin,
                     DestroyModelMixin,
                     ListModelMixin,
                     GenericViewSet):
    queryset = Story.objects.all()
    serializer_class = StorySerializer


# 手动CRUD视图
# =========================================================
# 手工实现CRUD的具体逻辑

# 看routers.SimpleRouter的源码, 路由函数自动做了动作跟视图的映射:
# List视图
# .as_view({'get': 'list',
#           'post': 'create',
#           })
# Detail
# .as_view({'get': 'retrieve',
#           'put': 'update',
#           'patch': 'partial_update',
#           'delete': 'destroy',
#          })
# 因此我们只需分别实现list、create、retrieve、update、partial_update、destroy方法,即实现了CRUD效果
# mixins: https://www.django-rest-framework.org/api-guide/generic-views/#mixins
class StoryV3ViewSet(GenericViewSet):
    queryset = Story.objects.all()
    serializer_class = StorySerializer

    # CreateModelMixin的源码
    # =============================================================
    # 对应post请求
    def create(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        self.perform_create(serializer)
        headers = self.get_success_headers(serializer.data)
        return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)

    def perform_create(self, serializer):
        serializer.save()

    def get_success_headers(self, data):
        try:
            return {'Location': str(data[api_settings.URL_FIELD_NAME])}
        except (TypeError, KeyError):
            return {}

    # UpdateModelMixin的源码
    # ========================================================================
    # 对应put请求
    def update(self, request, *args, **kwargs):
        partial = kwargs.pop('partial', False)
        # 1. 调用get_object_or_404函数,若实例不存在,抛出404异常
        instance = self.get_object()
        # 2. 反序列序列化数据, 验证数据
        serializer = self.get_serializer(instance, data=request.data, partial=partial)
        serializer.is_valid(raise_exception=True)
        # 3. 保存到db
        self.perform_update(serializer)

        if getattr(instance, '_prefetched_objects_cache', None):
            # If 'prefetch_related' has been applied to a queryset, we need to
            # forcibly invalidate the prefetch cache on the instance.
            instance._prefetched_objects_cache = {}

        return Response(serializer.data)

    def perform_update(self, serializer):
        serializer.save()

    # 对应patch请求
    def partial_update(self, request, *args, **kwargs):
        kwargs['partial'] = True
        return self.update(request, *args, **kwargs)

    # DestroyModelMixin的源码
    # ========================================================================
    # 对应delete请求
    def destroy(self, request, *args, **kwargs):
        instance = self.get_object()
        self.perform_destroy(instance)
        return Response(status=status.HTTP_204_NO_CONTENT)

    def perform_destroy(self, instance):
        instance.delete()

    # RetrieveModelMixin的源码
    # ===========================================
    # 对应Detail视图的get请求
    def retrieve(self, request, *args, **kwargs):
        instance = self.get_object()
        serializer = self.get_serializer(instance)
        return Response(serializer.data)

    # ListModelMixin的源码
    # ===========================================
    # 对应List视图的get请求
    def list(self, request, *args, **kwargs):
        queryset = self.filter_queryset(self.get_queryset())

        page = self.paginate_queryset(queryset)
        if page is not None:
            serializer = self.get_serializer(page, many=True)
            return self.get_paginated_response(serializer.data)

        serializer = self.get_serializer(queryset, many=True)
        return Response(serializer.data)

列表查询

简单模型查询

app/api/routes/users.py

from sqlalchemy import func, select

@router.get(
    "/",
    dependencies=[Depends(get_current_active_superuser)],
    response_model=UsersPublic,
)
def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> Any:
    """Retrieve users."""
    count_statement = select(func.count()).select_from(User)
    count = session.execute(count_statement).scalar_one()

    statement = select(User).offset(skip).limit(limit)
    users = session.execute(statement).scalars().all()

    return UsersPublic(data=users, count=count)

指定返回字段,在多字段模型下能提高查询性能

app/api/routes/items.py

from sqlalchemy import func, select
from sqlalchemy.orm import load_only

@router.get("/", response_model=ItemsPublic)
def read_items(
    session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100
) -> Any:
    """获取项目列表"""
    if current_user.is_superuser:
        count_statement = select(func.count()).select_from(Item)
        # 手工指定字段数据返回
        statement = select(
            Item.id,
            Item.owner_id,
            Item.title,
            Item.description
        ).offset(skip).limit(limit)
        items: list[dict] = session.execute(statement).mappings().all()
    else:
        count_statement = (
            select(func.count())
            .select_from(Item)
            .where(Item.owner_id == current_user.id)
        )
        # 方式二: https://docs.sqlalchemy.org/en/20/orm/queryguide/columns.html#using-load-only-to-reduce-loaded-columns
        statement = (
            select(Item)
            .options(load_only(
                Item.id,
                Item.owner_id,
                Item.title,
                Item.description
            ))
            .where(Item.owner_id == current_user.id)
            .offset(skip)
            .limit(limit)
        )
        items: list[Item] = session.execute(statement).scalars().all()

    count = session.execute(count_statement).scalar_one()
    return ItemsPublic(data=items, count=count)

路由

project/router.py

from rest_framework import routers

from application.views import StoryViewSet, StoryV2ViewSet, StoryV3ViewSet

# https://www.django-rest-framework.org/api-guide/routers/#usage
router = routers.SimpleRouter()
# 自动生成url名字可在reverse函数中使用
# * story-detail: get(单个)、patch、put、delete请求
# * story-list: get(列表)、post请求
router.register('v1/stories', StoryViewSet)
router.register('v2/stories', StoryV2ViewSet, basename='storyV2')
router.register('v3/stories', StoryV3ViewSet, basename='storyV3')


测试

application/tests.py

from django.urls import reverse
from rest_framework.test import APITestCase

from application.models import Author


class StoryTests(APITestCase):
    @classmethod
    def setUpTestData(cls) -> None:
        cls.author = Author(id=1, name='Beatles', age=18)
        cls.author.save()

    def test_create_story(self):
        """
        Ensure we can create a new story.
        """
        for name, view_name in [
            ('全自动视图', 'story-list'),
            ('半自动视图', 'storyV2-list'),
            ('全手动视图', 'storyV3-list')
        ]:
            url = reverse(view_name)
            with self.subTest(name=name):
                data = {'title': 'Test Story',
                        'content': 'This is a test story.',
                        'author': self.author.pk}
                response = self.client.post(url, data, format='json')
                self.assertEqual(response.status_code, 201)
                self.assertEqual(response.data['title'], 'Test Story')
                self.assertEqual(response.data['author'], self.author.pk)

小技巧

create接口的测试思路参考drf的 example

  • 验证状态码

  • 用模型查询表的数量 -> 从返回值获取id, 然后用模型查id

  • 获取表的第一条数据,验证名称

swagger