Skip to content

Commit

Permalink
fix(pool): fix pool max ops per sender during replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
dancoombs committed Jan 5, 2025
1 parent ee1de4f commit 7fc60a7
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 31 deletions.
31 changes: 31 additions & 0 deletions crates/pool/src/mempool/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ where
})
}

pub(crate) fn all_operations(&self) -> impl Iterator<Item = Arc<PoolOperation>> + '_ {
self.by_hash.values().map(|o| o.po.clone())
}

/// Does maintenance on the pool.
///
/// 1) Removes all operations using the given entity, returning the hashes of the removed operations.
Expand Down Expand Up @@ -827,6 +831,33 @@ mod tests {
check_map_entry(pool.best.iter().nth(2), Some(&ops[0]));
}

#[test]
fn add_operations() {
let mut pool = pool();
let addr_a = Address::random();
let addr_b = Address::random();
let ops = vec![
create_op(addr_a, 0, 1),
create_op(addr_a, 1, 2),
create_op(addr_b, 0, 3),
];

let mut hashes = HashSet::new();
for op in ops.iter() {
hashes.insert(pool.add_operation(op.clone(), 0).unwrap());
}

let all = pool
.all_operations()
.map(|op| {
op.uo
.hash(pool.config.entry_point, pool.config.chain_spec.id)
})
.collect::<HashSet<_>>();

assert_eq!(all, hashes);
}

