Spring Boot自定义JpaRepository基类

摘要

  • SpringBoot基于JPA的数据操作非常方便,我们只需继承JpaRepository就可以拥有强大的数据操控能力,但是偶尔我们需要进行复杂的操作,比如批量插入与更新,或者是复杂sql等等,此时就需要我们对JpaRepository进行一些扩展。
  • @Query注解也可以直接执行sql,但是其有一些局限,比如只有select * 时才能直接封装为对象,只查询部分属性时就只能封装为Object[]或Map。如果希望@Query查询部分属性时也可以直接转换为对象,可以查看下一篇内容 Spring Boot的@Query注解

自定义JpaRepository接口

  • 包含一些常用操作,可以按需扩展
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
package com.example.jpa;

import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.NoRepositoryBean;

import javax.persistence.EntityManager;
import java.io.Serializable;
import java.util.List;
import java.util.Map;

@NoRepositoryBean //接口不参与jpa的代理
public interface BaseJpaRepository<T, ID extends Serializable> extends JpaRepository<T, ID>, JpaSpecificationExecutor<T>, Serializable {

EntityManager getEntityManager();

<E> List<E> findByHql(String hql);

List<Map> findBySql(String sql);
List<Map> findBySql(String sql, Object[] params);
List<Map> findBySql(String sql, Map<String, Object> params);


Map findBySqlFirst(String sql);
Map findBySqlFirst(String sql, Object[] params);
Map findBySqlFirst(String sql, Map<String, Object> params);

/**
* basic == true 表示基本数据类型
*/
<E> List<E> findBySql(String sql, Class clazz, boolean basic);
<E> List<E> findBySql(String sql, Class clazz, boolean basic, Object[] params);
<E> List<E> findBySql(String sql, Class clazz, boolean basic, Map<String, Object> params);

/**
* 分页查询
*/
<E> Page<E> findPageBySql(String sql, Pageable pageable, Class clazz, boolean basic);

<E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class clazz, boolean basic);

<E> Page<E> findPageBySql(String sql, Pageable pageable, Class clazz, boolean basic, Object[] params);

<E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class clazz, boolean basic, Object[] params);

<E> Page<E> findPageBySql(String sql, Pageable pageable, Class clazz, boolean basic, Map<String, Object> params);

<E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class clazz, boolean basic, Map<String, Object> params);

/**
* basic == true 表示基本数据类型
*/
<E> E findBySqlFirst(String sql, Class clazz, boolean basic);
<E> E findBySqlFirst(String sql, Class clazz, boolean basic, Object[] params);
<E> E findBySqlFirst(String sql, Class clazz, boolean basic, Map<String, Object> params);


T findByIdNew(ID id);

/**
* 批量插入
*/
<S extends T> Iterable<S> batchSave(Iterable<S> iterable);

/**
* 批量更新
*/
<S extends T> Iterable<S> batchUpdate(Iterable<S> iterable);

}

自定义JpaRepository接口的实现类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
package com.example.jpa;


import com.example.context.ApplicationContextProvider;
import org.hibernate.query.internal.NativeQueryImpl;
import org.hibernate.transform.Transformers;
import org.springframework.data.domain.Page;
import org.springframework.data.domain.PageImpl;
import org.springframework.data.domain.Pageable;
import org.springframework.data.domain.Sort;
import org.springframework.data.jpa.repository.support.JpaEntityInformation;
import org.springframework.data.jpa.repository.support.SimpleJpaRepository;
import org.springframework.transaction.annotation.Transactional;

import javax.persistence.EntityManager;
import javax.persistence.Query;
import java.io.Serializable;
import java.math.BigInteger;
import java.util.*;

public class BaseJpaRepositoryImpl<T, ID extends Serializable> extends SimpleJpaRepository<T, ID> implements BaseJpaRepository<T, ID> {

//批量更新时的阀值,每500条数据commit一次
private static final Integer BATCH_SIZE = 500;


//通过构造方法初始化EntityManager
private final EntityManager entityManager;


public BaseJpaRepositoryImpl(JpaEntityInformation<T, ID> entityInformation, EntityManager entityManager) {
super(entityInformation, entityManager);
this.entityManager = entityManager;
}


@Override
public EntityManager getEntityManager() {
return entityManager;
}

@Override
public <E> List<E> findByHql(String hql) {
return (List<E>) entityManager.createQuery(hql)
.getResultList();
}


@Override
public List<Map> findBySql(String sql) {
return findBySql(sql, new HashMap<>());
}

@Override
public List<Map> findBySql(String sql, Object[] params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.length > 0) {
for (int i = 0; i < params.length; i++) {
nativeQuery.setParameter(i + 1, params[i]);
}
}
return nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.getResultList();
}

