first attempt at algo

This commit is contained in:
Anna 2022-09-04 21:48:43 -04:00
parent dc71a2e780
commit 581c16a8b2
2 changed files with 70 additions and 3 deletions

View File

@ -1,4 +1,5 @@
#![feature(let_chains)]
#![feature(drain_filter)]
use std::collections::HashMap;
use std::sync::Arc;

View File

@ -1,6 +1,9 @@
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context;
use rand::distributions::WeightedIndex;
use rand::Rng;
use warp::{Filter, Rejection, Reply};
use warp::filters::BoxedFilter;
@ -13,15 +16,17 @@ pub fn get_location(state: Arc<State>) -> BoxedFilter<(impl Reply, )> {
.and(warp::path("messages"))
.and(warp::path::param())
.and(warp::path::end())
.and(warp::query::<HashMap<String, String>>())
.and(super::get_id(Arc::clone(&state)))
.and_then(move |location: u32, id: i64| logic(Arc::clone(&state), id, location))
.and_then(move |location: u32, query: HashMap<String, String>, id: i64| logic(Arc::clone(&state), id, location, query))
.boxed()
}
async fn logic(state: Arc<State>, id: i64, location: u32) -> Result<impl Reply, Rejection> {
async fn logic(state: Arc<State>, id: i64, location: u32, query: HashMap<String, String>) -> Result<impl Reply, Rejection> {
// TODO: when we're not just returning all results, make sure own messages are always present
let filter = query.contains_key("filter");
let location = location as i64;
let messages = sqlx::query_as!(
let mut messages = sqlx::query_as!(
RetrievedMessage,
// language=sqlite
r#"
@ -47,5 +52,66 @@ async fn logic(state: Arc<State>, id: i64, location: u32) -> Result<impl Reply,
.context("could not get messages from database")
.map_err(AnyhowRejection)
.map_err(warp::reject::custom)?;
if filter {
filter_messages(&mut messages);
}
Ok(warp::reply::json(&messages))
}
fn filter_messages(messages: &mut Vec<RetrievedMessage>) {
// just count nearby messages. this is O(n^2) but alternatives are hard
// let mut nearby = HashMap::with_capacity(messages.len());
let mut weights = HashMap::with_capacity(messages.len());
let mut ids = Vec::with_capacity(messages.len());
for a in messages.iter() {
let mut nearby = 0;
for b in messages.iter() {
if a.id == b.id {
continue;
}
let distance = (a.x - b.x).powi(2)
+ (a.y - b.y).powi(2)
+ (a.z - b.z).powi(2);
// 7.5 squared
if distance >= 56.25 {
continue;
}
// *nearby.entry(&a.id).or_insert(0) += 1;
nearby += 1;
}
if nearby <= 2 {
// always include groups of three or fewer
ids.push(a.id.clone());
continue;
}
let score = (a.positive_votes - a.negative_votes).max(0);
let raw_weight = score as f32 * (1.0 / nearby as f32);
let weight = raw_weight.trunc() as i64;
println!("{}: weight {} ({} nearby)", a.id, weight.max(1), nearby);
weights.insert(a.id.clone(), weight.max(1));
}
if weights.is_empty() {
return;
}
let max_weight = weights.values().map(|weight| *weight).max().unwrap();
messages.drain_filter(|msg| {
if ids.contains(&msg.id) {
return false;
}
let weight = match weights.get(&msg.id) {
Some(w) => *w,
None => return true,
};
// weight / max_weight chance of being included (returning true means NOT included)
!rand::thread_rng().gen_ratio(weight as u32, max_weight as u32)
});
}