diff --git a/.golangci.yml b/.golangci.yml index e253d15..55d9658 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -115,6 +115,7 @@ formatters: - gofmt - gofumpt - goimports + - golines exclusions: generated: lax paths: diff --git a/go.mod b/go.mod index 4a3c25e..34269f8 100644 --- a/go.mod +++ b/go.mod @@ -6,17 +6,50 @@ require ( github.com/alicebob/miniredis/v2 v2.35.0 github.com/google/go-cmp v0.7.0 github.com/google/uuid v1.6.0 + github.com/gotd/contrib v0.21.0 + github.com/gotd/td v0.128.0 github.com/redis/go-redis/v9 v9.11.0 github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 + golang.org/x/net v0.42.0 ) require ( + github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/coder/websocket v1.8.13 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/fatih/color v1.18.0 // indirect + github.com/ghodss/yaml v1.0.0 // indirect + github.com/go-faster/errors v0.7.1 // indirect + github.com/go-faster/jx v1.1.0 // indirect + github.com/go-faster/xor v1.0.0 // indirect + github.com/go-faster/yaml v0.4.6 // indirect + github.com/gotd/ige v0.2.2 // indirect + github.com/gotd/neo v0.1.5 // indirect + github.com/klauspost/compress v1.18.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ogen-go/ogen v1.12.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/segmentio/asm v1.2.0 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + go.uber.org/zap v1.27.0 // indirect + golang.org/x/crypto v0.40.0 // indirect + golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect + golang.org/x/mod v0.26.0 // indirect + golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.34.0 // indirect + golang.org/x/text v0.27.0 // indirect + golang.org/x/tools v0.35.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + rsc.io/qr v0.2.0 // indirect ) diff --git a/go.sum b/go.sum index 28b6e77..ecb752d 100644 --- a/go.sum +++ b/go.sum @@ -4,39 +4,120 @@ github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= +github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= +github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= +github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo= +github.com/go-faster/jx v1.1.0 h1:ZsW3wD+snOdmTDy9eIVgQdjUpXRRV4rqW8NS3t+20bg= +github.com/go-faster/jx v1.1.0/go.mod h1:vKDNikrKoyUmpzaJ0OkIkRQClNHFX/nF3dnTJZb3skg= +github.com/go-faster/xor v0.3.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ= +github.com/go-faster/xor v1.0.0 h1:2o8vTOgErSGHP3/7XwA5ib1FTtUsNtwCoLLBjl31X38= +github.com/go-faster/xor v1.0.0/go.mod h1:x5CaDY9UKErKzqfRfFZdfu+OSTfoZny3w5Ak7UxcipQ= +github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I= +github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gotd/contrib v0.21.0 h1:4Fj05jnyBE84toXZl7mVTvt7f732n5uglvztyG6nTr4= +github.com/gotd/contrib v0.21.0/go.mod h1:ENoUh75IhHGxfz/puVJg8BU4ZF89yrL6Q47TyoNqFYo= +github.com/gotd/ige v0.2.2 h1:XQ9dJZwBfDnOGSTxKXBGP4gMud3Qku2ekScRjDWWfEk= +github.com/gotd/ige v0.2.2/go.mod h1:tuCRb+Y5Y3eNTo3ypIfNpQ4MFjrnONiL2jN2AKZXmb0= +github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ= +github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ= +github.com/gotd/td v0.128.0 h1:OI0KyKwARNO4X+czb26+FLKXASFTWuHpgPs7Yaqm04o= +github.com/gotd/td v0.128.0/go.mod h1:rSekFfPYj5UEFky5EYnadT0WRU3DGoR4PFEMugk77uI= +github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= +github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ogen-go/ogen v1.12.0 h1:JMkn957i9/IPaSehqpblviy6Uao3eqQ+eVKUn4LM9pg= +github.com/ogen-go/ogen v1.12.0/go.mod h1:RL25amedfhq5xKTUuPBPn6nhYU59CWaVWYJ8YIjNHs0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8 h1:0A+M6Uqn+Eje4kHMK80dtF3JCXC4ykBgQG4Fe06QRhQ= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= +golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= +golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= +golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= +golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= +golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= +golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= +nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= +rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= +rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs= diff --git a/yatgclient/mtprotoproxy.go b/yatgclient/mtprotoproxy.go new file mode 100644 index 0000000..ac0782a --- /dev/null +++ b/yatgclient/mtprotoproxy.go @@ -0,0 +1,197 @@ +package yatgclient + +import ( + "encoding/hex" + "fmt" + "net" + "net/http" + "net/url" + "strconv" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + + "github.com/gotd/td/telegram/dcs" + "github.com/gotd/td/tg" +) + +// MTProto proxy helper +type MTProto struct { + Host string + Port uint16 + Secret string +} + +// NewMTProtoWithParseURL is a helper that allocates MTProto and calls ParseURL. +// +// Example: +// +// mtproto, _ := yatgclient.NewMTProtoWithParseURL("https://t.me/proxy?server=1.2.3.4&port=443&secret=abcdef", log) +func NewMTProtoWithParseURL(url string, log yalogger.Logger) (*MTProto, yaerrors.Error) { + mtproto := MTProto{} + + if err := mtproto.ParseURL(url, log); err != nil { + return nil, err.WrapWithLog("failed to create new mtproto proxy with url", log) + } + + return &mtproto, nil +} + +// String assembles a `t.me/proxy` share link from the struct fields. +// +// Example: +// +// m := yatgclient.MTProto{Host: "1.2.3.4", Port: 443, Secret: "abcdef"} +// link := m.String() // https://t.me/proxy?server=1.2.3.4&port=443&secret=abcdef +func (m *MTProto) String() string { + return fmt.Sprintf( + "https://t.me/proxy?server=%s&port=%d&secret=%s", + m.Host, m.Port, m.Secret, + ) +} + +// GetFullAddress returns the `host:port` pair suitable for dialing. +// +// Example: +// +// addr := m.GetFullAddress() // "1.2.3.4:443" +func (m *MTProto) GetFullAddress() string { + return net.JoinHostPort(m.Host, strconv.Itoa(int(m.Port))) +} + +// ParseURL populates the struct from a t.me/proxy share link. +// +// Supported formats: +// +// https://t.me/proxy?server=&port=&secret= +// +// Example: +// +// var m yatgclient.MTProto +// _ = m.ParseURL("https://t.me/proxy?server=1.2.3.4&port=443&secret=abcdef", log) +func (m *MTProto) ParseURL(proxyURL string, log yalogger.Logger) yaerrors.Error { + u, err := url.Parse(proxyURL) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to parse url for mtproto", + log, + ) + } + + const ( + queryHost = "server" + queryPort = "port" + querySecret = "secret" + ) + + host := u.Query().Get(queryHost) + if len(host) == 0 { + return yaerrors.FromStringWithLog( + http.StatusInternalServerError, + "failed to get host query", + log, + ) + } + + port := u.Query().Get(queryPort) + if len(port) == 0 { + return yaerrors.FromStringWithLog( + http.StatusInternalServerError, + "failed to get port query", + log, + ) + } + + secret := u.Query().Get(querySecret) + + if len(secret) == 0 { + return yaerrors.FromStringWithLog( + http.StatusInternalServerError, + "failed to get secret query", + log, + ) + } + + portInt, err := strconv.Atoi(port) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to parse port for mtproto", + log, + ) + } + + if portInt <= 0 || portInt > 65535 { + return yaerrors.FromStringWithLog( + http.StatusInternalServerError, + fmt.Sprintf("proxy port %d out of range 1–65535", portInt), + log, + ) + } + + m.Host = host + m.Secret = secret + m.Port = uint16(portInt) + + return nil +} + +// GetResolver builds a gotd `dcs.Resolver` backed by an MTProxy. +// +// Example: +// +// resolver, _ := mtproto.GetResolver(log) +func (m *MTProto) GetResolver(log yalogger.Logger) (dcs.Resolver, yaerrors.Error) { + if len(m.Host) == 0 { + return nil, yaerrors.FromStringWithLog( + http.StatusInternalServerError, + "empty host tag in mtproto", + log, + ) + } + + if m.Port == 0 { + return nil, yaerrors.FromStringWithLog( + http.StatusInternalServerError, + "proxy port equal zero", + log, + ) + } + + secret, err := hex.DecodeString(m.Secret) + if err != nil { + return nil, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to decode string as hex bytes", + log, + ) + } + + proxy, err := dcs.MTProxy(m.GetFullAddress(), secret, dcs.MTProxyOptions{}) + if err != nil { + return nil, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to create mtproto resolver", + log, + ) + } + + return proxy, nil +} + +// GetInputClientProxy converts the struct into tg.InputClientProxy from gotd +// +// Example: +// +// inputClientProxy := m.GetInputClientProxy() +func (m *MTProto) GetInputClientProxy() tg.InputClientProxy { + return tg.InputClientProxy{ + Address: m.Host, + Port: int(m.Port), + } +} diff --git a/yatgclient/socks5proxy.go b/yatgclient/socks5proxy.go new file mode 100644 index 0000000..ff3dd01 --- /dev/null +++ b/yatgclient/socks5proxy.go @@ -0,0 +1,181 @@ +package yatgclient + +import ( + "fmt" + "net" + "net/http" + "net/url" + "strconv" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + + "github.com/gotd/td/telegram/dcs" + "golang.org/x/net/proxy" +) + +// SOCKS5 helper +type SOCKS5 struct { + Host string + Port uint16 + Username *string + Password *string +} + +// NewSOCKS5WithParseURL parses a socks5:// URL into a SOCKS5 struct(socks5://username:password@host:port) +// +// Example: +// +// p, _ := yatgclient.NewSOCKS5WithParseURL("socks5://user:pass@1.2.3.4:1080", log) +func NewSOCKS5WithParseURL(url string, log yalogger.Logger) (*SOCKS5, yaerrors.Error) { + socks5 := SOCKS5{} + + if err := socks5.ParseURL(url, log); err != nil { + return nil, err.WrapWithLog("failed to create new socks5 proxy with url", log) + } + + return &socks5, nil +} + +// String returns socks5://… representation. +func (s *SOCKS5) String() string { + hostPort := s.GetFullAddress() + + if s.Username != nil && s.Password != nil { + return fmt.Sprintf("socks5://%s:%s@%s", *s.Username, *s.Password, hostPort) + } + + return "socks5://" + hostPort +} + +// GetFullAddress returns host:port. +func (s *SOCKS5) GetFullAddress() string { + return net.JoinHostPort(s.Host, strconv.Itoa(int(s.Port))) +} + +// GetAuth converts embedded creds into *proxy.Auth. +func (s *SOCKS5) GetAuth() *proxy.Auth { + if s.Username == nil || s.Password == nil { + return nil + } + + return &proxy.Auth{User: *s.Username, Password: *s.Password} +} + +// ParseURL fills the struct from a socks5:// URL. +// +// Example: +// +// var socks5 yatgclient.SOCKS5 +// _ = socks5.ParseURL("socks5://1.2.3.4:1080", log) +func (s *SOCKS5) ParseURL(proxyURL string, log yalogger.Logger) yaerrors.Error { + u, err := url.Parse(proxyURL) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to parse proxy url", + log, + ) + } + + switch u.Scheme { + case "socks5", "socks5h": + default: + return yaerrors.FromStringWithLog( + http.StatusInternalServerError, + fmt.Sprintf("unsupported proxy scheme %q (want socks5/socks5h)", u.Scheme), + log, + ) + } + + s.Host = u.Hostname() + + portStr := u.Port() + if portStr == "" { + log.Warn("proxy port not specified, using default 1080") + + portStr = "1080" + } + + portInt, err := strconv.Atoi(portStr) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "invalid proxy port", + log, + ) + } + + if portInt <= 0 || portInt > 65535 { + return yaerrors.FromStringWithLog( + http.StatusInternalServerError, + fmt.Sprintf("proxy port %d out of range 1–65535", portInt), + log, + ) + } + + s.Port = uint16(portInt) + + s.Username, s.Password = nil, nil + + if u.User != nil { + user := u.User.Username() + + s.Username = &user + if pass, ok := u.User.Password(); ok { + s.Password = &pass + } + } + + return nil +} + +// GetContextDialer converts SOCKS5 config into proxy.ContextDialer. +// +// Example: +// +// dialer, _ := socks5.GetContextDialer(log) +func (s *SOCKS5) GetContextDialer(log yalogger.Logger) (proxy.ContextDialer, yaerrors.Error) { + socks5, err := proxy.SOCKS5("tcp", s.GetFullAddress(), s.GetAuth(), proxy.Direct) + if err != nil { + return nil, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to create SOCKS5 proxy", + log, + ) + } + + contextDialer, ok := socks5.(proxy.ContextDialer) + if !ok { + return nil, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to cast proxy to ContextDialer", + log, + ) + } + + return contextDialer, nil +} + +// GetResolver returns a DC resolver using the SOCKS5 dialer. +// +// Example: +// +// resolver, _ := socks5.GetResolver(log) +func (s *SOCKS5) GetResolver(log yalogger.Logger) (dcs.Resolver, yaerrors.Error) { + dialer, err := s.GetContextDialer(log) + if err != nil { + return nil, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to get context dialer", + log, + ) + } + + return dcs.Plain(dcs.PlainOptions{Dial: dialer.DialContext}), nil +} diff --git a/yatgclient/yatgclient.go b/yatgclient/yatgclient.go new file mode 100644 index 0000000..9ec0c79 --- /dev/null +++ b/yatgclient/yatgclient.go @@ -0,0 +1,199 @@ +// Package yatgclient provides a thin convenience wrapper around gotd’s +// telegram.Client adding: +// - background‑connect helper with graceful shutdown +// - automatic bot‑token authorisation +// - updates.Manager wiring to yatgstorage (pts/qts/etc.) +// - SOCKS5 and MTProto proxy helpers (URL ↔ struct, dialer/resolver utilities) +package yatgclient + +import ( + "context" + "errors" + "net/http" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgstorage" + + "github.com/gotd/contrib/bg" + "github.com/gotd/td/telegram" + "github.com/gotd/td/telegram/updates" + "github.com/gotd/td/tgerr" +) + +// Client wrapper +type Client struct { + *telegram.Client + entityID int64 + log yalogger.Logger +} + +// Options to create a Client. +type ClientOptions struct { + AppID int + AppHash string + EntityID int64 + TelegramOptions telegram.Options +} + +// NewClient constructs a wrapper around gotd’s *telegram.Client. +// +// Example: +// +// cli := yatgclient.NewClient(yatgclient.ClientOptions{ +// AppID: 12345, AppHash: "abcd", EntityID: 42, +// TelegramOptions: telegram.Options{}, +// }, log) +func NewClient(options ClientOptions, log yalogger.Logger) *Client { + client := telegram.NewClient(options.AppID, options.AppHash, options.TelegramOptions) + + return &Client{ + Client: client, + entityID: options.EntityID, + log: log, + } +} + +// BackgroundConnect dials Telegram in a goroutine and stops automatically when +// ctx is cancelled. +// +// Example: +// +// _ = cli.BackgroundConnect(ctx) +func (c *Client) BackgroundConnect(ctx context.Context) yaerrors.Error { + stop, err := bg.Connect(c, bg.WithContext(ctx)) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to connect background client", + c.log, + ) + } + + go func() { + <-ctx.Done() + + if err := stop(); err != nil { + c.log.Errorf("Failed to stop telegram client connection: %v", err) + } + }() + + return nil +} + +// BotAuthorization ensures the client is authorised via botToken. +// +// Example: +// +// _ = cli.BotAuthorization(ctx, "123:ABC") +func (c *Client) BotAuthorization(ctx context.Context, botToken string) yaerrors.Error { + status, err := c.Auth().Status(ctx) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to check status bot authorization", + c.log, + ) + } + + if !status.Authorized { + if _, err := c.Auth().Bot(ctx, botToken); err != nil { + tgerr := &tgerr.Error{} + if errors.As(err, &tgerr) { + c.log.Errorf("%s", tgerr.Error()) + } else { + c.log.Errorf("%v", err) + } + + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to bot authorization", + c.log, + ) + } + } + + return nil +} + +// EntityError couples a processing error with the bot entityID. +// Used by RunUpdatesManager for multi‑bot setups. +type EntityError struct { + Err yaerrors.Error + EntityID int64 +} + +// RunUpdatesManager starts an updates.Manager in the background and returns a +// channel where any fatal error is sent. +// +// Example: +// +// errs := client.RunUpdatesManager(ctx, gaps, updates.AuthOptions{}, nil) +// if err := <-errs; err.Err != nil { log.Fatalf("%v", err.Err) } +func (c *Client) RunUpdatesManager( + ctx context.Context, + gaps *updates.Manager, + options updates.AuthOptions, + channel *chan EntityError, +) <-chan EntityError { + if channel == nil { + c := make(chan EntityError) + channel = &c + } + + c.log.Debug("Fetching self...") + + user, err := c.Self(ctx) + if err != nil { + go func() { + *channel <- EntityError{ + Err: yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to get self updates manager", + c.log, + ), + EntityID: c.entityID, + } + }() + + return *channel + } + + c.log.Debug("Running updates manager...") + + go func() { + if err = gaps.Run(ctx, c.API(), user.ID, options); err != nil { + *channel <- EntityError{ + Err: yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to run updates manager", + c.log, + ), + EntityID: c.entityID, + } + } + }() + + c.log.Debug("Updates manager started...") + + return *channel +} + +// NewUpdateManagerWithYaStorage creates an updates.Manager pre‑wired to a +// yatgstorage implementation. +// +// Example: +// +// gaps := yatgclient.NewUpdateManagerWithYaStorage(storage) +func NewUpdateManagerWithYaStorage(storage yatgstorage.IStorage) *updates.Manager { + return updates.New(updates.Config{ + Handler: storage.AccessHashSaveHandler(), + Storage: storage.TelegramStorageCompatible(), + AccessHasher: storage.TelegramAccessHasherCompatible(), + }) +} diff --git a/yatgclient/yatgclient_test.go b/yatgclient/yatgclient_test.go new file mode 100644 index 0000000..4715dd2 --- /dev/null +++ b/yatgclient/yatgclient_test.go @@ -0,0 +1,111 @@ +package yatgclient_test + +import ( + "fmt" + "net" + "strconv" + "testing" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgclient" + "github.com/gotd/td/tg" + "github.com/stretchr/testify/assert" + "golang.org/x/net/proxy" +) + +func TestSOCKS5_Works(t *testing.T) { + const ( + username = "skalse" + password = "lingvistka_sonya_echkere" + host = "yahost" + port = 8081 + ) + + url := fmt.Sprintf( + "socks5://%s:%s@%s", + username, + password, + net.JoinHostPort(host, strconv.Itoa(port)), + ) + + log := yalogger.NewBaseLogger(nil).NewLogger() + + socks5, _ := yatgclient.NewSOCKS5WithParseURL(url, log) + + t.Run("Username correct", func(t *testing.T) { + assert.Equal(t, username, *socks5.Username) + }) + + t.Run("Password correct", func(t *testing.T) { + assert.Equal(t, password, *socks5.Password) + }) + + t.Run("Host correct", func(t *testing.T) { + assert.Equal(t, host, socks5.Host) + }) + + t.Run("Port correct", func(t *testing.T) { + assert.Equal(t, uint16(port), socks5.Port) + }) + + t.Run("URL correct", func(t *testing.T) { + assert.Equal(t, url, socks5.String()) + }) + + t.Run("Get Full Address works", func(t *testing.T) { + expected := fmt.Sprintf("%s:%d", host, port) + + assert.Equal(t, expected, socks5.GetFullAddress()) + }) + + t.Run("Get Full Address works", func(t *testing.T) { + expected := proxy.Auth{User: username, Password: password} + + assert.Equal(t, expected, *socks5.GetAuth()) + }) +} + +func TestMTProto_Works(t *testing.T) { + const ( + secret = "https://open.spotify.com/track/1e1JKLEDKP7hEQzJfNAgPl?si=0dea7a7e6162462e" + host = "ya_playboy_carti" + port = 1847 + ) + + url := fmt.Sprintf("https://t.me/proxy?server=%s&port=%d&secret=%s", host, port, secret) + + log := yalogger.NewBaseLogger(nil).NewLogger() + + mtproto, _ := yatgclient.NewMTProtoWithParseURL(url, log) + + t.Run("Secret correct", func(t *testing.T) { + assert.Equal(t, secret, mtproto.Secret) + }) + + t.Run("Host correct", func(t *testing.T) { + assert.Equal(t, host, mtproto.Host) + }) + + t.Run("Port correct", func(t *testing.T) { + assert.Equal(t, uint16(port), mtproto.Port) + }) + + t.Run("Get Full Address works", func(t *testing.T) { + expected := fmt.Sprintf("%s:%d", host, port) + + assert.Equal(t, expected, mtproto.GetFullAddress()) + }) + + t.Run("Get Input Client Proxy works", func(t *testing.T) { + expected := tg.InputClientProxy{ + Address: host, + Port: port, + } + + assert.Equal(t, expected, mtproto.GetInputClientProxy()) + }) + + t.Run("URL correct", func(t *testing.T) { + assert.Equal(t, url, mtproto.String()) + }) +} diff --git a/yatgstorage/errors.go b/yatgstorage/errors.go new file mode 100644 index 0000000..a4809c1 --- /dev/null +++ b/yatgstorage/errors.go @@ -0,0 +1,25 @@ +package yatgstorage + +import "errors" + +var ( + ErrFailedToSetState = errors.New("failed to set telegram bot state") + ErrFailedToSetQts = errors.New("failed to set telegram bot qts") + ErrFailedToSetPts = errors.New("failed to set telegram bot pts") + ErrFailedToSetDate = errors.New("failed to set telegram bot date") + ErrFailedToSetSeq = errors.New("failed to set telegram bot seq") + ErrFailedToSetDateSeq = errors.New("failed to set telegram bot date and seq") + ErrFailedToGetState = errors.New("failed to get telegram bot state") + ErrFailedToUnmarshalState = errors.New("failed to unmarshal telegram bot state") + ErrFailedToSetChannelPts = errors.New("failed to set channel pts") + ErrFailedToGetChannelPts = errors.New("failed to get channel pts") + ErrFailedToUnmarshalChannelPts = errors.New("failed to unmarshal channel pts") + ErrFailedToSetChannelAccessHash = errors.New("failed to set channel access hash") + ErrFailedToGetChannelAccessHash = errors.New("failed to get channel access hash") + ErrFailedToUnmarshalChannelAccessHash = errors.New("failed to unmarshal channel access hash") + ErrFailedToParsePtsAsInt = errors.New("failed to parse pts as int") + ErrFailedToParseIDAsInt = errors.New("failed to parse id as int") + ErrFailedToParseAccessHashAsInt64 = errors.New("failed to parse access hash as int64") + ErrFailedToGetAllChannelPts = errors.New("failed to get all channel pts") + ErrFromCalledActionOfChannel = errors.New("error from called action of channel") +) diff --git a/yatgstorage/yatgstorage.go b/yatgstorage/yatgstorage.go new file mode 100644 index 0000000..adb2d1b --- /dev/null +++ b/yatgstorage/yatgstorage.go @@ -0,0 +1,939 @@ +// Package yatgstorage implements a Redis‑backed persistence layer for +// Telegram updates.Manager state (pts/qts/seq/date) plus channel/user +// access‑hash bookkeeping. +// +// The storage is fully compatible with gotd/td’s updates.Manager via +// TelegramStorageCompatible and TelegramAccessHasherCompatible adapters. +// +// # Layout in Redis +// +// - bot-state: – RedisJSON root with pts/qts/seq/date +// - bot-channel-pts: – HSET = +// - bot-channel-access-hash: – HSET = +// - bot-user-access-hash: – HSET = +// +// All JSON operations use redisjson (ReJSON v2). For high‑throughput +// production systems you can point s.cache at a sharded cluster without +// changing this code. +package yatgstorage + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yacache" + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/gotd/td/telegram" + "github.com/gotd/td/telegram/updates" + "github.com/gotd/td/tg" + "github.com/redis/go-redis/v9" +) + +const ( + // BasePathRedisJSON is the JSON root (“$”) used by ReJSON. + BasePathRedisJSON = "$" + // PtsPathRedisJSON is $.Pts in bot‑state JSON. + PtsPathRedisJSON = BasePathRedisJSON + ".Pts" + // QtsPathRedisJSON is $.Qts in bot‑state JSON. + QtsPathRedisJSON = BasePathRedisJSON + ".Qts" + // DatePathRedisJSON is $.Date in bot‑state JSON. + DatePathRedisJSON = BasePathRedisJSON + ".Date" + // SeqPathRedisJSON is $.Seq in bot‑state JSON. + SeqPathRedisJSON = BasePathRedisJSON + ".Seq" + + // AccessHashFieldRedisHSet is the field name for access‑hash in HSET buckets. + AccessHashFieldRedisHSet = "AccessHash" + // PtsFieldRedisHSet is the field name for pts in HSET buckets. + PtsFieldRedisHSet = "Pts" + + // Structured‑logging keys. + LoggerEntityID = "entity_id" + LoggerEntityKey = "entity_key" + LoggerUserID = "user_id" + LoggerChannelID = "channel_id" +) + +// IStorage exposes the behaviour required by your application **and** the +// gotd/td updates.Manager. Code in higher layers (handlers, services, unit +// tests) should depend on this interface rather than *Storage so that you can +// swap the implementation (e.g. in‑memory fake). All methods return a +// yaerrors.Error – a thin wrapper around the standard error enriched with an +// HTTP status and structured‑log context. +// +// Example: +// +// var stg IStorage = yatgstorage.NewStorage(cache, dispatcher, 123, log) +// if err := stg.SetPts(ctx, 123, 456); err != nil { +// log.Fatalf("failed: %v", err) +// } +type IStorage interface { + // Ping checks the backend yacache health. + Ping(ctx context.Context) yaerrors.Error + + // Bot‑wide state getters / setters. ‘found==false’ means “no key yet”. + GetState(ctx context.Context, entityID int64) (updates.State, bool, yaerrors.Error) + SetState(ctx context.Context, entityID int64, state updates.State) yaerrors.Error + SetPts(ctx context.Context, entityID int64, pts int) yaerrors.Error + SetQts(ctx context.Context, entityID int64, qts int) yaerrors.Error + SetDate(ctx context.Context, entityID int64, date int) yaerrors.Error + SetSeq(ctx context.Context, entityID int64, seq int) yaerrors.Error + SetDateSeq(ctx context.Context, entityID int64, date, seq int) yaerrors.Error + + // Per‑channel pts bookkeeping. + SetChannelPts(ctx context.Context, entityID, channelID int64, pts int) yaerrors.Error + GetChannelPts(ctx context.Context, entityID, channelID int64) (int, bool, yaerrors.Error) + ForEachChannels( + ctx context.Context, + entityID int64, + action func(ctx context.Context, channelID int64, pts int) error, + ) yaerrors.Error + + // Channel access‑hash bookkeeping. + SetChannelAccessHash(ctx context.Context, entityID, channelID, accessHash int64) yaerrors.Error + GetChannelAccessHash( + ctx context.Context, + entityID, channelID int64, + ) (int64, bool, yaerrors.Error) + + // Update‑pipeline helper: returns a handler that stores access‑hashes + // from any incoming updates before forwarding to the real handler. + AccessHashSaveHandler() HandlerFunc + + // User access‑hash bookkeeping. + SetUserAccessHash(ctx context.Context, userID int64, accessHash int64) + GetUserAccessHash(ctx context.Context, userID int64) (int64, yaerrors.Error) + + // gotd adapters + TelegramStorageCompatible() updates.StateStorage + TelegramAccessHasherCompatible() updates.ChannelAccessHasher +} + +// Storage is the production implementation backed by a yacache.Cache[*redis.Client]. +// +// It embeds a telegram.UpdateHandler (your own dispatcher) so we can inject a +// middle layer that persists access‑hashes before letting updates propagate. +// The zero value is **not** valid – use NewStorage. +// +// Example: +// +// stg := yatgstorage.NewStorage(cache, dispatcher, 42, log) +// _ = stg +// +// A single Storage instance should be used per bot (entityID). +// The struct keeps an internal map to cache “I have already created the base +// JSON object” flags for performance. +// +// Because methods are safe for concurrent use (they only rely on redis, which +// is thread‑safe), you may share *Storage between goroutines. +type Storage struct { + cache yacache.Cache[*redis.Client] + handler telegram.UpdateHandler + entityID int64 + stateKeys map[string]struct{} + log yalogger.Logger +} + +// NewStorage wires all dependencies and returns a ready‑to‑use *Storage. +// +// - cache – any yacache implementation; production code passes a Redis +// client, tests may pass yacache.NewMock. +// - handler – your app’s dispatcher (implements telegram.UpdateHandler). +// - entityID – unique bot identifier used to namespace all Redis keys. +// - log – structured logger. +// +// Example: +// +// stg := yatgstorage.NewStorage(cache, dispatcher, 123456, log) +// if err := stg.Ping(ctx); err != nil { +// log.Fatalf("redis down: %v", err) +// } +func NewStorage( + cache yacache.Cache[*redis.Client], + handler telegram.UpdateHandler, + entityID int64, + log yalogger.Logger, +) *Storage { + return &Storage{ + cache: cache, + handler: handler, + entityID: entityID, + stateKeys: map[string]struct{}{}, + log: log, + } +} + +// Ping checks that the yacache backend is operational. +// +// Example: +// +// if err := stg.Ping(ctx); err != nil { +// log.Errorf("storage unhealthy: %v", err) +// } +func (s *Storage) Ping(ctx context.Context) yaerrors.Error { + return s.cache.Ping(ctx) +} + +// TelegramStorageCompatible returns an adapter implementing updates.StateStorage +// so that gotd/td’s updates.Manager can persist pts/qts/seq/date directly into +// Redis. +// +// Example: +// +// manager := updates.New(updates.Config{Handler: handler, Storage: stg.TelegramStorageCompatible()}) +func (s *Storage) TelegramStorageCompatible() updates.StateStorage { + return &telegramStorage{ + storage: s, + } +} + +// TelegramAccessHasherCompatible returns an adapter implementing +// updates.ChannelAccessHasher so that updates.Manager can resolve channel +// access hashes via Redis. +func (s *Storage) TelegramAccessHasherCompatible() updates.ChannelAccessHasher { + return &telegramHasher{ + storage: s, + } +} + +// GetState retrieves the bot‑global State record (pts/qts/seq/date). +// found==false indicates the record does not exist yet. +// +// Example: +// +// state, ok, err := stg.GetState(ctx, botID) +// if err != nil { log.Fatal(err) } +// if ok { fmt.Printf("pts=%d", st.Pts) } +func (s *Storage) GetState( + ctx context.Context, + entityID int64, +) (updates.State, bool, yaerrors.Error) { + key := getBotStateKey(entityID) + + log := s.initBaseFieldsLog("Fetching entity state", key) + + if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { + return updates.State{}, false, err.WrapWithLog("failed to get entity state", log) + } + + data, yaerr := s.cache.Raw().JSONGet(ctx, key).Result() + if yaerr != nil { + return updates.State{}, false, nil + } + + var state updates.State + + err := json.Unmarshal([]byte(data), &state) + if err != nil { + return state, false, nil + } + + log.Debug("Entity state fetched") + + return state, true, nil +} + +// SetState stores the full updates.State. +// +// Example: +// +// err := stg.SetState(ctx, botID, updates.State{Pts: 10}) +// if err != nil { log.Fatal(err) } +func (s *Storage) SetState( + ctx context.Context, + entityID int64, + state updates.State, +) yaerrors.Error { + key := getBotStateKey(entityID) + + log := s.initBaseFieldsLog("Setting entity state", key).WithField(LoggerEntityID, entityID) + + if err := s.cache.Raw().JSONSet(ctx, key, BasePathRedisJSON, state).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetState), + "failed to set entity json", + log, + ) + } + + log.Debug("Entity state set") + + return nil +} + +// SetPts updates only $.Pts inside the stored state. +// +// Example: +// +// _ = stg.SetPts(ctx, botID, 123) +func (s *Storage) SetPts(ctx context.Context, entityID int64, pts int) yaerrors.Error { + key := getBotStateKey(entityID) + + log := s. + initBaseFieldsLog("Setting pts in entity state", key). + WithField(LoggerEntityID, entityID) + + if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { + return err.WrapWithLog("failed to set entity state pts", log) + } + + if err := s.cache.Raw().JSONSet(ctx, key, PtsPathRedisJSON, pts).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetPts), + "failed to set entity state pts", + log, + ) + } + + log.Debug("Have set pts in entity state") + + return nil +} + +// SetQts writes $.Qts only. +// +// Example: +// +// _ = stg.SetQts(ctx, botID, 77) +func (s *Storage) SetQts(ctx context.Context, entityID int64, qts int) yaerrors.Error { + key := getBotStateKey(entityID) + + log := s. + initBaseFieldsLog("Setting qts in entity state", key). + WithField(LoggerEntityID, entityID) + + if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { + return err.WrapWithLog("failed to set entity state qts", log) + } + + if err := s.cache.Raw().JSONSet(ctx, key, QtsPathRedisJSON, qts).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetQts), + "failed to set entity state qts", + log, + ) + } + + log.Debug("Have set qts in entity state") + + return nil +} + +// SetDate writes $.Date only. +// +// Example: +// +// _ = stg.SetDate(ctx, botID, int(time.Now().Unix())) +func (s *Storage) SetDate(ctx context.Context, entityID int64, date int) yaerrors.Error { + key := getBotStateKey(entityID) + + log := s. + initBaseFieldsLog("Setting date in state", key). + WithField(LoggerEntityID, entityID) + + if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { + return err.WrapWithLog("failed to set entity state date", log) + } + + if err := s.cache.Raw().JSONSet(ctx, key, DatePathRedisJSON, date).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetDate), + "failed to set entity state date", + log, + ) + } + + log.Debug("Have set date in entity state") + + return nil +} + +// SetSeq writes $.Seq only. +// +// Example: +// +// _ = stg.SetSeq(ctx, botID, 5) +func (s *Storage) SetSeq(ctx context.Context, entityID int64, seq int) yaerrors.Error { + key := getBotStateKey(entityID) + + log := s. + initBaseFieldsLog("Setting seq in state", key). + WithField(LoggerEntityID, entityID) + + if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { + return err.WrapWithLog("failed to set entity state seq", log) + } + + if err := s.cache.Raw().JSONSet(ctx, key, SeqPathRedisJSON, seq).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetSeq), + "failed to set entity state seq", + log, + ) + } + + log.Debug("Have set seq in entity state") + + return nil +} + +// SetDateSeq atomically writes $.Date and $.Seq. +// +// Example: +// +// _ = stg.SetDateSeq(ctx, botID, int(time.Now().Unix()), 9) +func (s *Storage) SetDateSeq(ctx context.Context, entityID int64, date, seq int) yaerrors.Error { + key := getBotStateKey(entityID) + + log := s. + initBaseFieldsLog("Setting date and seq in state", key). + WithField(LoggerEntityID, entityID) + + if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { + return err.WrapWithLog("failed to set entity state date and seq", log) + } + + if err := s.cache.Raw(). + JSONMSet(ctx, key, DatePathRedisJSON, date, key, SeqPathRedisJSON, seq).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetDateSeq), + "failed to set entity state date and seq", + log, + ) + } + + log.Debug("Have set date and seq in state") + + return nil +} + +// SetChannelPts stores channel pts value. +// +// Example: +// +// _ = stg.SetChannelPts(ctx, botID, chID, 120) +func (s *Storage) SetChannelPts( + ctx context.Context, + entityID, channelID int64, + pts int, +) yaerrors.Error { + key := getChannelPtsKey(entityID) + + log := s. + initBaseFieldsLog("Setting channel pts", key). + WithField(LoggerEntityID, entityID). + WithField(LoggerChannelID, channelID) + + if err := s.cache.Raw(). + HSet(ctx, key, strconv.FormatInt(channelID, 10), pts).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetChannelPts), + "failed to set channel pts", + log, + ) + } + + log.Debug("Have set channel pts") + + return nil +} + +// GetChannelPts returns pts for a channel. +// +// Example: +// +// pts, ok, _ := stg.GetChannelPts(ctx, botID, chID) +func (s *Storage) GetChannelPts( + ctx context.Context, + entityID, channelID int64, +) (int, bool, yaerrors.Error) { + key := getChannelPtsKey(entityID) + + log := s. + initBaseFieldsLog("Fetching channel pts", key). + WithField(LoggerUserID, entityID). + WithField(LoggerChannelID, channelID) + + data, yaerr := s.cache.HGet(ctx, key, strconv.FormatInt(channelID, 10)) + if yaerr != nil { + return 0, false, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(yaerr, ErrFailedToGetChannelPts), + "failed to get channel pts", + log, + ) + } + + res, err := strconv.ParseInt(data, 10, 0) + if err != nil { + return 0, false, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToParsePtsAsInt), + "failed to get channel pts", + log, + ) + } + + log.Debug("Fetched channel pts") + + return int(res), true, nil +} + +// ForEachChannels iterates over all channels. +// +// Example: +// +// _ = stg.ForEachChannels(ctx, botID, func(ctx context.Context, id int64, pts int) error { +// fmt.Println(id, pts); return nil +// }) +func (s *Storage) ForEachChannels( + ctx context.Context, + entityID int64, + action func(ctx context.Context, channelID int64, pts int) error, +) yaerrors.Error { + key := getChannelPtsKey(entityID) + + log := s.initBaseFieldsLog("Start action for each channels", key). + WithField(LoggerUserID, entityID) + + channels, err := s.cache.HGetAll(ctx, key) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToGetAllChannelPts), + "failed to get all channels", + log, + ) + } + + for c := range channels { + id, err := strconv.ParseInt(c, 10, 64) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToParseIDAsInt), + "failed to parse id as int", + log, + ) + } + + childLog := log.WithField(LoggerChannelID, id) + + pts, err := strconv.ParseInt(channels[c], 10, 0) + if err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToParsePtsAsInt), + "failed to parse pts as int", + log, + ) + } + + if err := action(ctx, id, int(pts)); err != nil { + childLog.Errorf("%v", err) + + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFromCalledActionOfChannel), + "failed to action of channel", + log, + ) + } + } + + log.Debug("Action manipulated for each channels") + + return nil +} + +// SetChannelAccessHash saves a channel access‑hash. +// +// Example: +// +// _ = stg.SetChannelAccessHash(ctx, botID, chID, hash) +func (s *Storage) SetChannelAccessHash( + ctx context.Context, + entityID, channelID, accessHash int64, +) yaerrors.Error { + key := getChannelAccessHashKey(entityID) + + log := s. + initBaseFieldsLog("Setting channel access hash for channel", key). + WithField(LoggerEntityID, entityID). + WithField(LoggerChannelID, channelID) + + if err := s.cache.Raw(). + HSet(ctx, key, strconv.FormatInt(channelID, 10), accessHash).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetChannelAccessHash), + "failed to set channel access hash", + log, + ) + } + + log.Debug("Have set channel access hash") + + return nil +} + +// GetChannelAccessHash retrieves a saved access‑hash. +// +// Example: +// +// hash, found, _ := stg.GetChannelAccessHash(ctx, botID, chID) +func (s *Storage) GetChannelAccessHash( + ctx context.Context, + entityID, channelID int64, +) (int64, bool, yaerrors.Error) { + key := getChannelAccessHashKey(entityID) + + log := s. + initBaseFieldsLog("Fetching channel access hash", key). + WithField(LoggerEntityID, entityID). + WithField(LoggerChannelID, channelID) + + data, err := s.cache.Raw(). + HGet(ctx, key, strconv.FormatInt(channelID, 10)).Result() + if err != nil { + return 0, false, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToGetChannelAccessHash), + "failed to get channel access hash", + log, + ) + } + + res, err := strconv.ParseInt(data, 10, 64) + if err != nil { + return 0, false, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToParseAccessHashAsInt64), + "failed to parse channel access hash as int64", + log, + ) + } + + log.Debug("Fetched channel access hash") + + return res, true, nil +} + +// HandlerFunc adapts a plain function into a gotd `telegram.UpdateHandler`. +// +// Example: +// +// h := yatgstorage.HandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error { +// fmt.Println("update received") +// return nil +// }) +// _ = h.Handle(ctx, &tg.Updates{}) +type HandlerFunc func(ctx context.Context, updates tg.UpdatesClass) error + +// Handle implements telegram.UpdateHandler by delegating to the underlying +// function. +// +// Example: +// +// _ = HandlerFunc(func(ctx context.Context, u tg.UpdatesClass) error { return nil }).Handle(ctx, &tg.Updates{}) +func (h HandlerFunc) Handle(ctx context.Context, updates tg.UpdatesClass) error { + return h(ctx, updates) +} + +// AccessHashSaveHandler returns middleware that intercepts Updates{,Combined}, +// saves every user’s AccessHash to Redis via SetUserAccessHash, then forwards +// the update to the real dispatcher. +// +// Example: +// +// clientOpts.UpdateHandler = storage.AccessHashSaveHandler() +func (s *Storage) AccessHashSaveHandler() HandlerFunc { + return HandlerFunc(func(ctx context.Context, updates tg.UpdatesClass) error { + switch update := updates.(type) { + case *tg.Updates: + for _, user := range update.MapUsers().NotEmptyToMap() { + if err := s.SetUserAccessHash(ctx, user.ID, user.AccessHash); err != nil { + s.log.Errorf("Failed to save user(%d) access hash(%d)", user.ID, user.AccessHash) + } + } + case *tg.UpdatesCombined: + for _, user := range update.MapUsers().NotEmptyToMap() { + if err := s.SetUserAccessHash(ctx, user.ID, user.AccessHash); err != nil { + s.log.Errorf("Failed to save user(%d) access hash(%d)", user.ID, user.AccessHash) + } + } + } + + return s.handler.Handle(ctx, updates) + }) +} + +// SetUserAccessHash persists a user access‑hash unless the ID equals the +// special @Channel_Bot placeholder. +// +// Example: +// +// _ = stg.SetUserAccessHash(ctx, 12345, 67890) +func (s *Storage) SetUserAccessHash( + ctx context.Context, + userID int64, + accessHash int64, +) yaerrors.Error { + const botChannelID = 136817688 // Ignore channel placeholder (@Channel_Bot - in Telegram) + + if userID != botChannelID { + key := getUserAccessHashKey(s.entityID) + + log := s.initBaseFieldsLog("Saving access hash", key).WithField(LoggerUserID, userID) + + if err := s.cache.Raw(). + HSet(ctx, key, strconv.FormatInt(userID, 10), accessHash).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to save user access hash", + log, + ) + } + + log.Debugf("Saved user access hash") + } + + return nil +} + +// GetUserAccessHash retrieves a user’s access‑hash. +// +// Example: +// +// hash, foundErr := stg.GetUserAccessHash(ctx, 12345) +func (s *Storage) GetUserAccessHash(ctx context.Context, userID int64) (int64, yaerrors.Error) { + key := getUserAccessHashKey(s.entityID) + + log := s.initBaseFieldsLog("fetching user access hash", key).WithField(LoggerUserID, userID) + + hash, err := s.cache.Raw().HGet(ctx, key, strconv.FormatInt(userID, 10)).Result() + if err != nil { + return 0, yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + err, + "failed to fetch user access hash", + log, + ) + } + + res, err := strconv.ParseInt(hash, 10, 64) + if err != nil { + return 0, yaerrors.FromErrorWithLog( + http.StatusBadRequest, + err, + ErrFailedToParseAccessHashAsInt64.Error(), + log, + ) + } + + log.Debugf("Fetched user access hash") + + return res, nil +} + +// initBaseFieldsLog attaches standard fields (entityID, redisKey) and issues a +// debug message. +// +// Example: +// +// l := stg.initBaseFieldsLog("doing work", "redis:key") +func (s *Storage) initBaseFieldsLog( + entryText string, + botKey string, +) yalogger.Logger { + log := s.log.WithField(LoggerEntityID, s.entityID).WithField(LoggerEntityKey, botKey) + + log.Debugf("%s", entryText) + + return log +} + +// safetyBaseStateJSON lazily creates an empty JSON object at key "$" if absent +// to guarantee follow‑up JSONSet operations succeed. +// +// Example: +// +// _ = stg.safetyBaseStateJSON(ctx, "bot-state:1", log) +func (s *Storage) safetyBaseStateJSON( + ctx context.Context, + key string, + log yalogger.Logger, +) yaerrors.Error { + if _, ok := s.stateKeys[key]; !ok { + if res, err := s.cache.Raw().JSONGet(ctx, key, BasePathRedisJSON).Result(); err != nil || + len(res) == 0 { + if err := s.cache.Raw().JSONSet(ctx, key, BasePathRedisJSON, updates.State{}).Err(); err != nil { + return yaerrors.FromErrorWithLog( + http.StatusInternalServerError, + errors.Join(err, ErrFailedToSetState), + "failed to create safety base root entity state", + log, + ) + } + } + + s.stateKeys[key] = struct{}{} + } + + return nil +} + +// getUserAccessHashKey forms the HSET key for user access‑hashes. +// +// Example: +// +// k := getUserAccessHashKey(42) // "bot-user-access-hash:42" +func getUserAccessHashKey(entityID int64) string { + return fmt.Sprintf("bot-user-access-hash:%d", entityID) +} + +// getBotStateKey forms the RedisJSON key for bot global state. +// +// Example: +// +// k := getBotStateKey(42) // "bot-state:42" +func getBotStateKey(entityID int64) string { + return fmt.Sprintf("bot-state:%d", entityID) +} + +// // getChannelAccessHashKey forms the HSET key for channel access‑hashes. +// +// Example: +// +// k := getChannelAccessHashKey(42) // "bot-channel-access-hash:42" +func getChannelAccessHashKey(entityID int64) string { + return fmt.Sprintf("bot-channel-access-hash:%d", entityID) +} + +// getChannelPtsKey forms the HSET key for channel pts. +// +// Example: +// +// k := getChannelPtsKey(42) // "bot-channel-pts:42" +func getChannelPtsKey(entityID int64) string { + return fmt.Sprintf("bot-channel-pts:%d", entityID) +} + +// Implementation native `gotd` iterface storage +type telegramStorage struct { + storage *Storage +} + +// GetState proxies Storage.GetState. +// +// Example: +// +// st, found, _ := stg.TelegramStorageCompatible().GetState(ctx, botID) +func (t *telegramStorage) GetState( + ctx context.Context, + userID int64, +) (state updates.State, found bool, err error) { + return t.storage.GetState(ctx, userID) +} + +// SetState proxies Storage.SetState. +// +// Example: +// +// _ = stg.TelegramStorageCompatible().SetState(ctx, botID, updates.State{Pts: 1}) +func (t *telegramStorage) SetState(ctx context.Context, userID int64, state updates.State) error { + return t.storage.SetState(ctx, userID, state) +} + +// SetPts proxies Storage.SetPts. +func (t *telegramStorage) SetPts(ctx context.Context, userID int64, pts int) error { + return t.storage.SetPts(ctx, userID, pts) +} + +// SetQts proxies Storage.SetQts. +func (t *telegramStorage) SetQts(ctx context.Context, userID int64, qts int) error { + return t.storage.SetQts(ctx, userID, qts) +} + +// SetDate proxies Storage.SetDate. +func (t *telegramStorage) SetDate(ctx context.Context, userID int64, date int) error { + return t.storage.SetDate(ctx, userID, date) +} + +// SetSeq proxies Storage.SetSeq. +func (t *telegramStorage) SetSeq(ctx context.Context, userID int64, seq int) error { + return t.storage.SetSeq(ctx, userID, seq) +} + +// SetDateSeq proxies Storage.SetDateSeq. +func (t *telegramStorage) SetDateSeq(ctx context.Context, userID int64, date, seq int) error { + return t.storage.SetDateSeq(ctx, userID, date, seq) +} + +// GetChannelPts proxies Storage.GetChannelPts. +func (t *telegramStorage) GetChannelPts( + ctx context.Context, + userID, channelID int64, +) (pts int, found bool, err error) { + return t.storage.GetChannelPts(ctx, userID, channelID) +} + +// SetChannelPts proxies Storage.SetChannelPts. +func (t *telegramStorage) SetChannelPts( + ctx context.Context, + userID, channelID int64, + pts int, +) error { + return t.storage.SetChannelPts(ctx, userID, channelID, pts) +} + +// SetChannelPts proxies Storage.ForEachChannels. +func (t *telegramStorage) ForEachChannels( + ctx context.Context, + userID int64, + f func(ctx context.Context, channelID int64, pts int) error, +) error { + return t.storage.ForEachChannels(ctx, userID, f) +} + +// Implementation native `gotd` interface hasher +type telegramHasher struct { + storage *Storage +} + +// SetChannelAccessHash proxies Storage.SetChannelAccessHash. +// +// Example: +// +// _ = stg.TelegramAccessHasherCompatible().SetChannelAccessHash(ctx, botID, chID, hash) +func (t *telegramHasher) SetChannelAccessHash( + ctx context.Context, + userID, channelID, accessHash int64, +) error { + return t.storage.SetChannelAccessHash(ctx, userID, channelID, accessHash) +} + +// GetChannelAccessHash proxies Storage.GetChannelAccessHash. +// +// Example: +// +// hash, found, _ := stg.TelegramAccessHasherCompatible().GetChannelAccessHash(ctx, botID, chID) +func (t *telegramHasher) GetChannelAccessHash( + ctx context.Context, + userID, + channelID int64, +) (accessHash int64, found bool, err error) { + return t.storage.GetChannelAccessHash(ctx, userID, channelID) +} diff --git a/yatgstorage/yatgstorage_test.go b/yatgstorage/yatgstorage_test.go new file mode 100644 index 0000000..1601fb7 --- /dev/null +++ b/yatgstorage/yatgstorage_test.go @@ -0,0 +1,124 @@ +package yatgstorage_test + +import ( + "context" + "testing" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yacache" + "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgstorage" + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func setupTestRedis(t *testing.T) (*redis.Client, func()) { + mr, err := miniredis.Run() + + require.NoError(t, err) + + client := redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + cleanup := func() { + client.Close() + mr.Close() + } + + return client, cleanup +} + +func TestStorage_CreateWorks(t *testing.T) { + client, cleanup := setupTestRedis(t) + + defer cleanup() + + if err := yatgstorage. + NewStorage(yacache.NewCache(client), nil, 0, yalogger.NewBaseLogger(nil).NewLogger()). + Ping(context.Background()); err != nil { + t.Fatalf("Failed to create tg storage") + } +} + +func TestStorageChannel_WorkFlowWorks(t *testing.T) { + const ( + entityID = 1111 + channelID = 1111 + ) + + ctx := context.Background() + + client, cleanup := setupTestRedis(t) + log := yalogger.NewBaseLogger(nil).NewLogger() + + defer cleanup() + + storage := yatgstorage. + NewStorage(yacache.NewCache(client), nil, 1001, log) + + t.Run("Set and Get channel pts - works", func(t *testing.T) { + const expected = 1000 + + _ = storage.SetChannelPts(ctx, entityID, channelID, expected) + + result, _, _ := storage.GetChannelPts(ctx, entityID, channelID) + + assert.Equal(t, expected, result) + }) + + t.Run("For each channels iterate - works", func(t *testing.T) { + const entityChildID = 9 + + channelIDs := []int64{1, 2, 3, 4, 5, 6, 7} + + for _, v := range channelIDs { + _ = storage.SetChannelPts(ctx, entityChildID, v, int(v)*2) + } + + _ = storage.ForEachChannels( + ctx, + entityChildID, + func(_ context.Context, channelID int64, pts int) error { + assert.Equal(t, int(channelID)*2, pts) + + return nil + }, + ) + }) + + t.Run("Set and Get channel access hash - works", func(t *testing.T) { + expected := int64(100) + + _ = storage.SetChannelAccessHash(ctx, entityID, channelID, expected) + + result, _, _ := storage.GetChannelAccessHash(ctx, entityID, channelID) + + assert.Equal(t, expected, result) + }) +} + +func TestStorageUser_WorkFlowWorks(t *testing.T) { + ctx := context.Background() + + client, cleanup := setupTestRedis(t) + log := yalogger.NewBaseLogger(nil).NewLogger() + + defer cleanup() + + storage := yatgstorage. + NewStorage(yacache.NewCache(client), nil, 1001, log) + + t.Run("Set and Get user access hash - works", func(t *testing.T) { + const userID = 2222 + + expected := int64(200) + + _ = storage.SetUserAccessHash(ctx, userID, expected) + + result, _ := storage.GetUserAccessHash(ctx, userID) + + assert.Equal(t, expected, result) + }) +} diff --git a/yathreadsafeset/utils.go b/yathreadsafeset/utils.go new file mode 100644 index 0000000..6449888 --- /dev/null +++ b/yathreadsafeset/utils.go @@ -0,0 +1,8 @@ +package yathreadsafeset + +// safetyCheck ensures that the internal set is initialized before any operations are performed. +func (m *ThreadSafeSet[K]) safetyCheck() { + if m.data == nil { + m.data = make(map[K]struct{}) + } +} diff --git a/yathreadsafeset/yathreadsafeset.go b/yathreadsafeset/yathreadsafeset.go new file mode 100644 index 0000000..342e619 --- /dev/null +++ b/yathreadsafeset/yathreadsafeset.go @@ -0,0 +1,518 @@ +package yathreadsafeset + +import ( + "encoding/json" + "fmt" + "maps" + "sync" +) + +// ThreadSafeSet is a generic set implementation that supports concurrent read and write operations safely. +type ThreadSafeSet[K comparable] struct { + data map[K]struct{} + mu sync.RWMutex +} + +// NewThreadSafeSet returns a new instance of a thread-safe set with initialized internal storage. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +func NewThreadSafeSet[K comparable]() *ThreadSafeSet[K] { + return &ThreadSafeSet[K]{ + data: make(map[K]struct{}), + } +} + +// Clear removes all values from the set, resetting its internal state. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// fmt.Println(set.String()) // Outputs: ["value1"] +// set.Clear() +// fmt.Println(set.String()) // Outputs: [] +func (m *ThreadSafeSet[K]) Clear() { + m.safetyCheck() + m.mu.Lock() + m.data = make(map[K]struct{}) + m.mu.Unlock() +} + +// Copy returns a new copy of the current set's content to avoid concurrency issues. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// copySet := set.Copy() +// set.Delete("value1") +// fmt.Println(copySet.String()) // Outputs: ["value1"] +func (m *ThreadSafeSet[K]) Copy() *ThreadSafeSet[K] { + m.safetyCheck() + m.mu.RLock() + + copySet := NewThreadSafeSet[K]() + maps.Copy(copySet.data, m.data) + + m.mu.RUnlock() + + return copySet +} + +// Delete removes the specified value from the set if it exists. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// set.Delete("value1") // Removes "value1" from the set +func (m *ThreadSafeSet[K]) Delete(value K) { + m.safetyCheck() + m.mu.Lock() + delete(m.data, value) + m.mu.Unlock() +} + +// Has checks whether a given value exists in the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// fmt.Println(set.Has("value1")) // Outputs: true +// fmt.Println(set.Has("value2")) // Outputs: false +func (m *ThreadSafeSet[K]) Has(value K) bool { + m.safetyCheck() + m.mu.RLock() + _, exists := m.data[value] + m.mu.RUnlock() + + return exists +} + +// Iterate iterates over the set and calls the given function for each value. +// +// DEADLOCK: During iteration, it is forbidden to modify the set (add or remove values), +// failing to do so will result in a deadlock. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// set.Set("value2") +// set.Iterate(func(value string) { +// fmt.Println(value) // Outputs: value1, value2 +// }) +func (m *ThreadSafeSet[K]) Iterate(fn func(K)) { + m.safetyCheck() + + m.mu.RLock() + defer m.mu.RUnlock() + + for k := range m.data { + fn(k) + } +} + +// IterateOnCopy iterates over a copy of the set to avoid holding locks during iteration. +// +// DEADLOCK: During iteration, it is forbidden to modify the set (add or remove values), +// failing to do so will result in a deadlock. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// +// set.IterateOnCopy(func(value string) { +// fmt.Println(value) // Outputs: value1 +// time.Sleep(1 * time.Second) // Assume time-consuming processing +// }) +func (m *ThreadSafeSet[K]) IterateOnCopy(fn func(K)) { + for k := range m.CopyRaw() { + fn(k) + } +} + +// IterateWithBreak iterates through the set until the callback returns false, then breaks. +// +// DEADLOCK: During iteration, it is forbidden to modify the set (add or remove values), +// failing to do so will result in a deadlock. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// +// set.IterateWithBreak(func(value string) bool { +// fmt.Println(value) // Outputs: value1 +// return true // Continue iteration +// }) +func (m *ThreadSafeSet[K]) IterateWithBreak(fn func(K) bool) { + m.safetyCheck() + + m.mu.RLock() + defer m.mu.RUnlock() + + for k := range m.data { + if !fn(k) { + break + } + } +} + +// Length returns the total number of values in the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// fmt.Println(set.Length()) // Outputs: 0 +// set.Set("value1") +// fmt.Println(set.Length()) // Outputs: 1 +func (m *ThreadSafeSet[K]) Length() int { + m.safetyCheck() + m.mu.RLock() + length := len(m.data) + m.mu.RUnlock() + + return length +} + +// MarshalJSON provides a custom JSON marshaling implementation for the thread-safe set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// jsonData, err := json.Marshal(set) +// +// if err != nil { +// // handle error +// } +// +// fmt.Println(string(jsonData)) // Outputs: ["value1"] +func (m *ThreadSafeSet[K]) MarshalJSON() ([]byte, error) { + m.safetyCheck() + m.mu.RLock() + + data, err := json.Marshal(m.Values()) + if err != nil { + return nil, fmt.Errorf("failed to marshal set: %w", err) + } + + m.mu.RUnlock() + + return data, nil +} + +// Pop removes and returns a boolean indicating if the value was found. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// fmt.Println(set.String()) // Outputs: ["value1"] +// popped := set.Pop("value1") // Removes "value1" from the set +// fmt.Println(popped) // Outputs: true +// fmt.Println(set.String()) // Outputs: [] +func (m *ThreadSafeSet[K]) Pop(value K) bool { + m.safetyCheck() + m.mu.Lock() + + _, ok := m.data[value] + if ok { + delete(m.data, value) + } + + m.mu.Unlock() + + return ok +} + +// Set adds a value to the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") // Adds "value1" to the set +// fmt.Println(set.String()) // Outputs: ["value1"] +func (m *ThreadSafeSet[K]) Set(value K) { + m.safetyCheck() + m.mu.Lock() + m.data[value] = struct{}{} + m.mu.Unlock() +} + +// ImportFromMap imports values from a map into the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// src := map[string]struct{}{"value1": {}, "value2": {}} +// set.ImportFromMap(src) +// fmt.Println(set.String()) // Outputs: ["value1", "value2"] +func (m *ThreadSafeSet[K]) ImportFromMap(src map[K]struct{}) { + m.safetyCheck() + m.mu.Lock() + + for k := range src { + m.data[k] = struct{}{} + } + + m.mu.Unlock() +} + +func (m *ThreadSafeSet[K]) CopyRaw() map[K]struct{} { + m.safetyCheck() + m.mu.RLock() + + copySet := make(map[K]struct{}, len(m.data)) + maps.Copy(copySet, m.data) + + m.mu.RUnlock() + + return copySet +} + +// String returns a pretty-printed JSON string representation of the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// fmt.Println(set.String()) // Outputs: ["value1"] +func (m *ThreadSafeSet[K]) String() string { + m.safetyCheck() + m.mu.RLock() + + b, err := json.MarshalIndent(m.Values(), "", " ") + if err != nil { + return "" + } + + m.mu.RUnlock() + + return string(b) +} + +// Values returns a slice of all values stored in the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// values := set.Values() +// fmt.Println(values) // Outputs: ["value1"] +func (m *ThreadSafeSet[K]) Values() []K { + m.safetyCheck() + m.mu.RLock() + + values := make([]K, 0, len(m.data)) + for k := range m.data { + values = append(values, k) + } + + m.mu.RUnlock() + + return values +} + +// Intersect returns a slice of values that are present in both the set and the provided slice. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// set.Set("value2") +// other := threadsafeset.NewThreadSafeSet[string]() +// other.Set("value2") +// intersection := set.Intersect(other) +// fmt.Println(intersection.String()) // Outputs: ["value2"] +func (m *ThreadSafeSet[K]) Intersect(other *ThreadSafeSet[K]) *ThreadSafeSet[K] { + m.safetyCheck() + other.safetyCheck() + m.mu.RLock() + other.mu.RLock() + + intersection := NewThreadSafeSet[K]() + + for k := range m.data { + if other.Has(k) { + intersection.Set(k) + } + } + + m.mu.RUnlock() + other.mu.RUnlock() + + return intersection +} + +// DeleteMultiple removes multiple values from the set. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// set.Set("value1") +// set.Set("value2") +// set.DeleteMultiple([]string{"value1", "value2"}) +func (m *ThreadSafeSet[K]) DeleteMultiple(values []K) { + m.safetyCheck() + m.mu.Lock() + + for _, v := range values { + delete(m.data, v) + } + + m.mu.Unlock() +} + +// IsEmpty checks if the set is empty. +// +// Example usage: +// +// set := threadsafeset.NewThreadSafeSet[string]() +// fmt.Println(set.IsEmpty()) // Outputs: true +// set.Set("value1") +// fmt.Println(set.IsEmpty()) // Outputs: false +func (m *ThreadSafeSet[K]) IsEmpty() bool { + return m.Length() == 0 +} + +// IsEqual checks if the current set is equal to another set. +// Two sets are considered equal if they contain the same elements. +// +// Example usage: +// +// set1 := threadsafeset.NewThreadSafeSet[string]() +// set2 := threadsafeset.NewThreadSafeSet[string]() +// set1.Set("value1") +// set2.Set("value1") +// fmt.Println(set1.IsEqual(set2)) // Outputs: true +// set2.Set("value2") +// fmt.Println(set1.IsEqual(set2)) // Outputs: false +func (m *ThreadSafeSet[K]) IsEqual(other *ThreadSafeSet[K]) bool { + m.safetyCheck() + other.safetyCheck() + + if m.Length() != other.Length() { + return false + } + + m.mu.RLock() + other.mu.RLock() + + for k := range m.data { + if !other.Has(k) { + return false + } + } + + m.mu.RUnlock() + other.mu.RUnlock() + + return true +} + +// Union returns a new set containing elements that are in either the current set or the other set. +// +// Example usage: +// +// set1 := threadsafeset.NewThreadSafeSet[string]() +// set2 := threadsafeset.NewThreadSafeSet[string]() +// set1.Set("value1") +// set2.Set("value2") +// result := set1.Union(set2) +// fmt.Println(result.String()) // Outputs: ["value1", "value2"] +func (m *ThreadSafeSet[K]) Union(other *ThreadSafeSet[K]) *ThreadSafeSet[K] { + m.safetyCheck() + other.safetyCheck() + + result := NewThreadSafeSet[K]() + + m.mu.RLock() + other.mu.RLock() + + for k := range m.data { + result.Set(k) + } + + for k := range other.data { + result.Set(k) + } + + m.mu.RUnlock() + other.mu.RUnlock() + + return result +} + +// Difference returns a new set containing elements that are in the current set but not in the other set. +// +// Example usage: +// +// set1 := threadsafeset.NewThreadSafeSet[string]() +// set2 := threadsafeset.NewThreadSafeSet[string]() +// set1.Set("value1") +// set2.Set("value2") +// result := set1.Difference(set2) +// fmt.Println(result.String()) // Outputs: ["value1"] +func (m *ThreadSafeSet[K]) Difference(other *ThreadSafeSet[K]) *ThreadSafeSet[K] { + m.safetyCheck() + other.safetyCheck() + + result := NewThreadSafeSet[K]() + + m.mu.RLock() + other.mu.RLock() + + for k := range m.data { + if !other.Has(k) { + result.Set(k) + } + } + + m.mu.RUnlock() + other.mu.RUnlock() + + return result +} + +// SymmetricDifference returns a new set containing elements that are in either set but not in both. +// +// Example usage: +// +// set1 := threadsafeset.NewThreadSafeSet[string]() +// set2 := threadsafeset.NewThreadSafeSet[string]() +// set1.Set("value1") +// set2.Set("value2") +// result := set1.SymmetricDifference(set2) +// fmt.Println(result.String()) // Outputs: ["value1", "value2"] +func (m *ThreadSafeSet[K]) SymmetricDifference(other *ThreadSafeSet[K]) *ThreadSafeSet[K] { + m.safetyCheck() + other.safetyCheck() + + result := NewThreadSafeSet[K]() + + m.mu.RLock() + other.mu.RLock() + + for k := range m.data { + if !other.Has(k) { + result.Set(k) + } + } + + for k := range other.data { + if !m.Has(k) { + result.Set(k) + } + } + + m.mu.RUnlock() + other.mu.RUnlock() + + return result +} diff --git a/yathreadsafeset/yathreadsafeset_test.go b/yathreadsafeset/yathreadsafeset_test.go new file mode 100644 index 0000000..0712c5d --- /dev/null +++ b/yathreadsafeset/yathreadsafeset_test.go @@ -0,0 +1,498 @@ +package yathreadsafeset_test + +import ( + "encoding/json" + "math/rand" + "reflect" + "slices" + "strings" + "sync" + "testing" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yathreadsafeset" +) + +func TestThreadSafeSet_BasicOps(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + + set.Set("a") + set.Set("b") + + if !set.Has("a") || !set.Has("b") { + t.Fatalf("Set or Has failed") + } + + if set.Has("c") { + t.Fatalf("Has returned true for missing element") + } + + if set.Length() != 2 { + t.Fatalf("Length failed, got %d", set.Length()) + } + + set.Delete("a") + + if set.Has("a") { + t.Fatalf("Delete failed") + } + + set.Delete("b") + + if !set.IsEmpty() { + t.Fatalf("IsEmpty failed after delete") + } + + set.Set("z") + + if !set.Pop("z") { + t.Fatalf("Pop failed") + } + + if set.Pop("z") { + t.Fatalf("Pop should fail for non-existent element") + } + + set.Set("x") + set.Clear() + + if !set.IsEmpty() { + t.Fatalf("Clear failed, set should be empty") + } +} + +func TestThreadSafeSet_Iterate(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[int]() + + vals := map[int]struct{}{1: {}, 2: {}, 3: {}} + for k := range vals { + set.Set(k) + } + + collected := map[int]struct{}{} + set.Iterate(func(x int) { + collected[x] = struct{}{} + }) + + if !reflect.DeepEqual(collected, vals) { + t.Fatalf("Iterate did not visit all values, got: %+v", collected) + } +} + +func TestThreadSafeSet_IterateOnCopy(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[int]() + for i := 1; i <= 5; i++ { + set.Set(i) + } + + var ( + mu sync.Mutex + visited []int + ) + + set.IterateOnCopy(func(x int) { + mu.Lock() + + visited = append(visited, x) + + mu.Unlock() + }) + + want := []int{1, 2, 3, 4, 5} + for _, v := range want { + found := slices.Contains(visited, v) + if !found { + t.Fatalf("IterateOnCopy missed %d", v) + } + } +} + +func TestThreadSafeSet_IterateWithBreak(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[int]() + for i := 1; i <= 5; i++ { + set.Set(i) + } + + var cnt int + set.IterateWithBreak(func(_ int) bool { + cnt++ + + return cnt < 3 + }) + + if cnt != 3 { + t.Fatalf("IterateWithBreak did not break after 3") + } +} + +func TestThreadSafeSet_ImportFromMap(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + src := map[string]struct{}{"foo": {}, "bar": {}} + set.ImportFromMap(src) + + if !set.Has("foo") || !set.Has("bar") { + t.Fatalf("ImportFromMap failed") + } +} + +func TestThreadSafeSet_CopyRaw(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + set.Set("a") + set.Set("b") + + m := set.CopyRaw() + if len(m) != 2 || m["a"] != struct{}{} || m["b"] != struct{}{} { + t.Fatalf("CopyRaw failed") + } +} + +func TestThreadSafeSet_StringAndMarshalJSON(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + set.Set("foo") + set.Set("bar") + + s := set.String() + if !strings.Contains(s, "foo") || !strings.Contains(s, "bar") || + strings.Contains(s, "") { + t.Fatalf("String() failed: %q", s) + } + + data, err := json.Marshal(set) + if err != nil { + t.Fatalf("MarshalJSON failed: %v", err) + } + + if !strings.Contains(string(data), "foo") || !strings.Contains(string(data), "bar") { + t.Fatalf("MarshalJSON output wrong: %q", string(data)) + } +} + +func TestThreadSafeSet_Values(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[int]() + + vals := []int{7, 8, 9} + for _, v := range vals { + set.Set(v) + } + + got := set.Values() + for _, v := range vals { + found := slices.Contains(got, v) + if !found { + t.Fatalf("Values() missing %d", v) + } + } +} + +func TestThreadSafeSet_DeleteMultiple(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + set.Set("x") + set.Set("y") + set.Set("z") + set.DeleteMultiple([]string{"x", "z"}) + + if set.Has("x") || set.Has("z") || !set.Has("y") { + t.Fatalf("DeleteMultiple failed") + } +} + +func TestThreadSafeSet_IsEqual(t *testing.T) { + a := yathreadsafeset.NewThreadSafeSet[int]() + + b := yathreadsafeset.NewThreadSafeSet[int]() + if !a.IsEqual(b) { + t.Fatalf("Empty sets should be equal") + } + + a.Set(1) + + if a.IsEqual(b) { + t.Fatalf("Should not be equal after add") + } + + b.Set(1) + + if !a.IsEqual(b) { + t.Fatalf("Sets with same content should be equal") + } +} + +func TestThreadSafeSet_Concurrency(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[int]() + + var wg sync.WaitGroup + + n := 1000 + + for i := range n { + wg.Add(1) + + go func(x int) { + set.Set(x) + wg.Done() + }(i) + } + + wg.Wait() + + if set.Length() != n { + t.Fatalf("Concurrency Set failed, got %d", set.Length()) + } + + for i := range n { + wg.Add(1) + + go func(x int) { + set.Delete(x) + wg.Done() + }(i) + } + + wg.Wait() + + if !set.IsEmpty() { + t.Fatalf("Concurrency Delete failed") + } +} + +func TestThreadSafeSet_IsEmpty(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + if !set.IsEmpty() { + t.Fatalf("Empty set must be IsEmpty()") + } + + set.Set("abc") + + if set.IsEmpty() { + t.Fatalf("Non-empty set is not empty") + } + + set.Clear() + + if !set.IsEmpty() { + t.Fatalf("IsEmpty after Clear() should be true") + } +} + +func TestThreadSafeSet_TypeParamSupport(t *testing.T) { + type custom struct{ v int } + + set := yathreadsafeset.NewThreadSafeSet[custom]() + val := custom{42} + set.Set(val) + + if !set.Has(val) { + t.Fatalf("Set/Has failed for custom type") + } +} + +func TestThreadSafeSet_MarshalUnmarshal(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + set.Set("one") + set.Set("two") + + b, err := json.Marshal(set) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + + var vals []string + if err := json.Unmarshal(b, &vals); err != nil { + t.Fatalf("Unmarshal failed: %v", err) + } + + if len(vals) != 2 || (vals[0] != "one" && vals[1] != "one") { + t.Fatalf("Marshal/Unmarshal output wrong: %+v", vals) + } +} + +func NewThreadSafeSetFromSlice[T comparable](slice []T) *yathreadsafeset.ThreadSafeSet[T] { + s := yathreadsafeset.NewThreadSafeSet[T]() + for _, v := range slice { + s.Set(v) + } + + return s +} + +func TestThreadSafeSet_Stress(_ *testing.T) { + const ( + goroutines = 64 + opsPerG = 5000 + ) + + set := yathreadsafeset.NewThreadSafeSet[int]() + + var wg sync.WaitGroup + + for range goroutines { + wg.Add(1) + + go func() { + for range opsPerG { + op := rand.Intn(4) + val := rand.Intn(1000) + + switch op { + case 0: + set.Set(val) + case 1: + set.Delete(val) + case 2: + set.Has(val) + case 3: + set.Length() + } + } + + wg.Done() + }() + } + + wg.Wait() +} + +func TestThreadSafeSet_Copy(t *testing.T) { + set := yathreadsafeset.NewThreadSafeSet[string]() + set.Set("a") + set.Set("b") + + copySet := set.Copy() + + if !copySet.Has("a") || !copySet.Has("b") { + t.Fatalf("Copy failed, missing elements") + } + + copySet.Delete("a") + + if !set.Has("a") { + t.Fatalf("Original set should not be affected by copy modification") + } + + if copySet.Has("b") { + t.Logf("Copy still has 'b': %v", copySet) + } else { + t.Fatalf("Copy should still have 'b'") + } +} + +func TestThreadSafeSet_TestSafety(t *testing.T) { + set := yathreadsafeset.ThreadSafeSet[int]{} + + set.Set(1) + + if !set.Has(1) { + t.Fatalf("Set/Has failed for single element") + } +} + +func TestThreadSafeSet_Intersect(t *testing.T) { + setA := yathreadsafeset.NewThreadSafeSet[int]() + setB := yathreadsafeset.NewThreadSafeSet[int]() + + for i := 1; i <= 5; i++ { + setA.Set(i) + } + + for i := 3; i <= 7; i++ { + setB.Set(i) + } + + intersection := setA.Intersect(setB) + + expected := []int{3, 4, 5} + for _, v := range expected { + if !intersection.Has(v) { + t.Fatalf("Intersection missing %d", v) + } + } + + if intersection.Length() != len(expected) { + t.Fatalf( + "Intersection length mismatch, got %d, want %d", + intersection.Length(), + len(expected), + ) + } +} + +func TestThreadSafeSet_Union(t *testing.T) { + setA := yathreadsafeset.NewThreadSafeSet[int]() + setB := yathreadsafeset.NewThreadSafeSet[int]() + + for i := 1; i <= 5; i++ { + setA.Set(i) + } + + for i := 4; i <= 8; i++ { + setB.Set(i) + } + + union := setA.Union(setB) + + expected := []int{1, 2, 3, 4, 5, 6, 7, 8} + for _, v := range expected { + if !union.Has(v) { + t.Fatalf("Union missing %d", v) + } + } + + if union.Length() != len(expected) { + t.Fatalf("Union length mismatch, got %d, want %d", union.Length(), len(expected)) + } +} + +func TestThreadSafeSet_Difference(t *testing.T) { + setA := yathreadsafeset.NewThreadSafeSet[int]() + setB := yathreadsafeset.NewThreadSafeSet[int]() + + for i := 1; i <= 5; i++ { + setA.Set(i) + } + + for i := 4; i <= 8; i++ { + setB.Set(i) + } + + diff := setA.Difference(setB) + + expected := []int{1, 2, 3} + for _, v := range expected { + if !diff.Has(v) { + t.Fatalf("Difference missing %d", v) + } + } + + if diff.Length() != len(expected) { + t.Fatalf("Difference length mismatch, got %d, want %d", diff.Length(), len(expected)) + } +} + +func TestThreadSafeSet_SymmetricDifference(t *testing.T) { + setA := yathreadsafeset.NewThreadSafeSet[int]() + setB := yathreadsafeset.NewThreadSafeSet[int]() + + for i := 1; i <= 5; i++ { + setA.Set(i) + } + + for i := 4; i <= 8; i++ { + setB.Set(i) + } + + diff := setA.SymmetricDifference(setB) + + expected := []int{1, 2, 3, 6, 7, 8} + for _, v := range expected { + if !diff.Has(v) { + t.Fatalf("SymmetricDifference missing %d", v) + } + } + + if diff.Length() != len(expected) { + t.Fatalf( + "SymmetricDifference length mismatch, got %d, want %d", + diff.Length(), + len(expected), + ) + } +}