Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/models/config/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::HashMap;
use serde::{Serialize, Deserialize};
use twilight_model::id::Id;
use twilight_model::id::marker::{ApplicationMarker, GuildMarker};
Expand Down
92 changes: 75 additions & 17 deletions src/server/guild/editing.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::collections::{BTreeSet, HashMap};
use std::collections::HashMap;
use std::sync::Arc;
use json_patch::Patch;
use mongodb::bson::doc;
use mongodb::bson::oid::ObjectId;
use serde_json::{Map, Value};
use serde::Serialize;
use tokio::sync::{Mutex, RwLock};
use tracing::{error, warn};
use twilight_model::id::Id;
Expand All @@ -11,18 +12,22 @@ use crate::context::Context;
use crate::models::config::GuildConfig;
use crate::server::guild::ws::{Connection, OutboundAction, OutboundMessage};

#[derive(Clone, Debug, Serialize)]
pub struct Change {
pub author_id: Id<UserMarker>,
pub changes: Patch
}

struct GuildEditingState {
pub connections: Vec<Arc<Connection>>,
pub changes: Value,
pub edited_by: BTreeSet<Id<UserMarker>>
pub changes: Vec<Change>,
}

impl Default for GuildEditingState {
fn default() -> Self {
Self {
connections: vec![],
changes: Value::Object(Map::new()),
edited_by: Default::default(),
changes: vec![],
}
}
}
Expand Down Expand Up @@ -59,14 +64,16 @@ impl GuildsEditing {

pub async fn marge_changes(
&self,
author: Id<UserMarker>,
author_id: Id<UserMarker>,
guild_id: Id<GuildMarker>,
changes: Value
changes: Patch
) -> Option<()> {
let guild = self.get_guild(guild_id).await?;
let mut guild_lock = guild.lock().await;
json_patch::merge(&mut guild_lock.changes, &changes);
guild_lock.edited_by.insert(author);
guild_lock.changes.push(Change {
author_id,
changes
});
Some(())
}

Expand All @@ -75,7 +82,55 @@ impl GuildsEditing {
list_lock.get(&guild_id).cloned()
}

pub async fn broadcast_changes(&self, context: &Arc<Context>, guild_id: Id<GuildMarker>) -> Option<()> {
pub async fn broadcast_users(&self, guild_id: Id<GuildMarker>) -> Option<()> {
let guild = self.get_guild(guild_id).await?;
let guild_lock = guild.lock().await;

let users = guild_lock.connections
.iter().map(|connection| connection.user_id)
.collect::<Vec<Id<UserMarker>>>();

for connection in &guild_lock.connections {
let _ = connection.tx.send(OutboundAction::Message(OutboundMessage::OverwriteUsers(users.to_owned())));
}

Some(())
}

pub async fn broadcast_change(
&self, guild_id: Id<GuildMarker>, author_id: Id<UserMarker>, changes: Patch
) -> Option<()> {
let guild = self.get_guild(guild_id).await?;
let guild_lock = guild.lock().await;

for connection in &guild_lock.connections {
let _ = connection.tx.send(OutboundAction::Message(OutboundMessage::PushChange(Change {
author_id,
changes: changes.to_owned()
})));
}

Some(())
}

pub async fn get_initialization_data(&self, context: &Arc<Context>, guild_id: Id<GuildMarker>)
-> Option<(GuildConfig, Vec<Change>, Vec<Id<UserMarker>>)> {
let config = context.mongodb
.get_config(guild_id)
.await
.inspect_err(|error| error!(name: "mongodb error", ?error))
.ok()?;

let guild = self.get_guild(guild_id).await?;
let guild_lock = guild.lock().await;
let users = guild_lock.connections
.iter().map(|connection| connection.user_id)
.collect::<Vec<Id<UserMarker>>>();

Some((config.to_owned(), guild_lock.changes.to_owned(), users))
}

pub async fn broadcast_config_overwrite(&self, context: &Arc<Context>, guild_id: Id<GuildMarker>) -> Option<()> {
let config = context.mongodb
.get_config(guild_id)
.await
Expand All @@ -89,10 +144,9 @@ impl GuildsEditing {
.collect::<Vec<Id<UserMarker>>>();

for connection in &guild_lock.connections {
let _ = connection.tx.send(OutboundAction::Message(OutboundMessage::UpdateConfigurationData {
let _ = connection.tx.send(OutboundAction::Message(OutboundMessage::OverwriteConfigurationData {
saved_config: config.to_owned(),
changes: guild_lock.changes.to_owned(),
users: users.to_owned(),
}));
}

Expand All @@ -112,9 +166,14 @@ impl GuildsEditing {
let mut new_config = serde_json::to_value(config)
.inspect_err(|error| error!(name: "cannot convert guild config to value", ?error))
.ok()?;
json_patch::merge(&mut new_config, &guild_lock.changes);
for patch in &guild_lock.changes {
json_patch::patch(&mut new_config, &patch.changes)
.inspect_err(|error| error!(name: "error applying patch to guild config", ?patch, ?error))
.ok()?;
}

let new_config: GuildConfig = serde_json::from_value(new_config)
.inspect_err(|error| error!(name: "cannot marge edits with guild config", ?error))
.inspect_err(|error| error!(name: "cannot serialize config after applying patches", ?error))
.ok()?;

if new_config.guild_id != guild_id {
Expand All @@ -138,8 +197,7 @@ impl GuildsEditing {
.ok()?;
context.mongodb.configs_cache.remove(&guild_id);

guild_lock.changes = Value::Object(Map::new());
guild_lock.edited_by.clear();
guild_lock.changes = vec![];

Some(())
}
Expand Down
60 changes: 37 additions & 23 deletions src/server/guild/ws.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::borrow::Cow;
use std::sync::Arc;
use futures_util::{SinkExt, StreamExt};
use json_patch::Patch;
use mongodb::bson::oid::ObjectId;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tokio::sync::mpsc::UnboundedSender;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info};
use tracing::{error, info, warn};
use twilight_model::id::Id;
use twilight_model::id::marker::UserMarker;
use twilight_model::user::CurrentUserGuild;
Expand All @@ -15,7 +15,7 @@ use crate::context::Context;
use crate::database::redis::PartialGuild;
use crate::models::config::GuildConfig;
use crate::ok_or_return;
use crate::server::guild::editing::GuildsEditing;
use crate::server::guild::editing::{Change, GuildsEditing};
use crate::server::session::AuthorizationInformation;

macro_rules! close {
Expand All @@ -28,14 +28,17 @@ macro_rules! unwrap_or_close_and_return {
($target: expr, $tx: expr, $reason: expr) => {
match $target {
Ok(value) => value,
Err(_) => {
close!($tx, $reason);
Err(err) => {
let reason = $reason;
tracing::warn!(name: "connection closed due to error", ?err, ?reason);
close!($tx, reason);
return
}
}
};
}

#[derive(Debug)]
pub enum CloseReason {
MessageIsNotString,
CannotParseJSON,
Expand Down Expand Up @@ -97,26 +100,33 @@ pub async fn handle_connection(
Message::close_with(reason.code(), reason.text())
).await;
guilds_editing.remove_connection(guild_id, session_id).await;
guilds_editing.broadcast_changes(&context, guild_id).await;
guilds_editing.broadcast_users(guild_id).await;
}
}
}
let _ = ws_tx.close().await;
});


guilds_editing.broadcast_users(guild_id).await;

guilds_editing.add_connection(guild_id, Connection {
user_id: info.user.id,
session_id,
tx: tx.to_owned(),
}).await;

let _ = tx.send(OutboundAction::Message(OutboundMessage::Initialization {
cached: ok_or_return!(context.redis.get_guild(guild.id).await, Ok),
oauth2: guild.to_owned(),
session_id
}));

guilds_editing.broadcast_changes(&context, guild_id).await;
if let Some((saved_config, changes, users)) =
guilds_editing.get_initialization_data(&context, guild_id).await {
let _ = tx.send(OutboundAction::Message(OutboundMessage::Initialization {
cached: ok_or_return!(context.redis.get_guild(guild.id).await, Ok),
oauth2: guild.to_owned(),
saved_config,
changes,
users,
session_id
}));
}

while let Some(result) = ws_rx.next().await {
let message = match result {
Expand All @@ -136,12 +146,12 @@ pub async fn handle_connection(
}

guilds_editing.remove_connection(guild_id, session_id).await;
guilds_editing.broadcast_changes(&context, guild_id).await;
guilds_editing.broadcast_users(guild_id).await;
}
#[derive(Debug, Deserialize)]
#[serde(tag = "action", content = "data")]
enum InboundMessage {
GuildConfigUpdate(Value),
GuildConfigUpdate(Patch),
ApplyChanges
}

Expand All @@ -151,13 +161,17 @@ pub enum OutboundMessage {
Initialization {
oauth2: CurrentUserGuild,
cached: PartialGuild,
session_id: ObjectId
},
UpdateConfigurationData {
session_id: ObjectId,
saved_config: GuildConfig,
changes: Value,
changes: Vec<Change>,
users: Vec<Id<UserMarker>>
}
},
OverwriteConfigurationData {
saved_config: GuildConfig,
changes: Vec<Change>
},
OverwriteUsers(Vec<Id<UserMarker>>),
PushChange(Change)
}

pub enum OutboundAction {
Expand All @@ -183,8 +197,8 @@ async fn on_message(

match message {
InboundMessage::GuildConfigUpdate(changes) => {
let _ = guilds_editing.marge_changes(info.user.id, guild.id, changes).await;
let _ = guilds_editing.broadcast_changes(&context, guild.id).await;
let _ = guilds_editing.marge_changes(info.user.id, guild.id, changes.to_owned()).await;
let _ = guilds_editing.broadcast_change(guild.id, info.user.id, changes).await;
}
InboundMessage::ApplyChanges => {
info!(
Expand All @@ -193,7 +207,7 @@ async fn on_message(
guild_id = %guild.id
);
guilds_editing.apply_changes(&context, guild.id).await;
let _ = guilds_editing.broadcast_changes(&context, guild.id).await;
let _ = guilds_editing.broadcast_config_overwrite(&context, guild.id).await;
let _ = context.redis.announce_config_update(guild.id).await
.inspect_err(|error| {
error!(name: "error sending guild_id to redis update announcer", ?error, %guild.id)
Expand Down