千家信息网

springboot配置多数据源后mybatis拦截器失效怎么办

发表于:2025-02-07 作者:千家信息网编辑
千家信息网最后更新 2025年02月07日,这篇文章给大家分享的是有关springboot配置多数据源后mybatis拦截器失效怎么办的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。配置文件是通过springcloud
千家信息网最后更新 2025年02月07日springboot配置多数据源后mybatis拦截器失效怎么办

这篇文章给大家分享的是有关springboot配置多数据源后mybatis拦截器失效怎么办的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。

配置文件是通过springcloudconfig远程分布式配置。采用阿里Druid数据源。并支持一主多从的读写分离。分页组件通过拦截器拦截带有page后缀的方法名,动态的设置total总数。

1. 解析配置文件初始化数据源

@Configurationpublic class DataSourceConfiguration {    /**     * 数据源类型     */    @Value("${spring.datasource.type}")    private Class dataSourceType;    /**     * 主数据源配置     *     * @return     */    @Bean(name = "masterDataSource", destroyMethod = "close")    @Primary    @ConfigurationProperties(prefix = "spring.datasource")    public DataSource masterDataSource() {        DataSource source = DataSourceBuilder.create().type(dataSourceType).build();        return source;    }    /**     * 从数据源配置     *     * @return     */    @Bean(name = "slaveDataSource0")    @ConfigurationProperties(prefix = "spring.slave0")    public DataSource slaveDataSource0() {        DataSource source = DataSourceBuilder.create().type(dataSourceType).build();        return source;    }    /**     * 从数据源集合     *     * @return     */    @Bean(name = "slaveDataSources")    public List slaveDataSources() {        List slaveDataSources = new ArrayList();        slaveDataSources.add(slaveDataSource0());        return slaveDataSources;    }}

2. 定义数据源枚举类型

public enum DataSourceType {    master("master", "master"), slave("slave", "slave");    private String type;    private String name;    DataSourceType(String type, String name) {        this.type = type;        this.name = name;    }    public String getType() {        return type;    }    public void setType(String type) {        this.type = type;    }    public String getName() {        return name;    }    public void setName(String name) {        this.name = name;    }}

3. TheadLocal保存数据源类型

public class DataSourceContextHolder {    private static final ThreadLocal local = new ThreadLocal();    public static ThreadLocal getLocal() {        return local;    }    public static void slave() {        local.set(DataSourceType.slave.getType());    }    public static void master() {        local.set(DataSourceType.master.getType());    }    public static String getJdbcType() {        return local.get();    }    public static void clearDataSource(){        local.remove();    }}

4. 自定义sqlSessionProxy

并将数据源填充到DataSourceRoute

