Skip to content

Commit

Permalink
Re-jigged algorithm to compute TWAB of periods
Browse files Browse the repository at this point in the history
  • Loading branch information
asselstine committed Aug 9, 2023
1 parent fdc8380 commit 8d97a9e
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 62 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,25 @@ For example:
- If a user held 100 tokens for 1 week, then their average balance over that time was 100.
- If instead they held 100 for half of the week and then 200 for the second half, then their average for the week would be 150.

## Notes

We want to compute the TWAB for a given period.

Period TWAB = (cumulativeAtEnd - cumulativeAtStart) / (deltaTime)

cumulativeAtStart is based on previous period observation.
deltaTime is the period length
cumulativeAtEnd can be computed using the TWAB's value.

period TWAB Observation = (cumulativeBalanceAtStart, avgBalance, timestamp)

Compute TWAB between (startTime, endTime)

We compute the start observation of the period containing startTime (balanceAtStart, avgBalance, startPeriodTimestamp)
We compute the start observation of the period containing endTime (balanceAtStart, avgBalance, startPeriodTimestamp)

Then we extrapolate the two observations, and subtract them to compute TWAB.

## Development

### Installation
Expand Down
10 changes: 7 additions & 3 deletions src/libraries/ObservationLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ library ObservationLib {
* @param _target Timestamp at which we are searching the Observation.
* @param _cardinality Cardinality of the circular buffer we are searching through.
* @param _time Timestamp at which we perform the binary search.
* @return beforeOrAtIndex Index of Observation recorded before, or at, the target.
* @return beforeOrAt Observation recorded before, or at, the target.
* @return afterOrAtIndex Index of Observation recorded at, or after, the target.
* @return afterOrAt Observation recorded at, or after, the target.
*/
function binarySearch(
Expand All @@ -59,7 +61,7 @@ library ObservationLib {
uint32 _target,
uint16 _cardinality,
uint32 _time
) internal view returns (Observation memory beforeOrAt, Observation memory afterOrAt) {
) internal view returns (uint32 beforeOrAtIndex, Observation memory beforeOrAt, uint32 afterOrAtIndex, Observation memory afterOrAt) {
uint256 leftSide = _oldestObservationIndex;
uint256 rightSide = _newestObservationIndex < leftSide
? leftSide + _cardinality - 1
Expand All @@ -71,7 +73,8 @@ library ObservationLib {
// After each iteration, we narrow down the search to the left or the right side while still starting our search in the middle.
currentIndex = (leftSide + rightSide) / 2;

beforeOrAt = _observations[uint16(RingBufferLib.wrap(currentIndex, _cardinality))];
beforeOrAtIndex = uint16(RingBufferLib.wrap(currentIndex, _cardinality));
beforeOrAt = _observations[beforeOrAtIndex];
uint32 beforeOrAtTimestamp = beforeOrAt.timestamp;

// We've landed on an uninitialized timestamp, keep searching higher (more recently).
Expand All @@ -80,7 +83,8 @@ library ObservationLib {
continue;
}

afterOrAt = _observations[uint16(RingBufferLib.nextIndex(currentIndex, _cardinality))];
afterOrAtIndex = uint16(RingBufferLib.nextIndex(currentIndex, _cardinality));
afterOrAt = _observations[afterOrAtIndex];

bool targetAfterOrAt = beforeOrAtTimestamp.lte(_target, _time);

Expand Down
143 changes: 93 additions & 50 deletions src/libraries/TwabLib.sol
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,7 @@ library TwabLib {
if (_targetTime >= overwritePeriodStartTime) {
revert TimestampNotFinalized(_targetTime, overwritePeriodStartTime);
}
ObservationLib.Observation memory prevOrAtObservation = _getPreviousOrAtObservation(
PERIOD_OFFSET,
_observations,
_accountDetails,
_targetTime
);
ObservationLib.Observation memory periodObs = computePeriodStartObservation(PERIOD_LENGTH, PERIOD_OFFSET, _observations, _accountDetails, _targetTime);
return prevOrAtObservation.balance;
}

Expand Down Expand Up @@ -326,38 +321,89 @@ library TwabLib {
revert InvalidTimeRange(_startTime, _endTime);
}

ObservationLib.Observation memory startObservation = _getPreviousOrAtObservation(
PERIOD_OFFSET,
_observations,
_accountDetails,
_startTime
);

if (_endTime == _startTime) {
return startObservation.balance;
if (_accountDetails.cardinality == 0) {
return 0;
}

uint32 periodEndTime = getPeriodEndTimeWithTimestamp(PERIOD_LENGTH, PERIOD_OFFSET, _endTime);
ObservationLib.Observation memory startPeriodObs = computePeriodStartObservation(PERIOD_LENGTH, PERIOD_OFFSET, _observations, _accountDetails, _startTime);
ObservationLib.Observation memory endPeriodObs = computePeriodStartObservation(PERIOD_LENGTH, PERIOD_OFFSET, _observations, _accountDetails, _endTime);

// recalculate for start
startPeriodObs = _calculateTemporaryObservation(startPeriodObs, _startTime);
endPeriodObs = _calculateTemporaryObservation(endPeriodObs, _endTime);

ObservationLib.Observation memory endObservation = _getPreviousOrAtObservation(
return (endPeriodObs.cumulativeBalance - startPeriodObs.cumulativeBalance) /
(_endTime - _startTime);
}

function computePeriodStartObservation(
uint32 PERIOD_LENGTH,
uint32 PERIOD_OFFSET,
ObservationLib.Observation[MAX_CARDINALITY] storage _observations,
AccountDetails memory _accountDetails,
uint32 _timestamp
) internal view returns (ObservationLib.Observation memory) {
// console2.log("computePeriodStartObservation 1", _timestamp);
uint32 period = _getTimestampPeriod(PERIOD_LENGTH, PERIOD_OFFSET, _timestamp);
uint32 periodEnd = getPeriodEndTime(PERIOD_LENGTH, PERIOD_OFFSET, period);
uint32 periodStart = getPeriodStartTime(PERIOD_LENGTH, PERIOD_OFFSET, period);

// console2.log("computePeriodStartObservation 2", _timestamp);

(uint32 periodObsIndex, ObservationLib.Observation memory periodObs) = _getPreviousOrAtObservation(
PERIOD_OFFSET,
_observations,
_accountDetails,
periodEndTime
periodEnd
);

if (startObservation.timestamp != _endTime) {
startObservation = _calculateTemporaryObservation(startObservation, _startTime);
ObservationLib.Observation memory prevPeriodObs;

// console2.log("computePeriodStartObservation 3", _timestamp);

if (periodObs.timestamp < periodStart) {
// the observation is for both
prevPeriodObs = periodObs;
} else {

// if not first period and has more observations
if (period > 1 && _accountDetails.cardinality > 1) {
// console2.log("computePeriodStartObservation 4", _timestamp);
uint32 prevPeriodObsIndex = uint32(RingBufferLib.prevIndex(periodObsIndex, _accountDetails.cardinality));
prevPeriodObs = _observations[prevPeriodObsIndex];
require(prevPeriodObs.timestamp.lt(periodObs.timestamp, uint32(block.timestamp)), "insufficient history");
} else {
prevPeriodObs = ObservationLib.Observation({
balance: 0,
cumulativeBalance: 0,
timestamp: PERIOD_OFFSET
});
}

}

if (endObservation.timestamp != _endTime) {
endObservation = _calculateTemporaryObservation(endObservation, _endTime);
// console2.log("computePeriodStartObservation 5", _timestamp);

// recalculate for start
prevPeriodObs = _calculateTemporaryObservation(prevPeriodObs, periodStart);

// console2.log("computePeriodStartObservation 6", _timestamp);

if (periodObs.timestamp != periodEnd) {
periodObs = _calculateTemporaryObservation(periodObs, periodEnd);
}

// Difference in amount / time
return
(endObservation.cumulativeBalance - startObservation.cumulativeBalance) /
(_endTime - _startTime);
// console2.log("computePeriodStartObservation 7", _timestamp);

uint96 twab = uint96(
(periodObs.cumulativeBalance - prevPeriodObs.cumulativeBalance) / PERIOD_LENGTH
);

return ObservationLib.Observation({
balance: twab,
timestamp: periodStart,
cumulativeBalance: prevPeriodObs.cumulativeBalance
});
}

/**
Expand All @@ -371,6 +417,8 @@ library TwabLib {
ObservationLib.Observation memory _prevObservation,
uint32 _time
) private view returns (ObservationLib.Observation memory) {
// console2.log("_calculateTemporaryObservation _time", _time);
// console2.log("_calculateTemporaryObservation _prevObservation.timestamp", _prevObservation.timestamp);
return
ObservationLib.Observation({
balance: _prevObservation.balance,
Expand Down Expand Up @@ -473,19 +521,11 @@ library TwabLib {
ObservationLib.Observation memory _observation,
uint32 _timestamp
) private view returns (uint128 cumulativeBalance) {
if (_timestamp < _observation.timestamp) {
// if before, then linearly extrapolate backwards
return
_observation.cumulativeBalance -
uint128(_observation.balance) *
(_observation.timestamp.checkedSub(_timestamp, uint32(block.timestamp)));
} else {
// if after, then linearly extrapolate forwards
return
_observation.cumulativeBalance +
uint128(_observation.balance) *
(_timestamp.checkedSub(_observation.timestamp, uint32(block.timestamp)));
}
// linearly extrapolate forwards
return
_observation.cumulativeBalance +
uint128(_observation.balance) *
(_timestamp.checkedSub(_observation.timestamp, uint32(block.timestamp)));
}

/**
Expand Down Expand Up @@ -584,14 +624,15 @@ library TwabLib {
* @param _observations The circular buffer of observations
* @param _accountDetails The account details to query with
* @param _targetTime The timestamp to look up
* @return prevOrAtIndex The index of the observation
* @return prevOrAtObservation The observation
*/
function getPreviousOrAtObservation(
uint32 PERIOD_OFFSET,
ObservationLib.Observation[MAX_CARDINALITY] storage _observations,
AccountDetails memory _accountDetails,
uint32 _targetTime
) internal view returns (ObservationLib.Observation memory prevOrAtObservation) {
) internal view returns (uint32 prevOrAtIndex, ObservationLib.Observation memory prevOrAtObservation) {
return _getPreviousOrAtObservation(PERIOD_OFFSET, _observations, _accountDetails, _targetTime);
}

Expand All @@ -601,51 +642,53 @@ library TwabLib {
* @param _observations The circular buffer of observations
* @param _accountDetails The account details to query with
* @param _targetTime The timestamp to look up
* @return prevOrAtIndex The index of the observation
* @return prevOrAtObservation The observation
*/
function _getPreviousOrAtObservation(
uint32 PERIOD_OFFSET,
ObservationLib.Observation[MAX_CARDINALITY] storage _observations,
AccountDetails memory _accountDetails,
uint32 _targetTime
) private view returns (ObservationLib.Observation memory prevOrAtObservation) {
) private view returns (uint32 prevOrAtIndex, ObservationLib.Observation memory prevOrAtObservation) {
uint32 currentTime = uint32(block.timestamp);

uint16 oldestTwabIndex;
uint16 newestTwabIndex;

// If there are no observations, return a zeroed observation
if (_accountDetails.cardinality == 0) {
return
ObservationLib.Observation({ cumulativeBalance: 0, balance: 0, timestamp: PERIOD_OFFSET });
return (0, ObservationLib.Observation({ cumulativeBalance: 0, balance: 0, timestamp: PERIOD_OFFSET }));
}

// Find the newest observation and check if the target time is AFTER it
(newestTwabIndex, prevOrAtObservation) = getNewestObservation(_observations, _accountDetails);
if (_targetTime >= prevOrAtObservation.timestamp) {
return prevOrAtObservation;
return (newestTwabIndex, prevOrAtObservation);
}

// If there is only 1 observation and it's after the target, then return zero
if (_accountDetails.cardinality == 1) {
return
ObservationLib.Observation({
(0, ObservationLib.Observation({
cumulativeBalance: 0,
balance: 0,
timestamp: PERIOD_OFFSET
});
}));
}

// Find the oldest Observation and check if the target time is BEFORE it
(oldestTwabIndex, prevOrAtObservation) = getOldestObservation(_observations, _accountDetails);
if (_targetTime < prevOrAtObservation.timestamp) {
return
ObservationLib.Observation({ cumulativeBalance: 0, balance: 0, timestamp: PERIOD_OFFSET });
(0, ObservationLib.Observation({ cumulativeBalance: 0, balance: 0, timestamp: PERIOD_OFFSET }));
}

ObservationLib.Observation memory afterOrAtObservation;
uint32 prevOrAtIndex;
uint32 afterOrAtIndex;
// Otherwise, we perform a binarySearch to find the observation before or at the timestamp
(prevOrAtObservation, afterOrAtObservation) = ObservationLib.binarySearch(
(prevOrAtIndex, prevOrAtObservation, afterOrAtIndex, afterOrAtObservation) = ObservationLib.binarySearch(
_observations,
newestTwabIndex,
oldestTwabIndex,
Expand All @@ -656,10 +699,10 @@ library TwabLib {

// If the afterOrAt is at, we can skip a temporary Observation computation by returning it here
if (afterOrAtObservation.timestamp == _targetTime) {
return afterOrAtObservation;
return (afterOrAtIndex, afterOrAtObservation);
}

return prevOrAtObservation;
return (prevOrAtIndex, prevOrAtObservation);
}

/**
Expand Down
36 changes: 28 additions & 8 deletions test/TwabController.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,38 @@ contract TwabControllerTest is BaseTest {
vm.warp(PERIOD_OFFSET);
twabController.mint(alice, 1000e18);

vm.warp(PERIOD_OFFSET + PERIOD_LENGTH/2 + 10);
// added 1000 halfway through second period, so second period avg is 1500
vm.warp(PERIOD_OFFSET + PERIOD_LENGTH + PERIOD_LENGTH/2);
twabController.mint(bob, 1000e18);

vm.warp(PERIOD_OFFSET + PERIOD_LENGTH*2);
// warp far enough into future to finalize
vm.warp(PERIOD_OFFSET + PERIOD_LENGTH*5);

uint32 startTime = PERIOD_OFFSET;
uint32 endTime = startTime + PERIOD_LENGTH;

// assertEq(
// twabController.getTwabBetween(address(this), alice, startTime, endTime),
// twabController.getTotalSupplyTwabBetween(address(this), startTime, endTime),
// "alice time is correct"
// );

// assertEq(
// twabController.getTotalSupplyTwabBetween(address(this), startTime, startTime + PERIOD_LENGTH*2),
// (1000e18 + 1500e18) / 2,
// "avg over two periods"
// );

// // half of first period at 1000, half of second period at 1500
// assertApproxEqAbs(
// twabController.getTotalSupplyTwabBetween(address(this), startTime + PERIOD_LENGTH/2, startTime + PERIOD_LENGTH + PERIOD_LENGTH/2),
// 1250e18,
// 1e16
// );

assertEq(
twabController.getTwabBetween(address(this), alice, PERIOD_OFFSET, PERIOD_OFFSET + PERIOD_LENGTH/2),
twabController.getTotalSupplyTwabBetween(address(this), PERIOD_OFFSET, PERIOD_OFFSET + PERIOD_LENGTH/2)
twabController.getTwabBetween(address(this), bob, startTime + PERIOD_LENGTH/2, startTime + PERIOD_LENGTH + PERIOD_LENGTH/2),
250e18
);
}

Expand Down Expand Up @@ -1047,10 +1071,6 @@ contract TwabControllerTest is BaseTest {
assertEq(manipulatedDrawBalance, actualDrawBalance);
}

function testCardinalityWrap() public {

}

/* ============ getTimestampPeriod ============ */

function testGetTimestampPeriod() public {
Expand Down
2 changes: 2 additions & 0 deletions test/mocks/ObservationLibMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ contract ObservationLibMock {
uint32 _time
) external view returns (ObservationLib.Observation memory, ObservationLib.Observation memory) {
(
,
ObservationLib.Observation memory beforeOrAt,
,
ObservationLib.Observation memory afterOrAt
) = ObservationLib.binarySearch(
observations,
Expand Down
2 changes: 1 addition & 1 deletion test/mocks/TwabLibMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ contract TwabLibMock {
function getPreviousOrAtObservation(
uint32 _targetTime
) external view returns (ObservationLib.Observation memory) {
ObservationLib.Observation memory prevOrAtObservation = TwabLib.getPreviousOrAtObservation(
(,ObservationLib.Observation memory prevOrAtObservation) = TwabLib.getPreviousOrAtObservation(
PERIOD_OFFSET,
account.observations,
account.details,
Expand Down

0 comments on commit 8d97a9e

Please sign in to comment.