diff --git a/src/implementation/TimeWeightedVotingPower.sol b/src/implementation/TimeWeightedVotingPower.sol index 1419ac6..f084aca 100644 --- a/src/implementation/TimeWeightedVotingPower.sol +++ b/src/implementation/TimeWeightedVotingPower.sol @@ -9,12 +9,19 @@ import {Ownable} from "@solady/contracts/auth/Ownable.sol"; /// @title TimeWeightedVotingPower /// @author BreadKit -/// @notice Lossless time-weighted voting power calculation using the breadchain pattern -/// @dev Walks the token's ERC20Votes checkpoint array to compute the exact -/// area-under-the-curve of delegated votes over the current cycle, then -/// divides by the period length to produce a time-weighted average. -/// The lookback window is derived from the cycle module's cycle length. -/// Every balance change is accounted for — no sampling or approximation. +/// @notice Time-weighted voting power with optional per-interval quadratic scaling +/// @dev By default computes a lossless time-weighted average over the cycle window. +/// When a non-zero scalingPeriod is set, applies a quadratic penalty to each +/// checkpoint interval where duration < scalingPeriod: +/// +/// - intervalLength >= scalingPeriod: contribution = value * intervalLength (no penalty) +/// - intervalLength < scalingPeriod: contribution = value * intervalLength^2 / scalingPeriod +/// +/// This makes flash-loan attacks progressively more expensive even when the +/// total period is long — an attacker holding tokens for only 1 block with +/// scalingPeriod=10 gets ~100x less voting power than the nominal amount. +/// +/// Every balance change in the ERC20Votes checkpoint array is fully accounted for. contract TimeWeightedVotingPower is IVotingPowerStrategy, Ownable { // ============ Errors ============ @@ -30,31 +37,59 @@ contract TimeWeightedVotingPower is IVotingPowerStrategy, Ownable { /// @notice Thrown when end block is in the future error FuturePeriod(); + /// @notice Thrown when scaling period exceeds the maximum allowed value + error ScalingPeriodTooLarge(); + + /// @notice Maximum allowed scaling period (~30 days in blocks at 12s/block) + uint256 public constant MAX_SCALING_PERIOD = 365 days / 12; + // ============ Immutable Storage ============ /// @notice The ERC20Votes token used for voting power calculation - IVotesCheckpoints public immutable votingToken; + IVotesCheckpoints public immutable VOTING_TOKEN; /// @notice The cycle module for period tracking and lookback derivation - ICycleModule public immutable cycleModule; + ICycleModule public immutable CYCLE_MODULE; + + /// @notice Scaling period in blocks for the quadratic flash-loan penalty. + /// @dev When 0 (default), no penalty is applied (classic time-weighted average). + /// When set to a non-zero value, intervals shorter than this period receive + /// a quadratic penalty: contribution = value * duration^2 / scalingPeriod. + /// Example: scalingPeriod=100, attacker holds 1000 ETH for 1 block → + /// 10 ETH effective power (100x reduction). + /// WARNING: The owner can change scalingPeriod mid-cycle. A timelock or + /// governance delay is recommended for production deployments. + uint256 public scalingPeriod; + + // ============ Events ============ + + /// @notice Emitted when the scaling period is updated + /// @param oldPeriod The previous scaling period value + /// @param newPeriod The new scaling period value + event ScalingPeriodUpdated(uint256 oldPeriod, uint256 newPeriod); + + // ============ Constructor ============ /// @notice Constructs the time-weighted voting power strategy /// @dev Reverts if either token or cycle module address is zero /// @param _votingToken The ERC20Votes token with checkpoint support /// @param _cycleModule The cycle module for period tracking - constructor(IVotesCheckpoints _votingToken, ICycleModule _cycleModule) { + /// @param _scalingPeriod Initial scaling period in blocks (0 = disabled) + constructor(IVotesCheckpoints _votingToken, ICycleModule _cycleModule, uint256 _scalingPeriod) { if (address(_votingToken) == address(0)) revert InvalidToken(); if (address(_cycleModule) == address(0)) revert InvalidCycleModule(); + if (_scalingPeriod > MAX_SCALING_PERIOD) revert ScalingPeriodTooLarge(); - votingToken = _votingToken; - cycleModule = _cycleModule; + VOTING_TOKEN = _votingToken; + CYCLE_MODULE = _cycleModule; + scalingPeriod = _scalingPeriod; _initializeOwner(msg.sender); } /// @inheritdoc IVotingPowerStrategy function getCurrentVotingPower(address account) external view override returns (uint256) { - uint256 cycleStart = cycleModule.lastCycleStartBlock(); + uint256 cycleStart = CYCLE_MODULE.lastCycleStartBlock(); uint256 periodEnd = block.number; uint256 periodStart = cycleStart; @@ -82,12 +117,46 @@ contract TimeWeightedVotingPower is IVotingPowerStrategy, Ownable { return _calculateTimeWeightedPower(account, startBlock, endBlock); } + // ============ Admin Functions ============ + + /// @notice Sets the scaling period for flash-loan quadratic penalty + /// @dev Only callable by owner. Setting to 0 disables the penalty (classic TWAV). + /// @param _scalingPeriod New scaling period in blocks + function setScalingPeriod(uint256 _scalingPeriod) external onlyOwner { + if (_scalingPeriod > MAX_SCALING_PERIOD) revert ScalingPeriodTooLarge(); + uint256 old = scalingPeriod; + scalingPeriod = _scalingPeriod; + emit ScalingPeriodUpdated(old, _scalingPeriod); + } + + // ============ Internal Functions ============ + + /// @dev Applies the quadratic scaling penalty to a checkpoint interval. + /// When scalingPeriod is 0, returns area unchanged (no penalty). + /// When intervalLength >= scalingPeriod, returns area unchanged (fully vested). + /// When intervalLength < scalingPeriod, returns area * intervalLength / scalingPeriod. + /// Since area = value * intervalLength, the effective formula is: + /// value * intervalLength^2 / scalingPeriod (quadratic in duration). + /// @param area The raw area contribution (value * intervalLength) + /// @param intervalLength The duration of this checkpoint interval in blocks + /// @return The scaled area contribution + function _applyScalingPenalty(uint256 area, uint256 intervalLength) internal view returns (uint256) { + if (scalingPeriod == 0 || intervalLength >= scalingPeriod) { + return area; + } + return (area * intervalLength) / scalingPeriod; + } + /// @dev Walks the token's checkpoint array in reverse to compute the exact /// integral of (delegated votes * blocks held) over [start, end), then /// divides by the period length to produce the time-weighted average. /// This is the breadchain pattern — every balance change is accounted for. + /// + /// When scalingPeriod is non-zero, each interval shorter than scalingPeriod + /// receives a quadratic penalty (see _applyScalingPenalty), making flash-loan + /// attacks progressively more expensive. function _calculateTimeWeightedPower(address account, uint256 start, uint256 end) internal view returns (uint256) { - uint32 numCkpts = votingToken.numCheckpoints(account); + uint32 numCkpts = VOTING_TOKEN.numCheckpoints(account); if (numCkpts == 0) return 0; uint256 periodLength = end - start; @@ -95,21 +164,26 @@ contract TimeWeightedVotingPower is IVotingPowerStrategy, Ownable { uint256 upperBound = end; for (uint32 i = numCkpts; i > 0; i--) { - Checkpoints.Checkpoint208 memory ckpt = votingToken.checkpoints(account, i - 1); + Checkpoints.Checkpoint208 memory ckpt = VOTING_TOKEN.checkpoints(account, i - 1); uint256 key = uint256(ckpt._key); uint256 value = uint256(ckpt._value); // Checkpoint is at or after the period end — skip it if (key >= end) continue; + uint256 intervalLength; if (key <= start) { // Checkpoint predates the period — its value covers [start, upperBound) - totalArea += value * (upperBound - start); + intervalLength = upperBound - start; + uint256 area = value * intervalLength; + totalArea += _applyScalingPenalty(area, intervalLength); break; } // Checkpoint is within (start, end) — its value covers [key, upperBound) - totalArea += value * (upperBound - key); + intervalLength = upperBound - key; + uint256 contribution = value * intervalLength; + totalArea += _applyScalingPenalty(contribution, intervalLength); upperBound = key; } diff --git a/test/TimeWeightedVotingPower.t.sol b/test/TimeWeightedVotingPower.t.sol index a907572..a663a1e 100644 --- a/test/TimeWeightedVotingPower.t.sol +++ b/test/TimeWeightedVotingPower.t.sol @@ -56,24 +56,24 @@ contract TimeWeightedVotingPowerTest is Test { vm.roll(1); cycleModule.initialize(CYCLE_LENGTH); - strategy = new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(cycleModule))); + strategy = new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(cycleModule)), 0); } // ============ Constructor Tests ============ function testConstructorSetsState() public view { - assertEq(address(strategy.votingToken()), address(token)); - assertEq(address(strategy.cycleModule()), address(cycleModule)); + assertEq(address(strategy.VOTING_TOKEN()), address(token)); + assertEq(address(strategy.CYCLE_MODULE()), address(cycleModule)); } function testConstructorRevertsInvalidToken() public { vm.expectRevert(TimeWeightedVotingPower.InvalidToken.selector); - new TimeWeightedVotingPower(IVotesCheckpoints(address(0)), ICycleModule(address(cycleModule))); + new TimeWeightedVotingPower(IVotesCheckpoints(address(0)), ICycleModule(address(cycleModule)), 0); } function testConstructorRevertsInvalidCycleModule() public { vm.expectRevert(TimeWeightedVotingPower.InvalidCycleModule.selector); - new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(0))); + new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(0)), 0); } // ============ Lossless Calculation Tests ============ @@ -358,4 +358,161 @@ contract TimeWeightedVotingPowerTest is Test { // With exact checkpoint walking and only 1 checkpoint, gas should be very low assertLt(gasUsed, 30_000, "Gas should be very low with few checkpoints"); } + + // ============ scalingPeriod Feature Tests (Issue #91) ============ + + // Test 1: scalingPeriod=0 disables the quadratic penalty entirely + function testScalingPeriodDisabledWhenZero() public { + // strategy was created with scalingPeriod=0 in setUp + assertEq(strategy.scalingPeriod(), 0); + + vm.roll(10); + token.mint(user1, 100 ether); + + // Advance to block 500 + vm.roll(500); + // Mint more tokens (this checkpoint is outside our query period) + token.mint(user1, 50 ether); + + vm.roll(501); + + // Query period [10, 14): user held 100 ether the whole interval + // With scalingPeriod=0, no penalty. intervalLength=4, area=400 ether, avg=100 ether + uint256 power = strategy.getVotingPowerForPeriod(user1, 10, 14); + + // Baseline: 100 ether held for entire 4-block period, no penalty + uint256 baseline = 100 ether; + assertEq(power, baseline, "scalingPeriod=0 should apply no penalty"); + } + + // Test 2: Quadratic penalty reduces voting power for short intervals + function testQuadraticPenaltyApplied() public { + // Create a new strategy with scalingPeriod=100 + TimeWeightedVotingPower scaledStrategy = + new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(cycleModule)), 100); + + vm.roll(10); + token.mint(user1, 100 ether); + + // Advance to block 50 (intervalLength = 40, which is < scalingPeriod=100) + vm.roll(50); + + // Query period [10, 50): 40 blocks + // Checkpoint at block 10 predates/equals start => intervalLength = upperBound - start = 50-10 = 40 + // area = 100 ether * 40 = 4000 ether + // scaled = (4000 ether * 40) / 100 = 1600 ether + // avg = 1600 ether / 40 = 40 ether + uint256 power = scaledStrategy.getVotingPowerForPeriod(user1, 10, 50); + assertEq(power, 40 ether, "Quadratic penalty should reduce power for short intervals"); + + // Without penalty the power would be 100 ether — confirm reduction + uint256 unpenalizedPower = strategy.getVotingPowerForPeriod(user1, 10, 50); + assertEq(unpenalizedPower, 100 ether, "Unpenalized strategy should return full balance"); + assertLt(power, unpenalizedPower, "Penalty should reduce voting power"); + } + + // Test 3: No penalty when intervalLength exactly equals scalingPeriod + function testNoPenaltyWhenIntervalExceedsScalingPeriod() public { + // scalingPeriod=100, intervalLength=100 => factor = 100/100 = 1.0 => no reduction + TimeWeightedVotingPower scaledStrategy = + new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(cycleModule)), 100); + + vm.roll(10); + token.mint(user1, 100 ether); + + // Advance to block 110 (intervalLength = 100 == scalingPeriod) + vm.roll(110); + + // area = 100 ether * 100 = 10000 ether + // intervalLength (100) >= scalingPeriod (100) => no penalty + // avg = 10000 ether / 100 = 100 ether + uint256 power = scaledStrategy.getVotingPowerForPeriod(user1, 10, 110); + assertEq(power, 100 ether, "No penalty when intervalLength == scalingPeriod"); + } + + // Test 4: No penalty when intervalLength strictly exceeds scalingPeriod + function testScalingPeriodOnlyAffectsShortIntervals() public { + // scalingPeriod=100, intervalLength=150 > scalingPeriod => no penalty + TimeWeightedVotingPower scaledStrategy = + new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(cycleModule)), 100); + + vm.roll(10); + token.mint(user1, 100 ether); + + // Advance to block 160 (intervalLength = 150 > scalingPeriod=100) + vm.roll(160); + + // area = 100 ether * 150 = 15000 ether + // intervalLength (150) >= scalingPeriod (100) => no penalty + // avg = 15000 ether / 150 = 100 ether + uint256 power = scaledStrategy.getVotingPowerForPeriod(user1, 10, 160); + assertEq(power, 100 ether, "No penalty when intervalLength > scalingPeriod"); + } + + // Test 5: Owner can update scalingPeriod + function testSetScalingPeriodByOwner() public { + // strategy created with scalingPeriod=0 in setUp, owner = address(this) + assertEq(strategy.scalingPeriod(), 0); + + strategy.setScalingPeriod(100); + + assertEq(strategy.scalingPeriod(), 100, "scalingPeriod should be updated to 100"); + } + + // Test 6: setScalingPeriod emits ScalingPeriodUpdated event + function testSetScalingPeriodEmitsEvent() public { + assertEq(strategy.scalingPeriod(), 0); + + vm.expectEmit(true, true, true, true, address(strategy)); + emit TimeWeightedVotingPower.ScalingPeriodUpdated(0, 100); + + strategy.setScalingPeriod(100); + } + + // Test 7: Non-owner cannot set scalingPeriod + function testSetScalingPeriodRevertsNonOwner() public { + address nonOwner = address(0xBEEF1234); + vm.prank(nonOwner); + vm.expectRevert(abi.encodeWithSignature("Unauthorized()")); + strategy.setScalingPeriod(100); + } + + // Test 7b: setScalingPeriod reverts if value exceeds MAX_SCALING_PERIOD + function testSetScalingPeriodRevertsExceedsMax() public { + vm.expectRevert(abi.encodeWithSignature("ScalingPeriodTooLarge()")); + strategy.setScalingPeriod(type(uint64).max); + } + + // Test 8: Flash loan attack is mitigated by quadratic scaling + function testFlashLoanMitigatedByScaling() public { + // scalingPeriod=100: attacker holding tokens for 1 block gets 100x less power + TimeWeightedVotingPower scaledStrategy = + new TimeWeightedVotingPower(IVotesCheckpoints(address(token)), ICycleModule(address(cycleModule)), 100); + + // Attacker mints 1000 ether at block 500 + vm.roll(500); + token.mint(user2, 1000 ether); + + // Move 1 block so checkpoint is visible + vm.roll(501); + + // Query the 1-block window [500, 501) + // intervalLength = 1 < scalingPeriod=100 + // area = 1000 ether * 1 = 1000 ether + // scaled = (1000 ether * 1) / 100 = 10 ether + // avg = 10 ether / 1 = 10 ether + uint256 scaledPower = scaledStrategy.getVotingPowerForPeriod(user2, 500, 501); + assertEq(scaledPower, 10 ether, "Flash loan with scalingPeriod=100 should give 10 ether not 1000 ether"); + + // Without scaling, the same 1-block window returns the full balance + uint256 unscaledPower = strategy.getVotingPowerForPeriod(user2, 500, 501); + assertEq(unscaledPower, 1000 ether, "Without scaling, full balance is returned"); + + // Confirm the 100x reduction + assertEq( + unscaledPower / scaledPower, + 100, + "Scaling should give 100x reduction for 1-block hold with scalingPeriod=100" + ); + } }