diff --git a/README.md b/README.md index d5a3770..a9a0232 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ Before running AutoCommit, it's advisable to set a few environment variables - `OPENAI_URL`: Override openai api eg: azure openai (Optional; Default: openai url) - `OPENAI_API_KEY`: The API key for the GPT-4 model (🚨 **Required**). - `OPENAI_MODEL`: Specify a different language model 🔄 (Optional; Default: `gpt-4`). -- `FINE_TUNE_PARAMS`: Additional parameters for fine-tuning the model output ⚙️ (Optional; Default: `{}`). +- `FINE_TUNE_PARAMS`: Additional parameters for fine-tuning the model output ⚙️ (Optional; Default: `{}`). Supports JSON format with parameters like `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, `presence_penalty`. Add these environment variables by appending them to your `.bashrc`, `.zshrc`, or other shell configuration files 📄: @@ -67,7 +67,7 @@ export FINE_TUNE_PARAMS='{"temperature": 0.7}' Or, you can set them inline before running the AutoCommit command 🖱️: ```bash -OPENAI_URL=your-openai-api-key-here OPENAI_MODEL=gpt-4 FINE_TUNE_PARAMS='{"temperature": 0.7}' git auto-commit +OPENAI_URL=https://api.openai.com/v1 OPENAI_API_KEY=your-openai-api-key-here OPENAI_MODEL=gpt-4 FINE_TUNE_PARAMS='{"temperature": 0.7}' git auto-commit ``` ### Complete Install 📦 @@ -75,7 +75,7 @@ OPENAI_URL=your-openai-api-key-here OPENAI_MODEL=gpt-4 FINE_TUNE_PARAMS='{"tempe For an end-to-end installation experience, execute 👇: ```bash -bash <(curl -s https://raw.githubusercontent.com/ghcli/commit/main/install.sh) +bash <(curl -s https://raw.githubusercontent.com/ghcli/gh-commit/main/install.sh) ``` This comprehensive script accomplishes the following 📋: diff --git a/go.mod b/go.mod index 213bb4b..72e6123 100644 --- a/go.mod +++ b/go.mod @@ -1,32 +1,17 @@ -module github.com/megamanics/gh-commit +// Forked from github.com/megamanics/gh-commit +module github.com/ellisvalentiner/gh-commit-src go 1.21 require ( github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.2.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.6.1 - github.com/cli/go-gh/v2 v2.4.0 github.com/joho/godotenv v1.3.0 github.com/sashabaranov/go-openai v1.15.3 ) require ( github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 // indirect - github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect - github.com/cli/safeexec v1.0.1 // indirect - github.com/cli/shurcooL-graphql v0.0.4 // indirect - github.com/henvic/httpretty v0.1.2 // indirect - github.com/kr/pretty v0.3.1 // indirect - github.com/lucasb-eyer/go-colorful v1.2.0 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/muesli/termenv v0.15.2 // indirect - github.com/rivo/uniseg v0.4.4 // indirect - github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e // indirect golang.org/x/net v0.17.0 // indirect - golang.org/x/sys v0.13.0 // indirect - golang.org/x/term v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect - golang.org/x/tools v0.13.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 329e81d..c91573b 100644 --- a/go.sum +++ b/go.sum @@ -8,15 +8,6 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0 h1:sXr+ck84g/ZlZUOZiNELInm github.com/Azure/azure-sdk-for-go/sdk/internal v1.3.0/go.mod h1:okt5dMMTOFjX/aovMlrjvvXoPMBVSPzk9185BT0+eZM= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0 h1:OBhqkivkhkMqLPymWEppkm7vgPQY2XsHoEkaMQ0AdZY= github.com/AzureAD/microsoft-authentication-library-for-go v1.0.0/go.mod h1:kgDmCTgBzIEPFElEF+FK0SdjAor06dRq2Go927dnQ6o= -github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= -github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/cli/go-gh/v2 v2.4.0 h1:6j3YxA8uJVOL4lBWjqDmMiAQNnJ2fiZagCuEmQXl+pU= -github.com/cli/go-gh/v2 v2.4.0/go.mod h1:h3salfqqooVpzKmHp6aUdeNx62UmxQRpLbagFSHTJGQ= -github.com/cli/safeexec v1.0.1 h1:e/C79PbXF4yYTN/wauC4tviMxEV13BwljGj0N9j+N00= -github.com/cli/safeexec v1.0.1/go.mod h1:Z/D4tTN8Vs5gXYHDCbaM1S/anmEDnJb1iW0+EJ5zx3Q= -github.com/cli/shurcooL-graphql v0.0.4 h1:6MogPnQJLjKkaXPyGqPRXOI2qCsQdqNfUY1QSJu2GuY= -github.com/cli/shurcooL-graphql v0.0.4/go.mod h1:3waN4u02FiZivIV+p1y4d0Jo1jc6BViMA73C+sZo2fk= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dnaeon/go-vcr v1.2.0 h1:zHCHvJYTMh1N7xnV7zf1m1GPBF9Ad0Jk/whtQ1663qI= @@ -25,86 +16,26 @@ github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOW github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= -github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= -github.com/henvic/httpretty v0.1.2 h1:EQo556sO0xeXAjP10eB+BZARMuvkdGqtfeS4Ntjvkiw= -github.com/henvic/httpretty v0.1.2/go.mod h1:ViEsly7wgdugYtymX54pYp6Vv2wqZmNHayJ6q8tlKCc= github.com/joho/godotenv v1.3.0 h1:Zjp+RcGpHhGlrMbJzXTrZZPrWj+1vfm90La1wgB6Bhc= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= -github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= -github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= -github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= -github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= -github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/muesli/termenv v0.15.2 h1:GohcuySI0QmI3wN8Ok9PtKGkgkFIk7y6Vpb5PvrY+Wo= -github.com/muesli/termenv v0.15.2/go.mod h1:Epx+iuz8sNs7mNKhxzH4fWXGNpZwUaJKRS1noLXviQ8= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU= github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= -github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/sashabaranov/go-openai v1.15.3 h1:rzoNK9n+Cak+PM6OQ9puxDmFllxfnVea9StlmhglXqA= github.com/sashabaranov/go-openai v1.15.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e h1:BuzhfgfWQbX0dWzYzT1zsORLnHRv3bcRcsaUk0VmXA8= -github.com/thlib/go-timezone-local v0.0.0-20210907160436-ef149e42d28e/go.mod h1:/Tnicc6m/lsJE0irFMA0LfIwTBo4QP7A8IfyIv4zZKI= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= -golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.13.0 h1:Iey4qkscZuv0VvIt8E0neZjtPVQFSc870HQ448QgEmQ= -golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= -gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 580221f..37dc752 100755 --- a/main.go +++ b/main.go @@ -48,12 +48,16 @@ func main() { if flag.NFlag() == 0 { diff, err := getGitDiff() + if err != nil { + fmt.Printf("Error: %v\n", err) + return + } completionResponse, err := getChatCompletionResponse(getDiffPrompt(diff)) - completionResponse = formatResponse(completionResponse) if err != nil { fmt.Printf("Error: %v\n", err) return } + completionResponse = formatResponse(completionResponse) fmt.Println(completionResponse) } } \ No newline at end of file diff --git a/main_test.go b/main_test.go index f57abc5..e2f1ff4 100644 --- a/main_test.go +++ b/main_test.go @@ -19,15 +19,20 @@ func TestStatsFlag(t *testing.T) { } } -func Test_main(t *testing.T) { - tests := []struct { - name string - }{ - // TODO: Add test cases. +func TestAskFlag(t *testing.T) { + flagSet := flag.NewFlagSet("TestAskFlag", flag.ContinueOnError) + ask := flagSet.String("ask", "", "ask a question") + + err := flagSet.Parse([]string{"-ask", "test question"}) + if err != nil { + t.Fatal("Error parsing flags:", err) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - main() - }) + + if *ask != "test question" { + t.Errorf("Expected ask flag to be 'test question', but got '%s'", *ask) } } + +// Note: Testing main() directly is difficult as it has side effects +// In a production environment, you'd refactor main() to be more testable +// by extracting the logic into separate functions diff --git a/util.go b/util.go index 05cee9a..98fcaee 100644 --- a/util.go +++ b/util.go @@ -2,41 +2,57 @@ package main import ( "context" + "encoding/json" "fmt" - "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/cli/go-gh/v2/pkg/api" - "github.com/joho/godotenv" - openai "github.com/sashabaranov/go-openai" "math" "os" "os/exec" "strconv" "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/joho/godotenv" + openai "github.com/sashabaranov/go-openai" ) const MaxDiffLength = 30000 // set to 30k since large model has maximum context length is 32768 tokens. +func isGitRepository() bool { + cmd := exec.Command("git", "rev-parse", "--git-dir") + err := cmd.Run() + return err == nil +} + func getGitDiff() (string, error) { + if !isGitRepository() { + return "", fmt.Errorf("not a git repository") + } cmd := exec.Command("git", "diff") output, err := cmd.Output() + if err != nil { + return "", fmt.Errorf("error running git diff: %v", err) + } diff := strings.TrimSpace(string(output)) if diff == "" { cmd = exec.Command("git", "diff", "--staged") output, err = cmd.Output() if err != nil { - return "", err + return "", fmt.Errorf("error running git diff --staged: %v", err) } diff = strings.TrimSpace(string(output)) } + if diff == "" { + return "", fmt.Errorf("no changes detected. Please stage or make changes before generating a commit message") + } runes := []rune(diff) size := len(runes) if size > MaxDiffLength { runes = runes[:MaxDiffLength] return string(runes), fmt.Errorf("the total length was %d and only first 30k were used", size) } - return string(output), nil + return diff, nil } func calculateTimeSaved(numCommits int, wordCount int) float64 { @@ -48,6 +64,9 @@ func calculateTimeSaved(numCommits int, wordCount int) float64 { } func getCommitStats() (int, int, error) { + if !isGitRepository() { + return 0, 0, fmt.Errorf("not a git repository") + } cmd := exec.Command("git", "log", "--oneline") stdout, err := cmd.StdoutPipe() if err != nil { @@ -107,11 +126,15 @@ func getPrompt(message string) []azopenai.ChatMessage { func getChatCompletionResponse(messages []azopenai.ChatMessage) (string, error) { err := godotenv.Load() if err != nil { - fmt.Errorf(".env file not found: %v", err) + // .env file is optional, so we ignore the error + _ = err + } + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return "", fmt.Errorf("OPENAI_API_KEY environment variable is not set. Please export OPENAI_API_KEY=") } - keyCredential, err := azopenai.NewKeyCredential(os.Getenv("OPENAI_API_KEY")) + keyCredential, err := azopenai.NewKeyCredential(apiKey) if err != nil { - fmt.Errorf("export OPENAI_API_KEY= #execute this in your terminal and try again") return "", fmt.Errorf("error creating Azure OpenAI client: %v", err) } url := os.Getenv("OPENAI_URL") @@ -139,12 +162,38 @@ func getChatCompletionResponse(messages []azopenai.ChatMessage) (string, error) model = openai.GPT4 } + options := azopenai.ChatCompletionsOptions{ + Messages: messages, + Deployment: model, + } + + // Parse FINE_TUNE_PARAMS if provided + fineTuneParams := os.Getenv("FINE_TUNE_PARAMS") + if fineTuneParams != "" { + var params map[string]interface{} + if err := json.Unmarshal([]byte(fineTuneParams), ¶ms); err == nil { + // Apply common parameters + if temp, ok := params["temperature"].(float64); ok { + options.Temperature = to.Ptr(float32(temp)) + } + if maxTokens, ok := params["max_tokens"].(float64); ok { + options.MaxTokens = to.Ptr(int32(maxTokens)) + } + if topP, ok := params["top_p"].(float64); ok { + options.TopP = to.Ptr(float32(topP)) + } + if frequencyPenalty, ok := params["frequency_penalty"].(float64); ok { + options.FrequencyPenalty = to.Ptr(float32(frequencyPenalty)) + } + if presencePenalty, ok := params["presence_penalty"].(float64); ok { + options.PresencePenalty = to.Ptr(float32(presencePenalty)) + } + } + } + resp, err := client.GetChatCompletions( context.Background(), - azopenai.ChatCompletionsOptions{ - Messages: messages, - Deployment: model, - }, + options, nil, ) @@ -152,28 +201,19 @@ func getChatCompletionResponse(messages []azopenai.ChatMessage) (string, error) return "", fmt.Errorf("Completion error: %v", err) } - //for _, choice := range resp.Choices { - // fmt.Fprintf(os.Stderr, "Content[%d]: %s\n", *choice.Index, *choice.Message.Content) - //} - - return *resp.Choices[0].Message.Content, nil -} - -func getUserName() { - client, err := api.DefaultRESTClient() - if err != nil { - fmt.Println(err) - return + if len(resp.Choices) == 0 { + return "", fmt.Errorf("API returned no choices in response") } - response := struct{ Login string }{} - err = client.Get("user", &response) - if err != nil { - fmt.Println(err) - return + + if resp.Choices[0].Message.Content == nil { + return "", fmt.Errorf("API returned empty content in response") } + + return *resp.Choices[0].Message.Content, nil } -var patterns = []string{"```bash", "```plaintext","```diff", "```", "```python", "```javascript", "```go", "```java", "```csharp", "```ruby", "```php", "```html", "```css", "```json", "```xml", "```yaml", "```md", "```markdown", "```sql", "```shell", "```powershell", "```dockerfile", "```makefile", "```ini", "```apacheconf", "```nginx", "```git", "```vim", "```vimscrip"} +// Patterns ordered from most specific to least specific so language-specific markers are removed before generic ``` +var patterns = []string{"```bash", "```plaintext", "```diff", "```python", "```javascript", "```go", "```java", "```csharp", "```ruby", "```php", "```html", "```css", "```json", "```xml", "```yaml", "```md", "```markdown", "```sql", "```shell", "```powershell", "```dockerfile", "```makefile", "```ini", "```apacheconf", "```nginx", "```git", "```vim", "```vimscrip", "```"} func formatResponse(response string) string { for _, pattern := range patterns { diff --git a/util_test.go b/util_test.go index df2ecc8..7f15683 100644 --- a/util_test.go +++ b/util_test.go @@ -1,31 +1,43 @@ package main import ( - "reflect" + "os" + "strings" "testing" "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" ) +func Test_isGitRepository(t *testing.T) { + // This test will pass if run in a git repository, fail otherwise + // We can't easily mock this without setting up a test git repo + result := isGitRepository() + // Just verify the function doesn't panic + _ = result +} + func Test_getGitDiff(t *testing.T) { - tests := []struct { - name string - want string - wantErr bool - }{ - // TODO: Add test cases. + // This test requires a git repository + // We'll test that it handles errors gracefully when not in a git repo + if !isGitRepository() { + _, err := getGitDiff() + if err == nil { + t.Error("getGitDiff() should return error when not in git repository") + } + if err != nil && !strings.Contains(err.Error(), "not a git repository") { + t.Errorf("getGitDiff() error should mention 'not a git repository', got: %v", err) + } + return } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getGitDiff() - if (err != nil) != tt.wantErr { - t.Errorf("getGitDiff() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("getGitDiff() = %v, want %v", got, tt.want) - } - }) + + // If in git repo, test that it doesn't panic + _, err := getGitDiff() + if err != nil { + // Empty diff is expected in some cases + if !strings.Contains(err.Error(), "no changes detected") { + t.Logf("getGitDiff() returned error: %v", err) + } } } @@ -39,7 +51,31 @@ func Test_calculateTimeSaved(t *testing.T) { args args want float64 }{ - // TODO: Add test cases. + { + name: "Zero commits and words", + args: args{numCommits: 0, wordCount: 0}, + want: 0.0, + }, + { + name: "100 words", + args: args{numCommits: 1, wordCount: 100}, + want: 0.0, // 100 / 40 / 60 = 0.0416... rounded to 0.0 + }, + { + name: "2400 words (1 hour)", + args: args{numCommits: 10, wordCount: 2400}, + want: 1.0, // 2400 / 40 / 60 = 1.0 + }, + { + name: "4800 words (2 hours)", + args: args{numCommits: 20, wordCount: 4800}, + want: 2.0, // 4800 / 40 / 60 = 2.0 + }, + { + name: "573 words (from README example)", + args: args{numCommits: 29, wordCount: 573}, + want: 0.2, // 573 / 40 / 60 = 0.23875 rounded to 0.2 + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -51,28 +87,20 @@ func Test_calculateTimeSaved(t *testing.T) { } func Test_getCommitStats(t *testing.T) { - tests := []struct { - name string - want int - want1 int - wantErr bool - }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, got1, err := getCommitStats() - if (err != nil) != tt.wantErr { - t.Errorf("getCommitStats() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("getCommitStats() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("getCommitStats() got1 = %v, want %v", got1, tt.want1) - } - }) + // This test requires a git repository and may not work in all environments + // We'll test that it doesn't panic and handles errors gracefully + got, got1, err := getCommitStats() + if err != nil { + // If not in git repo, error is expected + if !isGitRepository() { + return // Expected error + } + t.Logf("getCommitStats() returned error: %v", err) + } else { + // If successful, verify we got non-negative values + if got < 0 || got1 < 0 { + t.Errorf("getCommitStats() returned negative values: got = %v, got1 = %v", got, got1) + } } } @@ -83,19 +111,58 @@ func Test_getDiffPrompt(t *testing.T) { tests := []struct { name string args args - want []azopenai.ChatMessage + want int // number of messages }{ - // TODO: Add test cases. + { + name: "Empty diff", + args: args{diff: ""}, + want: 3, // system, user, system + }, + { + name: "Simple diff", + args: args{diff: "+func test() {}"}, + want: 3, + }, + { + name: "With PROMPT_OVERRIDE", + args: args{diff: "test"}, + want: 3, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := getDiffPrompt(tt.args.diff); !reflect.DeepEqual(got, tt.want) { - t.Errorf("getDiffPrompt() = %v, want %v", got, tt.want) + got := getDiffPrompt(tt.args.diff) + if len(got) != tt.want { + t.Errorf("getDiffPrompt() message count = %v, want %v", len(got), tt.want) + } + // Verify structure + if len(got) >= 1 && got[0].Role == nil || *got[0].Role != azopenai.ChatRoleSystem { + t.Errorf("getDiffPrompt() first message should be system role") + } + if len(got) >= 2 && got[1].Role == nil || *got[1].Role != azopenai.ChatRoleUser { + t.Errorf("getDiffPrompt() second message should be user role") } }) } } +func Test_getDiffPrompt_WithPromptOverride(t *testing.T) { + // Set PROMPT_OVERRIDE + originalPrompt := os.Getenv("PROMPT_OVERRIDE") + defer os.Setenv("PROMPT_OVERRIDE", originalPrompt) + + customPrompt := "Custom prompt for testing" + os.Setenv("PROMPT_OVERRIDE", customPrompt) + + messages := getDiffPrompt("test diff") + if len(messages) < 1 { + t.Fatal("Expected at least one message") + } + if messages[0].Content == nil || *messages[0].Content != customPrompt { + t.Errorf("Expected custom prompt, got %v", messages[0].Content) + } +} + func Test_getPrompt(t *testing.T) { type args struct { message string @@ -105,58 +172,67 @@ func Test_getPrompt(t *testing.T) { args args want []azopenai.ChatMessage }{ - // TODO: Add test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := getPrompt(tt.args.message); !reflect.DeepEqual(got, tt.want) { - t.Errorf("getPrompt() = %v, want %v", got, tt.want) - } - }) - } -} - -func Test_getChatCompletionResponse(t *testing.T) { - type args struct { - messages []azopenai.ChatMessage - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - // TODO: Add test cases. + { + name: "Empty message", + args: args{message: ""}, + want: []azopenai.ChatMessage{ + {Role: to.Ptr(azopenai.ChatRoleSystem), Content: to.Ptr("")}, + }, + }, + { + name: "Simple message", + args: args{message: "Test message"}, + want: []azopenai.ChatMessage{ + {Role: to.Ptr(azopenai.ChatRoleSystem), Content: to.Ptr("Test message")}, + }, + }, + { + name: "Long message", + args: args{message: "This is a longer test message with multiple words"}, + want: []azopenai.ChatMessage{ + {Role: to.Ptr(azopenai.ChatRoleSystem), Content: to.Ptr("This is a longer test message with multiple words")}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := getChatCompletionResponse(tt.args.messages) - if (err != nil) != tt.wantErr { - t.Errorf("getChatCompletionResponse() error = %v, wantErr %v", err, tt.wantErr) + got := getPrompt(tt.args.message) + if len(got) != len(tt.want) { + t.Errorf("getPrompt() message count = %v, want %v", len(got), len(tt.want)) return } - if got != tt.want { - t.Errorf("getChatCompletionResponse() = %v, want %v", got, tt.want) + if got[0].Role == nil || *got[0].Role != *tt.want[0].Role { + t.Errorf("getPrompt() role = %v, want %v", got[0].Role, tt.want[0].Role) + } + if got[0].Content == nil || *got[0].Content != *tt.want[0].Content { + t.Errorf("getPrompt() content = %v, want %v", got[0].Content, tt.want[0].Content) } }) } } -func Test_getUserName(t *testing.T) { - tests := []struct { - name string - }{ - { - name: "Test_getUserName", - }, +func Test_getChatCompletionResponse_MissingAPIKey(t *testing.T) { + // Save original API key + originalKey := os.Getenv("OPENAI_API_KEY") + defer os.Setenv("OPENAI_API_KEY", originalKey) + + // Unset API key + os.Unsetenv("OPENAI_API_KEY") + + messages := []azopenai.ChatMessage{ + {Role: to.Ptr(azopenai.ChatRoleSystem), Content: to.Ptr("test")}, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - getUserName() - }) + + _, err := getChatCompletionResponse(messages) + if err == nil { + t.Error("getChatCompletionResponse() should return error when OPENAI_API_KEY is not set") + } + if err != nil && !strings.Contains(err.Error(), "OPENAI_API_KEY") { + t.Errorf("getChatCompletionResponse() error should mention OPENAI_API_KEY, got: %v", err) } } + func Test_formatResponse(t *testing.T) { type args struct { response string @@ -180,6 +256,41 @@ func Test_formatResponse(t *testing.T) { }, want: "Hello World", }, + { + name: "Test with go code block", + args: args{ + response: "```go\npackage main\n```", + }, + want: "\npackage main\n", + }, + { + name: "Test with python code block", + args: args{ + response: "```python\ndef hello():\n pass\n```", + }, + want: "\ndef hello():\n pass\n", + }, + { + name: "Test with no code block", + args: args{ + response: "Hello World", + }, + want: "Hello World", + }, + { + name: "Test with markdown code block", + args: args{ + response: "```markdown\n# Title\n```", + }, + want: "\n# Title\n", + }, + { + name: "Test with plaintext", + args: args{ + response: "```plaintext\nText here\n```", + }, + want: "\nText here\n", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {