Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[multikey] Add multikey support #26

Merged
merged 4 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

All notable changes to the Aptos Python SDK will be captured in this file. This changelog is written by hand for now.

## Unreleased
- Add Multikey support for Python, with an example
- Deprecate and remove non-BCS transaction submission

## 0.8.6
- add client for graphql indexer service with light demo in coin transfer
- add mypy to ignore missing types for graphql and ecdsa
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ examples:
poetry run python -m examples.simulate_transfer_coin
poetry run python -m examples.transfer_coin
poetry run python -m examples.transfer_two_by_two
poetry run python -m examples.multikey

examples_cli:
poetry run python -m examples.hello_blockchain
Expand Down
2 changes: 2 additions & 0 deletions aptos_sdk/account_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def from_key(key: asymmetric_crypto.PublicKey) -> AccountAddress:
hasher.update(AuthKeyScheme.MultiEd25519)
elif isinstance(key, asymmetric_crypto_wrapper.PublicKey):
hasher.update(AuthKeyScheme.SingleKey)
elif isinstance(key, asymmetric_crypto_wrapper.MultiPublicKey):
hasher.update(AuthKeyScheme.MultiKey)
else:
raise Exception("Unsupported asymmetric_crypto.PublicKey key type.")

Expand Down
137 changes: 136 additions & 1 deletion aptos_sdk/asymmetric_crypto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from __future__ import annotations

from typing import List, Tuple, cast

from . import asymmetric_crypto, ed25519, secp256k1_ecdsa
from .bcs import Deserializer, Serializer

Expand All @@ -29,7 +31,10 @@ def to_crypto_bytes(self) -> bytes:
return ser.output()

def verify(self, data: bytes, signature: asymmetric_crypto.Signature) -> bool:
return self.public_key.verify(data, signature)
# Convert signature to the original signature
sig = cast(Signature, signature)

return self.public_key.verify(data, sig.signature)

@staticmethod
def deserialize(deserializer: Deserializer) -> PublicKey:
Expand Down Expand Up @@ -85,3 +90,133 @@ def deserialize(deserializer: Deserializer) -> Signature:
def serialize(self, serializer: Serializer):
serializer.uleb128(self.variant)
serializer.struct(self.signature)


class MultiPublicKey(asymmetric_crypto.PublicKey):
gregnazario marked this conversation as resolved.
Show resolved Hide resolved
keys: List[PublicKey]
threshold: int

MIN_KEYS = 2
MAX_KEYS = 32
MIN_THRESHOLD = 1

def __init__(self, keys: List[asymmetric_crypto.PublicKey], threshold: int):
assert (
self.MIN_KEYS <= len(keys) <= self.MAX_KEYS
), f"Must have between {self.MIN_KEYS} and {self.MAX_KEYS} keys."
assert (
self.MIN_THRESHOLD <= threshold <= len(keys)
), f"Threshold must be between {self.MIN_THRESHOLD} and {len(keys)}."

# Ensure keys are wrapped
self.keys = []
for key in keys:
if isinstance(key, PublicKey):
self.keys.append(key)
else:
self.keys.append(PublicKey(key))

self.threshold = threshold

def __str__(self) -> str:
return f"{self.threshold}-of-{len(self.keys)} Multi key"

def verify(self, data: bytes, signature: asymmetric_crypto.Signature) -> bool:
try:
total_sig = cast(MultiSignature, signature)
assert self.threshold <= len(
total_sig.signatures
), f"Insufficient signatures, {self.threshold} > {len(total_sig.signatures)}"

for idx, signature in total_sig.signatures:
assert (
len(self.keys) > idx
), f"Signature index exceeds available keys {len(self.keys)} < {idx}"
assert self.keys[idx].verify(
data, signature
), "Unable to verify signature"

except Exception:
Copy link
Contributor

@fishronsage fishronsage Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't catch the exception here, otherwise the assert above will have no effect (don't know what happened)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from MultiEd25519, it's definitely not a great pattern to use. Will follow up with a PR afterwards to move this and MultiEd25519 into a two step process:

  1. One function to verify the signature and throw exceptions
  2. One function that wraps it and returns true or false

