From 6edd3b8035503bf8e9dadb6df057b3bf03eaee82 Mon Sep 17 00:00:00 2001 From: Axel Dahlberg Date: Mon, 23 Dec 2024 21:45:22 -0800 Subject: [PATCH 1/2] feat: support multiple targets --- README.md | 3 ++- lua/ts-node-action/init.lua | 35 +++++++++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 042a578..1bf3dbb 100644 --- a/README.md +++ b/README.md @@ -182,8 +182,9 @@ Boolean value. If `true`, will run `=` operator on new buffer text. Requires #### `target` -TSNode. If present, this node will be used as the target for replacement instead +TSNode or list of TSNodes. If present, this node will be used as the target for replacement instead of the node under your cursor. +If list of nodes their combined range will be used for replacement. Note that in this case if the target nodes specified are not next to each other, any thing in between will also be replaced. Here's a simplified example of how a node-action function gets called: diff --git a/lua/ts-node-action/init.lua b/lua/ts-node-action/init.lua index 1dc7f5b..f4b8d5c 100644 --- a/lua/ts-node-action/init.lua +++ b/lua/ts-node-action/init.lua @@ -1,15 +1,46 @@ local M = {} +--- @private +--- @param targets TSNode[] +--- @return integer start_row +--- @return integer start_col +--- @return integer end_row +--- @return integer end_col +local function combined_range(targets) + local start_row, start_col, end_row, end_col + for _, target in ipairs(targets) do + local sr, sc, er, ec = target:range() + if start_row == nil or sr < start_row then + start_row = sr + end + if start_col == nil or sc < start_col then + start_col = sc + end + if end_row == nil or er > end_row then + end_row = er + end + if end_col == nil or ec > end_col then + end_col = ec + end + end + return start_row, start_col, end_row, end_col +end + --- @private --- @param replacement string|table ---- @param opts { cursor: { col: number, row: number }, callback: function, format: boolean, target: TSNode } +--- @param opts { cursor: { col: number, row: number }, callback: function, format: boolean, target: TSNode | TSNode[] } --- All opts fields are optional local function replace_node(node, replacement, opts) if type(replacement) ~= "table" then replacement = { replacement } end - local start_row, start_col, end_row, end_col = (opts.target or node):range() + local start_row, start_col, end_row, end_col + if vim.islist(opts.target) then + start_row, start_col, end_row, end_col = combined_range(opts.target) + else + start_row, start_col, end_row, end_col = (opts.target or node):range() + end vim.api.nvim_buf_set_text( vim.api.nvim_get_current_buf(), start_row, From 82a267c282f432f17f058896a1cf11c2768dbbf8 Mon Sep 17 00:00:00 2001 From: Axel Dahlberg Date: Wed, 25 Dec 2024 01:22:23 -0800 Subject: [PATCH 2/2] feat: add actions to expand comprehensions/inline for loops --- lua/ts-node-action/actions/ft/init.lua | 0 lua/ts-node-action/actions/ft/python.lua | 181 +++++++++++++++++++++++ 2 files changed, 181 insertions(+) create mode 100644 lua/ts-node-action/actions/ft/init.lua create mode 100644 lua/ts-node-action/actions/ft/python.lua diff --git a/lua/ts-node-action/actions/ft/init.lua b/lua/ts-node-action/actions/ft/init.lua new file mode 100644 index 0000000..e69de29 diff --git a/lua/ts-node-action/actions/ft/python.lua b/lua/ts-node-action/actions/ft/python.lua new file mode 100644 index 0000000..5ef365b --- /dev/null +++ b/lua/ts-node-action/actions/ft/python.lua @@ -0,0 +1,181 @@ +--- @param node TSNode +--- @return string +local get_node_text = function(node) + return vim.treesitter.get_node_text(node, 0) +end + +--- @param node TSNode +--- @param name string +--- @return TSNode +local get_field = function(node, name) + local fields = node:field(name) + if #fields ~= 1 then + error(string.format("not exactly one field with name='%s'", name)) + end + return fields[1] +end + +--- @param node TSNode +--- @return string? +--- @return TSNode? +local get_assignment = function(node) + if node:type() ~= "assignment" then + return + end + return get_node_text(get_field(node, "left")), get_field(node, "right") +end + +--- @param node TSNode +--- @return string? +local get_new_type = function(node) + if node:type() == "list" and node:named_child_count() == 0 then + return "list" + elseif ( + node:type() == "call" + and get_node_text(get_field(node, "function")) == "set" + and get_field(node, "arguments"):child_count() == 2) + then + return "set" + elseif node:type() == "dictionary" and node:named_child_count() == 0 then + return "dictionary" + end +end + +--- @param node TSNode +--- @return TSNode? +local get_single_node_body = function(node) + if node:child_count() ~= 1 then + return + end + return node:child(0):child(0) +end + +--- @param name string +--- @param node TSNode +--- @return string? +local get_dict_key_pair = function(name, node) + if node:type() ~= "assignment" then + return + end + local left = get_field(node, "left") + if left:type() ~= "subscript" then + return + end + if name ~= get_node_text(get_field(left, "value")) then + return + end + return string.format("%s: %s", get_node_text(get_field(left, "subscript")), get_node_text(get_field(node, "right"))) +end + +--- @param append string +--- @param name string +--- @param node TSNode +--- @return string? +local get_append_to_value = function(append, name, node) + if node:type() ~= "call" then + return + end + local func = get_field(node, "function") + if func:named_child_count() == 0 then + return + end + if name ~= get_node_text(get_field(func, "object")) then + return + end + if get_node_text(get_field(func, "attribute")) ~= append then + return + end + return get_node_text(get_field(node, "arguments"):named_child(0)) +end + +--- @param typ string +--- @param name string +--- @param node TSNode +--- @return string? +local get_body = function(typ, name, node) + if typ == "list" then + return get_append_to_value("append", name, node) + elseif typ == "set" then + return get_append_to_value("add", name, node) + elseif typ == "dictionary" then + return get_dict_key_pair(name, node) + end +end + +--- @param opts {new: string, make_for_body: fun(name: string, body: TSNode): string} +--- @return fun(node: TSNode): string[]?, table? +local comprehension = function(opts) + return function(node) + local parent = node:parent() + local name = get_assignment(parent) + if not name then + return + end + -- TODO support if there are more or if clauses + if node:named_child_count() > 2 then + return + end + local for_clause = get_node_text(node:named_child(1)) + local for_body = opts.make_for_body(name, get_field(node, "body")) + return vim.split(string.format("%s = %s\n%s:\n%s", name, opts.new, for_clause, for_body), "\n"), { + format = true, + target = parent, + cursor = {row = 1, col = 0}, + } + end +end + +return { + expand_list_comprehension = comprehension({ + new = "[]", + make_for_body = function(name, body) return string.format("%s.append(%s)", name, get_node_text(body)) end, + }), + expand_set_comprehension = comprehension({ + new = "set()", + make_for_body = function(name, body) return string.format("%s.add(%s)", name, get_node_text(body)) end, + }), + expand_dictionary_comprehension = comprehension({ + new = "{}", + make_for_body = function(name, body) return string.format( + "%s[%s] = %s", + name, + get_node_text(get_field(body, "key")), + get_node_text(get_field(body, "value")) + ) end, + }), + inline_for_statement = function(node) + local previous = node:prev_sibling():child(0) + -- TODO support nested loops, look up until assignment + if previous:type() ~= "assignment" then + return + end + local name, value = get_assignment(previous) + if not name then + return + end + local typ = get_new_type(value) + if not typ then + return + end + local for_variable = get_node_text(get_field(node, "left")) + local for_range = get_node_text(get_field(node, "right")) + local statement = get_single_node_body(get_field(node, "body")) + if not statement then + return + end + local body = get_body(typ, name, statement) + if not body then + return + end + local templates = { + list = "%s = [%s for %s in %s]", + set = "%s = {%s for %s in %s}", + dictionary = "%s = {%s for %s in %s}" + } + return vim.split(string.format(templates[typ], name, body, for_variable, for_range), "\n"), { + format = true, + cursor = {row = 0, col = #name + 3}, + target = {previous, node}, + } + end, +}