feat: SQL processor

This commit is contained in:
Echo009 2021-03-11 11:40:25 +08:00
parent f4d7459f63
commit 87ed304737
8 changed files with 365 additions and 3 deletions

View File

@ -21,6 +21,8 @@
<junit.version>5.6.1</junit.version>
<logback.version>1.2.3</logback.version>
<powerjob.worker.version>3.4.6</powerjob.worker.version>
<spring.jdbc.version>5.2.9.RELEASE</spring.jdbc.version>
<h2.db.version>1.4.200</h2.db.version>
<!-- 全部 shade 化,避免依赖冲突 -->
<fastjson.version>1.2.68</fastjson.version>
@ -65,6 +67,13 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-jdbc</artifactId>
<version>${spring.jdbc.version}</version>
<scope>provided</scope>
</dependency>
<!-- Junit tests -->
<dependency>
<groupId>org.junit.jupiter</groupId>
@ -80,6 +89,14 @@
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<!-- h2 database -->
<dependency>
<groupId>com.h2database</groupId>
<artifactId>h2</artifactId>
<version>${h2.db.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
@ -104,6 +121,7 @@
<shadedPattern>shade.powerjob.org</shadedPattern>
<excludes>
<exclude>org.slf4j.*</exclude>
<exclude>org.springframework.*</exclude>
</excludes>
</relocation>
<relocation>

View File

@ -0,0 +1,201 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.github.kfcfans.powerjob.worker.log.OmsLogger;
import com.google.common.collect.Maps;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;
import org.springframework.util.StringUtils;
import tech.powerjob.official.processors.CommonBasicProcessor;
import tech.powerjob.official.processors.util.CommonUtils;
import javax.sql.DataSource;
import java.util.Map;
import java.util.function.Predicate;
/**
* SQL 处理器只能用 Spring Bean 的方式加载
* 注意 : 默认情况下没有过滤任何 SQL
* 建议生产环境一定要使用 {@link SqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
*
* 默认情况下会直接执行参数中的 SQL
* 可以通过添加 {@link SqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求比如 宏变量替换参数替换等
*
* @author Echo009
* @since 2021/3/10
*/
@Slf4j
public class SqlProcessor extends CommonBasicProcessor {
/**
* 默认的数据源名称
*/
private static final String DEFAULT_DATASOURCE_NAME = "default";
/**
* 默认超时时间
*/
private static final int DEFAULT_TIMEOUT = 60;
/**
* name => data source
*/
private final Map<String, DataSource> dataSourceMap;
/**
* name => SQL validator
* 注意
* - 返回 true 表示验证通过
* - 返回 false 表示 SQL 非法将被拒绝执行
*/
private final Map<String, Predicate<String>> sqlValidatorMap;
/**
* 自定义 SQL 解析器
*/
private SqlParser sqlParser;
/**
* 指定默认的数据源
*
* @param defaultDataSource 默认数据源
*/
public SqlProcessor(DataSource defaultDataSource) {
dataSourceMap = Maps.newConcurrentMap();
sqlValidatorMap = Maps.newConcurrentMap();
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
}
@Override
protected ProcessResult process0(TaskContext taskContext) {
OmsLogger omsLogger = taskContext.getOmsLogger();
SqlParams sqlParams = JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
// 检查数据源
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
omsLogger.info("current data source name is empty, use the default data source");
}
DataSource dataSource = dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
});
StopWatch stopWatch = new StopWatch("SQL Processor");
// 解析
stopWatch.start("Parse SQL");
if (sqlParser != null) {
sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext));
}
stopWatch.stop();
// 校验 SQL
stopWatch.start("Validate SQL");
validateSql(sqlParams.getSql());
stopWatch.stop();
// 执行
stopWatch.start("Execute SQL");
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
jdbcTemplate.setSkipResultsProcessing(true);
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
omsLogger.info("start to execute sql: {}", sqlParams.getSql());
jdbcTemplate.execute(sqlParams.getSql());
stopWatch.stop();
omsLogger.info(stopWatch.prettyPrint());
String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
return new ProcessResult(true, message);
}
/**
* 设置 SQL 验证器
*
* @param sqlParser SQL 解析器
*/
public void setSqlParser(SqlParser sqlParser) {
this.sqlParser = sqlParser;
}
/**
* 注册一个 SQL 验证器
*
* @param validatorName 验证器名称
* @param sqlValidator 验证器
*/
public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
sqlValidatorMap.put(validatorName, sqlValidator);
log.info("[SqlProcessor]register sql validator({})' successfully.", validatorName);
}
/**
* 注册数据源
*
* @param dataSourceName 数据源名称
* @param dataSource 数据源
*/
public void registerDataSource(String dataSourceName, DataSource dataSource) {
Assert.notNull(dataSourceName, "DataSource name must not be null");
Assert.notNull(dataSource, "DataSource must not be null");
dataSourceMap.put(dataSourceName, dataSource);
log.info("[SqlProcessor]register data source({})' successfully.", dataSourceName);
}
/**
* 移除数据源
*
* @param dataSourceName 数据源名称
*/
public void removeDataSource(String dataSourceName) {
DataSource remove = dataSourceMap.remove(dataSourceName);
if (remove != null) {
log.warn("[SqlProcessor]remove data source({})' successfully.", dataSourceName);
}
}
/**
* 校验 SQL 合法性
*/
private void validateSql(String sql) {
if (sqlValidatorMap.isEmpty()) {
return;
}
for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
Predicate<String> validator = entry.getValue();
if (!validator.test(sql)) {
throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
}
}
}
@Data
public static class SqlParams {
/**
* 数据源名称
*/
private String dataSourceName;
/**
* 需要执行的 SQL
*/
private String sql;
/**
* 超时时间
*/
private Integer timeout;
}
@FunctionalInterface
public interface SqlParser {
/**
* 自定义 SQL 解析逻辑
*
* @param sql 原始 SQL 语句
* @param taskContext 任务上下文
* @return 解析后的 SQL
*/
String parse(String sql, TaskContext taskContext);
}
}

