myBatis学习笔记(10)——使用拦截器实现分页查询
1. Page
package com.sm.model;
import java.util.List;
public class Page<T> {
    public static final int DEFAULT_PAGE_SIZE = 20;
    protected int pageNo = 1; // 当前页, 默觉得第1页
    protected int pageSize = DEFAULT_PAGE_SIZE; // 每页记录数
    protected long totalRecord = -1; // 总记录数, 默觉得-1, 表示须要查询
    protected int totalPage = -1; // 总页数, 默觉得-1, 表示须要计算
    protected List<T> results; // 当前页记录List形式
    public int getPageNo() {
        return pageNo;
    }
    public void setPageNo(int pageNo) {
        this.pageNo = pageNo;
    }
    public int getPageSize() {
        return pageSize;
    }
    public void setPageSize(int pageSize) {
        this.pageSize = pageSize;
        computeTotalPage();
    }
    public long getTotalRecord() {
        return totalRecord;
    }
    public int getTotalPage() {
        return totalPage;
    }
    public void setTotalRecord(long totalRecord) {
        this.totalRecord = totalRecord;
        computeTotalPage();
    }
    protected void computeTotalPage() {
        if (getPageSize() > 0 && getTotalRecord() > -1) {
            this.totalPage = (int) (getTotalRecord() % getPageSize() == 0 ? getTotalRecord() / getPageSize() : getTotalRecord() / getPageSize() + 1);
        }
    }
    public List<T> getResults() {
        return results;
    }
    public void setResults(List<T> results) {
        this.results = results;
    }
    @Override
    public String toString() {
        StringBuilder builder = new StringBuilder().append("Page [pageNo=").append(pageNo).append(", pageSize=").append(pageSize)
                .append(", totalRecord=").append(totalRecord < 0 ?
"null" : totalRecord).append(", totalPage=")
                .append(totalPage < 0 ? "null" : totalPage).append(", results=").append(results == null ?
"null" : results).append("]");
        return builder.toString();
    }
}
2. 实现拦截器
package com.sm.model;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.DefaultParameterHandler;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
        @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
