diff --git a/go.mod b/go.mod index 34269f8..a523031 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,9 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 golang.org/x/net v0.42.0 + gorm.io/driver/sqlite v1.6.0 + gorm.io/gorm v1.30.1 + modernc.org/sqlite v1.38.2 ) require ( @@ -21,6 +24,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.5 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-faster/errors v0.7.1 // indirect @@ -29,11 +33,16 @@ require ( github.com/go-faster/yaml v0.4.6 // indirect github.com/gotd/ige v0.2.2 // indirect github.com/gotd/neo v0.1.5 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/klauspost/compress v1.18.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/ncruces/go-strftime v0.1.9 // indirect github.com/ogen-go/ogen v1.12.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect go.opentelemetry.io/otel v1.37.0 // indirect @@ -43,7 +52,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect golang.org/x/crypto v0.40.0 // indirect - golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect + golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect golang.org/x/mod v0.26.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.34.0 // indirect @@ -51,5 +60,8 @@ require ( golang.org/x/tools v0.35.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.66.3 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect rsc.io/qr v0.2.0 // indirect ) diff --git a/go.sum b/go.sum index ecb752d..853b352 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/r github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= @@ -36,6 +38,8 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gotd/contrib v0.21.0 h1:4Fj05jnyBE84toXZl7mVTvt7f732n5uglvztyG6nTr4= @@ -46,6 +50,10 @@ github.com/gotd/neo v0.1.5 h1:oj0iQfMbGClP8xI59x7fE/uHoTJD7NZH9oV1WNuPukQ= github.com/gotd/neo v0.1.5/go.mod h1:9A2a4bn9zL6FADufBdt7tZt+WMhvZoc5gWXihOPoiBQ= github.com/gotd/td v0.128.0 h1:OI0KyKwARNO4X+czb26+FLKXASFTWuHpgPs7Yaqm04o= github.com/gotd/td v0.128.0/go.mod h1:rSekFfPYj5UEFky5EYnadT0WRU3DGoR4PFEMugk77uI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -56,12 +64,18 @@ github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHP github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= +github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/ogen-go/ogen v1.12.0 h1:JMkn957i9/IPaSehqpblviy6Uao3eqQ+eVKUn4LM9pg= github.com/ogen-go/ogen v1.12.0/go.mod h1:RL25amedfhq5xKTUuPBPn6nhYU59CWaVWYJ8YIjNHs0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= github.com/redis/go-redis/v9 v9.11.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= @@ -92,8 +106,8 @@ go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= -golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= -golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= +golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= @@ -117,6 +131,36 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= +gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= +gorm.io/gorm v1.30.1 h1:lSHg33jJTBxs2mgJRfRZeLDG+WZaHYCk3Wtfl6Ngzo4= +gorm.io/gorm v1.30.1/go.mod h1:8Z33v652h4//uMA76KjeDH8mJXPm1QNCYrMeatR0DOE= +modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM= +modernc.org/cc/v4 v4.26.2/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.28.0 h1:rjznn6WWehKq7dG4JtLRKxb52Ecv8OUGah8+Z/SfpNU= +modernc.org/ccgo/v4 v4.28.0/go.mod h1:JygV3+9AV6SmPhDasu4JgquwU81XAKLd3OKTUDNOiKE= +modernc.org/fileutil v1.3.8 h1:qtzNm7ED75pd1C7WgAGcK4edm4fvhtBsEiI/0NQ54YM= +modernc.org/fileutil v1.3.8/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.66.3 h1:cfCbjTUcdsKyyZZfEUKfoHcP3S0Wkvz3jgSzByEWVCQ= +modernc.org/libc v1.66.3/go.mod h1:XD9zO8kt59cANKvHPXpx7yS2ELPheAey0vjIuZOhOU8= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.38.2 h1:Aclu7+tgjgcQVShZqim41Bbw9Cho0y/7WzYptXqkEek= +modernc.org/sqlite v1.38.2/go.mod h1:cPTJYSlgg3Sfg046yBShXENNtPrWrDX8bsbAQBzgQ5E= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= nhooyr.io/websocket v1.8.17 h1:KEVeLJkUywCKVsnLIDlD/5gtayKp8VoCkksHCGGfT9Y= nhooyr.io/websocket v1.8.17/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY= diff --git a/yatgclient/yatgclient.go b/yatgclient/yatgclient.go index 9ec0c79..9ae36fb 100644 --- a/yatgclient/yatgclient.go +++ b/yatgclient/yatgclient.go @@ -189,10 +189,14 @@ func (c *Client) RunUpdatesManager( // // Example: // -// gaps := yatgclient.NewUpdateManagerWithYaStorage(storage) -func NewUpdateManagerWithYaStorage(storage yatgstorage.IStorage) *updates.Manager { +// gaps := yatgclient.NewUpdateManagerWithYaStorage(entityID, handler, storage) +func NewUpdateManagerWithYaStorage( + entityID int64, + handler telegram.UpdateHandler, + storage yatgstorage.IStorage, +) *updates.Manager { return updates.New(updates.Config{ - Handler: storage.AccessHashSaveHandler(), + Handler: storage.AccessHashSaveHandler(entityID, handler), Storage: storage.TelegramStorageCompatible(), AccessHasher: storage.TelegramAccessHasherCompatible(), }) diff --git a/yatgstorage/session_storage.go b/yatgstorage/session_storage.go new file mode 100644 index 0000000..f408286 --- /dev/null +++ b/yatgstorage/session_storage.go @@ -0,0 +1,427 @@ +package yatgstorage + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "crypto/sha256" + "io" + "net/http" + "time" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" + "github.com/gotd/td/telegram" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// ISessionStorage defines the methods for session management, including encryption, storage, and retrieval. +// It also provides compatibility with the Telegram session storage interface. +type ISessionStorage interface { + // LoadSession loads the session data from the repository and decrypts it. + // + // Example usage: + // + // sessionData, err := storage.LoadSession(ctx) + LoadSession(ctx context.Context) ([]byte, yaerrors.Error) + + // StoreSession stores the session data in the repository after encrypting it. + // + // Example usage: + // + // err := storage.StoreSession(ctx, sessionData) + StoreSession(ctx context.Context, data []byte) yaerrors.Error + + // TelegramSessionStorageCompatible provides compatibility with `gotd` session storage interface. + // + // Example usage: + // + // telegramStorage := storage.TelegramSessionStorageCompatible() + TelegramSessionStorageCompatible() telegram.SessionStorage +} + +// IEntitySessionStorageRepo defines the methods for storing and fetching encrypted authentication keys for a session. +// +// UpdateAuthKey: +// - This method allows updating or inserting an encrypted +// authentication key for a specific entity identified by `entityID`. +// +// FetchAuthKey: +// - This method retrieves the encrypted authentication key associated with the given `entityID`. +type IEntitySessionStorageRepo interface { + // UpdateAuthKey updates the encrypted authentication key for a specific entity. + // + // Example usage: + // + // repo.UpdateAuthKey(ctx, entityID, encryptedAuthKey) + UpdateAuthKey(ctx context.Context, entityID int64, encryptedAuthKey []byte) yaerrors.Error + + // FetchAuthKey retrieves the encrypted authentication key for a specific entity. + // + // Example usage: + // + // repo.FetchAuthKey(ctx, entityID) + FetchAuthKey(ctx context.Context, entityID int64) ([]byte, yaerrors.Error) +} + +// SessionStorage manages session data, including encryption and storage, using the provided repository. +type SessionStorage struct { + entityID int64 + aes AES + repo IEntitySessionStorageRepo +} + +// NewSessionStorage creates a new SessionStorage instance with an in-memory repository for session data storage. +// +// entityID: The ID of the entity (user, bot) whose session is being managed. +// secret: The secret key used for encrypting/decrypting session data. +// +// Returns a pointer to a new SessionStorage instance. +func NewSessionStorage(entityID int64, secret string) *SessionStorage { + return NewSessionStorageWithCustomRepo(entityID, secret, NewMemorySessionStorage(entityID)) +} + +// NewSessionStorageWithCustomRepo creates a SessionStorage instance with a custom repository for session data storage. +// +// entityID: The ID of the entity (user, bot) whose session is being managed. +// secret: The secret key used for encrypting/decrypting session data. +// repo: A custom repository implementing the IEntitySessionStorageRepo interface. +// +// Returns a pointer to a new SessionStorage instance. +func NewSessionStorageWithCustomRepo( + entityID int64, + secret string, + repo IEntitySessionStorageRepo, +) *SessionStorage { + return &SessionStorage{ + entityID: entityID, + aes: NewAES(secret), + repo: repo, + } +} + +// StoreSession encrypts the session data and stores it using the provided repository. +// +// ctx: The context for the operation. +// data: The session data to be encrypted and stored. +// +// Returns an error if encryption or storage fails. +// +// Example usage: +// +// err := sessionStorage.StoreSession(ctx, sessionData) +func (s *SessionStorage) StoreSession(ctx context.Context, data []byte) yaerrors.Error { + out, err := s.aes.Encrypt(data) + if err != nil { + return err.Wrap("failed to encrypt AES") + } + + if err = s.repo.UpdateAuthKey(ctx, s.entityID, out); err != nil { + return err.Wrap("failed to save updated session") + } + + return nil +} + +// LoadSession retrieves and decrypts the session data from the repository. +// +// ctx: The context for the operation. +// +// Returns the decrypted session data or nil if no session data exists, along with an error if decryption fails. +// +// Example usage: +// +// sessionData, err := sessionStorage.LoadSession(ctx) +func (s *SessionStorage) LoadSession(ctx context.Context) ([]byte, yaerrors.Error) { + session, err := s.repo.FetchAuthKey(ctx, s.entityID) + if err != nil { + return nil, err.Wrap("failed to fetch session") + } + + if len(session) == 0 { + return nil, nil + } + + out, err := s.aes.Decrypt(session) + if err != nil { + return nil, err.Wrap("failed to decrypt AES") + } + + return out, nil +} + +// TelegramSessionStorageCompatible provides a compatibility layer to work with Telegram's SessionStorage interface. +// +// Returns a SessionStorage-compatible implementation that works with gotd. +func (s *SessionStorage) TelegramSessionStorageCompatible() telegram.SessionStorage { + return &telegramSessionStorage{ + storage: s, + } +} + +// telegramSessionStorage is an implementation of the Telegram SessionStorage interface, +// which is used to store and load sessions in a way compatible with the gotd library. +type telegramSessionStorage struct { + storage *SessionStorage +} + +// StoreSession stores the session data using the SessionStorage's StoreSession method. +func (t *telegramSessionStorage) StoreSession(ctx context.Context, data []byte) error { + return t.storage.StoreSession(ctx, data) +} + +// LoadSession loads the session data using the SessionStorage's LoadSession method. +func (t *telegramSessionStorage) LoadSession(ctx context.Context) ([]byte, error) { + return t.storage.LoadSession(ctx) +} + +// YaTgClientSession is the database model for storing encrypted session data for a client. It holds the +// entity ID, encrypted authentication key, and the timestamp of when the session was last updated. +type YaTgClientSession struct { + EntityID int64 `gorm:"primaryKey;autoIncrement:false"` + EncryptedAuthKey []byte `gorm:"type:blob"` + UpdatedAt time.Time `gorm:"autoUpdatedAt"` +} + +// FieldEncryptedAuthKey is the field name used for storing the encrypted authentication key in the database. +const FieldEncryptedAuthKey = "encrypted_auth_key" + +// GormRepo is the repository that manages the session storage in a GORM-backed database. +type GormRepo struct { + poolDB *gorm.DB +} + +// NewGormSessionStorage creates a new GormRepo and runs the migrations for the YaTgClientSession model. +// +// poolDB: The GORM database connection. +// +// Returns a new instance of GormRepo and any errors encountered during migration. +func NewGormSessionStorage(poolDB *gorm.DB) (*GormRepo, yaerrors.Error) { + if err := poolDB.AutoMigrate(&YaTgClientSession{}); err != nil { + return nil, yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to make auto migrate", + ) + } + + return &GormRepo{poolDB: poolDB}, nil +} + +// UpdateAuthKey updates the encrypted authentication key for a specific entity in the database. +// +// ctx: The context for the operation. +// entityID: The ID of the entity whose session is being updated. +// encryptedAuthKey: The new encrypted authentication key. +// +// Returns an error if the update fails. +func (g *GormRepo) UpdateAuthKey( + ctx context.Context, + entityID int64, + encryptedAuthKey []byte, +) yaerrors.Error { + if err := g.poolDB.WithContext(ctx). + Clauses(clause.OnConflict{DoUpdates: clause.AssignmentColumns([]string{FieldEncryptedAuthKey})}). + Model(&YaTgClientSession{}). + Where(&YaTgClientSession{EntityID: entityID}). + Create(&YaTgClientSession{ + EntityID: entityID, + EncryptedAuthKey: encryptedAuthKey, + }).Error; err != nil { + return yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to update encrypted auth key", + ) + } + + return nil +} + +// FetchAuthKey retrieves the encrypted authentication key for a specific entity from the database. +// +// ctx: The context for the operation. +// entityID: The ID of the entity whose session is being fetched. +// +// Returns the encrypted authentication key or an error if the fetch operation fails. +func (g *GormRepo) FetchAuthKey( + ctx context.Context, + entityID int64, +) ([]byte, yaerrors.Error) { + var botSession YaTgClientSession + + if err := g.poolDB.WithContext(ctx). + Model(&YaTgClientSession{}). + Where(&YaTgClientSession{EntityID: entityID}). + Select(FieldEncryptedAuthKey). + Take(&botSession).Error; err != nil { + return nil, yaerrors.FromError( + http.StatusInternalServerError, + err, + "failed to fetch YaTgClientSession", + ) + } + + return botSession.EncryptedAuthKey, nil +} + +// MemoryRepo is an in-memory implementation of the IEntitySessionStorageRepo interface, +// used for testing or simple scenarios where persistence is not required. +type MemoryRepo struct { + Client YaTgClientSession +} + +// NewMemorySessionStorage initializes a new MemoryRepo instance for the given entityID. +// +// entityID: The ID of the entity whose session is being managed. +// +// Returns a new MemoryRepo instance. +func NewMemorySessionStorage(entityID int64) *MemoryRepo { + return &MemoryRepo{ + Client: YaTgClientSession{ + EntityID: entityID, + UpdatedAt: time.Now(), + }, + } +} + +// UpdateAuthKey updates the session's encrypted authentication key in memory. +// +// _ context.Context: The context for the operation (not used in this in-memory implementation). +// _ int64: The entityID (not used in this in-memory implementation). +// encryptedAuthKey: The encrypted authentication key to be stored. +// +// Returns nil after storing the key in memory. +func (m *MemoryRepo) UpdateAuthKey( + _ context.Context, + _ int64, + encryptedAuthKey []byte, +) yaerrors.Error { + m.Client.EncryptedAuthKey = encryptedAuthKey + m.Client.UpdatedAt = time.Now() + + return nil +} + +// FetchAuthKey fetches the encrypted authentication key from memory. +// +// _ context.Context: The context for the operation (not used in this in-memory implementation). +// _ int64: The entityID (not used in this in-memory implementation). +// +// Returns the encrypted authentication key stored in memory. +func (m *MemoryRepo) FetchAuthKey( + _ context.Context, + _ int64, +) ([]byte, yaerrors.Error) { + return m.Client.EncryptedAuthKey, nil +} + +// AES is a struct that holds the encryption key used for AES encryption and decryption. +// It provides methods to encrypt and decrypt data using AES (CTR mode). +type AES struct { + key []byte +} + +// NewAES creates a new AES instance with the given key. The key is used for encryption and decryption. +// +// key: The AES encryption key as a string. +// +// Returns an AES instance that can be used for encrypting and decrypting data. +func NewAES(key string) AES { + return AES{ + key: DeriveAESKey(key), + } +} + +// Encrypt encrypts data using AES encryption with the provided key (CTR mode). +// +// text: The data to be encrypted. +// +// Returns the encrypted data (ciphertext) and any errors encountered during the process. +// +// Example usage: +// +// encryptedData, err := aes.Encrypt(sessionData) +func (a *AES) Encrypt(text []byte) ([]byte, yaerrors.Error) { + block, err := aes.NewCipher(a.key) + if err != nil { + return nil, yaerrors.FromError( + http.StatusInternalServerError, + err, + "could not create new cipher", + ) + } + + cipherText := make([]byte, aes.BlockSize+len(text)) + + iv := cipherText[:aes.BlockSize] + if _, err = io.ReadFull(rand.Reader, iv); err != nil { + return nil, yaerrors.FromError( + http.StatusInternalServerError, + err, + "could not encrypt", + ) + } + + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(cipherText[aes.BlockSize:], text) + + return cipherText, nil +} + +// Decrypt decrypts data that was encrypted using AES encryption with the provided key (CTR mode). +// +// text: The encrypted data (ciphertext) to be decrypted. +// +// Returns the decrypted data and any errors encountered during the decryption process. +// +// Example usage: +// +// decryptedData, err := aes.Decrypt(encryptedData) +func (a *AES) Decrypt(text []byte) ([]byte, yaerrors.Error) { + if len(text) < aes.BlockSize { + return nil, yaerrors.FromString( + http.StatusInternalServerError, + "invalid text block size", + ) + } + + block, err := aes.NewCipher(a.key) + if err != nil { + return nil, yaerrors.FromError( + http.StatusInternalServerError, + err, + "could not create new cipher", + ) + } + + iv := text[:aes.BlockSize] + text = text[aes.BlockSize:] + + stream := cipher.NewCTR(block, iv) + stream.XORKeyStream(text, text) + + return text, nil +} + +// DeriveAESKey generates a 256-bit AES key from the provided input string using SHA-256 hashing. +// +// This function takes a string, hashes it using SHA-256, and returns the resulting 256-bit key +// that can be used for AES encryption (AES-256). The result is a 32-byte array, which is suitable +// for AES-256 encryption (256-bit key length). +// +// Parameters: +// - data (string): The input string used to derive the AES key. +// +// Returns: +// - []byte: A 256-bit AES key derived from the input string. +// +// Example usage: +// +// key := DeriveAESKey("my_secret_key") +func DeriveAESKey(data string) []byte { + sum := sha256.Sum256([]byte(data)) + + return sum[:] +} diff --git a/yatgstorage/session_storage_test.go b/yatgstorage/session_storage_test.go new file mode 100644 index 0000000..c1d74bb --- /dev/null +++ b/yatgstorage/session_storage_test.go @@ -0,0 +1,119 @@ +package yatgstorage_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/YaCodeDev/GoYaCodeDevUtils/yatgstorage" + "github.com/stretchr/testify/assert" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + _ "modernc.org/sqlite" +) + +const ( + entityID = 1000 + secret = "123456789:ABCDFEG" + encryptedAuthKey = "stolyarovtop" +) + +func newMockDB(t *testing.T) *gorm.DB { + sqlDB, err := sql.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("failed to open sqlite in memory") + } + + poolDB, err := gorm.Open( + gorm.Dialector( + sqlite.Dialector{ + Conn: sqlDB, + DriverName: "sqlite", + }, + ), &gorm.Config{}) + if err != nil { + t.Fatalf("failed to connect to in-memory database: %v", err) + } + + return poolDB +} + +func TestSessionStorage_WorkflowWorks(t *testing.T) { + ctx := context.Background() + + storage := yatgstorage.NewSessionStorage(entityID, secret) + + _ = storage.StoreSession(ctx, []byte(encryptedAuthKey)) + + aes := yatgstorage.NewAES(secret) + + expected, _ := aes.Encrypt([]byte(encryptedAuthKey)) + + assert.NotEqual(t, []byte(encryptedAuthKey), expected) + + expected, _ = aes.Decrypt(expected) + + result, _ := storage.LoadSession(ctx) + + assert.Equal(t, expected, result) +} + +func TestAutoMigrate_Works(t *testing.T) { + poolDB := newMockDB(t) + + _, _ = yatgstorage.NewGormSessionStorage(poolDB) + + expected := true + + assert.Equal(t, expected, poolDB.Migrator().HasTable(&yatgstorage.YaTgClientSession{})) +} + +func TestGormSessionStorage_WorkflowWorks(t *testing.T) { + ctx := context.Background() + + poolDB := newMockDB(t) + + storage, _ := yatgstorage.NewGormSessionStorage(poolDB) + + _ = storage.UpdateAuthKey(ctx, entityID, []byte(encryptedAuthKey)) + + t.Run("Create works", func(t *testing.T) { + err := poolDB.Where(&yatgstorage.YaTgClientSession{ + EntityID: entityID, + EncryptedAuthKey: []byte(encryptedAuthKey), + }).Find(&yatgstorage.YaTgClientSession{}).Error + + assert.Equal(t, nil, err) + }) + + t.Run("Fetch works", func(t *testing.T) { + expected := yatgstorage.YaTgClientSession{} + + _ = poolDB.Where(&yatgstorage.YaTgClientSession{ + EntityID: entityID, + EncryptedAuthKey: []byte(encryptedAuthKey), + }).Find(&expected) + + result, _ := storage.FetchAuthKey(ctx, entityID) + + assert.Equal(t, expected.EncryptedAuthKey, result) + }) +} + +func TestMemorySessionStorage_WorkflowWorks(t *testing.T) { + ctx := context.Background() + + storage := yatgstorage.NewMemorySessionStorage(entityID) + + _ = storage.UpdateAuthKey(ctx, entityID, []byte(encryptedAuthKey)) + + t.Run("Create works", func(t *testing.T) { + assert.Equal(t, []byte(encryptedAuthKey), storage.Client.EncryptedAuthKey) + }) + + t.Run("Fetch works", func(t *testing.T) { + result, _ := storage.FetchAuthKey(ctx, entityID) + + assert.Equal(t, []byte(encryptedAuthKey), result) + }) +} diff --git a/yatgstorage/yatgstorage.go b/yatgstorage/yatgstorage.go index b12bd8c..a9f4984 100644 --- a/yatgstorage/yatgstorage.go +++ b/yatgstorage/yatgstorage.go @@ -28,6 +28,7 @@ import ( "github.com/YaCodeDev/GoYaCodeDevUtils/yacache" "github.com/YaCodeDev/GoYaCodeDevUtils/yaerrors" "github.com/YaCodeDev/GoYaCodeDevUtils/yalogger" + "github.com/YaCodeDev/GoYaCodeDevUtils/yathreadsafeset" "github.com/gotd/td/telegram" "github.com/gotd/td/telegram/updates" "github.com/gotd/td/tg" @@ -102,11 +103,16 @@ type IStorage interface { // Update‑pipeline helper: returns a handler that stores access‑hashes // from any incoming updates before forwarding to the real handler. - AccessHashSaveHandler() HandlerFunc + AccessHashSaveHandler(int64, telegram.UpdateHandler) HandlerFunc // User access‑hash bookkeeping. - SetUserAccessHash(ctx context.Context, userID int64, accessHash int64) yaerrors.Error - GetUserAccessHash(ctx context.Context, userID int64) (int64, yaerrors.Error) + SetUserAccessHash( + ctx context.Context, + entityID int64, + userID int64, + accessHash int64, + ) yaerrors.Error + GetUserAccessHash(ctx context.Context, entityID int64, userID int64) (int64, yaerrors.Error) // gotd adapters TelegramStorageCompatible() updates.StateStorage @@ -121,20 +127,14 @@ type IStorage interface { // // Example: // -// stg := yatgstorage.NewStorage(cache, dispatcher, 42, log) +// stg := yatgstorage.NewStorage(cache, log) // _ = stg // -// A single Storage instance should be used per bot (entityID). -// The struct keeps an internal map to cache “I have already created the base -// JSON object” flags for performance. -// // Because methods are safe for concurrent use (they only rely on redis, which // is thread‑safe), you may share *Storage between goroutines. type Storage struct { cache yacache.Cache[*redis.Client] - handler telegram.UpdateHandler - entityID int64 - stateKeys map[string]struct{} + stateKeys *yathreadsafeset.ThreadSafeSet[string] log yalogger.Logger } @@ -142,27 +142,21 @@ type Storage struct { // // - cache – any yacache implementation; production code passes a Redis // client, tests may pass yacache.NewMock. -// - handler – your app’s dispatcher (implements telegram.UpdateHandler). -// - entityID – unique bot identifier used to namespace all Redis keys. // - log – structured logger. // // Example: // -// stg := yatgstorage.NewStorage(cache, dispatcher, 123456, log) +// stg := yatgstorage.NewStorage(cache, log) // if err := stg.Ping(ctx); err != nil { // log.Fatalf("redis down: %v", err) // } func NewStorage( cache yacache.Cache[*redis.Client], - handler telegram.UpdateHandler, - entityID int64, log yalogger.Logger, ) *Storage { return &Storage{ cache: cache, - handler: handler, - entityID: entityID, - stateKeys: map[string]struct{}{}, + stateKeys: yathreadsafeset.NewThreadSafeSet[string](), log: log, } } @@ -214,7 +208,7 @@ func (s *Storage) GetState( ) (updates.State, bool, yaerrors.Error) { key := getBotStateKey(entityID) - log := s.initBaseFieldsLog("Fetching entity state", key) + log := s.initBaseFieldsLog("Fetching entity state", entityID, key) data, err := s.cache.Raw().JSONGet(ctx, key).Result() if err != nil { @@ -251,7 +245,8 @@ func (s *Storage) SetState( ) yaerrors.Error { key := getBotStateKey(entityID) - log := s.initBaseFieldsLog("Setting entity state", key).WithField(LoggerEntityID, entityID) + log := s.initBaseFieldsLog("Setting entity state", entityID, key). + WithField(LoggerEntityID, entityID) if err := s.cache.Raw().JSONSet(ctx, key, BasePathRedisJSON, state).Err(); err != nil { return yaerrors.FromErrorWithLog( @@ -276,7 +271,7 @@ func (s *Storage) SetPts(ctx context.Context, entityID int64, pts int) yaerrors. key := getBotStateKey(entityID) log := s. - initBaseFieldsLog("Setting pts in entity state", key). + initBaseFieldsLog("Setting pts in entity state", entityID, key). WithField(LoggerEntityID, entityID) if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { @@ -292,7 +287,7 @@ func (s *Storage) SetPts(ctx context.Context, entityID int64, pts int) yaerrors. ) } - log.Debug("Have set pts in entity state") + log.Debug("Entity state set pts") return nil } @@ -306,7 +301,7 @@ func (s *Storage) SetQts(ctx context.Context, entityID int64, qts int) yaerrors. key := getBotStateKey(entityID) log := s. - initBaseFieldsLog("Setting qts in entity state", key). + initBaseFieldsLog("Setting qts in entity state", entityID, key). WithField(LoggerEntityID, entityID) if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { @@ -322,7 +317,7 @@ func (s *Storage) SetQts(ctx context.Context, entityID int64, qts int) yaerrors. ) } - log.Debug("Have set qts in entity state") + log.Debug("Entity state set qts") return nil } @@ -336,7 +331,7 @@ func (s *Storage) SetDate(ctx context.Context, entityID int64, date int) yaerror key := getBotStateKey(entityID) log := s. - initBaseFieldsLog("Setting date in state", key). + initBaseFieldsLog("Setting date in state", entityID, key). WithField(LoggerEntityID, entityID) if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { @@ -352,7 +347,7 @@ func (s *Storage) SetDate(ctx context.Context, entityID int64, date int) yaerror ) } - log.Debug("Have set date in entity state") + log.Debug("Entity state set date") return nil } @@ -366,7 +361,7 @@ func (s *Storage) SetSeq(ctx context.Context, entityID int64, seq int) yaerrors. key := getBotStateKey(entityID) log := s. - initBaseFieldsLog("Setting seq in state", key). + initBaseFieldsLog("Setting seq in state", entityID, key). WithField(LoggerEntityID, entityID) if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { @@ -382,7 +377,7 @@ func (s *Storage) SetSeq(ctx context.Context, entityID int64, seq int) yaerrors. ) } - log.Debug("Have set seq in entity state") + log.Debug("Entity state set seq") return nil } @@ -396,7 +391,7 @@ func (s *Storage) SetDateSeq(ctx context.Context, entityID int64, date, seq int) key := getBotStateKey(entityID) log := s. - initBaseFieldsLog("Setting date and seq in state", key). + initBaseFieldsLog("Setting date and seq in state", entityID, key). WithField(LoggerEntityID, entityID) if err := s.safetyBaseStateJSON(ctx, key, log); err != nil { @@ -413,7 +408,7 @@ func (s *Storage) SetDateSeq(ctx context.Context, entityID int64, date, seq int) ) } - log.Debug("Have set date and seq in state") + log.Debug("Entity state set date and seq") return nil } @@ -431,7 +426,7 @@ func (s *Storage) SetChannelPts( key := getChannelPtsKey(entityID) log := s. - initBaseFieldsLog("Setting channel pts", key). + initBaseFieldsLog("Setting channel pts", entityID, key). WithField(LoggerEntityID, entityID). WithField(LoggerChannelID, channelID) @@ -445,7 +440,7 @@ func (s *Storage) SetChannelPts( ) } - log.Debug("Have set channel pts") + log.Debug("Channel pts set") return nil } @@ -462,7 +457,7 @@ func (s *Storage) GetChannelPts( key := getChannelPtsKey(entityID) log := s. - initBaseFieldsLog("Fetching channel pts", key). + initBaseFieldsLog("Fetching channel pts", entityID, key). WithField(LoggerUserID, entityID). WithField(LoggerChannelID, channelID) @@ -509,7 +504,7 @@ func (s *Storage) ForEachChannels( ) yaerrors.Error { key := getChannelPtsKey(entityID) - log := s.initBaseFieldsLog("Start action for each channels", key). + log := s.initBaseFieldsLog("Start action for each channels", entityID, key). WithField(LoggerUserID, entityID) channels, err := s.cache.HGetAll(ctx, key) @@ -574,7 +569,7 @@ func (s *Storage) SetChannelAccessHash( key := getChannelAccessHashKey(entityID) log := s. - initBaseFieldsLog("Setting channel access hash for channel", key). + initBaseFieldsLog("Setting channel access hash for channel", entityID, key). WithField(LoggerEntityID, entityID). WithField(LoggerChannelID, channelID) @@ -588,7 +583,7 @@ func (s *Storage) SetChannelAccessHash( ) } - log.Debug("Have set channel access hash") + log.Debug("Channel access hash set") return nil } @@ -605,7 +600,7 @@ func (s *Storage) GetChannelAccessHash( key := getChannelAccessHashKey(entityID) log := s. - initBaseFieldsLog("Fetching channel access hash", key). + initBaseFieldsLog("Fetching channel access hash", entityID, key). WithField(LoggerEntityID, entityID). WithField(LoggerChannelID, channelID) @@ -667,24 +662,27 @@ func (h HandlerFunc) Handle(ctx context.Context, updates tg.UpdatesClass) error // Example: // // clientOpts.UpdateHandler = storage.AccessHashSaveHandler() -func (s *Storage) AccessHashSaveHandler() HandlerFunc { +func (s *Storage) AccessHashSaveHandler( + entityID int64, + handler telegram.UpdateHandler, +) HandlerFunc { return HandlerFunc(func(ctx context.Context, updates tg.UpdatesClass) error { switch update := updates.(type) { case *tg.Updates: for _, user := range update.MapUsers().NotEmptyToMap() { - if err := s.SetUserAccessHash(ctx, user.ID, user.AccessHash); err != nil { + if err := s.SetUserAccessHash(ctx, entityID, user.ID, user.AccessHash); err != nil { s.log.Errorf("Failed to save user(%d) access hash(%d)", user.ID, user.AccessHash) } } case *tg.UpdatesCombined: for _, user := range update.MapUsers().NotEmptyToMap() { - if err := s.SetUserAccessHash(ctx, user.ID, user.AccessHash); err != nil { + if err := s.SetUserAccessHash(ctx, entityID, user.ID, user.AccessHash); err != nil { s.log.Errorf("Failed to save user(%d) access hash(%d)", user.ID, user.AccessHash) } } } - return s.handler.Handle(ctx, updates) + return handler.Handle(ctx, updates) }) } @@ -696,15 +694,17 @@ func (s *Storage) AccessHashSaveHandler() HandlerFunc { // _ = stg.SetUserAccessHash(ctx, 12345, 67890) func (s *Storage) SetUserAccessHash( ctx context.Context, + entityID int64, userID int64, accessHash int64, ) yaerrors.Error { const botChannelID = 136817688 // Ignore channel placeholder (@Channel_Bot - in Telegram) if userID != botChannelID { - key := getUserAccessHashKey(s.entityID) + key := getUserAccessHashKey(entityID) - log := s.initBaseFieldsLog("Saving access hash", key).WithField(LoggerUserID, userID) + log := s.initBaseFieldsLog("Saving access hash", entityID, key). + WithField(LoggerUserID, userID) if err := s.cache.Raw(). HSet(ctx, key, strconv.FormatInt(userID, 10), accessHash).Err(); err != nil { @@ -727,10 +727,15 @@ func (s *Storage) SetUserAccessHash( // Example: // // hash, foundErr := stg.GetUserAccessHash(ctx, 12345) -func (s *Storage) GetUserAccessHash(ctx context.Context, userID int64) (int64, yaerrors.Error) { - key := getUserAccessHashKey(s.entityID) +func (s *Storage) GetUserAccessHash( + ctx context.Context, + entityID int64, + userID int64, +) (int64, yaerrors.Error) { + key := getUserAccessHashKey(entityID) - log := s.initBaseFieldsLog("fetching user access hash", key).WithField(LoggerUserID, userID) + log := s.initBaseFieldsLog("fetching user access hash", entityID, key). + WithField(LoggerUserID, userID) hash, err := s.cache.Raw().HGet(ctx, key, strconv.FormatInt(userID, 10)).Result() if err != nil { @@ -765,9 +770,10 @@ func (s *Storage) GetUserAccessHash(ctx context.Context, userID int64) (int64, y // l := stg.initBaseFieldsLog("doing work", "redis:key") func (s *Storage) initBaseFieldsLog( entryText string, + entityID int64, botKey string, ) yalogger.Logger { - log := s.log.WithField(LoggerEntityID, s.entityID).WithField(LoggerEntityKey, botKey) + log := s.log.WithField(LoggerEntityID, entityID).WithField(LoggerEntityKey, botKey) log.Debugf("%s", entryText) @@ -785,7 +791,7 @@ func (s *Storage) safetyBaseStateJSON( key string, log yalogger.Logger, ) yaerrors.Error { - if _, ok := s.stateKeys[key]; !ok { + if !s.stateKeys.Has(key) { if res, err := s.cache.Raw().JSONGet(ctx, key, BasePathRedisJSON).Result(); err != nil || len(res) == 0 { if err := s.cache.Raw().JSONSet(ctx, key, BasePathRedisJSON, updates.State{}).Err(); err != nil { @@ -798,7 +804,7 @@ func (s *Storage) safetyBaseStateJSON( } } - s.stateKeys[key] = struct{}{} + s.stateKeys.Set(key) } return nil diff --git a/yatgstorage/yatgstorage_test.go b/yatgstorage/yatgstorage_test.go index 1601fb7..3b31a0e 100644 --- a/yatgstorage/yatgstorage_test.go +++ b/yatgstorage/yatgstorage_test.go @@ -36,7 +36,7 @@ func TestStorage_CreateWorks(t *testing.T) { defer cleanup() if err := yatgstorage. - NewStorage(yacache.NewCache(client), nil, 0, yalogger.NewBaseLogger(nil).NewLogger()). + NewStorage(yacache.NewCache(client), yalogger.NewBaseLogger(nil).NewLogger()). Ping(context.Background()); err != nil { t.Fatalf("Failed to create tg storage") } @@ -56,7 +56,7 @@ func TestStorageChannel_WorkFlowWorks(t *testing.T) { defer cleanup() storage := yatgstorage. - NewStorage(yacache.NewCache(client), nil, 1001, log) + NewStorage(yacache.NewCache(client), log) t.Run("Set and Get channel pts - works", func(t *testing.T) { const expected = 1000 @@ -108,16 +108,19 @@ func TestStorageUser_WorkFlowWorks(t *testing.T) { defer cleanup() storage := yatgstorage. - NewStorage(yacache.NewCache(client), nil, 1001, log) + NewStorage(yacache.NewCache(client), log) t.Run("Set and Get user access hash - works", func(t *testing.T) { - const userID = 2222 + const ( + entityID = 1000 + userID = 2222 + ) expected := int64(200) - _ = storage.SetUserAccessHash(ctx, userID, expected) + _ = storage.SetUserAccessHash(ctx, entityID, userID, expected) - result, _ := storage.GetUserAccessHash(ctx, userID) + result, _ := storage.GetUserAccessHash(ctx, entityID, userID) assert.Equal(t, expected, result) })