return False
return True

@staticmethod
def from_crypto_bytes(indata: bytes) -> MultiPublicKey:
deserializer = Deserializer(indata)
return deserializer.struct(MultiPublicKey)

def to_crypto_bytes(self) -> bytes:
serializer = Serializer()
serializer.struct(self)
return serializer.output()

@staticmethod
def deserialize(deserializer: Deserializer) -> MultiPublicKey:
keys = deserializer.sequence(PublicKey.deserialize)
threshold = deserializer.u8()
return MultiPublicKey(keys, threshold)

def serialize(self, serializer: Serializer):
serializer.sequence(self.keys, Serializer.struct)
serializer.u8(self.threshold)


class MultiSignature(asymmetric_crypto.Signature):
signatures: List[Tuple[int, Signature]]
BITMAP_NUM_OF_BYTES: int = 4

def __init__(self, signatures: List[Tuple[int, asymmetric_crypto.Signature]]):
# Sort first to ensure no issues in order
# signatures.sort(key=lambda x: x[0])
self.signatures = []
for index, signature in signatures:
assert (
index < self.BITMAP_NUM_OF_BYTES * 8
), "bitmap value exceeds maximum value"
if isinstance(signature, Signature):
self.signatures.append((index, signature))
else:
self.signatures.append((index, Signature(signature)))

def __eq__(self, other: object):
if not isinstance(other, MultiSignature):
return NotImplemented
return self.signatures == other.signatures

def __str__(self) -> str:
return f"{self.signatures}"

@staticmethod
def deserialize(deserializer: Deserializer) -> MultiSignature:
signatures = deserializer.sequence(Signature.deserialize)
deserializer.uleb128()
bitmap = deserializer.u32()
num_bits = MultiSignature.BITMAP_NUM_OF_BYTES * 8
sig_index = 0
indexed_signatures = []

for i in range(0, num_bits):
has_signature = (bitmap & index_to_bitmap_value(i)) != 0
if has_signature:
indexed_signatures.append((i, signatures[sig_index]))
sig_index += 1

return MultiSignature(signatures)

def serialize(self, serializer: Serializer):
actual_sigs = []
bitmap = 0

for i, signature in self.signatures:
bitmap |= index_to_bitmap_value(i)
actual_sigs.append(signature)

serializer.sequence(actual_sigs, Serializer.struct)
serializer.uleb128(self.BITMAP_NUM_OF_BYTES)
serializer.u32(bitmap)