@Override
public List<Map> findBySql(String sql, Map<String, Object> params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.size() > 0) {
for (String key : params.keySet()) {
nativeQuery.setParameter(key, params.get(key));
}
}
return nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.getResultList();
}


@Override
public <E> List<E> findBySql(String sql, Class clazz, boolean basic) {
return findBySql(sql, clazz, basic, new HashMap<>());
}

@Override
public <E> List<E> findBySql(String sql, Class clazz, boolean basic, Object[] params) {
return getJpaUtil().mapListToObjectList(findBySql(sql, params), clazz, basic);
}

@Override
public <E> List<E> findBySql(String sql, Class clazz, boolean basic, Map<String, Object> params) {
return getJpaUtil().mapListToObjectList(findBySql(sql, params), clazz, basic);
}

@Override
public <E> Page<E> findPageBySql(String sql, Pageable pageable, Class clazz, boolean basic) {
return findPageBySql(sql, pageable, clazz, basic, new HashMap<>());
}

@Override
public <E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class clazz, boolean basic) {
return findPageBySql(sql, countSql, pageable, clazz, basic, new HashMap<>());
}

@Override
public <E> Page<E> findPageBySql(String sql, Pageable pageable, Class clazz, boolean basic, Object[] params) {
return findPageBySql(sql, null, pageable, clazz, basic, params);
}

@Override
public <E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class clazz, boolean basic, Object[] params) {
if (!sql.toLowerCase().contains("order by")) {
StringBuilder stringBuilder = new StringBuilder(sql);
stringBuilder.append(" order by ");
final Sort sort = pageable.getSort();
final List<Sort.Order> orders = sort.toList();
for (Sort.Order order : orders) {
stringBuilder.append(order.getProperty())
.append(" ")
.append(order.getDirection().name())
.append(",");
}
sql = stringBuilder.toString();
sql = sql.substring(0, sql.length() - 1);
}

final Query nativeQuery = entityManager.createNativeQuery(sql);
nativeQuery.setFirstResult(pageable.getPageNumber() * pageable.getPageSize());
nativeQuery.setMaxResults(pageable.getPageSize());

if (params != null && params.length > 0) {
for (int i = 0; i < params.length; i++) {
nativeQuery.setParameter(i + 1, params[i]);
}
}


List<Map> resultList = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP).getResultList();

final List<E> objectList = getJpaUtil().mapListToObjectList(resultList, clazz, basic);

if (!StringUtils.hasText(countSql)) {
countSql = "select count(*) from ( " + sql + " ) a";
}
final BigInteger count = findBySqlFirst(countSql, BigInteger.class, true);

Page<E> page = new PageImpl<>(objectList, pageable, count.longValue());

return page;
}

@Override
public <E> Page<E> findPageBySql(String sql, Pageable pageable, Class clazz, boolean basic, Map<String, Object> params) {
return findPageBySql(sql, null, pageable, clazz, basic, params);
}

@Override
public <E> Page<E> findPageBySql(String sql, String countSql, Pageable pageable, Class clazz, boolean basic, Map<String, Object> params) {
if (!sql.toLowerCase().contains("order by")) {
StringBuilder stringBuilder = new StringBuilder(sql);
stringBuilder.append(" order by ");
final Sort sort = pageable.getSort();
final List<Sort.Order> orders = sort.toList();
for (Sort.Order order : orders) {
stringBuilder.append(order.getProperty())
.append(" ")
.append(order.getDirection().name())
.append(",");
}
sql = stringBuilder.toString();
sql = sql.substring(0, sql.length() - 1);
}

final Query nativeQuery = entityManager.createNativeQuery(sql);
nativeQuery.setFirstResult(pageable.getPageNumber() * pageable.getPageSize());
nativeQuery.setMaxResults(pageable.getPageSize());

if (params != null && params.size() > 0) {
for (String key : params.keySet()) {
nativeQuery.setParameter(key, params.get(key));
}
}


List<Map> resultList = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP).getResultList();

