mirror of
https://github.com/PowerJob/PowerJob.git
synced 2025-07-17 00:00:04 +08:00
feat: SQL processor
This commit is contained in:
parent
f4d7459f63
commit
87ed304737
@ -21,6 +21,8 @@
|
||||
<junit.version>5.6.1</junit.version>
|
||||
<logback.version>1.2.3</logback.version>
|
||||
<powerjob.worker.version>3.4.6</powerjob.worker.version>
|
||||
<spring.jdbc.version>5.2.9.RELEASE</spring.jdbc.version>
|
||||
<h2.db.version>1.4.200</h2.db.version>
|
||||
|
||||
<!-- 全部 shade 化,避免依赖冲突 -->
|
||||
<fastjson.version>1.2.68</fastjson.version>
|
||||
@ -65,6 +67,13 @@
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-jdbc</artifactId>
|
||||
<version>${spring.jdbc.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- Junit tests -->
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
@ -80,6 +89,14 @@
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<!-- h2 database -->
|
||||
<dependency>
|
||||
<groupId>com.h2database</groupId>
|
||||
<artifactId>h2</artifactId>
|
||||
<version>${h2.db.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
@ -104,6 +121,7 @@
|
||||
<shadedPattern>shade.powerjob.org</shadedPattern>
|
||||
<excludes>
|
||||
<exclude>org.slf4j.*</exclude>
|
||||
<exclude>org.springframework.*</exclude>
|
||||
</excludes>
|
||||
</relocation>
|
||||
<relocation>
|
||||
|
@ -0,0 +1,201 @@
|
||||
package tech.powerjob.official.processors.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
|
||||
import com.github.kfcfans.powerjob.worker.log.OmsLogger;
|
||||
import com.google.common.collect.Maps;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.jdbc.core.JdbcTemplate;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.util.StopWatch;
|
||||
import org.springframework.util.StringUtils;
|
||||
import tech.powerjob.official.processors.CommonBasicProcessor;
|
||||
import tech.powerjob.official.processors.util.CommonUtils;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
import java.util.Map;
|
||||
import java.util.function.Predicate;
|
||||
|
||||
/**
|
||||
* SQL 处理器,只能用 Spring Bean 的方式加载
|
||||
* 注意 : 默认情况下没有过滤任何 SQL
|
||||
* 建议生产环境一定要使用 {@link SqlProcessor#registerSqlValidator} 方法注册至少一个校验器拦截非法 SQL
|
||||
*
|
||||
* 默认情况下会直接执行参数中的 SQL
|
||||
* 可以通过添加 {@link SqlProcessor.SqlParser} 来实现定制 SQL 解析逻辑的需求(比如 宏变量替换,参数替换等)
|
||||
*
|
||||
* @author Echo009
|
||||
* @since 2021/3/10
|
||||
*/
|
||||
@Slf4j
|
||||
public class SqlProcessor extends CommonBasicProcessor {
|
||||
/**
|
||||
* 默认的数据源名称
|
||||
*/
|
||||
private static final String DEFAULT_DATASOURCE_NAME = "default";
|
||||
/**
|
||||
* 默认超时时间
|
||||
*/
|
||||
private static final int DEFAULT_TIMEOUT = 60;
|
||||
/**
|
||||
* name => data source
|
||||
*/
|
||||
private final Map<String, DataSource> dataSourceMap;
|
||||
/**
|
||||
* name => SQL validator
|
||||
* 注意 :
|
||||
* - 返回 true 表示验证通过
|
||||
* - 返回 false 表示 SQL 非法,将被拒绝执行
|
||||
*/
|
||||
private final Map<String, Predicate<String>> sqlValidatorMap;
|
||||
/**
|
||||
* 自定义 SQL 解析器
|
||||
*/
|
||||
private SqlParser sqlParser;
|
||||
|
||||
/**
|
||||
* 指定默认的数据源
|
||||
*
|
||||
* @param defaultDataSource 默认数据源
|
||||
*/
|
||||
public SqlProcessor(DataSource defaultDataSource) {
|
||||
dataSourceMap = Maps.newConcurrentMap();
|
||||
sqlValidatorMap = Maps.newConcurrentMap();
|
||||
registerDataSource(DEFAULT_DATASOURCE_NAME, defaultDataSource);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
protected ProcessResult process0(TaskContext taskContext) {
|
||||
|
||||
OmsLogger omsLogger = taskContext.getOmsLogger();
|
||||
SqlParams sqlParams = JSON.parseObject(CommonUtils.parseParams(taskContext), SqlParams.class);
|
||||
// 检查数据源
|
||||
if (StringUtils.isEmpty(sqlParams.getDataSourceName())) {
|
||||
sqlParams.setDataSourceName(DEFAULT_DATASOURCE_NAME);
|
||||
omsLogger.info("current data source name is empty, use the default data source");
|
||||
}
|
||||
DataSource dataSource = dataSourceMap.computeIfAbsent(sqlParams.getDataSourceName(), dataSourceName -> {
|
||||
throw new IllegalArgumentException("can't find data source with name " + dataSourceName);
|
||||
});
|
||||
StopWatch stopWatch = new StopWatch("SQL Processor");
|
||||
|
||||
// 解析
|
||||
stopWatch.start("Parse SQL");
|
||||
if (sqlParser != null) {
|
||||
sqlParams.setSql(sqlParser.parse(sqlParams.getSql(), taskContext));
|
||||
}
|
||||
stopWatch.stop();
|
||||
|
||||
// 校验 SQL
|
||||
stopWatch.start("Validate SQL");
|
||||
validateSql(sqlParams.getSql());
|
||||
stopWatch.stop();
|
||||
|
||||
// 执行
|
||||
stopWatch.start("Execute SQL");
|
||||
JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
|
||||
jdbcTemplate.setSkipResultsProcessing(true);
|
||||
jdbcTemplate.setQueryTimeout(sqlParams.getTimeout() == null ? DEFAULT_TIMEOUT : sqlParams.getTimeout());
|
||||
omsLogger.info("start to execute sql: {}", sqlParams.getSql());
|
||||
jdbcTemplate.execute(sqlParams.getSql());
|
||||
stopWatch.stop();
|
||||
omsLogger.info(stopWatch.prettyPrint());
|
||||
String message = String.format("execute successfully, used time: %s millisecond", stopWatch.getTotalTimeMillis());
|
||||
return new ProcessResult(true, message);
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置 SQL 验证器
|
||||
*
|
||||
* @param sqlParser SQL 解析器
|
||||
*/
|
||||
public void setSqlParser(SqlParser sqlParser) {
|
||||
this.sqlParser = sqlParser;
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册一个 SQL 验证器
|
||||
*
|
||||
* @param validatorName 验证器名称
|
||||
* @param sqlValidator 验证器
|
||||
*/
|
||||
public void registerSqlValidator(String validatorName, Predicate<String> sqlValidator) {
|
||||
sqlValidatorMap.put(validatorName, sqlValidator);
|
||||
log.info("[SqlProcessor]register sql validator({})' successfully.", validatorName);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 注册数据源
|
||||
*
|
||||
* @param dataSourceName 数据源名称
|
||||
* @param dataSource 数据源
|
||||
*/
|
||||
public void registerDataSource(String dataSourceName, DataSource dataSource) {
|
||||
Assert.notNull(dataSourceName, "DataSource name must not be null");
|
||||
Assert.notNull(dataSource, "DataSource must not be null");
|
||||
dataSourceMap.put(dataSourceName, dataSource);
|
||||
log.info("[SqlProcessor]register data source({})' successfully.", dataSourceName);
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除数据源
|
||||
*
|
||||
* @param dataSourceName 数据源名称
|
||||
*/
|
||||
public void removeDataSource(String dataSourceName) {
|
||||
DataSource remove = dataSourceMap.remove(dataSourceName);
|
||||
if (remove != null) {
|
||||
log.warn("[SqlProcessor]remove data source({})' successfully.", dataSourceName);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* 校验 SQL 合法性
|
||||
*/
|
||||
private void validateSql(String sql) {
|
||||
if (sqlValidatorMap.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
for (Map.Entry<String, Predicate<String>> entry : sqlValidatorMap.entrySet()) {
|
||||
Predicate<String> validator = entry.getValue();
|
||||
if (!validator.test(sql)) {
|
||||
throw new IllegalArgumentException("illegal sql, can't pass the validation of " + entry.getKey());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class SqlParams {
|
||||
/**
|
||||
* 数据源名称
|
||||
*/
|
||||
private String dataSourceName;
|
||||
/**
|
||||
* 需要执行的 SQL
|
||||
*/
|
||||
private String sql;
|
||||
/**
|
||||
* 超时时间
|
||||
*/
|
||||
private Integer timeout;
|
||||
|
||||
}
|
||||
|
||||
@FunctionalInterface
|
||||
public interface SqlParser {
|
||||
/**
|
||||
* 自定义 SQL 解析逻辑
|
||||
*
|
||||
* @param sql 原始 SQL 语句
|
||||
* @param taskContext 任务上下文
|
||||
* @return 解析后的 SQL
|
||||
*/
|
||||
String parse(String sql, TaskContext taskContext);
|
||||
}
|
||||
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
package tech.powerjob.official.processors;
|
||||
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.TaskContext;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.WorkflowContext;
|
||||
import com.github.kfcfans.powerjob.worker.log.impl.OmsLocalLogger;
|
||||
import com.github.kfcfans.powerjob.worker.log.impl.OmsServerLogger;
|
||||
|
||||
import java.util.concurrent.ThreadLocalRandom;
|
||||
|
||||
@ -25,7 +25,7 @@ public class TestUtils {
|
||||
taskContext.setTaskId("0.0");
|
||||
taskContext.setTaskName("TEST_TASK");
|
||||
taskContext.setOmsLogger(new OmsLocalLogger());
|
||||
|
||||
taskContext.setWorkflowContext(new WorkflowContext(null, null));
|
||||
return taskContext;
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,80 @@
|
||||
package tech.powerjob.official.processors.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.github.kfcfans.powerjob.worker.core.processor.ProcessResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.junit.jupiter.api.BeforeAll;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
||||
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
||||
import tech.powerjob.official.processors.TestUtils;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||
|
||||
/**
|
||||
* @author Echo009
|
||||
* @since 2021/3/11
|
||||
*/
|
||||
@Slf4j
|
||||
class SqlProcessorTest {
|
||||
|
||||
private static SqlProcessor sqlProcessor;
|
||||
|
||||
@BeforeAll
|
||||
static void initSqlProcessor() {
|
||||
|
||||
EmbeddedDatabaseBuilder builder = new EmbeddedDatabaseBuilder();
|
||||
EmbeddedDatabase database = builder.setType(EmbeddedDatabaseType.H2)
|
||||
.addScript("classpath:db_init.sql")
|
||||
.build();
|
||||
sqlProcessor = new SqlProcessor(database);
|
||||
// do nothing
|
||||
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
||||
// 排除掉包含 drop 的 SQL
|
||||
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
||||
// do nothing
|
||||
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
||||
log.info("init sql processor successfully!");
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
void testSqlValidator() {
|
||||
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
|
||||
sqlParams.setSql("drop table test_table");
|
||||
// 校验不通过
|
||||
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testIncorrectDataSourceName() {
|
||||
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table task_info (a varchar(255), b varchar(255), c varchar(255))");
|
||||
sqlParams.setDataSourceName("(๑•̀ㅂ•́)و✧");
|
||||
// 数据源名称非法
|
||||
assertThrows(IllegalArgumentException.class, () -> sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams))));
|
||||
}
|
||||
|
||||
@Test
|
||||
void testExecDDL() {
|
||||
SqlProcessor.SqlParams sqlParams = constructSqlParam("create table power_job (a varchar(255), b varchar(255), c varchar(255))");
|
||||
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
||||
assertTrue(processResult.isSuccess());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testExecSQL() {
|
||||
SqlProcessor.SqlParams sqlParams = constructSqlParam("insert into test_table (id, content) values (0, 'Fight for a better tomorrow')");
|
||||
ProcessResult processResult = sqlProcessor.process0(TestUtils.genTaskContext(JSON.toJSONString(sqlParams)));
|
||||
assertTrue(processResult.isSuccess());
|
||||
}
|
||||
|
||||
static SqlProcessor.SqlParams constructSqlParam(String sql){
|
||||
SqlProcessor.SqlParams sqlParams = new SqlProcessor.SqlParams();
|
||||
sqlParams.setSql(sql);
|
||||
return sqlParams;
|
||||
}
|
||||
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
create table test_table
|
||||
(
|
||||
id bigint(20) primary key,
|
||||
content varchar(255),
|
||||
gmt_create datetime default now(),
|
||||
gmt_modified datetime default now()
|
||||
);
|
@ -16,6 +16,7 @@
|
||||
<springboot.version>2.2.6.RELEASE</springboot.version>
|
||||
<powerjob.worker.starter.version>3.4.6</powerjob.worker.starter.version>
|
||||
<fastjson.version>1.2.68</fastjson.version>
|
||||
<powerjob.official.processors.version>1.0.1</powerjob.official.processors.version>
|
||||
|
||||
<!-- 部署时跳过该module -->
|
||||
<maven.deploy.skip>true</maven.deploy.skip>
|
||||
@ -52,6 +53,11 @@
|
||||
<artifactId>fastjson</artifactId>
|
||||
<version>${fastjson.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.kfcfans</groupId>
|
||||
<artifactId>powerjob-official-processors</artifactId>
|
||||
<version>${powerjob.official.processors.version}</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
</dependencies>
|
||||
|
@ -0,0 +1,51 @@
|
||||
package com.github.kfcfans.powerjob.samples.config;
|
||||
|
||||
import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils;
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
import com.zaxxer.hikari.HikariDataSource;
|
||||
import org.h2.Driver;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.context.annotation.DependsOn;
|
||||
import tech.powerjob.official.processors.impl.SqlProcessor;
|
||||
|
||||
import javax.sql.DataSource;
|
||||
|
||||
/**
|
||||
* @author Echo009
|
||||
* @since 2021/3/10
|
||||
*/
|
||||
@Configuration
|
||||
public class SqlProcessorConfiguration {
|
||||
|
||||
|
||||
@Bean
|
||||
@DependsOn({"initPowerJob"})
|
||||
public DataSource sqlProcessorDataSource() {
|
||||
String jdbcUrl = String.format("jdbc:h2:file:%spowerjob_sql_processor_db;DB_CLOSE_DELAY=-1;DATABASE_TO_UPPER=false", OmsWorkerFileUtils.getH2WorkDir());
|
||||
HikariConfig config = new HikariConfig();
|
||||
config.setDriverClassName(Driver.class.getName());
|
||||
config.setJdbcUrl(jdbcUrl);
|
||||
config.setAutoCommit(true);
|
||||
// 池中最小空闲连接数量
|
||||
config.setMinimumIdle(1);
|
||||
// 池中最大连接数量
|
||||
config.setMaximumPoolSize(10);
|
||||
return new HikariDataSource(config);
|
||||
}
|
||||
|
||||
|
||||
@Bean
|
||||
public SqlProcessor sqlProcessor(@Qualifier("sqlProcessorDataSource") DataSource dataSource) {
|
||||
SqlProcessor sqlProcessor = new SqlProcessor(dataSource);
|
||||
// do nothing
|
||||
sqlProcessor.registerSqlValidator("fakeSqlValidator", (sql) -> true);
|
||||
// 排除掉包含 drop 的 SQL
|
||||
sqlProcessor.registerSqlValidator("interceptDropValidator", (sql) -> sql.matches("^(?i)((?!drop).)*$"));
|
||||
// do nothing
|
||||
sqlProcessor.setSqlParser((sql, taskContext) -> sql);
|
||||
return sqlProcessor;
|
||||
}
|
||||
|
||||
}
|
@ -1,6 +1,5 @@
|
||||
package com.github.kfcfans.powerjob.worker.persistence;
|
||||
|
||||
import com.github.kfcfans.powerjob.worker.OhMyWorker;
|
||||
import com.github.kfcfans.powerjob.worker.common.constants.StoreStrategy;
|
||||
import com.github.kfcfans.powerjob.worker.common.utils.OmsWorkerFileUtils;
|
||||
import com.zaxxer.hikari.HikariConfig;
|
||||
|
Loading…
x
Reference in New Issue
Block a user