refactor: SqlProcessor => SimpleSpringJdbcTemplateSqlProcessor 💡

This commit is contained in:
Echo009 2021-03-12 14:37:40 +08:00
parent 30d0d7d338
commit 1d67e97b45
6 changed files with 253 additions and 166 deletions

View File

@ -1,4 +1,4 @@
package tech.powerjob.official.processors.impl;
package tech.powerjob.official.processors.impl.sql;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
@ -7,81 +7,55 @@ import com.github.kfcfans.powerjob.worker.log.OmsLogger;
import com.google.common.collect.Maps;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;
import org.springframework.util.StringUtils;
import tech.powerjob.official.processors.CommonBasicProcessor;
import tech.powerjob.official.processors.util.CommonUtils;
import javax.sql.DataSource;
import java.util.Map;
import java.util.function.Predicate;
/**
* SQL 处理器只能用 Spring Bean 的方式加载
* 注意 : 默认情况下没有过滤任何 SQL
* 建议生产环境一定要使用 {@link SqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
* SQL Processor
*
* 默认情况下会直接执行参数中的 SQL
* 可以通过添加 {@link SqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
* 处理流程
* * * 解析参数 => 校验参数 => 解析 SQL => 校验 SQL => 执行 SQL
*
* 可以通过 {@link AbstractSqlProcessor#registerSqlValidator} 方法注册 SQL 校验器拦截非法 SQL
* 可以通过指定 {@link AbstractSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
*
* @author Echo009
* @since 2021/3/10
* @since 2021/3/12
*/
@Slf4j
public class SqlProcessor extends CommonBasicProcessor {
/**
* 默认的数据源名称
*/
private static final String DEFAULT_DATASOURCE_NAME = "default";
public abstract class AbstractSqlProcessor extends CommonBasicProcessor {
/**
* 默认超时时间
*/
private static final int DEFAULT_TIMEOUT = 60;
/**
* name => data source
*/
private final Map<String, DataSource> dataSourceMap;
protected static final int DEFAULT_TIMEOUT = 60;
/**
* name => SQL validator
* 注意
* - 返回 true 表示验证通过
* - 返回 false 表示 SQL 非法将被拒绝执行
*/
private final Map<String, Predicate<String>> sqlValidatorMap;
protected final Map<String, Predicate<String>> sqlValidatorMap = Maps.newConcurrentMap();
/**
* 自定义 SQL 解析器
*/
private SqlParser sqlParser;
/**
* 指定默认的数据源
*
* @param defaultDataSource 默认数据源
*/
public SqlProcessor(DataSource defaultDataSource) {
dataSourceMap = Maps.newConcurrentMap();
sqlValidatorMap = Maps.newConcurrentMap();
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
}
protected SqlParser sqlParser;
@Override
protected ProcessResult process0(TaskContext taskContext) {
public ProcessResult process0(TaskContext taskContext) {
OmsLogger omsLogger = taskContext.getOmsLogger();
SqlParams sqlParams = JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
// 检查数据源
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
omsLogger.info("current data source name is empty, use the default data source");
}
DataSource dataSource = dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
});
StopWatch stopWatch = new StopWatch("SQL Processor");
// 解析参数
SqlParams sqlParams = extractParams(taskContext);
omsLogger.info("[AbstractSqlProcessor-{}]origin sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
// 校验参数
validateParams(sqlParams);
StopWatch stopWatch = new StopWatch("SQL Processor");
// 解析
stopWatch.start("Parse SQL");
if (sqlParser != null) {
@ -96,17 +70,43 @@ public class SqlProcessor extends CommonBasicProcessor {
// 执行
stopWatch.start("Execute SQL");
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
jdbcTemplate.setSkipResultsProcessing(true);
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
omsLogger.info("start to execute sql: {}", sqlParams.getSql());
jdbcTemplate.execute(sqlParams.getSql());
omsLogger.info("[AbstractSqlProcessor-{}]final sql params: {}", taskContext.getInstanceId(), JSON.toJSON(sqlParams));
executeSql(sqlParams, taskContext);
stopWatch.stop();
omsLogger.info(stopWatch.prettyPrint());
String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
return new ProcessResult(true, message);
}
/**
* 执行 SQL
*
* @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文
*/
abstract void executeSql(SqlParams sqlParams, TaskContext taskContext);
/**
* 提取参数信息
*
* @param taskContext 任务上下文
* @return SqlParams
*/
protected SqlParams extractParams(TaskContext taskContext) {
return JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
}
/**
* 校验参数如果校验不通过直接抛异常
*
* @param sqlParams SQL 参数信息
*/
protected void validateParams(SqlParams sqlParams) {
// do nothing
}
/**
* 设置 SQL 验证器
*
@ -116,6 +116,7 @@ public class SqlProcessor extends CommonBasicProcessor {
this.sqlParser = sqlParser;
}
/**
* 注册一个 SQL 验证器
*
@ -124,33 +125,7 @@ public class SqlProcessor extends CommonBasicProcessor {
*/
public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
sqlValidatorMap.put(validatorName, sqlValidator);
log.info("[SqlProcessor]register sql validator({})' successfully.", validatorName);
}
/**
* 注册数据源
*
* @param dataSourceName 数据源名称
* @param dataSource 数据源
*/
public void registerDataSource(String dataSourceName, DataSource dataSource) {
Assert.notNull(dataSourceName, "DataSource name must not be null");
Assert.notNull(dataSource, "DataSource must not be null");
dataSourceMap.put(dataSourceName, dataSource);
log.info("[SqlProcessor]register data source({})' successfully.", dataSourceName);
}
/**
* 移除数据源
*
* @param dataSourceName 数据源名称
*/
public void removeDataSource(String dataSourceName) {
DataSource remove = dataSourceMap.remove(dataSourceName);
if (remove != null) {
log.warn("[SqlProcessor]remove data source({})' successfully.", dataSourceName);
}
log.info("register sql validator({})' successfully.", validatorName);
}
@ -169,6 +144,7 @@ public class SqlProcessor extends CommonBasicProcessor {
}
}
@Data
public static class SqlParams {
/**
@ -183,9 +159,15 @@ public class SqlProcessor extends CommonBasicProcessor {
* 超时时间
*/
private Integer timeout;
/**
* jdbc url
* 具体格式可参考 https://www.baeldung.com/java-jdbc-url-format
*/
private String jdbcUrl;
}
@FunctionalInterface
public interface SqlParser {
/**
@ -198,4 +180,5 @@ public class SqlProcessor extends CommonBasicProcessor {
String parse(String sql, TaskContext taskContext);
}
}

View File

@ -0,0 +1,103 @@
package tech.powerjob.official.processors.impl.sql;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.util.Map;
/**
* 简单 Spring SQL 处理器强依赖于 Spring jdbc 只能用 Spring Bean 的方式加载
* 直接忽略 SQL 执行的返回值
*
* 注意 :
* 默认情况下没有过滤任何 SQL
* 建议生产环境一定要使用 {@link SimpleSpringJdbcTemplateSqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
*
* 默认情况下会直接执行参数中的 SQL
* 可以通过添加 {@link SimpleSpringJdbcTemplateSqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
*
* @author Echo009
* @since 2021/3/10
*/
@Slf4j
public class SimpleSpringJdbcTemplateSqlProcessor extends AbstractSqlProcessor {
/**
* 默认的数据源名称
*/
private static final String DEFAULT_DATASOURCE_NAME = "default";
/**
* name => data source
*/
private final Map<String, DataSource> dataSourceMap;
/**
* 指定默认的数据源
*
* @param defaultDataSource 默认数据源
*/
public SimpleSpringJdbcTemplateSqlProcessor(DataSource defaultDataSource) {
dataSourceMap = Maps.newConcurrentMap();
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
}
/**
* 校验参数如果校验不通过直接抛异常
*
* @param sqlParams SQL 参数信息
*/
@Override
protected void validateParams(SqlParams sqlParams) {
// 检查数据源
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
// use the default data source when current data source name is empty
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
}
dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
});
}
/**
* 执行 SQL
*
* @param sqlParams SQL processor 参数信息
* @param taskContext 任务上下文
*/
@Override
void executeSql(SqlParams sqlParams, TaskContext taskContext) {
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSourceMap.get(sqlParams.getDataSourceName()));
jdbcTemplate.setSkipResultsProcessing(true);
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
jdbcTemplate.execute(sqlParams.getSql());
}
/**
* 注册数据源
*
* @param dataSourceName 数据源名称
* @param dataSource 数据源
*/
public void registerDataSource(String dataSourceName, DataSource dataSource) {
Assert.notNull(dataSourceName, "DataSource name must not be null");
Assert.notNull(dataSource, "DataSource must not be null");
dataSourceMap.put(dataSourceName, dataSource);
log.info("register data source({})' successfully.", dataSourceName);
}
/**
* 移除数据源
*
* @param dataSourceName 数据源名称
*/
public void removeDataSource(String dataSourceName) {
DataSource remove = dataSourceMap.remove(dataSourceName);
if (remove != null) {
log.warn("remove data source({})' successfully.", dataSourceName);
}
}
}

