refactor: SimpleSpringJdbcTemplateSqlProcessor => SimpleSpringSqlProcessor

This commit is contained in:
Echo009 2021-03-13 14:49:42 +08:00
parent 86b584be2f
commit 93158ba19b
4 changed files with 133 additions and 100 deletions

View File

@ -2,30 +2,32 @@ package tech.powerjob.official.processors.impl.sql;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.google.common.collect.Maps;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.Statement;
import java.util.Map;
/**
* 简单 Spring SQL 处理器强依赖于 Spring jdbc 只能用 Spring Bean 的方式加载
* 简单 Spring SQL 处理器目前只能用 Spring Bean 的方式加载
* 直接忽略 SQL 执行的返回值
*
* 注意 :
* 默认情况下没有过滤任何 SQL
* 建议生产环境一定要使用 {@link SimpleSpringJdbcTemplateSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
* 建议生产环境一定要使用 {@link SimpleSpringSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
*
* 默认情况下会直接执行参数中的 SQL
* 可以通过添加 {@link SimpleSpringJdbcTemplateSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
* 可以通过添加 {@link SimpleSpringSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
*
* @author Echo009
* @since 2021/3/10
*/
@Slf4j
public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
public class SimpleSpringSqlProcessor extends AbstractSqlProcessor {
/**
* 默认的数据源名称
*/
@ -40,7 +42,7 @@ public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
*
* @param defaultDataSource 默认数据源
*/
public SimpleSpringJdbcTemplateSqlProcessor(DataSource defaultDataSource) {
public SimpleSpringSqlProcessor(DataSource defaultDataSource) {
dataSourceMap = Maps.newConcurrentMap();
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
}
@ -63,17 +65,33 @@ public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
}
/**
* 执行 SQL
* 执行 SQL忽略返回值
*
* @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文
*/
@Override
void executeSql(SqlParams sqlParams, TaskContext taskContext) {
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSourceMap.get(sqlParams.getDataSourceName()));
jdbcTemplate.setSkipResultsProcessing(true);
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
jdbcTemplate.execute(sqlParams.getSql());
@SneakyThrows
@SuppressWarnings({"squid:S1181"})
protected void executeSql(SqlParams sqlParams, TaskContext taskContext) {
DataSource currentDataSource = dataSourceMap.get(sqlParams.getDataSourceName());
boolean originAutoCommitFlag ;
try (Connection connection = currentDataSource.getConnection()) {
originAutoCommitFlag = connection.getAutoCommit();
connection.setAutoCommit(false);
try (Statement statement = connection.createStatement()) {
statement.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
statement.execute(sqlParams.getSql());
connection.commit();
} catch (Throwable e) {
connection.rollback();
// rethrow
throw e;
} finally {
// reset
connection.setAutoCommit(originAutoCommitFlag);
}
}
}
/**

View File

@ -1,81 +0,0 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SimpleSpringJdbcTemplateSqlProcessorTest {
private static SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(database);
// do nothing
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
static SimpleSpringJdbcTemplateSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -0,0 +1,96 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.h2.jdbc.JdbcSQLIntegrityConstraintViolationException;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
import static org.junit.jupiter.api.Assertions.*;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SimpleSpringSqlProcessorTest {
private static SimpleSpringSqlProcessor simpleSpringSqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
simpleSpringSqlProcessor = new SimpleSpringSqlProcessor(database);
// do nothing
simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SimpleSpringSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SimpleSpringSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SimpleSpringSqlProcessor.SqlParams sqlParams1 = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult1 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1)));
assertTrue(processResult1.isSuccess());
assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams1))));
// 第二条会失败回滚
SimpleSpringSqlProcessor.SqlParams sqlParams2 = constructSqlParam("insert into test_table (id, content) values (1, '?');insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
assertThrows(JdbcSQLIntegrityConstraintViolationException.class, () -> simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams2))));
// 上方回滚这里就能成功插入
SimpleSpringSqlProcessor.SqlParams sqlParams3 = constructSqlParam("insert into test_table (id, content) values (1, '?')");
ProcessResult processResult3 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams3)));
assertTrue(processResult3.isSuccess());
SimpleSpringSqlProcessor.SqlParams sqlParams4 = constructSqlParam("insert into test_table (id, content) values (2, '?');insert into test_table (id, content) values (3, '?')");
ProcessResult processResult4 = simpleSpringSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams4)));
assertTrue(processResult4.isSuccess());
}
static SimpleSpringSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringSqlProcessor.SqlParams sqlParams = new SimpleSpringSqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -8,7 +8,7 @@ import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
import tech.powerjob.official.processors.impl.sql.SimpleSpringSqlProcessor;
import javax.sql.DataSource;
@ -37,15 +37,15 @@ public class SqlProcessorConfiguration {
@Bean
public SimpleSpringJdbcTemplateSqlProcessor springSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(dataSource);
public SimpleSpringSqlProcessor simpleSpringSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
SimpleSpringSqlProcessor simpleSpringSqlProcessor = new SimpleSpringSqlProcessor(dataSource);
// do nothing
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
simpleSpringSqlProcessor.registerSqlValidator("fakeSqlValidator", sql -> true);
// 排除掉包含 drop SQL
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
simpleSpringSqlProcessor.registerSqlValidator("interceptDropValidator", sql -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
return simpleSpringJdbcTemplateSqlProcessor;
simpleSpringSqlProcessor.setSqlParser((sql, taskContext) -> sql);
return simpleSpringSqlProcessor;
}
}