public class PageInterceptor implements Interceptor {
    private static final Logger log = LoggerFactory.getLogger(PageInterceptor.class);
    public static final String MYSQL = "mysql";
    public static final String ORACLE = "oracle";
    protected String databaseType;// 数据库类型。不同的数据库有不同的分页方法
    protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();
    public String getDatabaseType() {
        return databaseType;
    }
    public void setDatabaseType(String databaseType) {
        if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
            throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
        }
        this.databaseType = databaseType;
    }
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }
    @Override
    public void setProperties(Properties properties) {
        String databaseType = properties.getProperty("databaseType");
        if (databaseType != null) {
            setDatabaseType(databaseType);
        }
    }
    @Override
    @SuppressWarnings({ "unchecked", "rawtypes" })
    public Object intercept(Invocation invocation) throws Throwable {
        if (invocation.getTarget() instanceof StatementHandler) {// 控制SQL和查询总数的地方
            Page page = pageThreadLocal.get();
            if (page == null) { //不是分页查询
                return invocation.proceed();
            }
            RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
            StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
            BoundSql boundSql = delegate.getBoundSql();
            Connection connection = (Connection) invocation.getArgs()[0];
            prepareAndCheckDatabaseType(connection); // 准备数据库类型
            if (page.getTotalPage() > -1) {
                if (log.isTraceEnabled()) {
                    log.trace("已经设置了总页数, 不须要再查询总数.");
                }
            } else {
                Object parameterObj = boundSql.getParameterObject();
                MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
                queryTotalRecord(page, parameterObj, mappedStatement, connection);
            }
            String sql = boundSql.getSql();
            String pageSql = buildPageSql(page, sql);
            if (log.isDebugEnabled()) {
                log.debug("分页时, 生成分页pageSql: " + pageSql);
            }
            ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
            return invocation.proceed();
        } else { // 查询结果的地方
            // 获取是否有分页Page对象
            Page<?> page = findPageObject(invocation.getArgs()[1]);
            if (page == null) {
                if (log.isTraceEnabled()) {
                    log.trace("没有Page对象作为參数, 不是分页查询.");
                }
                return invocation.proceed();
            } else {
                if (log.isTraceEnabled()) {
                    log.trace("检測到分页Page对象, 使用分页查询.");
                }
            }
            //设置真正的parameterObj
            invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
            pageThreadLocal.set(page);
            try {
                Object resultObj = invocation.proceed(); // Executor.query(..)
                if (resultObj instanceof List) {
                    /* @SuppressWarnings({ "unchecked", "rawtypes" }) */
                    page.setResults((List) resultObj);
                }
                return resultObj;
            } finally {
                pageThreadLocal.remove();
            }
        }
    }
    protected Page<?> findPageObject(Object parameterObj) {
        if (parameterObj instanceof Page<?>) {
            return (Page<?>) parameterObj;
        } else if (parameterObj instanceof Map) {
            for (Object val : ((Map<?
, ?
>) parameterObj).values()) {
                if (val instanceof Page<?>) {
                    return (Page<?
>) val;
                }
            }
        }
        return null;
    }
    /**
     * <pre>
     * 把真正的參数对象解析出来
     * Spring会自己主动封装对个參数对象为Map<String, Object>对象
     * 对于通过@Param指定key值參数我们不做处理,由于XML文件须要该KEY值
     * 而对于没有@Param指定时,Spring会使用0,1作为主键
     * 对于没有@Param指定名称的參数,一般XML文件会直接对真正的參数对象解析。
     * 此时解析出真正的參数作为根对象
     * </pre>
     * @author jundong.xu_C
     * @param parameterObj
     * @return
     */
    protected Object extractRealParameterObject(Object parameterObj) {
        if (parameterObj instanceof Map<?, ?>) {
            Map<?, ?
> parameterMap = (Map<?, ?>) parameterObj;
            if (parameterMap.size() == 2) {
                boolean springMapWithNoParamName = true;
                for (Object key : parameterMap.keySet()) {
                    if (!(key instanceof String)) {
                        springMapWithNoParamName = false;
                        break;
                    }
                    String keyStr = (String) key;
                    if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
                        springMapWithNoParamName = false;
                        break;
                    }
                }
                if (springMapWithNoParamName) {
                    for (Object value : parameterMap.values()) {
                        if (!(value instanceof Page<?
>)) {
                            return value;
                        }
                    }
                }
            }
        }
        return parameterObj;
    }
    protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
        if (databaseType == null) {
            String productName = connection.getMetaData().getDatabaseProductName();
            if (log.isTraceEnabled()) {
                log.trace("Database productName: " + productName);
            }
            productName = productName.toLowerCase();
            if (productName.indexOf(MYSQL) != -1) {
                databaseType = MYSQL;
            } else if (productName.indexOf(ORACLE) != -1) {
                databaseType = ORACLE;
            } else {
                throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
            }
            if (log.isInfoEnabled()) {
                log.info("自己主动检測到的数据库类型为: " + databaseType);
            }
        }
    }
    /**
     * <pre>
     * 生成分页SQL
     * </pre>
     *
     * @author jundong.xu_C
     * @param page
     * @param sql
     * @return
     */
    protected String buildPageSql(Page<?> page, String sql) {
        if (MYSQL.equalsIgnoreCase(databaseType)) {
            return buildMysqlPageSql(page, sql);
        } else if (ORACLE.equalsIgnoreCase(databaseType)) {
            return buildOraclePageSql(page, sql);
        }
        return sql;
    }
    /**
     * <pre>
     * 生成Mysql分页查询SQL
     * </pre>
     *
     * @author jundong.xu_C
     * @param page
     * @param sql
     * @return
     */
    protected String buildMysqlPageSql(Page<?> page, String sql) {
        // 计算第一条记录的位置,Mysql中记录的位置是从0開始的。
        int offset = (page.getPageNo() - 1) * page.getPageSize();
        return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString();
    }
    /**
     * <pre>
     * 生成Oracle分页查询SQL
     * </pre>
     *
     * @author jundong.xu_C
     * @param page
     * @param sql
     * @return
     */
    protected String buildOraclePageSql(Page<?> page, String sql) {
        // 计算第一条记录的位置。Oracle分页是通过rownum进行的。而rownum是从1開始的
        int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
        StringBuilder sb = new StringBuilder(sql);
        sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
        sb.insert(0, "select * from (").append(") where r >= ").append(offset);
        return sb.toString();
    }
    /**
     * <pre>
     * 查询总数
     * </pre>
     *
     * @author jundong.xu_C
     * @param page
     * @param parameterObject
     * @param mappedStatement
     * @param connection
     * @throws SQLException
     */
    protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, Connection connection) throws SQLException {
        BoundSql boundSql = mappedStatement.getBoundSql(page);
        String sql = boundSql.getSql();
        String countSql = this.buildCountSql(sql);
        if (log.isDebugEnabled()) {
            log.debug("分页时, 生成countSql: " + countSql);
        }
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);
            parameterHandler.setParameters(pstmt);
            rs = pstmt.executeQuery();
            if (rs.next()) {
                long totalRecord = rs.getLong(1);
                page.setTotalRecord(totalRecord);
            }
        } finally {
            if (rs != null)
                try {
                    rs.close();
                } catch (Exception e) {
                    if (log.isWarnEnabled()) {
                        log.warn("关闭ResultSet时异常.", e);
                    }
                }
            if (pstmt != null)
                try {
                    pstmt.close();
                } catch (Exception e) {
                    if (log.isWarnEnabled()) {
                        log.warn("关闭PreparedStatement时异常.", e);
                    }
                }
        }
    }
    /**
     * 依据原Sql语句获取相应的查询总记录数的Sql语句
     *
     * @param sql
     * @return
     */
    protected String buildCountSql(String sql) {
        int index = sql.indexOf("from");
        return "select count(*) " + sql.substring(index);
    }
    /**
     * 利用反射进行操作的一个工具类
     *
     */
    private static class ReflectUtil {
        /**
         * 利用反射获取指定对象的指定属性
         *
         * @param obj 目标对象
         * @param fieldName 目标属性
         * @return 目标属性的值
         */
        public static Object getFieldValue(Object obj, String fieldName) {
            Object result = null;
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                field.setAccessible(true);
                try {
                    result = field.get(obj);
                } catch (IllegalArgumentException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
            return result;
        }
        /**
         * 利用反射获取指定对象里面的指定属性
         *
         * @param obj 目标对象
         * @param fieldName 目标属性
         * @return 目标字段
         */
        private static Field getField(Object obj, String fieldName) {
            Field field = null;
            for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
                try {
                    field = clazz.getDeclaredField(fieldName);
                    break;
                } catch (NoSuchFieldException e) {
                    // 杩欓噷涓嶇敤鍋氬鐞嗭紝瀛愮被娌℃湁璇ュ瓧娈靛彲鑳藉搴旂殑鐖剁被鏈夛紝閮芥病鏈夊氨杩斿洖null銆�
                }
            }
            return field;
        }
        /**
         * 利用反射设置指定对象的指定属性为指定的值
         *
         * @param obj 目标对象
         * @param fieldName 目标属性
         * @param fieldValue 目标值
         */
        public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                try {
                    field.setAccessible(true);
                    field.set(obj, fieldValue);
                } catch (IllegalArgumentException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                } catch (IllegalAccessException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
    }
    public static class PageNotSupportException extends RuntimeException {
        public PageNotSupportException() {
            super();
        }
        public PageNotSupportException(String message, Throwable cause) {
            super(message, cause);
        }
        public PageNotSupportException(String message) {
            super(message);
        }
        public PageNotSupportException(Throwable cause) {
            super(cause);
        }
    }
}
3. spring配置文件(mybatis已和spring整合)
    <!-- 配置mybatis的sqlSessionFactory -->
    <bean id="sqlSessionFactoryBean" class="org.mybatis.spring.SqlSessionFactoryBean">
        <property name="dataSource" ref="dataSource"></property>
        <!-- 配置了typeAliasesPackage之后,在映射文件里,这个包下的实体类能够不写全名 -->
        <property name="typeAliasesPackage" value="com.sm.model"></property>
        <!-- 配置映射映射文件的位置 -->
        <property name="mapperLocations" value="classpath:resources/mapper/*.xml"></property>
        <property name="plugins">
            <!-- 分页拦截器 -->
            <bean class="com.sm.model.PageInterceptor"></bean>
        </property>
    </bean>
