rustjr-account-management/tests/unit/concurrency_tests.rs
tangweijie 137126c335 feat: 添加安全配置、API文档和错误码体系
- 添加JWT/加密/速率限制安全配置
- 为所有API添加OpenAPI文档注解
- 建立统一的6位错误码体系
- 实现账务原子更新(乐观锁重试机制)
- 添加Swagger UI和请求ID中间件

Ref: #安全配置 #API文档 #错误处理
2026-01-06 10:28:35 +08:00

514 lines
17 KiB
Rust
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 并发控制单元测试
//!
//! 测试场景:
//! - 乐观锁版本冲突检测
//! - 原子更新重试机制
//! - 并发扣款安全性
//!
//! 测试策略:
//! 1. 使用内存模拟存储验证并发控制逻辑
//! 2. 模拟并发场景验证重试机制
use std::sync::Arc;
use tokio::sync::Mutex;
use rust_decimal::Decimal;
use rust_decimal_macros::dec;
use rustjr::domain::ledger::entity::{AccountBalance, DeductionResult};
use rustjr::domain::ledger::repository::{AccountBalanceRepository};
use rustjr::domain::account::AccountType;
use rustjr::error::AppError;
/// 内存模拟余额仓储(用于测试并发控制)
struct MockBalanceRepository {
/// 存储账户余额
store: Arc<Mutex<Vec<AccountBalance>>>,
}
impl MockBalanceRepository {
fn new() -> Self {
Self {
store: Arc::new(Mutex::new(Vec::new())),
}
}
/// 初始化测试数据
async fn init_test_data(&self, personal: Decimal, labor: Decimal) {
let mut store = self.store.lock().await;
store.clear();
store.push(AccountBalance {
id: 1,
account_id: 1001,
account_type: AccountType::Virtual,
personal_balance: personal,
labor_balance: labor,
frozen_balance: Decimal::ZERO,
bank_balance: personal + labor,
transit_amount: Decimal::ZERO,
system_balance: personal + labor,
available_balance: personal + labor,
frozen_amount: Decimal::ZERO,
version: 1,
updated_at: chrono::Utc::now(),
});
}
}
#[async_trait::async_trait]
impl AccountBalanceRepository for MockBalanceRepository {
async fn find_by_account(
&self,
account_id: i64,
account_type: AccountType,
) -> Result<Option<AccountBalance>, rustjr::error::AppError> {
let store = self.store.lock().await;
Ok(store.iter()
.find(|b| b.account_id == account_id && b.account_type == account_type)
.cloned())
}
async fn get_or_create(
&self,
account_id: i64,
account_type: AccountType,
) -> Result<AccountBalance, rustjr::error::AppError> {
if let Some(balance) = self.find_by_account(account_id, account_type).await? {
return Ok(balance);
}
Ok(AccountBalance::new(account_id, account_type))
}
async fn find_for_update(
&self,
account_id: i64,
account_type: AccountType,
) -> Result<Option<AccountBalance>, rustjr::error::AppError> {
// 模拟行级锁:直接返回当前值
self.find_by_account(account_id, account_type).await
}
async fn get_or_create_for_update(
&self,
account_id: i64,
account_type: AccountType,
) -> Result<AccountBalance, rustjr::error::AppError> {
self.get_or_create(account_id, account_type).await
}
async fn update(&self, balance: &AccountBalance) -> Result<(), rustjr::error::AppError> {
let mut store = self.store.lock().await;
if let Some(existing) = store.iter_mut().find(|b| b.id == balance.id) {
if existing.version != balance.version {
return Err(AppError::ConcurrentModification);
}
let mut new_balance = balance.clone();
new_balance.version += 1;
new_balance.updated_at = chrono::Utc::now();
*existing = new_balance;
} else {
let mut new_balance = balance.clone();
new_balance.version = 1;
new_balance.updated_at = chrono::Utc::now();
store.push(new_balance);
}
Ok(())
}
async fn atomic_update(
&self,
account_id: i64,
account_type: AccountType,
expected_version: i32,
update_fn: impl FnOnce(&mut AccountBalance) -> Result<(), rustjr::error::AppError>,
) -> Result<(), rustjr::error::AppError> {
let mut store = self.store.lock().await;
if let Some(existing) = store.iter_mut()
.find(|b| b.account_id == account_id && b.account_type == account_type)
{
if existing.version != expected_version {
return Err(AppError::ConcurrentModification);
}
update_fn(existing)?;
existing.version += 1;
existing.updated_at = chrono::Utc::now();
} else {
let mut new_balance = AccountBalance::new(account_id, account_type);
update_fn(&mut new_balance)?;
new_balance.version = 1;
new_balance.updated_at = chrono::Utc::now();
store.push(new_balance);
}
Ok(())
}
async fn batch_update(
&self,
balances: &[AccountBalance],
) -> Result<(), rustjr::error::AppError> {
for balance in balances {
self.update(balance).await?;
}
Ok(())
}
async fn freeze(
&self,
account_id: i64,
account_type: AccountType,
amount: Decimal,
) -> Result<(), rustjr::error::AppError> {
self.atomic_update(
account_id,
account_type,
0,
|balance| {
let available = balance.total_available();
if available < amount {
return Err(AppError::InsufficientBalance {
available,
required: amount,
});
}
let mut remaining = amount;
let from_personal = remaining.min(balance.personal_balance);
balance.personal_balance -= from_personal;
remaining -= from_personal;
if remaining > Decimal::ZERO {
let from_labor = remaining.min(balance.labor_balance);
balance.labor_balance -= from_labor;
}
balance.frozen_balance += amount;
balance.sync_legacy_fields();
Ok(())
},
).await
}
async fn unfreeze(
&self,
account_id: i64,
account_type: AccountType,
amount: Decimal,
) -> Result<(), rustjr::error::AppError> {
self.atomic_update(
account_id,
account_type,
0,
|balance| {
let unfreeze_amount = amount.min(balance.frozen_balance);
balance.frozen_balance -= unfreeze_amount;
balance.personal_balance += unfreeze_amount;
balance.sync_legacy_fields();
Ok(())
},
).await
}
}
// ==================== 乐观锁版本测试 ====================
#[tokio::test]
async fn test_optimistic_lock_version_mismatch() {
let repo = MockBalanceRepository::new();
repo.init_test_data(dec!(1000.00), dec!(500.00)).await;
// 模拟版本冲突
let result = repo.atomic_update(
1001,
AccountType::Virtual,
999, // 错误的版本号
|balance| {
balance.personal_balance -= dec!(100.00);
Ok(())
},
).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AppError::ConcurrentModification));
}
#[tokio::test]
async fn test_optimistic_lock_success() {
let repo = MockBalanceRepository::new();
repo.init_test_data(dec!(1000.00), dec!(500.00)).await;
// 正确的版本号
let result = repo.atomic_update(
1001,
AccountType::Virtual,
1, // 正确的版本号
|balance| {
balance.personal_balance -= dec!(100.00);
balance.bank_balance -= dec!(100.00);
balance.sync_legacy_fields();
Ok(())
},
).await;
assert!(result.is_ok());
// 验证余额已更新
let balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
assert_eq!(balance.personal_balance, dec!(900.00));
assert_eq!(balance.version, 2); // 版本号递增
}
// ==================== 并发扣款测试 ====================
#[tokio::test]
async fn test_concurrent_deduction_safety() {
let repo = Arc::new(MockBalanceRepository::new());
repo.init_test_data(dec!(1000.00), dec!(500.00)).await;
let mut handles = Vec::new();
// 并发发起 10 次扣款,每次扣 100
for _ in 0..10 {
let repo = repo.clone();
handles.push(tokio::spawn(async move {
loop {
// 获取当前余额
let balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
let available = balance.total_available();
if available < dec!(100.00) {
// 余额不足,跳过
break Ok::<_, AppError>(());
}
// 尝试原子更新
let result = repo.atomic_update(
1001,
AccountType::Virtual,
balance.version,
|b| {
if b.total_available() < dec!(100.00) {
return Err(AppError::InsufficientBalance {
available: b.total_available(),
required: dec!(100.00),
});
}
let mut remaining = dec!(100.00);
let from_personal = remaining.min(b.personal_balance);
b.personal_balance -= from_personal;
remaining -= from_personal;
if remaining > Decimal::ZERO {
let from_labor = remaining.min(b.labor_balance);
b.labor_balance -= from_labor;
}
b.bank_balance -= dec!(100.00);
b.sync_legacy_fields();
Ok(())
},
).await;
match result {
Ok(()) => break Ok(()),
Err(AppError::ConcurrentModification) => {
// 版本冲突,重试
continue;
}
Err(e) => break Err(e),
}
}
));
}
// 等待所有任务完成
let results: Vec<Result<(), AppError>> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
// 统计成功次数
let success_count = results.iter().filter(|r| r.is_ok()).count();
// 验证最终余额
let final_balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
// 应该有 15 次成功扣款(初始余额 1500每次扣 100
assert_eq!(success_count, 15);
assert_eq!(final_balance.personal_balance, Decimal::ZERO);
assert_eq!(final_balance.labor_balance, Decimal::ZERO);
assert_eq!(final_balance.bank_balance, Decimal::ZERO);
}
// ==================== 重试机制测试 ====================
#[tokio::test]
async fn test_retry_on_version_conflict() {
let repo = Arc::new(MockBalanceRepository::new());
repo.init_test_data(dec!(10000.00), dec!(0.00)).await;
let mut attempt_count = 0;
let retry_count = Arc::new(Mutex::new(0));
let mut handles = Vec::new();
// 并发发起 5 次更新,每次增加版本冲突的可能性
for i in 0..5 {
let repo = repo.clone();
let retry_count = retry_count.clone();
handles.push(tokio::spawn(async move {
for retry in 0..10 {
let balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
// 每次更新增加版本冲突的可能性
let result = repo.atomic_update(
1001,
AccountType::Virtual,
balance.version,
|b| {
b.personal_balance += dec!(1.00);
b.bank_balance += dec!(1.00);
b.sync_legacy_fields();
Ok(())
},
).await;
match result {
Ok(()) => {
attempt_count += 1;
break;
}
Err(AppError::ConcurrentModification) => {
if retry == 9 {
*retry_count.lock().await += 1;
}
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
Ok::<_, AppError>(())
}));
}
// 等待所有任务完成
futures::future::join_all(handles).await;
// 所有更新都应该成功(通过重试)
let final_balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
assert_eq!(final_balance.personal_balance, dec!(10005.00));
}
// ==================== 冻结解冻并发测试 ====================
#[tokio::test]
async fn test_concurrent_freeze_safety() {
let repo = Arc::new(MockBalanceRepository::new());
repo.init_test_data(dec!(1000.00), dec!(1000.00)).await; // 总共 2000
let mut handles = Vec::new();
// 并发发起 10 次冻结,每次冻结 100
for _ in 0..10 {
let repo = repo.clone();
handles.push(tokio::spawn(async move {
loop {
let balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
let available = balance.total_available();
if available < dec!(100.00) {
break Ok::<_, AppError>(());
}
let result = repo.freeze(1001, AccountType::Virtual, dec!(100.00)).await;
match result {
Ok(()) => break Ok(()),
Err(AppError::ConcurrentModification) => {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
Err(e) => break Err(e),
}
}
}));
}
// 等待所有任务完成
let results: Vec<Result<(), AppError>> = futures::future::join_all(handles)
.await
.into_iter()
.map(|r| r.unwrap())
.collect();
let success_count = results.iter().filter(|r| r.is_ok()).count();
// 验证最终状态
let final_balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
// 应该有 20 次成功冻结(初始余额 2000每次冻结 100
assert_eq!(success_count, 20);
assert_eq!(final_balance.personal_balance + final_balance.labor_balance, Decimal::ZERO);
assert_eq!(final_balance.frozen_balance, dec!(2000.00));
}
// ==================== 积分系统集成测试 ====================
#[tokio::test]
async fn test_concurrent_balance_consistency() {
// 测试在极端并发场景下余额的一致性
let repo = Arc::new(MockBalanceRepository::new());
repo.init_test_data(dec!(1000000.00), dec!(0.00)).await; // 初始 100 万
let num_tasks = 100;
let amount_per_task = dec!(10000.00); // 每个任务扣 1 万
let mut handles = Vec::new();
for _ in 0..num_tasks {
let repo = repo.clone();
handles.push(tokio::spawn(async move {
loop {
let balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
let result = repo.atomic_update(
1001,
AccountType::Virtual,
balance.version,
|b| {
if b.personal_balance < amount_per_task {
return Err(AppError::InsufficientBalance {
available: b.personal_balance,
required: amount_per_task,
});
}
b.personal_balance -= amount_per_task;
b.bank_balance -= amount_per_task;
b.sync_legacy_fields();
Ok(())
},
).await;
match result {
Ok(()) => break,
Err(AppError::ConcurrentModification) => {
tokio::time::sleep(std::time::Duration::from_micros(100)).await;
}
Err(AppError::InsufficientBalance { .. }) => break,
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
}));
}
// 等待所有任务完成
futures::future::join_all(handles).await;
// 验证最终余额
let final_balance = repo.find_by_account(1001, AccountType::Virtual).await.unwrap().unwrap();
// 验证不变量
assert!(final_balance.validate_invariant().is_ok());
// 验证余额计算(应该有 100 万 - 100万 = 0
let total_used = dec!(1000000.00) - final_balance.personal_balance;
assert_eq!(final_balance.personal_balance, Decimal::ZERO);
assert_eq!(final_balance.bank_balance, Decimal::ZERO);
}