View File

@ -0,0 +1,81 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SimpleSpringJdbcTemplateSqlProcessorTest {
private static SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(database);
// do nothing
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult = simpleSpringJdbcTemplateSqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
static SimpleSpringJdbcTemplateSqlProcessor.SqlParams constructSqlParam(String sql){
SimpleSpringJdbcTemplateSqlProcessor.SqlParams sqlParams = new SimpleSpringJdbcTemplateSqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -1,80 +0,0 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SqlProcessorTest {
private static SqlProcessor sqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
sqlProcessor = new SqlProcessor(database);
// do nothing
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
static SqlProcessor.SqlParams constructSqlParam(String sql){
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -8,7 +8,7 @@ import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import tech.powerjob.official.processors.impl.SqlProcessor;
import tech.powerjob.official.processors.impl.sql.SimpleSpringJdbcTemplateSqlProcessor;
import javax.sql.DataSource;
@ -37,15 +37,15 @@ public class SqlProcessorConfiguration {
@Bean
public SqlProcessor sqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
SqlProcessor sqlProcessor = new SqlProcessor(dataSource);
public SimpleSpringJdbcTemplateSqlProcessor springSqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
SimpleSpringJdbcTemplateSqlProcessor simpleSpringJdbcTemplateSqlProcessor = new SimpleSpringJdbcTemplateSqlProcessor(dataSource);
// do nothing
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
simpleSpringJdbcTemplateSqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
return sqlProcessor;
simpleSpringJdbcTemplateSqlProcessor.setSqlParser((sql, taskContext) -> sql);
return simpleSpringJdbcTemplateSqlProcessor;
}
}

View File

@ -13,11 +13,11 @@ public interface BasicProcessor {
/**
* 核心处理逻辑
* 可通过 {@link TaskContext#workflowContext} 获取工作流上下文
* 可通过 {@link TaskContext#getWorkflowContext()} 方法获取工作流上下文
*
* @param context 任务上下文可通过 jobParams instanceParams 分别获取控制台参数和OpenAPI传递的任务实例参数
* @return 处理结果msg有长度限制超长会被裁剪不允许返回 null
* @throws Exception 异常允许抛出异常但不推荐最好由业务开发者自己处理
* @throws Exception 异常允许抛出异常但不推荐最好由业务开发者自己处理
*/
ProcessResult process(TaskContext context) throws Exception;
}