implement base indexes logic

This commit is contained in:
Andrea Franz
2024-09-10 01:08:53 +02:00
parent 6b86d96142
commit 7e15f37eba
3 changed files with 116 additions and 8 deletions

View File

@@ -12,7 +12,7 @@ contract StakeManagerScript is Script {
function run() public {
vm.startBroadcast();
stakeManager = new StakeManager();
// stakeManager = new StakeManager();
vm.stopBroadcast();
}

View File

@@ -1,14 +1,122 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.26;
contract StakeManager {
uint256 public number;
import {IERC20} from "@openzeppelin/contracts/token/ERC20/IERC20.sol";
import {ReentrancyGuard} from "@openzeppelin/contracts/utils/ReentrancyGuard.sol";
function setNumber(uint256 newNumber) public {
number = newNumber;
contract StakeManager is ReentrancyGuard {
error StakingManager__AmountCannotBeZero();
error StakingManager__TransferFailed();
error StakingManager__InsufficientBalance();
IERC20 public immutable stakingToken;
IERC20 public immutable rewardToken;
uint256 public constant SCALE_FACTOR = 1e18;
uint256 public totalStaked;
uint256 public rewardIndex;
uint256 public accountedRewards;
struct UserInfo {
uint256 stakedBalance;
uint256 userRewardIndex;
}
function increment() public {
number++;
mapping(address => UserInfo) public users;
constructor(address _stakingToken, address _rewardToken) {
stakingToken = IERC20(_stakingToken);
rewardToken = IERC20(_rewardToken);
}
function stake(uint256 amount) external nonReentrant {
if (amount == 0) {
revert StakingManager__AmountCannotBeZero();
}
updateRewardIndex();
UserInfo storage user = users[msg.sender];
uint256 userRewards = calculateUserRewards(msg.sender);
if (userRewards > 0) {
safeRewardTransfer(msg.sender, userRewards);
}
bool success = stakingToken.transferFrom(msg.sender, address(this), amount);
if (!success) {
revert StakingManager__TransferFailed();
}
user.stakedBalance += amount;
totalStaked += amount;
user.userRewardIndex = rewardIndex;
}
function unstake(uint256 amount) external nonReentrant {
UserInfo storage user = users[msg.sender];
if (amount > user.stakedBalance) {
revert StakingManager__InsufficientBalance();
}
updateRewardIndex();
uint256 userRewards = calculateUserRewards(msg.sender);
if (userRewards > 0) {
safeRewardTransfer(msg.sender, userRewards);
}
user.stakedBalance -= amount;
totalStaked -= amount;
bool success = stakingToken.transfer(msg.sender, amount);
if (!success) {
revert StakingManager__TransferFailed();
}
user.userRewardIndex = rewardIndex;
}
function updateRewardIndex() public {
if (totalStaked == 0) {
return;
}
uint256 rewardBalance = rewardToken.balanceOf(address(this));
uint256 newRewards = rewardBalance > accountedRewards ? rewardBalance - accountedRewards : 0;
if (newRewards > 0) {
rewardIndex += (newRewards * SCALE_FACTOR) / totalStaked;
accountedRewards += newRewards;
}
}
function getStakedBalance(address userAddress) public view returns (uint256) {
return users[userAddress].stakedBalance;
}
function getPendingRewards(address userAddress) public view returns (uint256) {
return calculateUserRewards(userAddress);
}
function calculateUserRewards(address userAddress) public view returns (uint256) {
UserInfo storage user = users[userAddress];
return (user.stakedBalance * (rewardIndex - user.userRewardIndex)) / SCALE_FACTOR;
}
function safeRewardTransfer(address to, uint256 amount) internal {
uint256 rewardBalance = rewardToken.balanceOf(address(this));
// If amount is higher than the contract's balance (for rounding error), transfer the balance.
if (amount > rewardBalance) {
bool success = rewardToken.transfer(to, rewardBalance);
if (!success) {
revert StakingManager__TransferFailed();
}
} else {
bool success = rewardToken.transfer(to, amount);
if (!success) {
revert StakingManager__TransferFailed();
}
}
}
}

View File

@@ -8,6 +8,6 @@ contract StakeManagerTest is Test {
StakeManager public stakeManager;
function setUp() public {
stakeManager = new StakeManager();
// stakeManager = new StakeManager();
}
}