4. mapper.xml
<select id="getUsers" resultType="User" parameterType="Map">
    select * from user where username=#{user.username}
</select>
6. DAO
List<User> getUsers(Map map);
7. 測试
Page page = new Page();
//配置分页參数
page.setPageNo(1);
page.setPageSize(3);
//条件查询,传參
User user = new User();
user.setUsername("2");
Map map = new HashMap<>();
map.put("user", user);
map.put("page", page);
List<User> list = userDAO.getUsers(map);
System.out.println(list);
System.out.println(page);
8. 总结
上面的分页拦截器,拷下来直接用就好了。假设想了解实现原理。能够看慕课网的视频通过自己主动回复机器人学Mybatis—加强版
myBatis学习笔记(10)——使用拦截器实现分页查询的更多相关文章
- mybatis学习笔记(10)-一对一查询
		
mybatis学习笔记(10)-一对一查询 标签: mybatis mybatis学习笔记10-一对一查询 resultType实现 resultMap实现 resultType和resultMap实 ...
 - SpringBoot学习笔记:自定义拦截器
		
SpringBoot学习笔记:自定义拦截器 快速开始 拦截器类似于过滤器,但是拦截器提供更精细的的控制能力,它可以在一个请求过程中的两个节点进行拦截: 在请求发送到Controller之前 在响应发送 ...
 - 【Struts2学习笔记-6--】Struts2之拦截器
		