final List<E> objectList = getJpaUtil().mapListToObjectList(resultList, clazz, basic);

if (!StringUtils.hasText(countSql)) {
countSql = "select count(*) from ( " + sql + " ) a";
}
final BigInteger count = findBySqlFirst(countSql, BigInteger.class, true);

Page<E> page = new PageImpl<>(objectList, pageable, count.longValue());

return page;
}

@Override
public Map findBySqlFirst(String sql) {
return findBySqlFirst(sql, new HashMap<>());
}

@Override
public Map findBySqlFirst(String sql, Object[] params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.length > 0) {
for (int i = 0; i < params.length; i++) {
nativeQuery.setParameter(i + 1, params[i]);
}
}
final Optional first = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.stream().findFirst();
if (first.isPresent()) {
return (Map) first.get();
}
return null;
}

@Override
public Map findBySqlFirst(String sql, Map<String, Object> params) {
Query nativeQuery = entityManager.createNativeQuery(sql);
if (params != null && params.size() > 0) {
for (String key : params.keySet()) {
nativeQuery.setParameter(key, params.get(key));
}
}
final Optional first = nativeQuery.unwrap(NativeQueryImpl.class)
.setResultTransformer(Transformers.ALIAS_TO_ENTITY_MAP)
.stream().findFirst();
if (first.isPresent()) {
return (Map) first.get();
}
return null;
}


@Override
public <E> E findBySqlFirst(String sql, Class clazz, boolean basic) {
return findBySqlFirst(sql, clazz, basic, new HashMap<>());
}

@Override
public <E> E findBySqlFirst(String sql, Class clazz, boolean basic, Object[] params) {
return getJpaUtil().mapToObject(findBySqlFirst(sql, params), clazz, basic);
}

@Override
public <E> E findBySqlFirst(String sql, Class clazz, boolean basic, Map<String, Object> params) {
return getJpaUtil().mapToObject(findBySqlFirst(sql, params), clazz, basic);
}

@Override
public T findByIdNew(ID id) {
T t = null;

Optional<T> optional = this.findById(id);
if (optional.isPresent()) {
t = optional.get();
}

return t;

}


@Override
@Transactional
public <S extends T> Iterable<S> batchSave(Iterable<S> iterable) {
Iterator<S> iterator = iterable.iterator();
int index = 0;
while (iterator.hasNext()) {
entityManager.persist(iterator.next());
index++;
if (index % BATCH_SIZE == 0) {
entityManager.flush();
entityManager.clear();
}
}
if (index % BATCH_SIZE != 0) {
entityManager.flush();
entityManager.clear();
}
return iterable;
}

@Override
@Transactional
public <S extends T> Iterable<S> batchUpdate(Iterable<S> iterable) {
Iterator<S> iterator = iterable.iterator();
int index = 0;
while (iterator.hasNext()) {
entityManager.merge(iterator.next());
index++;
if (index % BATCH_SIZE == 0) {
entityManager.flush();
entityManager.clear();
}
}
if (index % BATCH_SIZE != 0) {
entityManager.flush();
entityManager.clear();
}
return iterable;
}


private JpaUtil getJpaUtil() {
JpaUtil objectUtil = (JpaUtil) ApplicationContextProvider.getBean("jpaUtil");
return objectUtil;
}

}

工具类JpaUtil

  • 其功能就是将Map对象通过json转换成指定对象
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package com.example.jpa;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

@Component("jpaUtil")
public class JpaUtil {

@Autowired
ObjectMapper objectMapper;

/**
* 查询结果为List<Map>时,可以通过该方法转换为对象List,注意Map中key要与对象属性匹配,或者对象属性标注了@JsonProperty
*/
public <E> List<E> mapListToObjectList(List<Map> mapList, Class clazz, boolean basic) {

List<E> list = new ArrayList<>();
for (Map map : mapList) {
if (basic) {
list.add((E) map.values().stream().findFirst().get());

} else {
try {
final String valueAsString = objectMapper.writeValueAsString(map);
E newInstance = (E) objectMapper.readValue(valueAsString, clazz);
list.add(newInstance);
} catch (JsonProcessingException e) {
e.printStackTrace();
}
}
}
return list;
}

/**
* 查询结果为Map时,可以通过该方法转换为对象,注意Map中key要与对象属性匹配,或者对象属性标注了@JsonProperty
*/
public <E> E mapToObject(Map map, Class clazz, boolean basic) {
if(map == null){
return null;
}
E newInstance = null;
//基本类型,说明返回值只有一列
if (basic) {
newInstance = (E) map.values().stream().findFirst().get();

} else {
try {
final String valueAsString = objectMapper.writeValueAsString(map);
newInstance = (E) objectMapper.readValue(valueAsString, clazz);
} catch (JsonProcessingException e) {
e.printStackTrace();
}
}
return newInstance;
}
}

