From ea4c0e7966f6b495c5c2f4b60bd0e1f267060135 Mon Sep 17 00:00:00 2001 From: Damien Whitten Date: Tue, 10 Jun 2025 13:54:15 -0700 Subject: [PATCH] Idempotency --- go.mod | 2 +- go.sum | 4 +- internal/integration/shadow_test.go | 5 +- internal/pgstore/pgmigrate/builder.go | 12 +- internal/pgstore/pgmigrate/psm.go | 6 +- psm/builder.go | 9 - psm/event.go | 155 +++++++++++- psm/gen_interfaces.go | 6 +- psm/run.go | 80 ++---- psm/statemachine.go | 337 +++----------------------- psm/storage.go | 227 +++++++++++++++++ psm/table_map.go | 11 +- 12 files changed, 460 insertions(+), 394 deletions(-) create mode 100644 psm/storage.go diff --git a/go.mod b/go.mod index 80ab3c6..7ef3c84 100644 --- a/go.mod +++ b/go.mod @@ -12,7 +12,7 @@ require ( github.com/lib/pq v1.10.9 github.com/pentops/flowtest v0.0.0-20250521181823-71b0be743b08 github.com/pentops/golib v0.0.0-20250326060930-8c83d58ddb63 - github.com/pentops/j5 v0.0.0-20250605002250-2add77d73a52 + github.com/pentops/j5 v0.0.0-20250610001046-f1c2162a4508 github.com/pentops/log.go v0.0.16 github.com/pentops/o5-messaging v0.0.0-20250520213617-fba07334e9aa github.com/pentops/pgtest.go v0.0.0-20241223222214-7638cc50e15b diff --git a/go.sum b/go.sum index 17cc02c..abd35ed 100644 --- a/go.sum +++ b/go.sum @@ -77,8 +77,8 @@ github.com/pentops/flowtest v0.0.0-20250521181823-71b0be743b08 h1:Xeip/GxtvcAGFF github.com/pentops/flowtest v0.0.0-20250521181823-71b0be743b08/go.mod h1:vNp8crAKcH0f/sZU9frkmQLUeDsTIgMqV14kQtkAqC0= github.com/pentops/golib v0.0.0-20250326060930-8c83d58ddb63 h1:s5qtWT2/s79gy/wm3/bwvKYLK6u2AkW05JiLPqxraP0= github.com/pentops/golib v0.0.0-20250326060930-8c83d58ddb63/go.mod h1:I58JIVvL1/nP4CEHGKGbBhvWIEA9mVkGeoviemaqanU= -github.com/pentops/j5 v0.0.0-20250605002250-2add77d73a52 h1:natrugFwp8KWyN1W1SHphF+IokRkU+Yf+k/akHk2V1A= -github.com/pentops/j5 v0.0.0-20250605002250-2add77d73a52/go.mod h1:DZbBKepsGataOEtfB8AjkRiejRtLGQcBejTUYJK5wlY= +github.com/pentops/j5 v0.0.0-20250610001046-f1c2162a4508 h1:S80wU/ls85v5UpdFOsUbYwMISyN0DmFvowIpDWUNF7s= +github.com/pentops/j5 v0.0.0-20250610001046-f1c2162a4508/go.mod h1:DZbBKepsGataOEtfB8AjkRiejRtLGQcBejTUYJK5wlY= github.com/pentops/log.go v0.0.16 h1:oxCuHSBOBPjfUVSXyOSEEdYUwytysj4T29/7T2FBp9Q= github.com/pentops/log.go v0.0.16/go.mod h1:yR34x8aMlvhdGvqgIU4+0MiLjJTKt0vpcgUnVN2nZV4= github.com/pentops/o5-messaging v0.0.0-20250520213617-fba07334e9aa h1:Sdnc9mrRSefBbbrwmpq/31ABuXBwtch2KGd68ORJS44= diff --git a/internal/integration/shadow_test.go b/internal/integration/shadow_test.go index fa077d3..358fbc4 100644 --- a/internal/integration/shadow_test.go +++ b/internal/integration/shadow_test.go @@ -7,6 +7,7 @@ import ( "github.com/google/uuid" "github.com/pentops/flowtest" "github.com/pentops/j5/gen/j5/state/v1/psm_j5pb" + "github.com/pentops/j5/lib/id62" "github.com/pentops/protostate/internal/testproto/gen/test/v1/test_pb" "github.com/pentops/protostate/internal/testproto/gen/test/v1/test_spb" "google.golang.org/protobuf/types/known/timestamppb" @@ -22,7 +23,7 @@ func TestStateMachineShadow(t *testing.T) { events := []*test_pb.FooEvent{{ Metadata: &psm_j5pb.EventMetadata{ - EventId: uuid.NewString(), + EventId: id62.NewString(), Sequence: 1, Cause: &psm_j5pb.Cause{ Type: &psm_j5pb.Cause_ExternalEvent{ @@ -47,7 +48,7 @@ func TestStateMachineShadow(t *testing.T) { }, }, { Metadata: &psm_j5pb.EventMetadata{ - EventId: uuid.NewString(), + EventId: id62.NewString(), Sequence: 1, Cause: &psm_j5pb.Cause{ Type: &psm_j5pb.Cause_ExternalEvent{ diff --git a/internal/pgstore/pgmigrate/builder.go b/internal/pgstore/pgmigrate/builder.go index c2c6d73..55e3e55 100644 --- a/internal/pgstore/pgmigrate/builder.go +++ b/internal/pgstore/pgmigrate/builder.go @@ -21,7 +21,6 @@ func (t *CreateTableBuilder) Column(name string, typ ColumnType, options ...Colu column := &column{ name: name, typeName: typ, - flags: []string{}, } for _, opt := range options { opt(column) @@ -59,9 +58,9 @@ type column struct { primaryKey bool // Multi Primary Key is possible notNull bool + unique bool typeName ColumnType - flags []string } type ColumnOption func(*column) @@ -76,6 +75,10 @@ func NotNull(c *column) { c.notNull = true } +func Unique(c *column) { + c.unique = true +} + type ColumnType string const ( @@ -102,6 +105,9 @@ func (t *CreateTableBuilder) Build() (*Table, error) { if col.notNull { column.Flags = append(column.Flags, "NOT NULL") } + if col.unique { + column.Flags = append(column.Flags, "UNIQUE") + } table.Columns = append(table.Columns, column) } @@ -138,7 +144,7 @@ type ForeignKey struct { } func (tt *Table) DownSQL() (string, error) { - return fmt.Sprintf("DROP TABLE %s;", tt.Name), nil + return fmt.Sprintf("DROP TABLE IF EXISTS %s;", tt.Name), nil } func (table *Table) ToSQL() (string, error) { diff --git a/internal/pgstore/pgmigrate/psm.go b/internal/pgstore/pgmigrate/psm.go index b5aa5a3..eb9f06f 100644 --- a/internal/pgstore/pgmigrate/psm.go +++ b/internal/pgstore/pgmigrate/psm.go @@ -224,7 +224,7 @@ func BuildPSMTables(spec psm.QueryTableSpec) (*Table, *Table, error) { stateTable := CreateTable(spec.State.TableName) eventTable := CreateTable(spec.Event.TableName). - Column(spec.Event.ID.ColumnName, uuidType, PrimaryKey) + Column(spec.Event.ID.ColumnName, id62Type, PrimaryKey) eventForeignKey := eventTable.ForeignKey("state", spec.State.TableName) for _, key := range spec.KeyColumns { @@ -252,7 +252,9 @@ func BuildPSMTables(spec psm.QueryTableSpec) (*Table, *Table, error) { stateTable.Column(spec.State.Root.ColumnName, jsonbType, NotNull) - eventTable.Column(spec.Event.Timestamp.ColumnName, timestamptzType, NotNull). + eventTable. + Column(spec.Event.Timestamp.ColumnName, timestamptzType, NotNull). + Column(spec.Event.IdempotencyHash.ColumnName, textType, NotNull, Unique). Column(spec.Event.Sequence.ColumnName, intType, NotNull). Column(spec.Event.Root.ColumnName, jsonbType, NotNull). Column(spec.Event.StateSnapshot.ColumnName, jsonbType, NotNull) diff --git a/psm/builder.go b/psm/builder.go index 9bf3ff9..37563dd 100644 --- a/psm/builder.go +++ b/psm/builder.go @@ -27,15 +27,6 @@ type StateMachineConfig[ tableName *string } -// DEPRECATED: This does nothing. -type SystemActor any - -// DEPRECATED: This does nothing. -func (smc *StateMachineConfig[K, S, ST, SD, E, IE]) SystemActor(systemActor SystemActor) *StateMachineConfig[K, S, ST, SD, E, IE] { - //smc.systemActor = systemActor - return smc -} - func (smc *StateMachineConfig[K, S, ST, SD, E, IE]) TableMap(tableMap *TableMap) *StateMachineConfig[K, S, ST, SD, E, IE] { smc.tableMap = tableMap return smc diff --git a/psm/event.go b/psm/event.go index 61a559a..d8e97ab 100644 --- a/psm/event.go +++ b/psm/event.go @@ -1,12 +1,17 @@ package psm import ( + "crypto/sha1" "fmt" + "math/big" + "strings" "time" "github.com/pentops/j5/gen/j5/auth/v1/auth_j5pb" "github.com/pentops/j5/gen/j5/messaging/v1/messaging_j5pb" "github.com/pentops/j5/gen/j5/state/v1/psm_j5pb" + "github.com/pentops/j5/lib/id62" + "google.golang.org/protobuf/types/known/timestamppb" ) type EventSpec[ @@ -20,9 +25,6 @@ type EventSpec[ // Keys must be set, to identify the state machine. Keys K - // EventID is optional and will be set by the state machine if empty - EventID string - // The inner PSM Event type. Must be set for incoming events. Event IE @@ -81,14 +83,151 @@ func (es *EventSpec[K, S, ST, SD, E, IE]) validateAndPrepare() error { // check that the cause type is supported. switch es.Cause.Type.(type) { - case *psm_j5pb.Cause_PsmEvent, - *psm_j5pb.Cause_Command, - *psm_j5pb.Cause_ExternalEvent, - *psm_j5pb.Cause_Message: - // All OK + case *psm_j5pb.Cause_PsmEvent: + case *psm_j5pb.Cause_Command: + case *psm_j5pb.Cause_ExternalEvent: + case *psm_j5pb.Cause_Message: + default: return fmt.Errorf("EventSpec.Cause.Source must be set") } return nil } + +func base62String(id []byte) string { + var i big.Int + i.SetBytes(id) + str := i.Text(62) + return fmt.Sprintf("%022s", str) +} + +// hashString generates a sha1 hash from the input strings. +func hashString(input ...string) (string, error) { + if len(input) == 0 { + return "", fmt.Errorf("hashString requires at least one input string") + } + fullInput := strings.Join(input, "") + sum := sha1.Sum([]byte(fullInput)) + return base62String(sum[:]), nil +} + +func eventIdempotencyKey(event looseEvent) (string, error) { + eventFullType := string(event.ProtoReflect().Descriptor().FullName()) + + switch cause := event.PSMMetadata().Cause.Type.(type) { + case *psm_j5pb.Cause_PsmEvent: + return hashString( + eventFullType, + cause.PsmEvent.EventId, + ) + + case *psm_j5pb.Cause_Command: + providedKey := cause.Command.IdempotencyKey + if providedKey == "" { + return event.PSMMetadata().GetEventId(), nil + } + return hashString( + eventFullType, + cause.Command.Actor.Claim.TenantId, + providedKey, + ) + + case *psm_j5pb.Cause_ExternalEvent: + if cause.ExternalEvent.ExternalId != nil { + return hashString( + eventFullType, + *cause.ExternalEvent.ExternalId, + ) + } else { + return event.PSMMetadata().GetEventId(), nil + } + + case *psm_j5pb.Cause_Message: + return hashString( + eventFullType, + cause.Message.MessageId, + ) + + case *psm_j5pb.Cause_Init: + return event.PSMMetadata().GetEventId(), nil + + default: + return "", fmt.Errorf("EventSpec.Cause.Source must be set") + } +} + +type preparedEvent[ + K IKeyset, + S IState[K, ST, SD], + ST IStatusEnum, + SD IStateData, + E IEvent[K, S, ST, SD, IE], + IE IInnerEvent, +] struct { + event E + state S + idempotencyKey string +} + +func prepareFollowEvent[ + K IKeyset, + S IState[K, ST, SD], + ST IStatusEnum, + SD IStateData, + E IEvent[K, S, ST, SD, IE], + IE IInnerEvent, +](event E, state S) (built preparedEvent[K, S, ST, SD, E, IE], err error) { + idempotencyKey, err := eventIdempotencyKey(event) + if err != nil { + return built, err + } + return preparedEvent[K, S, ST, SD, E, IE]{ + event: event, + state: state, + idempotencyKey: idempotencyKey, + }, nil +} + +func (es *EventSpec[K, S, ST, SD, E, IE]) buildWrapper(state S) (built preparedEvent[K, S, ST, SD, E, IE], err error) { + + evt := (*new(E)).ProtoReflect().New().Interface().(E) + if err := evt.SetPSMEvent(es.Event); err != nil { + return built, fmt.Errorf("set event: %w", err) + } + evt.SetPSMKeys(es.Keys) + + eventMeta := evt.PSMMetadata() + eventMeta.EventId = id62.NewString() + eventMeta.Timestamp = timestamppb.Now() + eventMeta.Cause = es.Cause + + incrementEventSequence(state, eventMeta) + + built.event = evt + built.state = state + + idempotencyKey, err := eventIdempotencyKey(evt) + if err != nil { + return built, err + } + built.idempotencyKey = idempotencyKey + + return + +} + +func incrementEventSequence[K IKeyset, S IState[K, ST, SD], ST IStatusEnum, SD IStateData](state S, eventMeta *psm_j5pb.EventMetadata) { + stateMeta := state.PSMMetadata() + + eventMeta.Sequence = 0 + if state.GetStatus() == 0 { + eventMeta.Sequence = 0 + stateMeta.CreatedAt = eventMeta.Timestamp + stateMeta.UpdatedAt = eventMeta.Timestamp + } else { + eventMeta.Sequence = stateMeta.LastSequence + 1 + stateMeta.LastSequence = eventMeta.Sequence + stateMeta.UpdatedAt = eventMeta.Timestamp + } +} diff --git a/psm/gen_interfaces.go b/psm/gen_interfaces.go index 8a09979..58cc49d 100644 --- a/psm/gen_interfaces.go +++ b/psm/gen_interfaces.go @@ -87,11 +87,15 @@ type IEvent[ SD IStateData, Inner any, ] interface { - proto.Message + looseEvent UnwrapPSMEvent() Inner SetPSMEvent(Inner) error PSMKeys() K SetPSMKeys(K) +} + +type looseEvent interface { + proto.Message PSMMetadata() *psm_j5pb.EventMetadata PSMIsSet() bool } diff --git a/psm/run.go b/psm/run.go index fd88a15..77cd087 100644 --- a/psm/run.go +++ b/psm/run.go @@ -2,7 +2,6 @@ package psm import ( "context" - "errors" "fmt" "github.com/pentops/log.go/log" @@ -22,12 +21,13 @@ const ( func (sm *StateMachine[K, S, ST, SD, E, IE]) runEvent( ctx context.Context, tx sqrlx.Transaction, - state S, - event E, + bb preparedEvent[K, S, ST, SD, E, IE], captureState captureStateType, ) (*S, error) { + event := bb.event + state := bb.state - if err := sm.validateEvent(event); err != nil { + if err := sm.validator.Validate(event); err != nil { return nil, fmt.Errorf("validating event %s: %w", event.ProtoReflect().Descriptor().FullName(), err) } @@ -52,7 +52,7 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runEvent( return nil, fmt.Errorf("run transition: %w", err) } - err = sm.storeAfterMutation(ctx, tx, state, event) + err = sm.storeStateAndEvent(ctx, tx, bb) if err != nil { return nil, fmt.Errorf("after transition from %s on %s: %w", transition.fromStatus.ShortString(), @@ -89,7 +89,6 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runEvent( return nil, fmt.Errorf("side effect outbox: %w", err) } } else { - err = outbox.DefaultSender.SendDelayed(ctx, tx, se.delay, se.msg) if err != nil { return nil, fmt.Errorf("delayed side effect outbox: %w", err) @@ -97,22 +96,6 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runEvent( } } - chain := []*EventSpec[K, S, ST, SD, E, IE]{} - for _, chained := range baton.chainEvents { - derived, err := sm.deriveEvent(event, chained) - if err != nil { - return nil, fmt.Errorf("derive chained: %w", err) - } - chain = append(chain, derived) - } - - if err := sm.eventsMustBeUnique(ctx, tx, chain...); err != nil { - if errors.Is(err, ErrDuplicateEventID) { - return nil, ErrDuplicateChainedEventID - } - return nil, err - } - log.Info(ctx, "Event Complete") var captureIntermediate = captureNoState @@ -120,15 +103,20 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runEvent( captureIntermediate = captureFinalState } - for _, chainedEvent := range chain { - prepared, err := sm.prepareEvent(state, chainedEvent) + for _, chained := range baton.chainEvents { + derived, err := sm.deriveEvent(event, chained) + if err != nil { + return nil, fmt.Errorf("derive chained: %w", err) + } + + prepared, err := derived.buildWrapper(state) if err != nil { return nil, fmt.Errorf("prepare event: %w", err) } - stateAfterRun, err := sm.runEvent(ctx, tx, state, prepared, captureIntermediate) + stateAfterRun, err := sm.runEvent(ctx, tx, prepared, captureIntermediate) if err != nil { - return nil, fmt.Errorf("chained event: %s: %w", chainedEvent.Event.PSMEventKey(), err) + return nil, fmt.Errorf("chained event: %s: %w", derived.Event.PSMEventKey(), err) } if captureState == captureFinalState { returnState = stateAfterRun @@ -138,15 +126,15 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runEvent( return returnState, nil } -func (sm *StateMachine[K, S, ST, SD, E, IE]) followEvent(ctx context.Context, tx sqrlx.Transaction, state S, event E) error { +func (sm *StateMachine[K, S, ST, SD, E, IE]) followEvent(ctx context.Context, tx sqrlx.Transaction, se preparedEvent[K, S, ST, SD, E, IE]) error { - typeKey := event.UnwrapPSMEvent().PSMEventKey() - statusBefore := state.GetStatus() + typeKey := se.event.UnwrapPSMEvent().PSMEventKey() + statusBefore := se.state.GetStatus() ctx = log.WithFields(ctx, map[string]any{ - "stateMachine": state.PSMKeys().PSMFullName(), + "stateMachine": se.state.PSMKeys().PSMFullName(), "transition": map[string]any{ - "eventId": event.PSMMetadata().EventId, + "eventId": se.event.PSMMetadata().EventId, "from": statusBefore.ShortString(), "event": typeKey, }, @@ -157,11 +145,11 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) followEvent(ctx context.Context, tx return err } - if err := transition.runMutations(ctx, state, event); err != nil { + if err := transition.runMutations(ctx, se.state, se.event); err != nil { return fmt.Errorf("run transition: %w", err) } - err = sm.storeAfterMutation(ctx, tx, state, event) + err = sm.storeStateAndEvent(ctx, tx, se) if err != nil { return fmt.Errorf("after transition from %s on %s: %w", transition.fromStatus.ShortString(), @@ -169,33 +157,9 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) followEvent(ctx context.Context, tx err) } - if err := transition.runFollowerHooks(ctx, tx, state, event); err != nil { + if err := transition.runFollowerHooks(ctx, tx, se.state, se.event); err != nil { return fmt.Errorf("run transition hooks: %w", err) } return nil } - -func (sm *StateMachine[K, S, ST, SD, E, IE]) storeAfterMutation( - ctx context.Context, - tx sqrlx.Transaction, - state S, - event E, -) error { - - if state.GetStatus() == 0 { - return fmt.Errorf("state machine transitioned to zero status") - } - - err := sm.validator.Validate(state) - if err != nil { - return err - } - - if err := sm.store(ctx, tx, state, event); err != nil { - return err - } - - return nil - -} diff --git a/psm/statemachine.go b/psm/statemachine.go index eb2a3c2..fcf5ec9 100644 --- a/psm/statemachine.go +++ b/psm/statemachine.go @@ -8,15 +8,8 @@ import ( "time" "buf.build/go/protovalidate" - sq "github.com/elgris/sqrl" - "github.com/google/uuid" "github.com/pentops/j5/gen/j5/state/v1/psm_j5pb" - "github.com/pentops/log.go/log" - "github.com/pentops/protostate/internal/dbconvert" "github.com/pentops/sqrlx.go/sqrlx" - "google.golang.org/protobuf/encoding/protojson" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" ) var ErrDuplicateEventID = errors.New("duplicate event ID") @@ -78,11 +71,17 @@ func NewStateMachine[ return nil, err } + validator, err := protovalidate.New() + if err != nil { + panic("failed to initialize validator: " + err.Error()) + } + return &StateMachine[K, S, ST, SD, E, IE]{ keyValueFunc: cb.keyValues, initialStateFunc: cb.initialStateFunc, tableMap: cb.tableMap, //SystemActor: cb.systemActor, + validator: validator, }, nil } @@ -203,48 +202,6 @@ func (sm *DBStateMachine[K, S, ST, SD, E, IE]) FollowEvents(ctx context.Context, return nil } -func (sm *StateMachine[K, S, ST, SD, E, IE]) getCurrentState(ctx context.Context, tx sqrlx.Transaction, keys K) (S, error) { - state := (*new(S)).ProtoReflect().New().Interface().(S) - - selectQuery := sq. - Select(sm.tableMap.State.Root.ColumnName). - From(sm.tableMap.State.TableName) - - allKeys, err := sm.keyValues(keys) - if err != nil { - return state, err - } - for _, key := range allKeys.values { - if !key.Primary { - continue - } - selectQuery = selectQuery.Where(sq.Eq{key.ColumnName: key.value}) - } - - var stateJSON []byte - err = tx.SelectRow(ctx, selectQuery).Scan(&stateJSON) - if errors.Is(err, sql.ErrNoRows) { - state.SetPSMKeys(proto.Clone(keys).(K)) - - if len(allKeys.missingRequired) > 0 { - return state, fmt.Errorf("missing required key(s) %v in initial event", allKeys.missingRequired) - } - - // OK, leave empty state alone - return state, nil - } - if err != nil { - qq, _, _ := selectQuery.ToSql() - return state, fmt.Errorf("selecting current state (%s): %w", qq, err) - } - - if err := protojson.Unmarshal(stateJSON, state); err != nil { - return state, err - } - - return state, nil -} - type keyValues struct { values []keyValue @@ -308,152 +265,31 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) keyValues(keysMessage K) (*keyValue }, nil } - -func (sm *StateMachine[K, S, ST, SD, E, IE]) store( - ctx context.Context, - tx sqrlx.Transaction, - state S, - event E, -) error { - - stateDBValue, err := dbconvert.MarshalProto(state) - if err != nil { - return fmt.Errorf("state field: %w", err) - } - - eventDBValue, err := dbconvert.MarshalProto(event) - if err != nil { - return fmt.Errorf("event field: %w", err) - } - - // TODO: This does not change during transitions, so should be calculated - // early and once. - keyValues, err := sm.keyValues(state.PSMKeys()) - if err != nil { - return fmt.Errorf("key fields: %w", err) - } - if len(keyValues.missingRequired) > 0 { - return fmt.Errorf("missing required key(s) %v in store", keyValues.missingRequired) - } - - eventMeta := event.PSMMetadata() - - upsertStateQuery := sqrlx.Upsert(sm.tableMap.State.TableName) - - insertValues := []any{} - insertColumns := []string{} - - insertEventQuery := sq.Insert(sm.tableMap.Event.TableName) - - insertColumns = append(insertColumns, sm.tableMap.Event.ID.ColumnName) - insertValues = append(insertValues, eventMeta.EventId) - - for _, key := range keyValues.values { - if key.Primary { - upsertStateQuery.Key(key.ColumnName, key.value) - } else { - upsertStateQuery.Set(key.ColumnName, key.value) - } - - insertColumns = append(insertColumns, key.ColumnName) - insertValues = append(insertValues, key.value) - } - - insertColumns = append(insertColumns, - sm.tableMap.Event.Timestamp.ColumnName, - sm.tableMap.Event.Sequence.ColumnName, - sm.tableMap.Event.Root.ColumnName, - sm.tableMap.Event.StateSnapshot.ColumnName, - ) - insertValues = append(insertValues, - eventMeta.Timestamp.AsTime(), - eventMeta.Sequence, - eventDBValue, - stateDBValue, - ) - insertEventQuery.Columns(insertColumns...).Values(insertValues...) - - upsertStateQuery.Set(sm.tableMap.State.Root.ColumnName, stateDBValue) - - _, err = tx.Insert(ctx, upsertStateQuery) - if err != nil { - log.WithFields(ctx, map[string]any{ - "keys": keyValues, - "error": err.Error(), - }).Error("failed to upsert state") - return fmt.Errorf("upsert state: %w", err) - } - - _, err = tx.Insert(ctx, insertEventQuery) - if err != nil { - log.WithFields(ctx, map[string]any{ - "keys": keyValues, - "error": err.Error(), - }).Error("failed to insert event") - return fmt.Errorf("insert event: %w", err) - } - - return nil -} - -func (sm *StateMachine[K, S, ST, SD, E, IE]) eventQuery(eventID string) *sq.SelectBuilder { - - selectQuery := sq. - Select(sm.tableMap.Event.Root.ColumnName). - From(sm.tableMap.Event.TableName). - Where(sq.Eq{sm.tableMap.Event.ID.ColumnName: eventID}) - - return selectQuery -} - -// followEventDeduplicate is similar to the firstEventUniqueCheck, but it -// compares the entire event including metadata, as this is not designed to -// handle consumer idempotency. -func (sm *StateMachine[K, S, ST, SD, E, IE]) followEventDeduplicate(ctx context.Context, tx sqrlx.Transaction, event E) (bool, error) { - selectQuery := sm.eventQuery(event.PSMMetadata().EventId) - - var eventData, stateData []byte - err := tx.SelectRow(ctx, selectQuery).Scan(&eventData, &stateData) - - if errors.Is(err, sql.ErrNoRows) { - return false, nil - } - if err != nil { - return false, fmt.Errorf("selecting event for deduplication: %w", err) - } - - existing := (*new(E)).ProtoReflect().New().Interface().(E) - if err := protojson.Unmarshal(eventData, existing); err != nil { - return true, fmt.Errorf("unmarshalling event: %w", err) - } - - if !proto.Equal(existing, event) { - return true, fmt.Errorf("event %s already exists with different data", existing.PSMMetadata().EventId) - } - - return true, nil -} - func (sm *StateMachine[K, S, ST, SD, E, IE]) followEvents(ctx context.Context, tx sqrlx.Transaction, events []E) error { for _, event := range events { - if err := sm.validateEvent(event); err != nil { + if err := sm.validator.Validate(event); err != nil { return fmt.Errorf("validating event %s: %w", event.ProtoReflect().Descriptor().FullName(), err) } - exists, err := sm.followEventDeduplicate(ctx, tx, event) + state, err := sm.getCurrentState(ctx, tx, event.PSMKeys()) if err != nil { return err } - if exists { - continue + + wrapped, err := prepareFollowEvent[K, S, ST, SD, E, IE](event, state) + if err != nil { + return err } - state, err := sm.getCurrentState(ctx, tx, event.PSMKeys()) + exists, err := sm.followEventDeduplicate(ctx, tx, wrapped) if err != nil { return err } + if exists { + continue + } if state.GetStatus() == 0 { newState, err := sm.runInitialEvent(ctx, tx, state) @@ -463,9 +299,9 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) followEvents(ctx context.Context, t state = newState } - sm.nextStateEvent(state, event.PSMMetadata()) + incrementEventSequence(state, event.PSMMetadata()) - err = sm.followEvent(ctx, tx, state, event) + err = sm.followEvent(ctx, tx, wrapped) if err != nil { return fmt.Errorf("run event %s (%s): %w", event.PSMMetadata().EventId, event.UnwrapPSMEvent().PSMEventKey(), err) } @@ -487,9 +323,8 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runInitialEvent(ctx context.Context } eventSpec := &EventSpec[K, S, ST, SD, E, IE]{ - Keys: keys, - EventID: uuid.NewString(), - Event: innerEvent, + Keys: keys, + Event: innerEvent, Cause: &psm_j5pb.Cause{ Type: &psm_j5pb.Cause_Init{ Init: &psm_j5pb.InitCause{}, @@ -497,13 +332,13 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runInitialEvent(ctx context.Context }, } - prepared, err := sm.prepareEvent(state, eventSpec) + prepared, err := eventSpec.buildWrapper(state) if err != nil { return state, fmt.Errorf("prepare event: %w", err) } // RunEvent modifies state in place - returnState, err := sm.runEvent(ctx, tx, state, prepared, captureFinalState) + returnState, err := sm.runEvent(ctx, tx, prepared, captureFinalState) if err != nil { return state, fmt.Errorf("input event %s: %w", eventSpec.Event.PSMEventKey(), err) } @@ -514,19 +349,10 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runInitialEvent(ctx context.Context func (sm *StateMachine[K, S, ST, SD, E, IE]) runTx(ctx context.Context, tx sqrlx.Transaction, outerEvent *EventSpec[K, S, ST, SD, E, IE]) (S, error) { if err := outerEvent.validateAndPrepare(); err != nil { + // type of 'S' is *State, we can't return nil, error return situations suck. return *new(S), fmt.Errorf("event %s: %w", outerEvent.Event.ProtoReflect().Descriptor().FullName(), err) } - if outerEvent.EventID == "" { - outerEvent.EventID = uuid.NewString() - } - - if existingState, didExist, err := sm.firstEventUniqueCheck(ctx, tx, outerEvent.EventID, outerEvent.Event); err != nil { - return existingState, err - } else if didExist { - return existingState, nil - } - state, err := sm.getCurrentState(ctx, tx, outerEvent.Keys) if err != nil { return state, err @@ -549,14 +375,18 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runTx(ctx context.Context, tx sqrlx outerEvent.Keys = state.PSMKeys() } - prepared, err := sm.prepareEvent(state, outerEvent) + prepared, err := outerEvent.buildWrapper(state) if err != nil { return state, fmt.Errorf("prepare event: %w", err) } - // RunEvent modifies state in place - returnState, err := sm.runEvent(ctx, tx, state, prepared, captureInitialState) // return the state after the first transition + if existingState, didExist, err := sm.causeEventIdempotency(ctx, tx, prepared); err != nil { + return existingState, err + } else if didExist { + return existingState, nil + } + returnState, err := sm.runEvent(ctx, tx, prepared, captureInitialState) // return the state after the first transition if err != nil { return state, fmt.Errorf("input event %s: %w", outerEvent.Event.PSMEventKey(), err) } @@ -564,116 +394,11 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) runTx(ctx context.Context, tx sqrlx return *returnState, nil } -// firstEventUniqueCheck checks if the event ID for the outer triggering event -// is unique in the event table. If not, it checks if the event is a repeat -// processing of the same event, and returns the state after the initial -// transition. -func (sm *StateMachine[K, S, ST, SD, E, IE]) firstEventUniqueCheck(ctx context.Context, tx sqrlx.Transaction, eventID string, data IE) (S, bool, error) { - var s S - selectQuery := sm.eventQuery(eventID) - - selectQuery.Column(sm.tableMap.Event.StateSnapshot.ColumnName) - - var eventData, stateData []byte - err := tx.SelectRow(ctx, selectQuery).Scan(&eventData, &stateData) - if errors.Is(err, sql.ErrNoRows) { - return s, false, nil - } - if err != nil { - return s, false, fmt.Errorf("selecting event: %w", err) - } - - existing := (*new(E)).ProtoReflect().New().Interface().(E) - - if err := protojson.Unmarshal(eventData, existing); err != nil { - return s, false, fmt.Errorf("unmarshalling event: %w", err) - } - - if !proto.Equal(existing.UnwrapPSMEvent(), data) { - return s, false, ErrDuplicateEventID - } - - state := (*new(S)).ProtoReflect().New() - if err := protojson.Unmarshal(stateData, state.Interface()); err != nil { - return s, false, fmt.Errorf("unmarshalling state: %w", err) - } - - return state.Interface().(S), true, nil -} - -func (sm *StateMachine[K, S, ST, SD, E, IE]) eventsMustBeUnique(ctx context.Context, tx sqrlx.Transaction, events ...*EventSpec[K, S, ST, SD, E, IE]) error { - for _, event := range events { - if event.EventID == "" { - continue // UUID Gen Later - } - selectQuery := sm.eventQuery(event.EventID) - - var data []byte - err := tx.SelectRow(ctx, selectQuery).Scan(&data) - if errors.Is(err, sql.ErrNoRows) { - continue - } - if err != nil { - return fmt.Errorf("selecting event: %w", err) - } - return ErrDuplicateEventID - } - return nil - -} - -func (sm *StateMachine[K, S, ST, SD, E, IE]) validateEvent(event E) error { - if sm.validator == nil { - v, err := protovalidate.New() - if err != nil { - fmt.Println("failed to initialize validator:", err) - } - sm.validator = v - } - - return sm.validator.Validate(event) -} - -func (sm *StateMachine[K, S, ST, SD, E, IE]) prepareEvent(state S, spec *EventSpec[K, S, ST, SD, E, IE]) (built E, err error) { - - built = (*new(E)).ProtoReflect().New().Interface().(E) - if err := built.SetPSMEvent(spec.Event); err != nil { - return built, fmt.Errorf("set event: %w", err) - } - built.SetPSMKeys(spec.Keys) - - eventMeta := built.PSMMetadata() - eventMeta.EventId = spec.EventID - eventMeta.Timestamp = timestamppb.Now() - eventMeta.Cause = spec.Cause - - sm.nextStateEvent(state, eventMeta) - - return - -} - -func (sm *StateMachine[K, S, ST, SD, E, IE]) nextStateEvent(state S, eventMeta *psm_j5pb.EventMetadata) { - stateMeta := state.PSMMetadata() - - eventMeta.Sequence = 0 - if state.GetStatus() == 0 { - eventMeta.Sequence = 0 - stateMeta.CreatedAt = eventMeta.Timestamp - stateMeta.UpdatedAt = eventMeta.Timestamp - } else { - eventMeta.Sequence = stateMeta.LastSequence + 1 - stateMeta.LastSequence = eventMeta.Sequence - stateMeta.UpdatedAt = eventMeta.Timestamp - } -} - func (sm *StateMachine[K, S, ST, SD, E, IE]) transitionFromLink(ctx context.Context, tx sqrlx.Transaction, cause *psm_j5pb.Cause, keys K, innerEvent IE) error { // nolint: unused // Used when the state machine is implementing LinkDestination event := &EventSpec[K, S, ST, SD, E, IE]{ Keys: keys, Timestamp: time.Now(), Event: innerEvent, - EventID: uuid.NewString(), Cause: cause, } @@ -688,14 +413,12 @@ func (sm *StateMachine[K, S, ST, SD, E, IE]) transitionFromLink(ctx context.Cont func (sm *StateMachine[K, S, ST, SD, E, IE]) deriveEvent(cause E, chained IE) (evt *EventSpec[K, S, ST, SD, E, IE], err error) { causeMetadata := cause.PSMMetadata() - eventID := uuid.NewString() psmKeys := cause.PSMKeys() eventOut := &EventSpec[K, S, ST, SD, E, IE]{ Keys: psmKeys, Timestamp: time.Now(), Event: chained, - EventID: eventID, Cause: &psm_j5pb.Cause{ Type: &psm_j5pb.Cause_PsmEvent{ PsmEvent: &psm_j5pb.PSMEventCause{ diff --git a/psm/storage.go b/psm/storage.go new file mode 100644 index 0000000..b42f2ae --- /dev/null +++ b/psm/storage.go @@ -0,0 +1,227 @@ +package psm + +import ( + "context" + "database/sql" + "errors" + "fmt" + + sq "github.com/elgris/sqrl" + "github.com/pentops/log.go/log" + "github.com/pentops/protostate/internal/dbconvert" + "github.com/pentops/sqrlx.go/sqrlx" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" +) + +// causeEventIdempotency checks if the event ID for the outer triggering event +// is unique in the event table. If not, it checks if the event is a repeat +// processing of the same event, and returns the state after the initial +// transition. +func (sm *StateMachine[K, S, ST, SD, E, IE]) causeEventIdempotency(ctx context.Context, tx sqrlx.Transaction, se preparedEvent[K, S, ST, SD, E, IE]) (S, bool, error) { + var s S + selectQuery := sq. + Select(sm.tableMap.Event.Root.ColumnName). + From(sm.tableMap.Event.TableName). + Where(sq.Eq{sm.tableMap.Event.IdempotencyHash.ColumnName: se.idempotencyKey}) + + selectQuery.Column(sm.tableMap.Event.StateSnapshot.ColumnName) + + var eventData, stateData []byte + err := tx.SelectRow(ctx, selectQuery).Scan(&eventData, &stateData) + if errors.Is(err, sql.ErrNoRows) { + return s, false, nil + } + if err != nil { + return s, false, fmt.Errorf("selecting event: %w", err) + } + + existing := (*new(E)).ProtoReflect().New().Interface().(E) + + if err := protojson.Unmarshal(eventData, existing); err != nil { + return s, false, fmt.Errorf("unmarshalling event: %w", err) + } + + if !proto.Equal(existing.UnwrapPSMEvent(), se.event.UnwrapPSMEvent()) { + return s, false, ErrDuplicateEventID + } + + state := (*new(S)).ProtoReflect().New() + if err := protojson.Unmarshal(stateData, state.Interface()); err != nil { + return s, false, fmt.Errorf("unmarshalling state: %w", err) + } + + return state.Interface().(S), true, nil +} + +// followEventDeduplicate is similar to the firstEventUniqueCheck, but it +// compares the entire event including metadata, as this is not designed to +// handle consumer idempotency. +func (sm *StateMachine[K, S, ST, SD, E, IE]) followEventDeduplicate(ctx context.Context, tx sqrlx.Transaction, se preparedEvent[K, S, ST, SD, E, IE]) (bool, error) { + selectQuery := sq. + Select(sm.tableMap.Event.Root.ColumnName). + From(sm.tableMap.Event.TableName). + Where(sq.Eq{sm.tableMap.Event.ID.ColumnName: se.event.PSMMetadata().EventId}) + + var eventData, stateData []byte + err := tx.SelectRow(ctx, selectQuery).Scan(&eventData, &stateData) + + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("selecting event for deduplication: %w", err) + } + + existing := (*new(E)).ProtoReflect().New().Interface().(E) + if err := protojson.Unmarshal(eventData, existing); err != nil { + return true, fmt.Errorf("unmarshalling event: %w", err) + } + + if !proto.Equal(existing, se.event) { + return true, fmt.Errorf("event %s already exists with different data", existing.PSMMetadata().EventId) + } + + return true, nil +} + +func (sm *StateMachine[K, S, ST, SD, E, IE]) storeStateAndEvent( + ctx context.Context, + tx sqrlx.Transaction, + evt preparedEvent[K, S, ST, SD, E, IE], +) error { + state := evt.state + event := evt.event + + if state.GetStatus() == 0 { + return fmt.Errorf("state machine transitioned to zero status") + } + + err := sm.validator.Validate(state) + if err != nil { + return err + } + + stateDBValue, err := dbconvert.MarshalProto(state) + if err != nil { + return fmt.Errorf("state field: %w", err) + } + + eventDBValue, err := dbconvert.MarshalProto(event) + if err != nil { + return fmt.Errorf("event field: %w", err) + } + + // TODO: This does not change during transitions, so should be calculated + // early and once. + keyValues, err := sm.keyValues(state.PSMKeys()) + if err != nil { + return fmt.Errorf("key fields: %w", err) + } + if len(keyValues.missingRequired) > 0 { + return fmt.Errorf("missing required key(s) %v in store", keyValues.missingRequired) + } + + eventMeta := event.PSMMetadata() + + upsertStateQuery := sqrlx.Upsert(sm.tableMap.State.TableName) + + insertValues := []any{} + insertColumns := []string{} + + insertEventQuery := sq.Insert(sm.tableMap.Event.TableName) + + insertColumns = append(insertColumns, sm.tableMap.Event.ID.ColumnName) + insertValues = append(insertValues, eventMeta.EventId) + + for _, key := range keyValues.values { + if key.Primary { + upsertStateQuery.Key(key.ColumnName, key.value) + } else { + upsertStateQuery.Set(key.ColumnName, key.value) + } + + insertColumns = append(insertColumns, key.ColumnName) + insertValues = append(insertValues, key.value) + } + + insertColumns = append(insertColumns, + sm.tableMap.Event.IdempotencyHash.ColumnName, + sm.tableMap.Event.Timestamp.ColumnName, + sm.tableMap.Event.Sequence.ColumnName, + sm.tableMap.Event.Root.ColumnName, + sm.tableMap.Event.StateSnapshot.ColumnName, + ) + insertValues = append(insertValues, + evt.idempotencyKey, + eventMeta.Timestamp.AsTime(), + eventMeta.Sequence, + eventDBValue, + stateDBValue, + ) + insertEventQuery.Columns(insertColumns...).Values(insertValues...) + + upsertStateQuery.Set(sm.tableMap.State.Root.ColumnName, stateDBValue) + + _, err = tx.Insert(ctx, upsertStateQuery) + if err != nil { + log.WithFields(ctx, map[string]any{ + "keys": keyValues, + "error": err.Error(), + }).Error("failed to upsert state") + return fmt.Errorf("upsert state: %w", err) + } + + _, err = tx.Insert(ctx, insertEventQuery) + if err != nil { + log.WithFields(ctx, map[string]any{ + "keys": keyValues, + "error": err.Error(), + }).Error("failed to insert event") + return fmt.Errorf("insert event: %w", err) + } + + return nil +} + +func (sm *StateMachine[K, S, ST, SD, E, IE]) getCurrentState(ctx context.Context, tx sqrlx.Transaction, keys K) (S, error) { + state := (*new(S)).ProtoReflect().New().Interface().(S) + + selectQuery := sq. + Select(sm.tableMap.State.Root.ColumnName). + From(sm.tableMap.State.TableName) + + allKeys, err := sm.keyValues(keys) + if err != nil { + return state, err + } + for _, key := range allKeys.values { + if !key.Primary { + continue + } + selectQuery = selectQuery.Where(sq.Eq{key.ColumnName: key.value}) + } + + var stateJSON []byte + err = tx.SelectRow(ctx, selectQuery).Scan(&stateJSON) + if errors.Is(err, sql.ErrNoRows) { + state.SetPSMKeys(proto.Clone(keys).(K)) + + if len(allKeys.missingRequired) > 0 { + return state, fmt.Errorf("missing required key(s) %v in initial event", allKeys.missingRequired) + } + + // OK, leave empty state alone + return state, nil + } + if err != nil { + qq, _, _ := selectQuery.ToSql() + return state, fmt.Errorf("selecting current state (%s): %w", qq, err) + } + + if err := protojson.Unmarshal(stateJSON, state); err != nil { + return state, err + } + + return state, nil +} diff --git a/psm/table_map.go b/psm/table_map.go index 09d23b3..e2a0f2b 100644 --- a/psm/table_map.go +++ b/psm/table_map.go @@ -40,6 +40,9 @@ func (tm *TableMap) Validate() error { if tm.Event.ID == nil { return fmt.Errorf("missing Event.Data in TableMap") } + if tm.Event.IdempotencyHash == nil { + return fmt.Errorf("missing Event.IdempotencyHash in TableMap") + } if tm.Event.Timestamp == nil { return fmt.Errorf("missing Event.Timestamp in TableMap") } @@ -66,9 +69,12 @@ type EventTableSpec struct { Root *FieldSpec // a UUID holding the primary key of the event - // TODO: Multi-column ID for Events? ID *FieldSpec + // Stores a globally unique itempotency key, hashed appropriately for + // unique-per-tenant at entry + IdempotencyHash *FieldSpec + // timestamptz The time of the event Timestamp *FieldSpec @@ -156,6 +162,9 @@ func buildDefaultTableMap(keyMessage protoreflect.MessageDescriptor) (*TableMap, ColumnName: "id", //PathFromRoot: psm.PathSpec{string(ss.EventMetadataField.Name()), "event_id"}, }, + IdempotencyHash: &FieldSpec{ + ColumnName: "idempotency", + }, Timestamp: &FieldSpec{ ColumnName: "timestamp", //PathFromRoot: psm.PathSpec{string(ss.EventMetadataField.Name()), "timestamp"},