@Configuration@ConditionalOnClass({EnableTransactionManagement.class})@Import({DataSourceConfiguration.class})public class DataSourceSqlSessionFactory {    private Logger logger = Logger.getLogger(DataSourceSqlSessionFactory.class);    @Value("${spring.datasource.type}")    private Class dataSourceType;    @Value("${mybatis.mapper-locations}")    private String mapperLocations;    @Value("${mybatis.type-aliases-package}")    private String aliasesPackage;    @Value("${slave.datasource.number}")    private int dataSourceNumber;    @Resource(name = "masterDataSource")    private DataSource masterDataSource;    @Resource(name = "slaveDataSources")    private List slaveDataSources;    @Bean    @ConditionalOnMissingBean    public SqlSessionFactory sqlSessionFactory() throws Exception {        logger.info("======================= init sqlSessionFactory");        SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();        sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy());        PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();        sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations));        sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage);        sqlSessionFactoryBean.getObject().getConfiguration().setMapUnderscoreToCamelCase(true);        return sqlSessionFactoryBean.getObject();    }    @Bean(name = "roundRobinDataSourceProxy")    public AbstractRoutingDataSource roundRobinDataSourceProxy() {        logger.info("======================= init robinDataSourceProxy");        DataSourceRoute proxy = new DataSourceRoute(dataSourceNumber);        Map targetDataSources = new HashMap();        targetDataSources.put(DataSourceType.master.getType(), masterDataSource);        if(null != slaveDataSources) {            for(int i=0; i

5. 自定义路由

public class DataSourceRoute extends AbstractRoutingDataSource {    private Logger logger = Logger.getLogger(DataSourceRoute.class);    private final int dataSourceNumber;        public DataSourceRoute(int dataSourceNumber) {        this.dataSourceNumber = dataSourceNumber;    }    @Override    protected Object determineCurrentLookupKey() {        String typeKey = DataSourceContextHolder.getJdbcType();        logger.info("==================== swtich dataSource:" + typeKey);        if (typeKey.equals(DataSourceType.master.getType())) {            return DataSourceType.master.getType();        }else{            //从数据源随机分配            Random random = new Random();            int slaveDsIndex = random.nextInt(dataSourceNumber);            return slaveDsIndex;        }    }}

6. 定义切面,dao层定义切面

@Aspect@Componentpublic class DataSourceAop {    private Logger logger = Logger.getLogger(DataSourceAop.class);    @Before("execution(* com.dbq.iot.mapper..*.get*(..)) || execution(* com.dbq.iot.mapper..*.isExist*(..)) " +            "|| execution(* com.dbq.iot.mapper..*.select*(..)) || execution(* com.dbq.iot.mapper..*.count*(..)) " +            "|| execution(* com.dbq.iot.mapper..*.list*(..)) || execution(* com.dbq.iot.mapper..*.query*(..))" +            "|| execution(* com.dbq.iot.mapper..*.find*(..))|| execution(* com.dbq.iot.mapper..*.search*(..))")    public void setSlaveDataSourceType(JoinPoint joinPoint) {        DataSourceContextHolder.slave();        logger.info("=========slave, method:" + joinPoint.getSignature().getName());    }    @Before("execution(* com.dbq.iot.mapper..*.add*(..)) || execution(* com.dbq.iot.mapper..*.del*(..))" +            "||execution(* com.dbq.iot.mapper..*.upDate*(..)) || execution(* com.dbq.iot.mapper..*.insert*(..))" +            "||execution(* com.dbq.iot.mapper..*.create*(..)) || execution(* com.dbq.iot.mapper..*.update*(..))" +            "||execution(* com.dbq.iot.mapper..*.delete*(..)) || execution(* com.dbq.iot.mapper..*.remove*(..))" +            "||execution(* com.dbq.iot.mapper..*.save*(..)) || execution(* com.dbq.iot.mapper..*.relieve*(..))" +            "|| execution(* com.dbq.iot.mapper..*.edit*(..))")    public void setMasterDataSourceType(JoinPoint joinPoint) {        DataSourceContextHolder.master();        logger.info("=========master, method:" + joinPoint.getSignature().getName());    }}

7. 最后在写库增加事务管理

@Configuration@Import({DataSourceConfiguration.class})public class DataSouceTranscation extends DataSourceTransactionManagerAutoConfiguration {    private Logger logger = Logger.getLogger(DataSouceTranscation.class);    @Resource(name = "masterDataSource")    private DataSource masterDataSource;    /**     * 配置事务管理器     *     * @return     */    @Bean(name = "transactionManager")    public DataSourceTransactionManager transactionManagers() {        logger.info("===================== init transactionManager");        return new DataSourceTransactionManager(masterDataSource);    }}

8. 在配置文件中增加数据源配置

spring.datasource.name=writedbspring.datasource.url=jdbc:mysql://192.168.0.1/master?useUnicode=true&characterEncoding=utf8&autoReconnect=true&failOverReadOnly=falsespring.datasource.username=rootspring.datasource.password=1234spring.datasource.type=com.alibaba.druid.pool.DruidDataSourcespring.datasource.driver-class-name=com.mysql.jdbc.Driverspring.datasource.filters=statspring.datasource.initialSize=20spring.datasource.minIdle=20spring.datasource.maxActive=200spring.datasource.maxWait=60000#从库的数量slave.datasource.number=1spring.slave0.name=readdbspring.slave0.url=jdbc:mysql://192.168.0.2/slave?useUnicode=true&characterEncoding=utf8&autoReconnect=true&failOverReadOnly=falsespring.slave0.username=rootspring.slave0.password=1234spring.slave0.type=com.alibaba.druid.pool.DruidDataSourcespring.slave0.driver-class-name=com.mysql.jdbc.Driverspring.slave0.filters=statspring.slave0.initialSize=20spring.slave0.minIdle=20spring.slave0.maxActive=200spring.slave0.maxWait=60000

这样就实现了在springcloud框架下的读写分离,并且支持多个从库的负载均衡(简单的通过随机分配,也有网友通过算法实现平均分配,具体做法是通过一个线程安全的自增长Integer类型,取余实现。个人觉得没大必要。如果有大神有更好的方法可以一起探讨。)

Mabatis分页配置可通过dao层的拦截器对特定方法进行拦截,拦截后添加自己的逻辑代码,比如计算total等,具体代码如下(参考了网友的代码,主要是通过@Intercepts注解):

@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})public class PageInterceptor implements Interceptor {    private static final Log logger = LogFactory.getLog(PageInterceptor.class);    private static final ObjectFactory DEFAULT_OBJECT_FACTORY = new DefaultObjectFactory();    private static final ObjectWrapperFactory DEFAULT_OBJECT_WRAPPER_FACTORY = new DefaultObjectWrapperFactory();    private static final ReflectorFactory DEFAULT_REFLECTOR_FACTORY = new DefaultReflectorFactory();    private static String defaultDialect = "mysql"; // 数据库类型(默认为mysql)    private static String defaultPageSqlId = ".*Page$"; // 需要拦截的ID(正则匹配)    private String dialect = ""; // 数据库类型(默认为mysql)    private String pageSqlId = ""; // 需要拦截的ID(正则匹配)    @Override    public Object intercept(Invocation invocation) throws Throwable {        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();        MetaObject metaStatementHandler = MetaObject.forObject(statementHandler, DEFAULT_OBJECT_FACTORY,                DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);        // 分离代理对象链(由于目标类可能被多个拦截器拦截,从而形成多次代理,通过下面的两次循环可以分离出最原始的的目标类)        while (metaStatementHandler.hasGetter("h")) {            Object object = metaStatementHandler.getValue("h");            metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);        }        // 分离最后一个代理对象的目标类        while (metaStatementHandler.hasGetter("target")) {            Object object = metaStatementHandler.getValue("target");            metaStatementHandler = MetaObject.forObject(object, DEFAULT_OBJECT_FACTORY, DEFAULT_OBJECT_WRAPPER_FACTORY,DEFAULT_REFLECTOR_FACTORY);        }        Configuration configuration = (Configuration) metaStatementHandler.getValue("delegate.configuration");        if (null == dialect || "".equals(dialect)) {            logger.warn("Property dialect is not setted,use default 'mysql' ");            dialect = defaultDialect;        }        if (null == pageSqlId || "".equals(pageSqlId)) {            logger.warn("Property pageSqlId is not setted,use default '.*Page$' ");            pageSqlId = defaultPageSqlId;        }        MappedStatement mappedStatement = (MappedStatement) metaStatementHandler.getValue("delegate.mappedStatement");        // 只重写需要分页的sql语句。通过MappedStatement的ID匹配,默认重写以Page结尾的MappedStatement的sql        if (mappedStatement.getId().matches(pageSqlId)) {            BoundSql boundSql = (BoundSql) metaStatementHandler.getValue("delegate.boundSql");            Object parameterObject = boundSql.getParameterObject();            if (parameterObject == null) {                throw new NullPointerException("parameterObject is null!");            } else {                PageParameter page = (PageParameter) metaStatementHandler                        .getValue("delegate.boundSql.parameterObject.page");                String sql = boundSql.getSql();                // 重写sql                String pageSql = buildPageSql(sql, page);                metaStatementHandler.setValue("delegate.boundSql.sql", pageSql);                metaStatementHandler.setValue("delegate.rowBounds.offset", RowBounds.NO_ROW_OFFSET);                metaStatementHandler.setValue("delegate.rowBounds.limit", RowBounds.NO_ROW_LIMIT);                Connection connection = (Connection) invocation.getArgs()[0];                // 重设分页参数里的总页数等                setPageParameter(sql, connection, mappedStatement, boundSql, page);            }        }        // 将执行权交给下一个拦截器        return invocation.proceed();    }    /**     * @param sql     * @param connection     * @param mappedStatement     * @param boundSql     * @param page     */    private void setPageParameter(String sql, Connection connection, MappedStatement mappedStatement,                                  BoundSql boundSql, PageParameter page) {        // 记录总记录数        String countSql = "select count(0) from (" + sql + ") as total";        PreparedStatement countStmt = null;        ResultSet rs = null;        try {            countStmt = connection.prepareStatement(countSql);            BoundSql countBS = new BoundSql(mappedStatement.getConfiguration(), countSql,                    boundSql.getParameterMappings(), boundSql.getParameterObject());            Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters");            if (metaParamsField != null) {                try {                    MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters");                    ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo);                } catch (SecurityException | NoSuchFieldException | IllegalArgumentException                        | IllegalAccessException e) {                    // TODO Auto-generated catch block                     logger.error("Ignore this exception", e);                }            }            Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters");            if (additionalField != null) {                try {                    Map map = (Map) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters");                    ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map);                } catch (SecurityException | NoSuchFieldException | IllegalArgumentException                        | IllegalAccessException e) {                    // TODO Auto-generated catch block                    logger.error("Ignore this exception", e);                }            }            setParameters(countStmt, mappedStatement, countBS, boundSql.getParameterObject());            rs = countStmt.executeQuery();            int totalCount = 0;            if (rs.next()) {                totalCount = rs.getInt(1);            }            page.setTotalCount(totalCount);            int totalPage = totalCount / page.getPageSize() + ((totalCount % page.getPageSize() == 0) ? 0 : 1);            page.setTotalPage(totalPage);        } catch (SQLException e) {            logger.error("Ignore this exception", e);        } finally {            try {                if (rs != null){                    rs.close();                }            } catch (SQLException e) {                logger.error("Ignore this exception", e);            }            try {                if (countStmt != null){                    countStmt.close();                }            } catch (SQLException e) {                logger.error("Ignore this exception", e);            }        }    }    /**     * 对SQL参数(?)设值     *     * @param ps     * @param mappedStatement     * @param boundSql     * @param parameterObject     * @throws SQLException     */    private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,                               Object parameterObject) throws SQLException {        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, boundSql);        parameterHandler.setParameters(ps);    }    /**     * 根据数据库类型,生成特定的分页sql     *     * @param sql     * @param page     * @return     */    private String buildPageSql(String sql, PageParameter page) {        if (page != null) {            StringBuilder pageSql = new StringBuilder();            pageSql = buildPageSqlForMysql(sql,page);            return pageSql.toString();        } else {            return sql;        }    }    /**     * mysql的分页语句     *     * @param sql     * @param page     * @return String     */    public StringBuilder buildPageSqlForMysql(String sql, PageParameter page) {        StringBuilder pageSql = new StringBuilder(100);        String beginrow = String.valueOf((page.getCurrentPage() - 1) * page.getPageSize());        pageSql.append(sql);        pageSql.append(" limit " + beginrow + "," + page.getPageSize());        return pageSql;    }    @Override    public Object plugin(Object target) {        if (target instanceof StatementHandler) {            return Plugin.wrap(target, this);        } else {            return target;        }    }    @Override    public void setProperties(Properties properties) {    }}

这里碰到一个比较有趣的问题,就是sql如果是foreach参数,在拦截后无法注入。需要加入以下代码才可以(有得资料上只提到重置metaParameters)。

Field metaParamsField = ReflectUtil.getFieldByFieldName(boundSql, "metaParameters");if (metaParamsField != null) {    try {        MetaObject mo = (MetaObject) ReflectUtil.getValueByFieldName(boundSql, "metaParameters");        ReflectUtil.setValueByFieldName(countBS, "metaParameters", mo);    } catch (SecurityException | NoSuchFieldException | IllegalArgumentException            | IllegalAccessException e) {        // TODO Auto-generated catch block         logger.error("Ignore this exception", e);    }}Field additionalField = ReflectUtil.getFieldByFieldName(boundSql, "additionalParameters");if (additionalField != null) {    try {        Map map = (Map) ReflectUtil.getValueByFieldName(boundSql, "additionalParameters");        ReflectUtil.setValueByFieldName(countBS, "additionalParameters", map);    } catch (SecurityException | NoSuchFieldException | IllegalArgumentException            | IllegalAccessException e) {        // TODO Auto-generated catch block        logger.error("Ignore this exception", e);    }}

读写分离倒是写好了,但是发现增加了mysql一主多从的读写分离后,此分页拦截器直接失效。

最后分析原因是因为,我们在做主从分离时,自定义了SqlSessionFactory,导致此拦截器没有注入。

在上面第4步中,DataSourceSqlSessionFactory中注入拦截器即可,具体代码如下

通过注解引入拦截器类:

@Import({DataSourceConfiguration.class,PageInterceptor.class})

注入拦截器

@Autowired    private PageInterceptor pageInterceptor;

SqlSessionFactoryBean中设置拦截器

sqlSessionFactoryBean.setPlugins(newInterceptor[]{pageInterceptor});

这里碰到一个坑,就是设置plugins时必须在sqlSessionFactoryBean.getObject()之前。

SqlSessionFactory在生成的时候就会获取plugins,并设置到Configuration中,如果在之后设置则不会注入。

可跟踪源码看到:

sqlSessionFactoryBean.getObject()
public SqlSessionFactory getObject() throws Exception {    if (this.sqlSessionFactory == null) {      afterPropertiesSet();    }    return this.sqlSessionFactory;}
public void afterPropertiesSet() throws Exception {    notNull(dataSource, "Property 'dataSource' is required");    notNull(sqlSessionFactoryBuilder, "Property 'sqlSessionFactoryBuilder' is required");    state((configuration == null && configLocation == null) || !(configuration != null && configLocation != null),              "Property 'configuration' and 'configLocation' can not specified with together");    this.sqlSessionFactory = buildSqlSessionFactory();  }

buildSqlSessionFactory()

if (!isEmpty(this.plugins)) {      for (Interceptor plugin : this.plugins) {        configuration.addInterceptor(plugin);        if (LOGGER.isDebugEnabled()) {          LOGGER.debug("Registered plugin: '" + plugin + "'");        }      }    }

最后贴上正确的配置代码(DataSourceSqlSessionFactory代码片段)

@Bean@ConditionalOnMissingBeanpublic SqlSessionFactory sqlSessionFactory() throws Exception {        logger.info("======================= init sqlSessionFactory");        SqlSessionFactoryBean sqlSessionFactoryBean = new SqlSessionFactoryBean();        sqlSessionFactoryBean.setPlugins(new Interceptor[]{pageInterceptor});        sqlSessionFactoryBean.setDataSource(roundRobinDataSourceProxy());        PathMatchingResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();        sqlSessionFactoryBean.setMapperLocations(resolver.getResources(mapperLocations));        sqlSessionFactoryBean.setTypeAliasesPackage(aliasesPackage);        SqlSessionFactory sqlSessionFactory = sqlSessionFactoryBean.getObject();        sqlSessionFactory.getConfiguration().setMapUnderscoreToCamelCase(true);        return sqlSessionFactory;}

感谢各位的阅读!关于"springboot配置多数据源后mybatis拦截器失效怎么办"这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,让大家可以学到更多知识,如果觉得文章不错,可以把它分享出去让更多的人看到吧!

0