动态切换数据库源码解析

动态切库可用于SaaS环境,多租户环境

所以浏览器的每次请求都有可能是不同租户,需要动态切换数据库来支持业务场景。
又所以每次请求都需要识别是哪个租户,这里我们用到了ThreadLocal,以此来保存线程的本地变量,携带上租户的一些信息。而租户的信息可以从Session或Token中获取,或者是url路径参数

准备好以上条件信息 我们开始秀吧

创建 DruidDynamicDataSource 继承自 AbstractRoutingDataSource

org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource

重写determineCurrentLookupKey方法即可 往下看具体操作

此类可以在执行数据库操作之前确定需要切换到哪个库

import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.springframework.jdbc.datasource.lookup.AbstractRoutingDataSource;

import javax.sql.DataSource;
import java.util.HashMap;
import java.util.Map;

/**
 * @author dadiwm321
 */
public class DruidDynamicDataSource extends AbstractRoutingDataSource {

    private Logger logger = LogManager.getLogger(getClass());

    //这里存放了每个租户数据源,如果没有则创建数据源 我们在配置中以为它put了默认数据源,看下面代码可知
    public Map<Object, Object> targetDataSources = new HashMap();

    private static final String mysql = "mysql";
    private static final String oracle = "oracle";

    @Override
    protected Object determineCurrentLookupKey() {
        //这里便是使用了ThreadLocal 实现每个线程独立保存本地副本,明确线程的租户信息
        String dataSourceName = DataContextHolder.getCurrentDataSourceName();
        if (dataSourceName == null) {
            logger.info("==================== change DB:<defaultDataSource> ==================== ");
            //如果为null则返回默认的数据库链接
            return "defaultDataSource";
        }
        Object obj = targetDataSources.get(dataSourceName);
        if (obj == null) {//没有则新建
            DataSource dataSource = null;
            DbInfo dbinfo = DataContextHolder.getCurrentDBInfo();
            if (dbinfo != null && StringUtils.isNotBlank(dbinfo.getDbName())) {
                String url;
                if (mysql.equals(dbinfo.getDbType())) {
                    url = "jdbc:mysql://" + dbinfo.getDomainName() + ":" + dbinfo.getPort() + "/" + dbinfo.getDbName() + "?useUnicode=true&characterEncoding=utf-8&allowMultiQueries=true&useSSL=false";

                } else if (oracle.equals(dbinfo.getDbType())) {
                    url = "jdbc:oracle:thin:@" + dbinfo.getDomainName() + ":" + dbinfo.getPort() + ":" + dbinfo.getDbName();
                    dbinfo.setUrl(url);
                } else {
                    logger.error("不支持的数据库类型:" + dbinfo.getDbType());
                    throw new RuntimeException("不支持的数据库类型:" + dbinfo.getDbType());
                }
                dbinfo.setUrl(url);

                dataSource = createDataSource(dbinfo);
            }

            if (null != dataSource) {
                targetDataSources.put(dataSourceName, dataSource);
                setTargetDataSources(targetDataSources);
                afterPropertiesSet();
                DataContextHolder.setDataSourceType(dataSourceName);
            }
        }
        logger.info("==================== change DB:<{}> ==================== ", dataSourceName);
        return DataContextHolder.getCurrentDataSourceName();//返回当前需要的数据源
    }

    public DruidDataSource createDataSource(DbInfo dbInfo) {
        DruidDataSource parent = (DruidDataSource) targetDataSources.get("defaultDataSource");
        DruidDataSource dataSource = new DruidDataSource();
        dataSource.setUrl(dbInfo.getUrl());
        dataSource.setUsername(dbInfo.getLoginName());
        dataSource.setPassword(dbInfo.getLoginPw());

        dataSource.setDriverClassName(dbInfo.getClassDriverName());

        dataSource.setMaxActive(parent.getMaxActive());
        dataSource.setMinIdle(parent.getMinIdle());
        dataSource.setInitialSize(parent.getInitialSize());
        dataSource.setMaxWait(parent.getMaxWait());
        dataSource.setTimeBetweenEvictionRunsMillis(parent.getTimeBetweenEvictionRunsMillis());
        dataSource.setMinEvictableIdleTimeMillis(parent.getMinEvictableIdleTimeMillis());
        dataSource.setValidationQuery(parent.getValidationQuery());
        dataSource.setBreakAfterAcquireFailure(parent.isBreakAfterAcquireFailure());
        dataSource.setConnectionErrorRetryAttempts(parent.getConnectionErrorRetryAttempts());
        dataSource.setTestWhileIdle(true);
        dataSource.setTestOnBorrow(false);
        dataSource.setTestOnReturn(false);
        dataSource.setDbType(dbInfo.getDbType());
        return dataSource;
    }
}

以上术语可能不太准确,希望大神不吝赐教