def index_to_bitmap_value(i: int) -> int:
bit = i % 8
byte = i // 8
return (128 >> bit) << (byte * 8)
71 changes: 8 additions & 63 deletions aptos_sdk/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,49 +518,6 @@ async def submit_and_wait_for_bcs_transaction(
await self.wait_for_transaction(txn_hash)
return await self.transaction_by_hash(txn_hash)

async def submit_transaction(self, sender: Account, payload: Dict[str, Any]) -> str:
"""
1) Generates a transaction request
2) submits that to produce a raw transaction
3) signs the raw transaction
4) submits the signed transaction
"""

txn_request = {
"sender": f"{sender.address()}",
"sequence_number": str(
await self.account_sequence_number(sender.address())
),
"max_gas_amount": str(self.client_config.max_gas_amount),
"gas_unit_price": str(self.client_config.gas_unit_price),
"expiration_timestamp_secs": str(
int(time.time()) + self.client_config.expiration_ttl
),
"payload": payload,
}

response = await self.client.post(
f"{self.base_url}/transactions/encode_submission", json=txn_request
)
if response.status_code >= 400:
raise ApiError(response.text, response.status_code)

to_sign = bytes.fromhex(response.json()[2:])
signature = sender.sign(to_sign)
txn_request["signature"] = {
"type": "ed25519_signature",
"public_key": f"{sender.public_key()}",
"signature": f"{signature}",
}

headers = {"Content-Type": "application/json"}
response = await self.client.post(
f"{self.base_url}/transactions", headers=headers, json=txn_request
)
if response.status_code >= 400:
raise ApiError(response.text, response.status_code)
return response.json()["hash"]

async def transaction_pending(self, txn_hash: str) -> bool:
response = await self._get(endpoint=f"transactions/by_hash/{txn_hash}")
# TODO(@davidiw): consider raising a different error here, since this is an ambiguous state
Expand Down Expand Up @@ -716,17 +673,22 @@ async def create_multi_agent_bcs_transaction(

async def create_bcs_transaction(
self,
sender: Account,
sender: Account | AccountAddress,
payload: TransactionPayload,
sequence_number: Optional[int] = None,
) -> RawTransaction:
if isinstance(sender, Account):
sender_address = sender.address()
else:
sender_address = sender

sequence_number = (
sequence_number
if sequence_number is not None
else await self.account_sequence_number(sender.address())
else await self.account_sequence_number(sender_address)
)
return RawTransaction(
sender.address(),
sender_address,
sequence_number,
payload,
self.client_config.max_gas_amount,
Expand All @@ -751,23 +713,6 @@ async def create_bcs_signed_transaction(
# Transaction wrappers
#

async def transfer(
self, sender: Account, recipient: AccountAddress, amount: int
) -> str:
"""Transfer a given coin amount from a given Account to the recipient's account address.
Returns the sequence number of the transaction used to transfer."""

payload = {
"type": "entry_function_payload",
"function": "0x1::aptos_account::transfer",
"type_arguments": [],
"arguments": [
f"{recipient}",
str(amount),
],
}
return await self.submit_transaction(sender, payload)

# :!:>bcs_transfer
async def bcs_transfer(
self,
Expand Down
28 changes: 28 additions & 0 deletions aptos_sdk/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(self, authenticator: typing.Any):
self.variant = AccountAuthenticator.MULTI_ED25519
elif isinstance(authenticator, SingleKeyAuthenticator):
self.variant = AccountAuthenticator.SINGLE_KEY
elif isinstance(authenticator, MultiKeyAuthenticator):
self.variant = AccountAuthenticator.MULTI_KEY
else:
raise Exception("Invalid type")
self.authenticator = authenticator
Expand Down Expand Up @@ -360,3 +362,29 @@ def deserialize(deserializer: Deserializer) -> SingleKeyAuthenticator:
def serialize(self, serializer: Serializer):
serializer.struct(self.public_key)
serializer.struct(self.signature)


class MultiKeyAuthenticator:
public_key: asymmetric_crypto_wrapper.MultiPublicKey
signature: asymmetric_crypto_wrapper.MultiSignature

def __init__(
self,
public_key: asymmetric_crypto_wrapper.MultiPublicKey,
signature: asymmetric_crypto_wrapper.MultiSignature,
):
self.public_key = public_key
self.signature = signature

def verify(self, data: bytes) -> bool:
return self.public_key.verify(data, self.signature)

@staticmethod
def deserialize(deserializer: Deserializer) -> MultiKeyAuthenticator:
public_key = deserializer.struct(asymmetric_crypto_wrapper.MultiPublicKey)
signature = deserializer.struct(asymmetric_crypto_wrapper.MultiSignature)
return MultiKeyAuthenticator(public_key, signature)

def serialize(self, serializer: Serializer):
serializer.struct(self.public_key)
serializer.struct(self.signature)
4 changes: 2 additions & 2 deletions examples/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# :!:>section_1
FAUCET_URL = os.getenv(
"APTOS_FAUCET_URL",
"https://faucet.devnet.aptoslabs.com",
"http://localhost:8081",
)
INDEXER_URL = os.getenv(
"APTOS_INDEXER_URL",
"https://api.devnet.aptoslabs.com/v1/graphql",
)
NODE_URL = os.getenv("APTOS_NODE_URL", "https://api.devnet.aptoslabs.com/v1")
NODE_URL = os.getenv("APTOS_NODE_URL", "http://localhost:8080/v1")
# <:!:section_1
Loading
Loading