View File

@ -1,8 +1,8 @@
package tech.powerjob.official.processors;
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
import com.github.kfcfans.powerjob.worker.core.processor.WorkflowContext;
import com.github.kfcfans.powerjob.worker.log.impl.OmsLocalLogger;
import com.github.kfcfans.powerjob.worker.log.impl.OmsServerLogger;
import java.util.concurrent.ThreadLocalRandom;
@ -25,7 +25,7 @@ public class TestUtils {
taskContext.setTaskId("0.0");
taskContext.setTaskName("TEST_TASK");
taskContext.setOmsLogger(new OmsLocalLogger());
taskContext.setWorkflowContext(new WorkflowContext(null, null));
return taskContext;
}
}

View File

@ -0,0 +1,80 @@
package tech.powerjob.official.processors.impl;
import com.alibaba.fastjson.JSON;
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
import tech.powerjob.official.processors.TestUtils;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
* @author Echo009
* @since 2021/3/11
*/
@Slf4j
class SqlProcessorTest {
private static SqlProcessor sqlProcessor;
@BeforeAll
static void initSqlProcessor() {
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
.addScript("classpath:db_init.sql")
.build();
sqlProcessor = new SqlProcessor(database);
// do nothing
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
log.info("init sql processor successfully!");
}
@Test
void testSqlValidator() {
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
sqlParams.setSql("drop table test_table");
// 校验不通过
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testIncorrectDataSourceName() {
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
// 数据源名称非法
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
}
@Test
void testExecDDL() {
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
@Test
void testExecSQL() {
SqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
assertTrue(processResult.isSuccess());
}
static SqlProcessor.SqlParams constructSqlParam(String sql){
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
sqlParams.setSql(sql);
return sqlParams;
}
}

View File

@ -0,0 +1,7 @@
create table test_table
(
id bigint(20) primary key,
content varchar(255),
gmt_create datetime default now(),
gmt_modified datetime default now()
);

View File

@ -16,6 +16,7 @@
<springboot.version>2.2.6.RELEASE</springboot.version>
<powerjob.worker.starter.version>3.4.6</powerjob.worker.starter.version>
<fastjson.version>1.2.68</fastjson.version>
<powerjob.official.processors.version>1.0.1</powerjob.official.processors.version>
<!-- 部署时跳过该module -->
<maven.deploy.skip>true</maven.deploy.skip>
@ -52,6 +53,11 @@
<artifactId>fastjson</artifactId>
<version>${fastjson.version}</version>
</dependency>
<dependency>
<groupId>com.github.kfcfans</groupId>
<artifactId>powerjob-official-processors</artifactId>
<version>${powerjob.official.processors.version}</version>
</dependency>
</dependencies>

View File

@ -0,0 +1,51 @@
package com.github.kfcfans.powerjob.samples.config;
import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils;
import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import org.h2.Driver;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import tech.powerjob.official.processors.impl.SqlProcessor;
import javax.sql.DataSource;
/**
* @author Echo009
* @since 2021/3/10
*/
@Configuration
public class SqlProcessorConfiguration {
@Bean
@DependsOn({"initPowerJob"})
public DataSource sqlProcessorDataSource() {
String jdbcUrl = String.format("jdbc:h2:file:%spowerjob_sql_processor_db;DB_CLOSE_DELAY=-1;DATABASE_TO_UPPER=false", OmsWorkerFileUtils.getH2WorkDir());
HikariConfig config = new HikariConfig();
config.setDriverClassName(Driver.class.getName());
config.setJdbcUrl(jdbcUrl);
config.setAutoCommit(true);
// 池中最小空闲连接数量
config.setMinimumIdle(1);
// 池中最大连接数量
config.setMaximumPoolSize(10);
return new HikariDataSource(config);
}
@Bean
public SqlProcessor sqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
SqlProcessor sqlProcessor = new SqlProcessor(dataSource);
// do nothing
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
// 排除掉包含 drop SQL
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
// do nothing
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
return sqlProcessor;
}
}

View File

@ -1,6 +1,5 @@
package com.github.kfcfans.powerjob.worker.persistence;
import com.github.kfcfans.powerjob.worker.OhMyWorker;
import com.github.kfcfans.powerjob.worker.common.constants.StoreStrategy;
import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils;
import com.zaxxer.hikari.HikariConfig;