Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import dev.dbos.transact.Constants;
import dev.dbos.transact.database.SystemDatabase;
import dev.dbos.transact.migrations.MigrationManager;

import java.io.PrintWriter;
Expand All @@ -17,6 +18,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import picocli.CommandLine;

@Timeout(value = 2, unit = TimeUnit.MINUTES)
Expand Down Expand Up @@ -72,18 +75,32 @@ public void migrate_twice() throws Exception {
assertTrue(checkTable(Constants.DB_SCHEMA, "workflow_status"));
}

@Test
public void migrate_custom_schema() throws Exception {

@ParameterizedTest
@ValueSource(strings = {"invalid\"schema", "invalid'schema"})
void testRunMigrations_fails_invalid_schema(String schema) throws Exception {
assertFalse(checkConnection());

var schema = "C\"$+0m'";
var cmd = new CommandLine(new DBOSCommand());
var sw = new StringWriter();
cmd.setOut(new PrintWriter(sw));

var exitCode =
cmd.execute("migrate", "-D=" + db_url, "-U=" + db_user, "--schema", "%s".formatted(schema));
assertEquals(1, exitCode);
}

@ParameterizedTest
@ValueSource(strings = {"F8nny_sCHem@-n@m3", "embedded\0null"})
public void migrate_custom_schema(String schema) throws Exception {

assertFalse(checkConnection());

var cmd = new CommandLine(new DBOSCommand());
var sw = new StringWriter();
cmd.setOut(new PrintWriter(sw));

var exitCode = cmd.execute("migrate", "-D=" + db_url, "-U=" + db_user, "--schema=" + schema);
var exitCode =
cmd.execute("migrate", "-D=" + db_url, "-U=" + db_user, "--schema", "%s".formatted(schema));
assertEquals(0, exitCode);

assertTrue(checkConnection());
Expand All @@ -105,7 +122,7 @@ static boolean checkTable(String schema, String table) throws SQLException {
"SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_schema = ? AND table_name = ?)";
try (var conn = DriverManager.getConnection(db_url, db_user, db_password);
var stmt = conn.prepareStatement(sql)) {
stmt.setString(1, schema);
stmt.setString(1, SystemDatabase.sanitizeSchema(schema));
stmt.setString(2, table);
try (var rs = stmt.executeQuery()) {
if (rs.next()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ void send(
// Insert notification
final String sql =
"""
INSERT INTO %s.notifications (destination_uuid, topic, message) VALUES (?, ?, ?)
INSERT INTO "%s".notifications (destination_uuid, topic, message) VALUES (?, ?, ?)
"""
.formatted(this.schema);

Expand Down Expand Up @@ -162,7 +162,7 @@ Object recv(
try (Connection conn = dataSource.getConnection()) {
final String sql =
"""
SELECT topic FROM %s.notifications WHERE destination_uuid = ? AND topic = ?
SELECT topic FROM "%s".notifications WHERE destination_uuid = ? AND topic = ?
"""
.formatted(this.schema);

Expand Down Expand Up @@ -214,12 +214,12 @@ Object recv(
"""
WITH oldest_entry AS (
SELECT destination_uuid, topic, message, created_at_epoch_ms
FROM %1$s.notifications
FROM "%1$s".notifications
WHERE destination_uuid = ? AND topic = ?
ORDER BY created_at_epoch_ms ASC
LIMIT 1
)
DELETE FROM %1$s.notifications
DELETE FROM "%1$s".notifications
WHERE destination_uuid = (SELECT destination_uuid FROM oldest_entry)
AND topic = (SELECT topic FROM oldest_entry)
AND created_at_epoch_ms = (SELECT created_at_epoch_ms FROM oldest_entry)
Expand Down Expand Up @@ -263,7 +263,7 @@ private void setEvent(
throws SQLException {
final String eventSql =
"""
INSERT INTO %s.workflow_events (workflow_uuid, key, value)
INSERT INTO "%s".workflow_events (workflow_uuid, key, value)
VALUES (?, ?, ?)
ON CONFLICT (workflow_uuid, key)
DO UPDATE SET value = EXCLUDED.value
Expand All @@ -279,7 +279,7 @@ ON CONFLICT (workflow_uuid, key)

final String eventHistorySql =
"""
INSERT INTO %s.workflow_events_history (workflow_uuid, function_id, key, value)
INSERT INTO "%s".workflow_events_history (workflow_uuid, function_id, key, value)
VALUES (?, ?, ?, ?)
ON CONFLICT (workflow_uuid, key, function_id)
DO UPDATE SET value = EXCLUDED.value
Expand Down Expand Up @@ -382,7 +382,7 @@ Object getEvent(
Object value = null;
final String sql =
"""
SELECT value FROM %s.workflow_events WHERE workflow_uuid = ? AND key = ?
SELECT value FROM "%s".workflow_events WHERE workflow_uuid = ? AND key = ?
"""
.formatted(this.schema);

Expand Down
12 changes: 6 additions & 6 deletions transact/src/main/java/dev/dbos/transact/database/QueuesDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ List<String> getAndStartQueuedWorkflows(
var limiterQuery =
"""
SELECT COUNT(*)
FROM %s.workflow_status
FROM "%s".workflow_status
WHERE queue_name = ?
AND status != ?
AND started_at_epoch_ms > ?
Expand Down Expand Up @@ -100,7 +100,7 @@ SELECT COUNT(*)
String pendingQuery =
"""
SELECT executor_id, COUNT(*) as task_count
FROM %s.workflow_status
FROM "%s".workflow_status
WHERE queue_name = ? AND status = ?
"""
.formatted(this.schema);
Expand Down Expand Up @@ -170,7 +170,7 @@ SELECT executor_id, COUNT(*) as task_count
var query =
"""
SELECT workflow_uuid
FROM %s.workflow_status
FROM "%s".workflow_status
WHERE queue_name = ?
AND status = ?
AND (application_version = ? OR application_version IS NULL)
Expand Down Expand Up @@ -226,7 +226,7 @@ SELECT executor_id, COUNT(*) as task_count
List<String> updatedWorkflowIds = new ArrayList<>();
String updateQuery =
"""
UPDATE %s.workflow_status
UPDATE "%s".workflow_status
SET status = ?,
application_version = ?,
executor_id = ?,
Expand Down Expand Up @@ -273,7 +273,7 @@ boolean clearQueueAssignment(String workflowId) throws SQLException {

final String sql =
"""
UPDATE %s.workflow_status
UPDATE "%s".workflow_status
SET started_at_epoch_ms = NULL, status = ?
WHERE workflow_uuid = ? AND queue_name IS NOT NULL AND status = ?
"""
Expand All @@ -294,7 +294,7 @@ List<String> getQueuePartitions(String queueName) throws SQLException {
final String sql =
"""
SELECT DISTINCT queue_partition_key
FROM %s.workflow_status
FROM "%s".workflow_status
WHERE queue_name = ?
AND status = ?
AND queue_partition_key IS NOT NULL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ static void recordStepResultTxn(
Objects.requireNonNull(schema);
String sql =
"""
INSERT INTO %s.operation_outputs
INSERT INTO "%s".operation_outputs
(workflow_uuid, function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT DO NOTHING RETURNING completed_at_epoch_ms
Expand Down Expand Up @@ -132,7 +132,7 @@ static StepResult checkStepExecutionTxn(
Objects.requireNonNull(schema);
final String sql =
"""
SELECT status FROM %s.workflow_status WHERE workflow_uuid = ?
SELECT status FROM "%s".workflow_status WHERE workflow_uuid = ?
"""
.formatted(schema);

Expand All @@ -158,7 +158,7 @@ static StepResult checkStepExecutionTxn(
String operationOutputSql =
"""
SELECT output, error, function_name
FROM %s.operation_outputs
FROM "%s".operation_outputs
WHERE workflow_uuid = ? AND function_id = ?
"""
.formatted(schema);
Expand Down Expand Up @@ -203,7 +203,7 @@ List<StepInfo> listWorkflowSteps(Connection connection, String workflowId) throw
final String sql =
"""
SELECT function_id, function_name, output, error, child_workflow_id, started_at_epoch_ms, completed_at_epoch_ms
FROM %s.operation_outputs
FROM "%s".operation_outputs
WHERE workflow_uuid = ?
ORDER BY function_id;
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ public class SystemDatabase implements AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger(SystemDatabase.class);

public static String sanitizeSchema(String schema) {
schema =
Objects.requireNonNullElse(schema, Constants.DB_SCHEMA)
.replace("\0", "")
.replace("\"", "\"\"");
return "\"%s\"".formatted(schema);
return Objects.requireNonNullElse(schema, Constants.DB_SCHEMA).replace("\0", "");
}

private final DataSource dataSource;
Expand All @@ -55,7 +51,12 @@ public static String sanitizeSchema(String schema) {
private final NotificationService notificationService;

private SystemDatabase(DataSource dataSource, String schema, boolean created) {
this.schema = sanitizeSchema(schema);
schema = sanitizeSchema(schema);
if (schema.contains("'") || schema.contains("\"")) {
throw new IllegalArgumentException("Schema name must not contain single or double quotes");
}

this.schema = schema;
this.dataSource = dataSource;
this.created = created;

Expand Down Expand Up @@ -425,7 +426,7 @@ public Optional<ExternalState> getExternalState(String service, String workflowN
() -> {
final String sql =
"""
SELECT value, update_seq, update_time FROM %s.event_dispatch_kv WHERE service_name = ? AND workflow_fn_name = ? AND key = ?
SELECT value, update_seq, update_time FROM "%s".event_dispatch_kv WHERE service_name = ? AND workflow_fn_name = ? AND key = ?
"""
.formatted(this.schema);

Expand Down Expand Up @@ -456,7 +457,7 @@ public ExternalState upsertExternalState(ExternalState state) {
() -> {
final var sql =
"""
INSERT INTO %s.event_dispatch_kv (
INSERT INTO "%s".event_dispatch_kv (
service_name, workflow_fn_name, key, value, update_time, update_seq)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT (service_name, workflow_fn_name, key)
Expand Down Expand Up @@ -509,15 +510,15 @@ public List<MetricData> getMetrics(Instant startTime, Instant endTime) {
final var wfSQL =
"""
SELECT name, COUNT(workflow_uuid) as count
FROM %s.workflow_status
FROM "%s".workflow_status
WHERE created_at >= ? AND created_at < ?
GROUP BY name
"""
.formatted(this.schema);
final var stepSQL =
"""
SELECT function_name, COUNT(*) as count
FROM %s.operation_outputs
FROM "%s".operation_outputs
WHERE completed_at_epoch_ms >= ? AND completed_at_epoch_ms < ?
GROUP BY function_name
"""
Expand Down Expand Up @@ -559,7 +560,7 @@ private String getCheckpointName(Connection conn, String workflowId, int functio
var sql =
"""
SELECT function_name
FROM %s.operation_outputs
FROM "%s".operation_outputs
WHERE workflow_uuid = ? AND function_id = ?
"""
.formatted(this.schema);
Expand Down Expand Up @@ -613,7 +614,7 @@ public void deleteWorkflows(String... workflowIds) {

var sql =
"""
DELETE FROM %s.workflow_status
DELETE FROM "%s".workflow_status
WHERE workflow_uuid = ANY(?);
"""
.formatted(this.schema);
Expand Down Expand Up @@ -642,7 +643,7 @@ List<String> getWorkflowChildrenInternal(String workflowId) throws SQLException
var sql =
"""
SELECT child_workflow_id
FROM %s.operation_outputs
FROM "%s".operation_outputs
WHERE workflow_uuid = ? AND child_workflow_id IS NOT NULL
"""
.formatted(this.schema);
Expand Down Expand Up @@ -673,7 +674,7 @@ List<WorkflowEvent> listWorkflowEvents(Connection conn, String workflowId) throw
var sql =
"""
SELECT key, value
FROM %s.workflow_events
FROM "%s".workflow_events
WHERE workflow_uuid = ?
"""
.formatted(this.schema);
Expand All @@ -697,7 +698,7 @@ List<WorkflowEventHistory> listWorkflowEventHistory(Connection conn, String work
var sql =
"""
SELECT key, value, function_id
FROM %s.workflow_events_history
FROM "%s".workflow_events_history
WHERE workflow_uuid = ?
"""
.formatted(this.schema);
Expand All @@ -721,7 +722,7 @@ List<WorkflowStream> listWorkflowStreams(Connection conn, String workflowId) thr
var sql =
"""
SELECT key, value, "offset", function_id
FROM %s.streams
FROM "%s".streams
WHERE workflow_uuid = ?
"""
.formatted(this.schema);
Expand Down Expand Up @@ -771,7 +772,7 @@ public List<ExportedWorkflow> exportWorkflow(String workflowId, boolean exportCh
public void importWorkflow(List<ExportedWorkflow> workflows) {
var wfSQL =
"""
INSERT INTO %s.workflow_status (
INSERT INTO "%s".workflow_status (
workflow_uuid, status,
name, class_name, config_name,
authenticated_user, assumed_role, authenticated_roles,
Expand All @@ -789,7 +790,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {

var stepSQL =
"""
INSERT INTO %s.operation_outputs (
INSERT INTO "%s".operation_outputs (
workflow_uuid, function_id, function_name,
output, error, child_workflow_id,
started_at_epoch_ms, completed_at_epoch_ms
Expand All @@ -801,7 +802,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {

var eventSQL =
"""
INSERT INTO %s.workflow_events (
INSERT INTO "%s".workflow_events (
workflow_uuid, key, value
) VALUES (
?, ?, ?
Expand All @@ -811,7 +812,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {

var eventHistorySQL =
"""
INSERT INTO %s.workflow_events_history (
INSERT INTO "%s".workflow_events_history (
workflow_uuid, key, value, function_id
) VALUES (
?, ?, ?, ?
Expand All @@ -821,7 +822,7 @@ public void importWorkflow(List<ExportedWorkflow> workflows) {

var streamsSQL =
"""
INSERT INTO %s.streams (
INSERT INTO "%s".streams (
workflow_uuid, key, value, function_id, offset
) VALUES (
?, ?, ?, ?, ?
Expand Down
Loading