RestFul风格代码库
以drf代码为参考基线, 路由命名参考flask文档的说明: https://flask.palletsprojects.com/en/2.3.x/views/#method-dispatching-and-apis
技术栈:
完整版: 在bilibili-api各框架下examples案例的project目录下
项目工程结构
project/app ; go-zero创建的api服务
model ; gorm模型
scripts/ddl ; sql脚本
gencode/configuration.go ; 生成代码脚本
tests ; 接口测试
go.mod
模型
from django.db import models
# 从自动生成的模型文件导入模型
from .gen import Author
# 在代码声明模型
# =========================================
class Story(models.Model):
title = models.CharField(max_length=200)
content = models.TextField()
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
author = models.ForeignKey(Author, on_delete=models.CASCADE)
def __str__(self):
return self.title
# This is an auto-generated Django model module.
# You'll have to do the following manually to clean this up:
# * Rearrange models' order
# * Make sure each model has one field with primary_key=True
# * Make sure each ForeignKey and OneToOneField has `on_delete` set to the desired behavior
# * Remove `managed = False` lines if you wish to allow Django to create, modify, and delete the table
# Feel free to rename the models, but don't rename db_table values or field names.
from django.db import models
class Author(models.Model):
first_name = models.CharField(max_length=50)
last_name = models.CharField(max_length=50)
birth_date = models.DateField(blank=True, null=True)
nationality = models.CharField(max_length=100, blank=True, null=True)
biography = models.TextField(blank=True, null=True)
created_at = models.DateTimeField(blank=True, null=True)
updated_at = models.DateTimeField(blank=True, null=True)
class Meta:
managed = True
db_table = 'author'
from datetime import datetime
from sqlalchemy import String, TEXT, func, Table, MetaData, Column, Integer, Date, DateTime
from sqlalchemy.orm import Mapped, mapped_column, relationship
from typing_extensions import Annotated
from .exts import db
created_at = Annotated[datetime, mapped_column(nullable=False, server_default=func.now())]
update_at = Annotated[datetime, mapped_column(nullable=False, server_default=func.now())]
class Author(db.Model):
__tablename__ = 'author'
id = Column('id', Integer, primary_key=True, autoincrement=True)
first_name = Column('first_name', String(50))
last_name = Column('last_name', String(50))
birth_date = Column('birth_date', Date, nullable=True)
nationality = Column('nationality', String(50), nullable=True)
biography = Column('biography', TEXT, nullable=True)
created_at = Column('created_at', DateTime, nullable=False, server_default=func.now()) # 创建时间
updated_at = Column('updated_at', DateTime, nullable=True, server_onupdate=func.now()) # 更新时间
# 一对多
stories: Mapped[list["Story"]] = relationship(back_populates="author")
class Story(db.Model):
__tablename__ = 'story'
id: Mapped[int] = mapped_column(primary_key=True)
title: Mapped[str] = mapped_column(String(200))
content: Mapped[str] = mapped_column(TEXT)
created_at: Mapped[created_at]
update_at: Mapped[update_at]
author_id: Mapped[int] = mapped_column(db.ForeignKey('author.id'))
author: Mapped["Author"] = relationship(back_populates="stories")
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import DeclarativeBase
# TODO 在企业项目可以考虑用flask_migrate代替手工create_all
from flask_migrate import Migrate
class Base(DeclarativeBase):
pass
db = SQLAlchemy(model_class=Base)
migrate = Migrate()
def init_exts(app):
db.init_app(app)
migrate.init_app(app, db)
集成gorm库
model/story.go
package model
import "gorm.io/gorm"
type Story struct {
gorm.Model
Title string `gorm:"size:128"`
Content string
AuthorID int32
// gen库自动生成的模型
Author Author
}
model/author.gen.go
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
// Code generated by gorm.io/gen. DO NOT EDIT.
package model
import (
"time"
)
const TableNameAuthor = "author"
// Author mapped from table <author>
type Author struct {
ID int32 `gorm:"column:id;type:int;primaryKey;autoIncrement:true" json:"id"`
FirstName string `gorm:"column:first_name;type:varchar(50);not null" json:"first_name"`
LastName string `gorm:"column:last_name;type:varchar(50);not null" json:"last_name"`
BirthDate *time.Time `gorm:"column:birth_date;type:date" json:"birth_date"`
Nationality *string `gorm:"column:nationality;type:varchar(100)" json:"nationality"`
Biography *string `gorm:"column:biography;type:text" json:"biography"`
CreatedAt *time.Time `gorm:"column:created_at;type:timestamp;default:CURRENT_TIMESTAMP" json:"created_at"`
UpdatedAt *time.Time `gorm:"column:updated_at;type:timestamp;default:CURRENT_TIMESTAMP" json:"updated_at"`
Stories []Story `gorm:"foreignKey:AuthorID;references:ID" json:"stories"`
}
// TableName Author's table name
func (*Author) TableName() string {
return TableNameAuthor
}
author.sql
CREATE TABLE IF NOT EXISTS author (
id INT AUTO_INCREMENT,
first_name VARCHAR(50) NOT NULL,
last_name VARCHAR(50) NOT NULL,
birth_date DATE,
nationality VARCHAR(100),
biography TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (id)
);
视图
application/views.py
"""ModelViewSet的源码解析"""
from rest_framework.response import Response
from rest_framework import status
from rest_framework.settings import api_settings
from rest_framework.viewsets import GenericViewSet, ModelViewSet
from .models import Story
from .serializers import StorySerializer
# 全自动CRUD视图
# =========================================================
# 由框架生成CRUD视图, 应用开发者只需指定好模型和序列化器
# ModelViewSet: https://www.django-rest-framework.org/api-guide/viewsets/#modelviewset
class StoryViewSet(ModelViewSet):
queryset = Story.objects.all()
serializer_class = StorySerializer
# todo 有时间可以补充用户认证和权限认证的内容
# 半自动CRUD视图
# =========================================================
# 此阶段CRUD代码仍然由框架生成
# 应用开发手工组装CRUD的各个混入视图
from rest_framework.mixins import CreateModelMixin, UpdateModelMixin, DestroyModelMixin, RetrieveModelMixin, ListModelMixin
# GenericViewSet: https://www.django-rest-framework.org/api-guide/viewsets/#genericviewset
class StoryV2ViewSet(CreateModelMixin,
RetrieveModelMixin,
UpdateModelMixin,
DestroyModelMixin,
ListModelMixin,
GenericViewSet):
queryset = Story.objects.all()
serializer_class = StorySerializer
# 手动CRUD视图
# =========================================================
# 手工实现CRUD的具体逻辑
# 看routers.SimpleRouter的源码, 路由函数自动做了动作跟视图的映射:
# List视图
# .as_view({'get': 'list',
# 'post': 'create',
# })
# Detail
# .as_view({'get': 'retrieve',
# 'put': 'update',
# 'patch': 'partial_update',
# 'delete': 'destroy',
# })
# 因此我们只需分别实现list、create、retrieve、update、partial_update、destroy方法,即实现了CRUD效果
# mixins: https://www.django-rest-framework.org/api-guide/generic-views/#mixins
class StoryV3ViewSet(GenericViewSet):
queryset = Story.objects.all()
serializer_class = StorySerializer
# CreateModelMixin的源码
# =============================================================
# 对应post请求
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
def perform_create(self, serializer):
serializer.save()
def get_success_headers(self, data):
try:
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
return {}
# UpdateModelMixin的源码
# ========================================================================
# 对应put请求
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
# 1. 调用get_object_or_404函数,若实例不存在,抛出404异常
instance = self.get_object()
# 2. 反序列序列化数据, 验证数据
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
# 3. 保存到db
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
return Response(serializer.data)
def perform_update(self, serializer):
serializer.save()
# 对应patch请求
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
# DestroyModelMixin的源码
# ========================================================================
# 对应delete请求
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)
return Response(status=status.HTTP_204_NO_CONTENT)
def perform_destroy(self, instance):
instance.delete()
# RetrieveModelMixin的源码
# ===========================================
# 对应Detail视图的get请求
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
# ListModelMixin的源码
# ===========================================
# 对应List视图的get请求
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(queryset)
if page is not None:
serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
from flask import jsonify, request, make_response
from flask.views import MethodView
from marshmallow import ValidationError
from app.models import Author
from app.exts import db
from app.schemas import AuthorSchema
class AuthorView(MethodView):
init_every_request = False
def __init__(self):
self.schema = AuthorSchema(session=db.session)
def post(self):
json_data = request.get_json()
try:
load_data = self.schema.load(json_data)
except ValidationError as err:
# 数据验证失败
return err.messages, 422
db.session.add(load_data)
db.session.commit()
return make_response(jsonify(msg='success'), 201)
class AuthorDetailView(MethodView):
def __init__(self):
self.schema = AuthorSchema(session=db.session)
def get(self, id):
author = db.one_or_404(db.select(Author).filter_by(id=id))
return jsonify(AuthorSchema().dump(author))
def delete(self, id):
author = db.one_or_404(db.select(Author).filter_by(id=id))
db.session.delete(author)
db.session.commit()
return jsonify(msg='success')
package logic
import (
"context"
"errors"
"project/model"
"project/app/internal/svc"
"project/app/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type CreateStoryLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewCreateStoryLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CreateStoryLogic {
return &CreateStoryLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *CreateStoryLogic) CreateStory(req *types.CreateOrUpdateStoryReq) (resp *types.CreateOrUpdateStoryResp, err error) {
table := l.svcCtx.Model.Story
author := l.svcCtx.Model.Author
// 查找关联用户
// ==========================================
a, err := author.WithContext(l.ctx).Where(author.ID.Eq(int32(req.UserId))).First()
if err != nil {
return nil, errors.New("用户不存在: " + err.Error())
}
story := model.Story{
Title: req.Title,
Content: req.Content,
AuthorID: a.ID,
}
// 写入数据
// ===========================================
err = table.WithContext(l.ctx).Create(&story)
if err != nil {
return nil, errors.New("创建故事失败: " + err.Error())
}
// 返回数据
// ==========================================
resp = &types.CreateOrUpdateStoryResp{
Id: int64(story.ID),
Title: story.Title,
UserId: int64(a.ID),
UserName: a.FirstName + a.LastName,
}
return
}
列表查询
简单模型查询
app/api/routes/users.py
from sqlalchemy import func, select
@router.get(
"/",
dependencies=[Depends(get_current_active_superuser)],
response_model=UsersPublic,
)
def read_users(session: SessionDep, skip: int = 0, limit: int = 100) -> Any:
"""Retrieve users."""
count_statement = select(func.count()).select_from(User)
count = session.execute(count_statement).scalar_one()
statement = select(User).offset(skip).limit(limit)
users = session.execute(statement).scalars().all()
return UsersPublic(data=users, count=count)
指定返回字段,在多字段模型下能提高查询性能
app/api/routes/items.py
from sqlalchemy import func, select
from sqlalchemy.orm import load_only
@router.get("/", response_model=ItemsPublic)
def read_items(
session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100
) -> Any:
"""获取项目列表"""
if current_user.is_superuser:
count_statement = select(func.count()).select_from(Item)
# 手工指定字段数据返回
statement = select(
Item.id,
Item.owner_id,
Item.title,
Item.description
).offset(skip).limit(limit)
items: list[dict] = session.execute(statement).mappings().all()
else:
count_statement = (
select(func.count())
.select_from(Item)
.where(Item.owner_id == current_user.id)
)
# 方式二: https://docs.sqlalchemy.org/en/20/orm/queryguide/columns.html#using-load-only-to-reduce-loaded-columns
statement = (
select(Item)
.options(load_only(
Item.id,
Item.owner_id,
Item.title,
Item.description
))
.where(Item.owner_id == current_user.id)
.offset(skip)
.limit(limit)
)
items: list[Item] = session.execute(statement).scalars().all()
count = session.execute(count_statement).scalar_one()
return ItemsPublic(data=items, count=count)
app/api/internal/logic/liststorieslogic.go
func (l *ListStoriesLogic) ListStories(req *types.ListStoriesReq) (resp *types.ListStoriesResp, err error) {
resp = new(types.ListStoriesResp)
Story := l.svcCtx.Model.Story
var conds []gen.Condition
if req.KeyWord != "" {
keyword := "%" + req.KeyWord + "%"
conds = append(conds, Story.Title.Like(keyword))
}
count, err := Story.WithContext(l.ctx).Where(conds...).Count()
// 返回数据
if err != nil {
return nil, errors.New("查询失败: " + err.Error())
}
result, err := Story.WithContext(l.ctx).Preload(Story.Author).Where(conds...).Offset(req.Skip).Limit(req.Limit).Find()
if err != nil {
return nil, errors.New("获取故事列表失败: " + err.Error())
}
list := make([]types.RetrieveStoryResp, len(result))
for i, story := range result {
list[i] = types.RetrieveStoryResp{
Id: story.ID,
Title: story.Title,
UserId: int64(story.AuthorID),
UserName: story.Author.FirstName + story.Author.LastName,
}
}
resp.Data = list
resp.Total = int(count)
return
}
路由
project/router.py
from rest_framework import routers
from application.views import StoryViewSet, StoryV2ViewSet, StoryV3ViewSet
# https://www.django-rest-framework.org/api-guide/routers/#usage
router = routers.SimpleRouter()
# 自动生成url名字可在reverse函数中使用
# * story-detail: get(单个)、patch、put、delete请求
# * story-list: get(列表)、post请求
router.register('v1/stories', StoryViewSet)
router.register('v2/stories', StoryV2ViewSet, basename='storyV2')
router.register('v3/stories', StoryV3ViewSet, basename='storyV3')
from flask import Blueprint
from app.views import AuthorView, AuthorDetailView
blue = Blueprint('public', __name__)
blue.add_url_rule('/author/', view_func=AuthorView.as_view('author'), endpoint='author')
blue.add_url_rule('/author/<int:id>', view_func=AuthorDetailView.as_view('author-detail'), endpoint='author-detail')
不用管, goctl会自动处理
测试
application/tests.py
from django.urls import reverse
from rest_framework.test import APITestCase
from application.models import Author
class StoryTests(APITestCase):
@classmethod
def setUpTestData(cls) -> None:
cls.author = Author(id=1, name='Beatles', age=18)
cls.author.save()
def test_create_story(self):
"""
Ensure we can create a new story.
"""
for name, view_name in [
('全自动视图', 'story-list'),
('半自动视图', 'storyV2-list'),
('全手动视图', 'storyV3-list')
]:
url = reverse(view_name)
with self.subTest(name=name):
data = {'title': 'Test Story',
'content': 'This is a test story.',
'author': self.author.pk}
response = self.client.post(url, data, format='json')
self.assertEqual(response.status_code, 201)
self.assertEqual(response.data['title'], 'Test Story')
self.assertEqual(response.data['author'], self.author.pk)
import pytest
from flask import url_for
from app.models import Author
from app.exts import db
@pytest.fixture
def author1():
author = Author(id=1, first_name="蝉时雨", last_name="paradox")
db.session.add(author)
db.session.commit()
yield
db.session.delete(author)
db.session.commit()
class TestAuthorView:
def setup_class(self):
self.url = url_for(endpoint='public.author')
def setup_method(self):
db.create_all()
def teardown_method(self):
db.session.commit()
db.drop_all()
def test_post(self, client):
data = {
"first_name": "蝉时雨",
"last_name": "paradox",
}
res = client.post(self.url, json=data)
assert res.status_code == 201
assert res.json['msg'] == 'success'
class TestAuthorDetailView:
def setup_method(self):
db.create_all()
def teardown_method(self):
db.session.commit()
db.drop_all()
def test_get(self, author1, client):
url = url_for(endpoint='public.author-detail', id=1)
res = client.get(url)
assert res.json == {'first_name': '蝉时雨', 'id': 1, 'last_name': 'paradox', 'stories': []}
def test_delete(self, author1, client):
url = url_for(endpoint='public.author-detail', id=1)
res = client.delete(url)
assert res.json == {'msg': 'success'}
if __name__ == '__main__':
pytest.main()
import pytest
from flask import Flask
from testcontainers.mysql import MySqlContainer
@pytest.fixture(scope="session")
def app(mysql):
"""连接mysql容器"""
from app.urls import blue
from app.exts import init_exts
app = Flask(__name__)
app.register_blueprint(blueprint=blue)
config_object = {
"TESTING": True,
'SERVER_NAME': 'flask.app.unittest.com',
# Flask-SQLAlchemy配置 =====================
"SQLALCHEMY_DATABASE_URI": mysql.get_connection_url(),
'SQLALCHEMY_ECHO': False, # 显示SQL语句和占位符参数
}
app.config.update(config_object)
init_exts(app)
with app.app_context() as app_context:
app_context.push()
yield app
@pytest.fixture(scope='session')
def client(app):
return app.test_client()
@pytest.fixture(scope="session")
def mysql():
with MySqlContainer('mysql:8.0.35') as mysql:
yield mysql
app/internal/logic/liststorieslogic_test.go
package logic
import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
"go-zero-api/app/dal"
"go-zero-api/app/internal/svc"
"go-zero-api/app/internal/types"
"go-zero-api/app/testhelpers"
"go-zero-api/model"
"testing"
"time"
)
type ListStoriesSuite struct {
testhelpers.DbSuite
l *ListStoriesLogic
author *model.Author
}
func (s *ListStoriesSuite) SetupSuite() {
s.Init()
}
func (s *ListStoriesSuite) SetupTest() {
s.l = NewListStoriesLogic(s.Ctx, &svc.ServiceContext{
Model: dal.Use(s.DB),
})
err := s.DB.AutoMigrate(&model.Author{}, &model.Story{})
if err != nil {
s.FailNow(err.Error())
}
birthDate := time.Now()
s.author = &model.Author{
ID: 1,
FirstName: "蝉",
LastName: "时雨",
BirthDate: &birthDate,
}
table := s.l.svcCtx.Model.Author
err = table.WithContext(s.l.ctx).Create(s.author)
if err != nil {
s.FailNow(err.Error())
}
values := []*model.Story{
{
Title: "Test Story",
Content: "This a test story.",
AuthorID: s.author.ID,
},
{
Title: "测试故事",
Content: "这是一个测试的故事。",
AuthorID: s.author.ID,
},
}
Story := s.l.svcCtx.Model.Story
err = Story.WithContext(s.l.ctx).Create(values...)
if err != nil {
s.FailNow(err.Error())
}
}
func (s *ListStoriesSuite) TearDownTest() {
err := s.DB.Migrator().DropTable(&model.Author{}, &model.Story{})
if err != nil {
s.FailNow(err.Error())
}
}
func (s *ListStoriesSuite) TestListStories() {
ast := assert.New(s.T())
cases := []struct {
Name string
KeyWord string
f func(ast *assert.Assertions, resp *types.ListStoriesResp)
}{
{
Name: "默认参数",
KeyWord: "",
f: func(ast *assert.Assertions, resp *types.ListStoriesResp) {
ast.Equal(2, resp.Total)
},
},
{
Name: "关键字搜索",
KeyWord: "故事",
f: func(ast *assert.Assertions, resp *types.ListStoriesResp) {
ast.Equal(1, resp.Total)
ast.Equal("测试故事", resp.Data[0].Title)
},
},
}
for _, c := range cases {
s.Run(c.Name, func() {
resp, err := s.l.ListStories(&types.ListStoriesReq{
KeyWord: c.KeyWord,
Skip: 0,
Limit: 100,
})
ast.Nil(err)
if c.f != nil {
c.f(ast, resp)
}
})
}
}
func TestListStoriesSuite(t *testing.T) {
suite.Run(t, new(ListStoriesSuite))
}
app/tests/api/routes/test_users.py
from fastapi.testclient import TestClient
def test_retrieve_users(
client: TestClient, superuser_token_headers: dict[str, str], db: Session
) -> None:
username = random_email()
password = random_lower_string()
user_in = UserCreate(email=username, password=password)
crud.create_user(session=db, user_create=user_in)
username2 = random_email()
password2 = random_lower_string()
user_in2 = UserCreate(email=username2, password=password2)
crud.create_user(session=db, user_create=user_in2)
r = client.get(f"{settings.API_V1_STR}/users/", headers=superuser_token_headers)
all_users = r.json()
assert len(all_users["data"]) > 1
assert "count" in all_users
for item in all_users["data"]:
assert "email" in item