diff --git a/README.md b/README.md index 8a5a8630..d6274702 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,8 @@ RosettaDB provides a comprehensive set of commands to cover various aspects of d - **[dbt](docs/markdowns/dbt.md)**: Generate dbt models for analytics workflows. - **[generate](docs/markdowns/generate.md)**: Generate Spark code for data transfers (Python or Scala). - **[query](docs/markdowns/query.md)**: Explore and query your data using AI-driven capabilities. +- **[sql](docs/markdowns/sql.md)**: Explore your data pure SQL Syntax. + ## Copyright and License Information Unless otherwise specified, all content, including all source code files and documentation files in this repository are: diff --git a/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java b/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java index 9d10e714..d2a3f408 100644 --- a/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java +++ b/cli/src/main/java/com/adaptivescale/rosetta/cli/Cli.java @@ -42,6 +42,7 @@ import picocli.CommandLine; import queryhelper.pojo.GenericResponse; import queryhelper.service.AIService; +import queryhelper.service.QueryService; import java.io.BufferedReader; import java.io.File; @@ -749,6 +750,33 @@ private void query(@CommandLine.Option(names = {"-s", "--source"}, required = tr log.info(response.getMessage()); } + @CommandLine.Command(name = "sql", description = "Write SQL for you Schema", mixinStandardHelpOptions = true) + private void sql(@CommandLine.Option(names = {"-s", "--source"}, required = true) String sourceName, + @CommandLine.Option(names = {"-q", "--query"}, required = true) String query, + @CommandLine.Option(names = {"-l", "--limit"}, required = false, defaultValue = "200") Integer showRowLimit, + @CommandLine.Option(names = {"--no-limit"}, required = false, defaultValue = "false") Boolean noRowLimit, + @CommandLine.Option(names = {"--output"}, required = false) Path output + ) + throws Exception { + requireConfig(config); + + Connection source = getSourceConnection(sourceName); + + Path sourceWorkspace = Paths.get("./", sourceName); + + Path dataDirectory = output != null ? output : sourceWorkspace.resolve(DEFAULT_OUTPUT_DIRECTORY); + Path outputFile = dataDirectory.getFileName().toString().contains(".") ? dataDirectory.getFileName() : null; + + dataDirectory = dataDirectory.getFileName().toString().contains(".") ? dataDirectory.getParent() != null ? dataDirectory.getParent() : Paths.get(".") : dataDirectory; + if (!dataDirectory.toFile().exists()) { + Files.createDirectories(dataDirectory); + } + + // If `noRowLimit` is true, set the row limit to 0 (no limit), otherwise use the value of `showRowLimit` + GenericResponse response = QueryService.executeQuery(query, source, noRowLimit ? 0 : showRowLimit, dataDirectory, outputFile); + log.info(response.getMessage()); + } + @CommandLine.Command(name = "validate", description = "Validate Connection", mixinStandardHelpOptions = true) private void validate(@CommandLine.Option(names = {"-s", "--source"}, required = true) String sourceName) throws Exception { diff --git a/cli/src/test/java/integration/SqlserverIntegrationTest.java b/cli/src/test/java/integration/SqlserverIntegrationTest.java index 06d21494..41bb086f 100644 --- a/cli/src/test/java/integration/SqlserverIntegrationTest.java +++ b/cli/src/test/java/integration/SqlserverIntegrationTest.java @@ -35,7 +35,7 @@ @TestMethodOrder(MethodOrderer.OrderAnnotation.class) public class SqlserverIntegrationTest { - private static String IMAGE = "fabricioveronez/northwind-database"; + private static String IMAGE = "mcr.microsoft.com/mssql/server:2022-latest"; private static String USERNAME = "SA"; private static String PASSWORD = "123abcD!"; private static String DATABASE = "Northwind"; @@ -169,10 +169,11 @@ public class SqlserverIntegrationTest { @Rule - public static MSSQLServerContainer mssqlserver = new MSSQLServerContainer() + public static MSSQLServerContainer mssqlserver = new MSSQLServerContainer(IMAGE) .acceptLicense() .withPassword(PASSWORD); + @BeforeAll public static void beforeAll() { mssqlserver.start(); diff --git a/docs/markdowns/sql.md b/docs/markdowns/sql.md new file mode 100644 index 00000000..40cfa467 --- /dev/null +++ b/docs/markdowns/sql.md @@ -0,0 +1,47 @@ +### Command: sql +The sql commands allows the user to write SQL queries directly to the connected Database of his choice. + rosetta [-c, --config CONFIG_FILE] sql [-h, --help] [-s, --source CONNECTION_NAME] [-q, --sql "Write SQL for you Schema"] [--output "Output DIRECTORY or FILE"] + +Parameter | Description +--- | --- +-h, --help | Show the help message and exit. +-c, --config CONFIG_FILE | YAML config file. If none is supplied it will use main.conf in the current directory if it exists. +-s, --source CONNECTION_NAME | The source connection is used to specify which models and connection to use. +-q --sql "SQL Query Code" | specify the query you want to run in you connected DB. +-l --limit Response Row limit (Optional) | Limits the number of rows in the generated CSV file. If not specified, the default limit is set to 200 rows. +--no-limit (Optional) | Specifies that there should be no limit on the number of rows in the generated CSV file. + + + +***Example*** (Query) +``` + rosetta sql -s mysql -q "select * from basic_library.authors;" +``` +***CSV Output Example*** +```CSV +surname,name,authorid +Howells,William Dean,1 +Brown,Frederic,2 +London,Jack,3 +Blaisdell,Albert,4 +Butler,Ellis,5 +Machen,Arthur,6 +Lucretius,Titus,7 +Tagore,Rabindranath,8 +Asimov,Isaac,9 +Dickens,Charles,10 +Emerson,Ralph Waldo,11 +Canfield,Dorothy,12 +Boccaccio,Givoanni,13 +Orwell,George,14 +Ovid,Publius,15 +Stevenson,Robert Louis,16 +Woolf,Virginia,17 +Eliot,George,18 +Edwards,Amelia B.,19 +Dostoevsky,Fyodor,20 +Dickinson,Emily,21 +Ferber,Edna,22 + +``` + diff --git a/queryhelper/src/main/java/queryhelper/service/AIService.java b/queryhelper/src/main/java/queryhelper/service/AIService.java index f0e1f412..87c9b678 100644 --- a/queryhelper/src/main/java/queryhelper/service/AIService.java +++ b/queryhelper/src/main/java/queryhelper/service/AIService.java @@ -1,9 +1,6 @@ package queryhelper.service; -import com.adaptivescale.rosetta.common.DriverManagerDriverProvider; -import com.adaptivescale.rosetta.common.JDBCUtils; import com.adaptivescale.rosetta.common.models.input.Connection; -import com.adataptivescale.rosetta.source.common.QueryHelper; import com.google.gson.Gson; import com.google.gson.JsonSyntaxException; import dev.langchain4j.model.openai.OpenAiChatModel; @@ -13,19 +10,15 @@ import queryhelper.pojo.QueryDataResponse; import queryhelper.pojo.QueryRequest; import queryhelper.utils.ErrorUtils; -import queryhelper.utils.FileUtils; import queryhelper.utils.PromptUtils; -import java.io.BufferedReader; -import java.io.FileReader; -import java.io.IOException; import java.nio.file.Path; -import java.sql.Driver; -import java.sql.SQLException; -import java.sql.Statement; -import java.text.SimpleDateFormat; import java.util.*; +import static queryhelper.utils.FileUtils.createCSVFile; +import static queryhelper.utils.FileUtils.generateTablePreview; +import static queryhelper.utils.QueryUtils.executeQueryAndGetRecords; + public class AIService { private final static String AI_MODEL = "gpt-3.5-turbo"; @@ -71,20 +64,6 @@ public static GenericResponse generateQuery(String userQueryRequest, String apiK return response; } - private static List> executeQueryAndGetRecords(String query, Connection source, Integer showRowLimit) { - try { - DriverManagerDriverProvider driverManagerDriverProvider = new DriverManagerDriverProvider(); - Driver driver = driverManagerDriverProvider.getDriver(source); - Properties properties = JDBCUtils.setJDBCAuth(source); - java.sql.Connection jdbcConnection = driver.connect(source.getUrl(), properties); - Statement statement = jdbcConnection.createStatement(); - statement.setMaxRows(showRowLimit); - List> select = QueryHelper.select(statement, query); - return select; - } catch (SQLException e) { - throw new RuntimeException(e); - } - } public static boolean isSelectStatement(String query) { boolean isSelectStatement = true; @@ -99,27 +78,6 @@ public static boolean isSelectStatement(String query) { return isSelectStatement; } - private static String createCSVFile(QueryDataResponse queryDataResponse, String csvFileName, Path dataDirectory, Path outputFileName) { - try { - if (outputFileName == null) { - String timestamp = new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date()); - String fileName = csvFileName.replaceAll("\\s+", "_") + "_" + timestamp + ".csv"; - Path csvFilePath = dataDirectory.resolve(fileName); - FileUtils.convertToCSV(csvFilePath.toString(), queryDataResponse.getRecords()); - - return csvFilePath.toString(); - } - - Path csvFilePath = dataDirectory.resolve(outputFileName.toString()); - FileUtils.convertToCSV(csvFilePath.toString(), queryDataResponse.getRecords()); - return csvFilePath.toString(); - - } catch (Exception e) { - GenericResponse genericResponse = ErrorUtils.csvFileError(e); - throw new RuntimeException(genericResponse.getMessage()); - } - } - public static String generateAIOutput(String apiKey, String aiModel, QueryRequest queryRequest, Connection source, String databaseDDL) { Gson gson = new Gson(); String aiOutputStr; @@ -151,56 +109,4 @@ public static String generateAIOutput(String apiKey, String aiModel, QueryReques return query; } - private static String generateTablePreview(String csvFile, int rowLimit) { - List rows = new ArrayList<>(); - try (BufferedReader reader = new BufferedReader(new FileReader(csvFile))) { - String line; - int rowCount = 0; - while ((line = reader.readLine()) != null && rowCount < rowLimit) { - String[] columns = line.split(","); - rows.add(columns); - rowCount++; - } - } catch (IOException e) { - throw new RuntimeException("Error reading CSV file", e); - } - - if (rows.isEmpty()) { - return "No data available to display."; - } - int maxColumns = rows.stream().mapToInt(row -> row.length).max().orElse(0); - int[] columnWidths = new int[maxColumns]; - for (String[] row : rows) { - for (int i = 0; i < row.length; i++) { - columnWidths[i] = Math.max(columnWidths[i], row[i].length()); - } - } - StringBuilder table = new StringBuilder(); - String rowSeparator = buildRowSeparator(columnWidths); - - table.append(rowSeparator); - for (String[] row : rows) { - table.append("|"); - for (int i = 0; i < maxColumns; i++) { - String cell = (i < row.length) ? row[i] : ""; - table.append(" ").append(String.format("%-" + columnWidths[i] + "s", cell)).append(" |"); - } - table.append("\n").append(rowSeparator); - } - - return table.toString(); - } - - private static String buildRowSeparator(int[] columnWidths) { - StringBuilder separator = new StringBuilder("+"); - for (int width : columnWidths) { - for (int i = 0; i < width + 2; i++) { - separator.append("-"); - } - separator.append("+"); - } - separator.append("\n"); - return separator.toString(); - } - } \ No newline at end of file diff --git a/queryhelper/src/main/java/queryhelper/service/QueryService.java b/queryhelper/src/main/java/queryhelper/service/QueryService.java new file mode 100644 index 00000000..40df480c --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/service/QueryService.java @@ -0,0 +1,46 @@ +package queryhelper.service; + +import com.adaptivescale.rosetta.common.models.input.Connection; +import queryhelper.pojo.GenericResponse; +import queryhelper.pojo.QueryDataResponse; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + + +import static queryhelper.utils.FileUtils.createCSVFile; +import static queryhelper.utils.FileUtils.generateTablePreview; +import static queryhelper.utils.QueryUtils.executeQueryAndGetRecords; + +public class QueryService { + + public static GenericResponse executeQuery(String query, Connection source, Integer showRowLimit, Path dataDirectory, Path outputFileName) { + + GenericResponse response = new GenericResponse(); + QueryDataResponse data = new QueryDataResponse(); + + + List> records = executeQueryAndGetRecords(query, source, showRowLimit); + data.setRecords(records); + + response.setData(data); + response.setStatusCode(200); + + String csvFile = createCSVFile(data, query, dataDirectory, outputFileName); + + String table = generateTablePreview(csvFile, 15); + + response.setMessage( + query + "\n" + + "Your response is saved to a CSV file named '" + csvFile + "'!" + "\n" + + "Table Output:" + "\n" + + table + + "..." + "\n" + + "Total rows: " + data.getRecords().size() + ); + + return response; + } +} + diff --git a/queryhelper/src/main/java/queryhelper/utils/FileUtils.java b/queryhelper/src/main/java/queryhelper/utils/FileUtils.java index 565efd45..1abdeebe 100644 --- a/queryhelper/src/main/java/queryhelper/utils/FileUtils.java +++ b/queryhelper/src/main/java/queryhelper/utils/FileUtils.java @@ -1,7 +1,12 @@ package queryhelper.utils; +import queryhelper.pojo.GenericResponse; +import queryhelper.pojo.QueryDataResponse; + import java.io.*; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.text.SimpleDateFormat; import java.util.*; import java.util.stream.Collectors; import java.io.FileWriter; @@ -20,6 +25,28 @@ public static String readJsonFile() { } } + public static String createCSVFile(QueryDataResponse queryDataResponse, String csvFileName, Path dataDirectory, Path outputFileName) { + try { + if (outputFileName == null) { + String timestamp = new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date()); + String fileName = csvFileName.replaceAll("\\s+", "_") + "_" + timestamp + ".csv"; + Path csvFilePath = dataDirectory.resolve(fileName); + FileUtils.convertToCSV(csvFilePath.toString(), queryDataResponse.getRecords()); + + return csvFilePath.toString(); + } + + Path csvFilePath = dataDirectory.resolve(outputFileName.toString()); + FileUtils.convertToCSV(csvFilePath.toString(), queryDataResponse.getRecords()); + return csvFilePath.toString(); + + } catch (Exception e) { + GenericResponse genericResponse = ErrorUtils.csvFileError(e); + throw new RuntimeException(genericResponse.getMessage()); + } + } + + public static void convertToCSV(String fileName, List> list) { try (FileWriter csvWriter = new FileWriter(fileName)) { if (!list.isEmpty()) { @@ -41,4 +68,56 @@ public static void convertToCSV(String fileName, List> list) } } + public static String generateTablePreview(String csvFile, int rowLimit) { + List rows = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new FileReader(csvFile))) { + String line; + int rowCount = 0; + while ((line = reader.readLine()) != null && rowCount < rowLimit) { + String[] columns = line.split(","); + rows.add(columns); + rowCount++; + } + } catch (IOException e) { + throw new RuntimeException("Error reading CSV file", e); + } + + if (rows.isEmpty()) { + return "No data available to display."; + } + int maxColumns = rows.stream().mapToInt(row -> row.length).max().orElse(0); + int[] columnWidths = new int[maxColumns]; + for (String[] row : rows) { + for (int i = 0; i < row.length; i++) { + columnWidths[i] = Math.max(columnWidths[i], row[i].length()); + } + } + StringBuilder table = new StringBuilder(); + String rowSeparator = buildRowSeparator(columnWidths); + + table.append(rowSeparator); + for (String[] row : rows) { + table.append("|"); + for (int i = 0; i < maxColumns; i++) { + String cell = (i < row.length) ? row[i] : ""; + table.append(" ").append(String.format("%-" + columnWidths[i] + "s", cell)).append(" |"); + } + table.append("\n").append(rowSeparator); + } + + return table.toString(); + } + + public static String buildRowSeparator(int[] columnWidths) { + StringBuilder separator = new StringBuilder("+"); + for (int width : columnWidths) { + for (int i = 0; i < width + 2; i++) { + separator.append("-"); + } + separator.append("+"); + } + separator.append("\n"); + return separator.toString(); + } + } diff --git a/queryhelper/src/main/java/queryhelper/utils/QueryUtils.java b/queryhelper/src/main/java/queryhelper/utils/QueryUtils.java new file mode 100644 index 00000000..8e7c5276 --- /dev/null +++ b/queryhelper/src/main/java/queryhelper/utils/QueryUtils.java @@ -0,0 +1,30 @@ +package queryhelper.utils; + +import com.adaptivescale.rosetta.common.DriverManagerDriverProvider; +import com.adaptivescale.rosetta.common.JDBCUtils; +import com.adaptivescale.rosetta.common.models.input.Connection; +import com.adataptivescale.rosetta.source.common.QueryHelper; + +import java.sql.Driver; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +public class QueryUtils { + public static List> executeQueryAndGetRecords(String query, Connection source, Integer showRowLimit) { + try { + DriverManagerDriverProvider driverManagerDriverProvider = new DriverManagerDriverProvider(); + Driver driver = driverManagerDriverProvider.getDriver(source); + Properties properties = JDBCUtils.setJDBCAuth(source); + java.sql.Connection jdbcConnection = driver.connect(source.getUrl(), properties); + Statement statement = jdbcConnection.createStatement(); + statement.setMaxRows(showRowLimit); + List> select = QueryHelper.select(statement, query); + return select; + } catch (SQLException e) { + throw new RuntimeException(e); + } + } +}