#[test]
fn best_ties() {
let mut pool = pool();
Expand Down
27 changes: 18 additions & 9 deletions crates/pool/src/mempool/reputation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ impl AddressReputation {
self.state.write().add_seen(address);
}

pub(crate) fn remove_seen(&self, address: Address, value: u64) {
self.state.write().remove_seen(address, value);
pub(crate) fn dec_seen(&self, address: Address) {
self.state.write().dec_seen(address);
}

pub(crate) fn handle_urep_030_penalty(&self, address: Address) {
Expand All @@ -132,6 +132,10 @@ impl AddressReputation {
self.state.write().handle_srep_050_penalty(address);
}

pub(crate) fn handle_erep_015_amendment(&self, address: Address, value: u64) {
self.state.write().handle_erep_015_amendment(address, value);
}

pub(crate) fn dump_reputation(&self) -> Vec<Reputation> {
self.state.read().dump_reputation()
}
Expand All @@ -140,8 +144,8 @@ impl AddressReputation {
self.state.write().add_included(address);
}

pub(crate) fn remove_included(&self, address: Address) {
self.state.write().remove_included(address);
pub(crate) fn dec_included(&self, address: Address) {
self.state.write().dec_included(address);
}

pub(crate) fn set_reputation(&self, address: Address, ops_seen: u64, ops_included: u64) {
Expand Down Expand Up @@ -222,9 +226,9 @@ impl AddressReputationInner {
count.ops_seen += 1;
}

fn remove_seen(&mut self, address: Address, value: u64) {
fn dec_seen(&mut self, address: Address) {
let count = self.counts.entry(address).or_default();
count.ops_seen = count.ops_seen.saturating_sub(value);
count.ops_seen = count.ops_seen.saturating_sub(1);
}

fn handle_urep_030_penalty(&mut self, address: Address) {
Expand All @@ -238,6 +242,11 @@ impl AddressReputationInner {
count.ops_seen = self.params.bundle_invalidation_ops_seen_staked_penalty;
}

pub(crate) fn handle_erep_015_amendment(&mut self, address: Address, value: u64) {
let count = self.counts.entry(address).or_default();
count.ops_seen = count.ops_seen.saturating_sub(value);
}

fn dump_reputation(&self) -> Vec<Reputation> {
self.counts
.iter()
Expand All @@ -254,7 +263,7 @@ impl AddressReputationInner {
count.ops_included += 1;
}

fn remove_included(&mut self, address: Address) {
fn dec_included(&mut self, address: Address) {
let count = self.counts.entry(address).or_default();
count.ops_included = count.ops_included.saturating_sub(1)
}
Expand Down Expand Up @@ -325,8 +334,8 @@ mod tests {
assert_eq!(counts.ops_seen, 1000);
assert_eq!(counts.ops_included, 1000);

reputation.remove_seen(addr, 1);
reputation.remove_included(addr);
reputation.dec_seen(addr);
reputation.dec_included(addr);
let counts = reputation.counts.get(&addr).unwrap();
assert_eq!(counts.ops_seen, 999);
assert_eq!(counts.ops_included, 999);
Expand Down
118 changes: 96 additions & 22 deletions crates/pool/src/mempool/uo_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ where

if let Some(po) = pool_op {
for entity_addr in po.entities().map(|e| e.address).unique() {
self.reputation.remove_included(entity_addr);
self.reputation.dec_included(entity_addr);
}

unmined_op_count += 1;
Expand Down Expand Up @@ -468,9 +468,6 @@ where
origin: OperationOrigin,
op: UserOperationVariant,
) -> MempoolResult<B256> {
// TODO(danc) aggregator reputation is not implemented
// TODO(danc) catch ops with aggregators prior to simulation and reject

// Check reputation of entities in involved in the operation
// If throttled, entity can have THROTTLED_ENTITY_MEMPOOL_COUNT inflight operation at a time, else reject
// If banned, reject
Expand Down Expand Up @@ -521,6 +518,8 @@ where
// Check if op is already known or replacing another, and if so, ensure its fees are high enough
// do this before simulation to save resources
let replacement = self.state.read().pool.check_replacement(&op)?;
let to_replace = replacement.and_then(|r| self.state.read().pool.get_operation_by_hash(r));

// Check if op violates the STO-040 spec rule
self.state.read().pool.check_multiple_roles_violation(&op)?;

Expand Down Expand Up @@ -587,6 +586,7 @@ where
{
let state = self.state.read();
if !pool_op.account_is_staked
&& to_replace.is_none()
&& state.pool.address_count(&pool_op.uo.sender())
>= self.config.same_sender_mempool_count
{
Expand All @@ -599,9 +599,16 @@ where
// Check unstaked non-sender entity counts in the mempool
for entity in pool_op
.unstaked_entities()
.unique()
.filter(|e| e.address != pool_op.entity_infos.sender.address())
{
let ops_allowed = self.reputation.get_ops_allowed(entity.address);
let mut ops_allowed = self.reputation.get_ops_allowed(entity.address);
if let Some(to_replace) = &to_replace {
if to_replace.entities().contains(&entity) {
ops_allowed += 1;
}
}

if state.pool.address_count(&entity.address) >= ops_allowed as usize {
return Err(MempoolError::MaxOperationsReached(
ops_allowed as usize,
Expand All @@ -628,17 +635,20 @@ where
// once the operation has been added to the pool
self.paymaster.add_or_update_balance(&pool_op).await?;

// Update reputation
if replacement.is_none() {
pool_op.entities().unique().for_each(|e| {
self.reputation.add_seen(e.address);
if self.reputation.status(e.address) == ReputationStatus::Throttled {
self.throttle_entity(e);
} else if self.reputation.status(e.address) == ReputationStatus::Banned {
self.remove_entity(e);
}
// Update reputation, handling replacement if needed
if let Some(to_replace) = to_replace {
to_replace.entities().unique().for_each(|e| {
self.reputation.dec_seen(e.address);
});
}
pool_op.entities().unique().for_each(|e| {
self.reputation.add_seen(e.address);
if self.reputation.status(e.address) == ReputationStatus::Throttled {
self.throttle_entity(e);
} else if self.reputation.status(e.address) == ReputationStatus::Banned {
self.remove_entity(e);
}
});

// Emit event
let op_hash = pool_op
Expand Down Expand Up @@ -735,17 +745,19 @@ where
entity.is_paymaster(),
"Attempted to add EREP-015 paymaster amendment for non-paymaster entity"
);
assert!(
update.value.is_some(),
"PaymasterOpsSeenDecrement must carry an explicit decrement value"
self.reputation.handle_erep_015_amendment(
entity.address,
update
.value
.expect("PaymasterOpsSeenDecrement must carry an explicit decrement value"),
);
self.reputation
.remove_seen(entity.address, update.value.unwrap());
}
}

if self.reputation.status(entity.address) == ReputationStatus::Banned {
self.remove_entity(entity);
} else if self.reputation.status(entity.address) == ReputationStatus::Throttled {
self.throttle_entity(entity);
}
}

Expand Down Expand Up @@ -787,7 +799,7 @@ where
}

fn all_operations(&self, max: usize) -> Vec<Arc<PoolOperation>> {
self.state.read().pool.best_operations().take(max).collect()
self.state.read().pool.all_operations().take(max).collect()
}

fn get_user_operation_by_hash(&self, hash: B256) -> Option<Arc<PoolOperation>> {
Expand Down Expand Up @@ -1003,6 +1015,25 @@ mod tests {
assert_eq!(pool.best_operations(3, 0).unwrap(), vec![]);
}

#[tokio::test]
async fn all_operations() {
let ops = vec![
create_op(Address::random(), 0, 3, None),
create_op(Address::random(), 0, 2, None),
create_op(Address::random(), 0, 1, None),
];
let uos = ops.iter().map(|op| op.op.clone()).collect::<Vec<_>>();
let pool = create_pool(ops);

for op in &uos {
let _ = pool
.add_operation(OperationOrigin::Local, op.clone())
.await
.unwrap();
}
check_ops_unordered(&pool.all_operations(16), &uos, pool.config.entry_point);
}

#[tokio::test]
async fn chain_update_mine() {
let paymaster = Address::random();
Expand Down Expand Up @@ -1288,7 +1319,7 @@ mod tests {
}

check_ops(
pool.all_operations(4),
pool.best_operations(4, 0).unwrap(),
vec![
uos[0].clone(),
uos[1].clone(),
Expand Down Expand Up @@ -1338,7 +1369,7 @@ mod tests {
.await
.unwrap();
check_ops(
pool.all_operations(4),
pool.best_operations(4, 0).unwrap(),
vec![
uos[1].clone(),
uos[2].clone(),
Expand Down Expand Up @@ -1680,6 +1711,33 @@ mod tests {
.is_err());
}

#[tokio::test]
async fn test_replacement_max_ops_for_unstaked_sender() {
let mut ops = vec![];
let addr = Address::random();
for i in 0..4 {
ops.push(create_op(addr, i, 1, None))
}
// replacement op for first op
ops.push(create_op(addr, 0, 2, None));

let pool = create_pool(ops.clone());

for op in ops.iter().take(4) {
pool.add_operation(OperationOrigin::Local, op.op.clone())
.await
.unwrap();
}

pool.add_operation(OperationOrigin::Local, ops[4].op.clone())
.await
.unwrap();

let uos = ops.into_iter().skip(1).map(|op| op.op).collect::<Vec<_>>();

check_ops_unordered(&pool.all_operations(16), &uos, pool.config.entry_point);
}

#[tokio::test]
async fn test_best_staked() {
let address = Address::random();
Expand Down Expand Up @@ -2043,4 +2101,20 @@ mod tests {
assert_eq!(actual.uo, expected);
}
}

fn check_ops_unordered(
actual: &[Arc<PoolOperation>],
expected: &[UserOperationVariant],
entry_point: Address,
) {
let actual_hashes = actual
.iter()
.map(|op| op.uo.hash(entry_point, 0))
.collect::<HashSet<_>>();
let expected_hashes = expected
.iter()
.map(|op| op.hash(entry_point, 0))
.collect::<HashSet<_>>();
assert_eq!(actual_hashes, expected_hashes);
}
}

0 comments on commit 7fc60a7

Please sign in to comment.