misc: make Multirecipient fail with 1 recipient

This commit is contained in:
vicnaum
2023-10-12 09:18:25 +02:00
parent e36c4f2b48
commit a4fd0b3d11
3 changed files with 37 additions and 51 deletions

View File

@@ -128,36 +128,27 @@ contract MultirecipientFeeCollectModule is BaseFeeCollectModule {
uint256 len = recipients.length;
// Check number of recipients is supported
if (len == 0) {
if (len < 2) {
revert Errors.InitParamsInvalid();
}
// Skip loop check if only 1 recipient in the array
if (len == 1) {
if (recipients[0].split != BPS_MAX) {
revert InvalidRecipientSplits();
// Check recipient splits sum to 10 000 BPS (100%)
uint256 totalSplits;
uint256 i;
while (i < len) {
if (recipients[i].split == 0) revert RecipientSplitCannotBeZero();
totalSplits += recipients[i].split;
// Store each recipient while looping - avoids extra gas costs in successful cases
_recipientsByPublicationByProfile[profileId][pubId].push(recipients[i]);
unchecked {
++i;
}
}
// If single recipient passes check above, store and return
_recipientsByPublicationByProfile[profileId][pubId].push(recipients[0]);
} else {
// Check recipient splits sum to 10 000 BPS (100%)
uint256 totalSplits;
for (uint256 i = 0; i < len; ) {
if (recipients[i].split == 0) revert RecipientSplitCannotBeZero();
totalSplits += recipients[i].split;
// Store each recipient while looping - avoids extra gas costs in successful cases
_recipientsByPublicationByProfile[profileId][pubId].push(recipients[i]);
unchecked {
++i;
}
}
if (totalSplits != BPS_MAX) {
revert InvalidRecipientSplits();
}
if (totalSplits != BPS_MAX) {
revert InvalidRecipientSplits();
}
}
@@ -176,26 +167,17 @@ contract MultirecipientFeeCollectModule is BaseFeeCollectModule {
][processCollectParams.publicationCollectedId];
uint256 len = recipients.length;
// If only 1 recipient, transfer full amount and skip split calculations
if (len == 1) {
IERC20(currency).safeTransferFrom(
processCollectParams.transactionExecutor,
recipients[0].recipient,
amount
);
} else {
uint256 i;
while (i < len) {
uint256 amountForRecipient = (amount * recipients[i].split) / BPS_MAX;
if (amountForRecipient != 0)
IERC20(currency).safeTransferFrom(
processCollectParams.transactionExecutor,
recipients[i].recipient,
amountForRecipient
);
unchecked {
++i;
}
uint256 i;
while (i < len) {
uint256 amountForRecipient = (amount * recipients[i].split) / BPS_MAX;
if (amountForRecipient != 0)
IERC20(currency).safeTransferFrom(
processCollectParams.transactionExecutor,
recipients[i].recipient,
amountForRecipient
);
unchecked {
++i;
}
}
}

View File

@@ -47,10 +47,14 @@ contract MultirecipientCollectModuleBase is BaseFeeCollectModuleBase {
multirecipientExampleInitData.referralFee = exampleInitData.referralFee;
multirecipientExampleInitData.followerOnly = exampleInitData.followerOnly;
multirecipientExampleInitData.endTimestamp = exampleInitData.endTimestamp;
if (multirecipientExampleInitData.recipients.length == 0)
if (multirecipientExampleInitData.recipients.length == 0) {
multirecipientExampleInitData.recipients.push(
RecipientData({recipient: exampleInitData.recipient, split: BPS_MAX})
RecipientData({recipient: exampleInitData.recipient, split: BPS_MAX / 2})
);
multirecipientExampleInitData.recipients.push(
RecipientData({recipient: exampleInitData.recipient, split: BPS_MAX / 2})
);
}
return abi.encode(multirecipientExampleInitData);
}

View File

@@ -65,7 +65,7 @@ contract MultirecipientCollectModule_Initialization is
);
}
function testCannotPostWithOneRecipientAndSplitNotEqualToBPS_MAX(
function testCannotPostWithOneRecipient(
uint256 profileId,
uint256 pubId,
address transactionExecutor,
@@ -76,12 +76,12 @@ contract MultirecipientCollectModule_Initialization is
vm.assume(pubId != 0);
vm.assume(transactionExecutor != address(0));
vm.assume(recipient != address(0));
split = uint16(bound(split, 0, BPS_MAX - 1));
split = uint16(bound(split, 0, BPS_MAX));
delete multirecipientExampleInitData.recipients;
multirecipientExampleInitData.recipients.push(RecipientData({recipient: recipient, split: split}));
vm.expectRevert(InvalidRecipientSplits.selector);
vm.expectRevert(ModuleErrors.InitParamsInvalid.selector);
vm.prank(collectPublicationAction);
IBaseFeeCollectModule(baseFeeCollectModule).initializePublicationCollectModule(
profileId,
@@ -430,7 +430,7 @@ contract MultirecipientCollectModule_FeeDistribution is MultirecipientCollectMod
delete multirecipientExampleInitData.recipients;
assertEq(multirecipientExampleInitData.recipients.length, 0);
numberOfRecipients = bound(numberOfRecipients, 1, MAX_RECIPIENTS);
numberOfRecipients = bound(numberOfRecipients, 2, MAX_RECIPIENTS);
// console.log('Number of recipients: %s', numberOfRecipients);
split1 = uint16(bound(split1, 1, BPS_MAX - numberOfRecipients));
split2 = uint16(bound(split2, 1, BPS_MAX - numberOfRecipients));