简单拦截器的使用 拦截器最基本的使用: 拦截方法的拦截器 拦截器的执行顺序 拦截结果的监听器-相当于 后拦截器 执行顺序: 覆盖拦截器栈里特定拦截器的参数 使用拦截器完成-权限控制 主要完成两个功能: ...
 - Struts2学习笔记(十)——自定义拦截器
		
Struts2拦截器是基于AOP思想实现的,而AOP的实现是基于动态代理.Struts2拦截器会在访问某个Action之前或者之后进行拦截,并且Struts2拦截器是可插拔的:Struts2拦截器栈就 ...
 - struts2框架学习笔记6:拦截器
		
拦截器是Struts2实现功能的核心部分 拦截器的创建: 第一种: package interceptor; import com.opensymphony.xwork2.ActionInvocati ...
 - SpringMVC学习笔记三:拦截器
		
一:拦截器工作原理 类比Struts2的拦截器,通过拦截器可以实现在调用controller的方法前.后进行一些操作. 二:拦截器实现 1:实现拦截器类 实现HandlerInterceptor 接口 ...
 - Struts2学习笔记04 之 拦截器
		
一.创建拦截器组件 1. 创建一个类,实现Interceptor接口,并实现intercept方法 2.注册拦截器 3.引用拦截器 二.拦截器栈 预置拦截器: 默认引用拦截器 拦截器调用顺序: Fil ...
 - Android学习笔记(10).布局管理器
		
布局管理器的几个类都是ViewGroup派生的,用于管理组件的分布和大小,使用布局管理器能够非常好地解决屏幕适配问题. 布局管理器本身也是一个UI组件,布局管理器能够相互嵌套使用,以下是布局管理器的类 ...
 - Mybatis学习笔记10 - 动态sql之if判断
		
示例代码: 接口定义: package com.mybatis.dao; import com.mybatis.bean.Employee; import java.util.List; public ...
 
随机推荐
- POJ-3669-流星雨
			
这题的话,卡了有两个小时左右,首先更新地图的时候越界了,我们进行更新的时候,要判断一下是不是小于零了,越界就会Runtime Error. 然后bfs 的时候,我没有允许它搜出300以外的范围,然后就 ...
 - MYSQL数据库SQL语句集锦
			
*特别说明:FILED代表数据表字段,CONDITIONS代表where之后的条件,TABLENAME代表数据表名 []中括号内的内容代表 可有可无. 创建数据库 create database ...
 - Xshell 配色方案 Ubuntu Solarized_Dark isayme
			
前言 最近在用Ubuntu,发现它的配色方案挺好看的,所以查了下有没有大神做过Xshell的Ubuntu配色方案. 一看,果然还是有大佬做了这个的. 三套配色配置如下: 1. Ubuntu的Solar ...
 - perl学习之:localtime
			
Perl中localtime()函数以及sprintf (2011-4-25 19:39)localtime函数 localtime函数,根据它所在的上下文,可以用两种完全不同的方法来运行.在标量上下 ...
 - verilog behavioral modeling--branch statement
			
conditional statement case statement 1. conditional statement if(expression) statement_o ...
 - LeetCode(153) Find Minimum in Rotated Sorted Array
			
题目 Total Accepted: 65121 Total Submissions: 190974 Difficulty: Medium Suppose a sorted array is rota ...
 - PAT Basic 1067
			
1067 试密码 当你试图登录某个系统却忘了密码时,系统一般只会允许你尝试有限多次,当超出允许次数时,账号就会被锁死.本题就请你实现这个小功能. 输入格式: 输入在第一行给出一个密码(长度不超过 20 ...
 - js prototype 添加属性对象
			
在本例中,我们将展示如何使用 prototype 属性来向对象添加属性: <script type="text/javascript"> function employ ...
 - js根据银行卡号判断属于哪个银行,并返回银行缩写及银行卡类型
			
在做绑定银行卡,输入银行卡的时候,产品有这么一个需求,需要用户输入银行卡号的时候,显示对应的银行卡名称及简称.于是苦苦寻觅,终于找到了支付宝的开放API,银行卡校验接口 https://ccdca ...
 - 爬虫开发python工具包介绍 (3)
			
本文来自网易云社区 作者:王涛 :arg str url: URL to fetch :arg str method: HTTP method, e.g. " ...