# 用Mybatis手寫一個分表插件
## 前言
在大型互聯網應用中,單表數據量超過千萬級別時,查詢性能會顯著下降。這時我們通常會采用分表策略將數據分散到多個表中。Mybatis作為Java領域最流行的ORM框架,其插件機制可以優雅地實現分表邏輯。本文將詳細講解如何從零開始手寫一個Mybatis分表插件。
## 一、分表技術概述
### 1.1 什么是分表
分表(Sharding)是指按照某種規則(如用戶ID、時間等)將一個大表的數據分散存儲到多個結構相同的小表中。這些小表可以位于同一個數據庫,也可以分布在不同的數據庫服務器上。
### 1.2 常見分表策略
1. **水平分表**:按行拆分,每個表存儲部分行數據
2. **垂直分表**:按列拆分,每個表存儲部分列數據
3. **哈希分表**:通過對分片鍵取模確定表名
4. **范圍分表**:按時間范圍或ID范圍分表
5. **目錄分表**:維護分片鍵與表的映射關系
### 1.3 Mybatis插件機制
Mybatis提供了強大的插件機制,允許我們在以下四個核心對象的方法執行前后進行攔截:
- Executor (執行器)
- StatementHandler (語句處理器)
- ParameterHandler (參數處理器)
- ResultSetHandler (結果集處理器)
## 二、插件設計思路
### 2.1 總體架構設計
┌──────────────────────────────────────────────────┐ │ Mybatis Sharding Plugin │ ├──────────────────────────────────────────────────┤ │ - 分表策略接口(ShardingStrategy) │ │ - 分表注解(@Sharding) │ │ - SQL重寫器(SqlRewriter) │ │ - 分表上下文(ShardingContext) │ └──────────────────────────────────────────────────┘
### 2.2 核心功能點
1. **表名替換**:根據分片鍵動態替換SQL中的表名
2. **參數解析**:從參數對象中提取分片鍵值
3. **結果歸并**:對跨表查詢的結果進行合并
4. **事務支持**:確保分表操作的事務一致性
### 2.3 技術難點
- SQL語法解析與重寫
- 分片鍵值提取策略
- 批量操作的分表處理
- 分布式事務協調
## 三、詳細實現步驟
### 3.1 創建Maven項目
```xml
<dependencies>
<dependency>
<groupId>org.mybatis</groupId>
<artifactId>mybatis</artifactId>
<version>3.5.6</version>
</dependency>
<!-- 其他依賴... -->
</dependencies>
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.METHOD})
public @interface Sharding {
// 邏輯表名
String logicTable();
// 分片字段名
String shardingKey();
// 分表數量
int tableNum() default 2;
// 分表策略
Class<? extends ShardingStrategy> strategy();
}
public interface ShardingStrategy {
/**
* 計算實際表名
* @param logicTable 邏輯表名
* @param shardingValue 分片鍵值
* @param tableNum 分表數量
* @return 實際物理表名
*/
String getActualTableName(String logicTable, Object shardingValue, int tableNum);
}
public class HashShardingStrategy implements ShardingStrategy {
@Override
public String getActualTableName(String logicTable, Object shardingValue, int tableNum) {
int hash = shardingValue.hashCode();
int index = Math.abs(hash % tableNum);
return logicTable + "_" + index;
}
}
public class RangeShardingStrategy implements ShardingStrategy {
@Override
public String getActualTableName(String logicTable, Object shardingValue, int tableNum) {
if (!(shardingValue instanceof Comparable)) {
throw new IllegalArgumentException("Range strategy requires comparable value");
}
Comparable<?> value = (Comparable<?>) shardingValue;
// 實現具體范圍計算邏輯...
return logicTable + "_" + calculatedIndex;
}
}
@Intercepts({
@Signature(type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class})
})
public class ShardingPlugin implements Interceptor {
private static final Pattern TABLE_PATTERN = Pattern.compile("(\\w+)");
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler handler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = SystemMetaObject.forObject(handler);
// 獲取Mapper接口和方法信息
MappedStatement mappedStatement = (MappedStatement)
metaObject.getValue("delegate.mappedStatement");
String mapperId = mappedStatement.getId();
String className = mapperId.substring(0, mapperId.lastIndexOf("."));
String methodName = mapperId.substring(mapperId.lastIndexOf(".") + 1);
// 檢查分表注解
Class<?> clazz = Class.forName(className);
Method method = findMethod(clazz, methodName);
Sharding sharding = method.getAnnotation(Sharding.class);
if (sharding == null) {
return invocation.proceed();
}
// 獲取原始SQL
BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
String originalSql = boundSql.getSql();
// 解析分片鍵值
Object parameterObject = boundSql.getParameterObject();
Object shardingValue = resolveShardingValue(parameterObject, sharding.shardingKey());
// 替換表名
String newSql = rewriteSql(originalSql, sharding, shardingValue);
metaObject.setValue("delegate.boundSql.sql", newSql);
return invocation.proceed();
}
// 其他輔助方法...
}
public class SqlRewriter {
public static String rewriteTableName(String sql, String logicTable, String actualTable) {
// 使用正則表達式精確匹配表名
String regex = "(?i)\\b" + logicTable + "\\b";
return sql.replaceAll(regex, actualTable);
}
public static String rewriteInsertSql(String sql, String logicTable, String actualTable) {
// 處理INSERT語句的特殊情況
return rewriteTableName(sql, logicTable, actualTable);
}
// 其他SQL重寫方法...
}
public class ShardingValueResolver {
public static Object resolveShardingValue(Object parameterObject, String shardingKey) {
if (parameterObject == null) {
return null;
}
if (parameterObject instanceof Map) {
return ((Map<?, ?>) parameterObject).get(shardingKey);
}
try {
// 使用反射獲取字段值
Field field = parameterObject.getClass().getDeclaredField(shardingKey);
field.setAccessible(true);
return field.get(parameterObject);
} catch (Exception e) {
throw new RuntimeException("Failed to resolve sharding value", e);
}
}
}
// 在ShardingPlugin中添加批量處理邏輯
private String handleBatchSql(String originalSql, Sharding sharding, Object parameterObject) {
if (!(parameterObject instanceof Collection)) {
return originalSql;
}
Collection<?> collection = (Collection<?>) parameterObject;
if (collection.isEmpty()) {
return originalSql;
}
// 獲取第一個元素的分表名
Object firstItem = collection.iterator().next();
Object shardingValue = resolveShardingValue(firstItem, sharding.shardingKey());
String actualTable = sharding.strategy().newInstance()
.getActualTableName(sharding.logicTable(), shardingValue, sharding.tableNum());
// 驗證所有元素是否屬于同一分表
for (Object item : collection) {
Object currentValue = resolveShardingValue(item, sharding.shardingKey());
String currentTable = sharding.strategy().newInstance()
.getActualTableName(sharding.logicTable(), currentValue, sharding.tableNum());
if (!currentTable.equals(actualTable)) {
throw new IllegalArgumentException("Batch operation must be in same sharding table");
}
}
return SqlRewriter.rewriteTableName(originalSql, sharding.logicTable(), actualTable);
}
@Intercepts({
@Signature(type = ResultSetHandler.class,
method = "handleResultSets",
args = {Statement.class})
})
public class ShardingResultMergePlugin implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
// 獲取原始結果
List<Object> results = (List<Object>) invocation.proceed();
// 如果啟用了分表查詢且是多表查詢
if (isShardingQuery() && isMultiTableQuery()) {
return mergeResults(results);
}
return results;
}
private List<Object> mergeResults(List<Object> results) {
// 實現結果合并邏輯
// ...
}
}
public class ShardingTransactionManager {
private ThreadLocal<Map<String, Connection>> connectionHolder = new ThreadLocal<>();
public void beginTransaction() {
// 獲取所有分片數據源的連接
Map<String, Connection> connections = new HashMap<>();
for (String dsName : shardingDataSources.keySet()) {
Connection conn = dataSource.getConnection();
conn.setAutoCommit(false);
connections.put(dsName, conn);
}
connectionHolder.set(connections);
}
public void commit() {
try {
for (Connection conn : connectionHolder.get().values()) {
conn.commit();
}
} catch (SQLException e) {
rollback();
throw new RuntimeException(e);
} finally {
closeConnections();
}
}
// 其他事務方法...
}
@Sharding(
logicTable = "t_order",
shardingKey = "orderId",
tableNum = 4,
strategy = HashShardingStrategy.class
)
public class Order {
private Long orderId;
private String userId;
private BigDecimal amount;
// getters/setters...
}
public interface OrderMapper {
@Insert("INSERT INTO t_order(order_id, user_id, amount) VALUES(#{orderId}, #{userId}, #{amount})")
int insert(Order order);
@Select("SELECT * FROM t_order WHERE order_id = #{orderId}")
Order selectById(@Param("orderId") Long orderId);
@Sharding(
logicTable = "t_order",
shardingKey = "userId",
tableNum = 4,
strategy = HashShardingStrategy.class
)
@Select("SELECT * FROM t_order WHERE user_id = #{userId}")
List<Order> selectByUserId(@Param("userId") String userId);
}
@Configuration
public class MybatisConfig {
@Bean
public ShardingPlugin shardingPlugin() {
return new ShardingPlugin();
}
@Bean
public SqlSessionFactory sqlSessionFactory(DataSource dataSource) throws Exception {
SqlSessionFactoryBean factoryBean = new SqlSessionFactoryBean();
factoryBean.setDataSource(dataSource);
factoryBean.setPlugins(new Interceptor[]{shardingPlugin()});
return factoryBean.getObject();
}
}
public class ShardingPluginTest {
@Test
public void testInsertSharding() {
Order order = new Order();
order.setOrderId(12345L);
order.setUserId("user1");
order.setAmount(new BigDecimal("100.00"));
orderMapper.insert(order);
// 驗證數據是否插入到正確的分表
Order result = orderMapper.selectById(12345L);
assertNotNull(result);
assertEquals("user1", result.getUserId());
}
@Test
public void testBatchInsert() {
List<Order> orders = new ArrayList<>();
for (int i = 0; i < 10; i++) {
Order order = new Order();
order.setOrderId(1000L + i);
order.setUserId("user" + (i % 2));
orders.add(order);
}
// 應該拋出異常,因為批量操作不能跨分表
assertThrows(IllegalArgumentException.class, () -> {
orderMapper.batchInsert(orders);
});
}
}
public class PerformanceTest {
@Test
public void testShardingPerformance() {
// 準備10萬條測試數據
List<Order> testData = prepareTestData(100000);
// 測試插入性能
long start = System.currentTimeMillis();
for (Order order : testData) {
orderMapper.insert(order);
}
long duration = System.currentTimeMillis() - start;
System.out.println("Insert 100000 records took: " + duration + "ms");
// 測試查詢性能
start = System.currentTimeMillis();
for (int i = 0; i < 1000; i++) {
orderMapper.selectById(testData.get(i).getOrderId());
}
duration = System.currentTimeMillis() - start;
System.out.println("Query 1000 records took: " + duration + "ms");
}
}
特性 | 自定義插件 | Sharding-JDBC |
---|---|---|
學習成本 | 高 | 低 |
靈活性 | 極高 | 高 |
功能完整性 | 需自行實現 | 完善 |
性能 | 取決于實現 | 優化良好 |
維護成本 | 高 | 低 |
適合自定義插件的情況: - 有特殊的分片需求 - 需要深度定制 - 希望減少第三方依賴
適合Sharding-JDBC的情況: - 快速實現標準分片功能 - 需要完善的事務支持 - 團隊技術儲備有限
通過本文的詳細講解,我們實現了一個功能完整的Mybatis分表插件。從基礎的分表策略到高級的批量操作支持,從核心的SQL重寫到性能優化技巧,涵蓋了分表插件開發的各個方面。希望這篇文章能幫助讀者深入理解Mybatis插件機制和分表技術,在實際項目中能夠靈活應用這些知識。
最佳實踐建議: 1. 在簡單場景下優先考慮成熟的分庫分表框架 2. 復雜定制場景可以考慮自研插件 3. 做好充分的測試驗證 4. 建立完善的監控體系
###
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。