ApplicationContextProvider

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package com.example.common.support;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.i18n.LocaleContextHolder;
import org.springframework.stereotype.Component;

@Component
public class ApplicationContextProvider
implements ApplicationContextAware {
/**
* 上下文对象实例
*/
private static ApplicationContext applicationContext;

/**
* 获取applicationContext
*
* @return
*/
public static ApplicationContext getApplicationContext() {
return applicationContext;
}

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
ApplicationContextProvider.applicationContext = applicationContext;
}

/**
* 通过name获取 Bean.
*
* @param name
* @return
*/
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}

/**
* 通过class获取Bean.
*
* @param clazz
* @param <T>
* @return
*/
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}

/**
* 通过name,以及Clazz返回指定的Bean
*
* @param name
* @param clazz
* @param <T>
* @return
*/
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}


/**
* 描述 : <获得多语言的资源内容>. <br>
* <p>
* <使用方法说明>
* </p>
*
* @param code
* @param args
* @return
*/
public static String getMessage(String code, Object[] args) {
return getApplicationContext().getMessage(code, args, LocaleContextHolder.getLocale());
}

/**
* 描述 : <获得多语言的资源内容>. <br>
* <p>
* <使用方法说明>
* </p>
*
* @param code
* @param args
* @param defaultMessage
* @return
*/
public static String getMessage(String code, Object[] args,
String defaultMessage) {
return getApplicationContext().getMessage(code, args, defaultMessage,
LocaleContextHolder.getLocale());
}
}

配置类上增加 @EnableJpaRepositories(repositoryBaseClass = BaseJpaRepositoryImpl.class)

1
2
3
4
5
@Configuration
@EnableJpaRepositories(repositoryBaseClass = BaseJpaRepositoryImpl.class, basePackages = "com.example.demo")
@EntityScan(basePackages = "com.example.demo")
public class JpaConfig {
}

业务Jpa对象继承BaseJpaRepository

1
2
3
4
5
6
7
8
9
public interface CountryJpaRepository extends BaseJpaRepository<Country, Long> {
//这种SELECT * 的形式也可以转换成对象,但是如果只查询某几个属性时,就会转换失败
@Query(value = "SELECT * FROM tbl_country", countQuery = "SELECT count(*) FROM tbl_country", nativeQuery = true)
Page<Country> findAll(Pageable pageable);

//失败
@Query(value = "SELECT id,name_zh FROM tbl_country WHERE name_zh = ?1",nativeQuery = true)
Country findByName(String name);
}

示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@GetMapping("/")
public Map index() {
Map<String, Object> map = new HashMap<>();

//查询总数,注意count求和返回的对象是BigInteger类型
final BigInteger size = countryJpaRepository.findBySqlFirst("SELECT count(*) as count FROM tbl_country", BigInteger.class, true);
map.put("size", size);

//分页查询
Pageable pageable = PageRequest.of(1, 20, Sort.by(Sort.Direction.DESC, "id").and(Sort.by(Sort.Direction.ASC, "name_zh")));
final Page<Country> pages = countryJpaRepository.findPageBySql("SELECT id,name_zh as nameZh ,name_en as nameEn FROM tbl_country", pageable, Country.class, false);

map.put("pages", pages);



//批量新增
List<DemoEntity> entityList = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
DemoEntity demoEntity = new DemoEntity();
demoEntity.setName("测试" + i);
demoEntity.initInsert();
entityList.add(demoEntity);
}
demoEntityJpaRepository.batchSave(entityList);

final Page<DemoEntity> demoEntities = demoEntityJpaRepository.findAll(PageRequest.of(1, 100, Sort.by(Sort.Direction.DESC, "name")));

//批量删除
demoEntityJpaRepository.deleteAllInBatch(demoEntities.getContent());

return map;
}