diff --git a/contracts/terminus/ERC1155WithTerminusStorage.sol b/contracts/terminus/ERC1155WithTerminusStorage.sol index 7728299..151722e 100644 --- a/contracts/terminus/ERC1155WithTerminusStorage.sol +++ b/contracts/terminus/ERC1155WithTerminusStorage.sol @@ -140,11 +140,12 @@ contract ERC1155WithTerminusStorage is view returns (bool) { - LibTerminus.TerminusStorage storage ts = LibTerminus.terminusStorage(); - if (operator == ts.poolController[poolID]) { - return true; - } - return false; + return LibTerminus._isApprovedForPool(poolID, operator); + } + + function approveForPool(uint256 poolID, address operator) external { + LibTerminus.enforcePoolIsController(poolID, _msgSender()); + LibTerminus._approveForPool(poolID, operator); } /** @@ -428,11 +429,6 @@ contract ERC1155WithTerminusStorage is address operator = _msgSender(); - require( - operator == from || isApprovedForPool(id, operator), - "ERC1155WithTerminusStorage: _burn -- caller is neither owner nor approved" - ); - _beforeTokenTransfer( operator, from, @@ -478,22 +474,13 @@ contract ERC1155WithTerminusStorage is address operator = _msgSender(); - bool approvedForPools = true; - LibTerminus.TerminusStorage storage ts = LibTerminus.terminusStorage(); for (uint256 i = 0; i < ids.length; i++) { require( ts.poolBurnable[ids[i]], "ERC1155WithTerminusStorage: _burnBatch -- pool is not burnable" ); - if (!isApprovedForPool(ids[i], operator)) { - approvedForPools = false; - } } - require( - from == _msgSender() || approvedForPools, - "ERC1155WithTerminusStorage: _burnBatch -- caller is neither owner nor approved" - ); _beforeTokenTransfer(operator, from, address(0), ids, amounts, ""); diff --git a/contracts/terminus/LibTerminus.sol b/contracts/terminus/LibTerminus.sol index 1b611ac..e6ff041 100644 --- a/contracts/terminus/LibTerminus.sol +++ b/contracts/terminus/LibTerminus.sol @@ -34,6 +34,7 @@ library LibTerminus { mapping(uint256 => bool) poolNotTransferable; mapping(uint256 => bool) poolBurnable; mapping(address => mapping(address => bool)) globalOperatorApprovals; + mapping(uint256 => mapping(address => bool)) globalPoolOperatorApprovals; } function terminusStorage() @@ -101,4 +102,23 @@ library LibTerminus { "LibTerminus: Must be pool controller" ); } + + function _isApprovedForPool(uint256 poolID, address operator) + internal + view + returns (bool) + { + LibTerminus.TerminusStorage storage ts = LibTerminus.terminusStorage(); + if (operator == ts.poolController[poolID]) { + return true; + } else if (ts.globalPoolOperatorApprovals[poolID][operator]) { + return true; + } + return false; + } + + function _approveForPool(uint256 poolID, address operator) internal { + LibTerminus.TerminusStorage storage ts = LibTerminus.terminusStorage(); + ts.globalPoolOperatorApprovals[poolID][operator] = true; + } } diff --git a/contracts/terminus/TerminusFacet.sol b/contracts/terminus/TerminusFacet.sol index 4bdb417..f6087a2 100644 --- a/contracts/terminus/TerminusFacet.sol +++ b/contracts/terminus/TerminusFacet.sol @@ -28,6 +28,52 @@ import "../diamond/libraries/LibDiamond.sol"; contract TerminusFacet is ERC1155WithTerminusStorage { constructor() {} + event PoolMintBatch( + uint256 indexed id, + address indexed operator, + address from, + address[] toAddresses, + uint256[] amounts + ); + + function poolMintBatch( + uint256 id, + address[] memory toAddresses, + uint256[] memory amounts + ) public { + address operator = _msgSender(); + LibTerminus.enforcePoolIsController(id, operator); + require( + toAddresses.length == amounts.length, + "TerminusFacet: _poolMintBatch -- toAddresses and amounts length mismatch" + ); + + LibTerminus.TerminusStorage storage ts = LibTerminus.terminusStorage(); + + uint256 i = 0; + uint256 totalAmount = 0; + + for (i = 0; i < toAddresses.length; i++) { + address to = toAddresses[i]; + uint256 amount = amounts[i]; + require( + to != address(0), + "TerminusFacet: _poolMintBatch -- cannot mint to zero address" + ); + totalAmount += amount; + ts.poolBalances[id][to] += amount; + emit TransferSingle(operator, address(0), to, id, amount); + } + + require( + ts.poolSupply[id] + totalAmount <= ts.poolCapacity[id], + "TerminusFacet: _poolMintBatch -- Minted tokens would exceed pool capacity" + ); + ts.poolSupply[id] += totalAmount; + + emit PoolMintBatch(id, operator, address(0), toAddresses, amounts); + } + function terminusController() external view returns (address) { return LibTerminus.terminusStorage().controller; } @@ -180,6 +226,11 @@ contract TerminusFacet is ERC1155WithTerminusStorage { uint256 poolID, uint256 amount ) external { + address operator = _msgSender(); + require( + operator == from || isApprovedForPool(poolID, operator), + "TerminusFacet: burn -- caller is neither owner nor approved" + ); _burn(from, poolID, amount); } } diff --git a/dao/TerminusFacet.py b/dao/TerminusFacet.py index 7cedded..f02707a 100644 --- a/dao/TerminusFacet.py +++ b/dao/TerminusFacet.py @@ -90,6 +90,12 @@ class TerminusFacet: if self.contract is None: raise Exception("contract has not been instantiated") + def approve_for_pool( + self, pool_id: int, operator: ChecksumAddress, transaction_config + ) -> Any: + self.assert_contract_is_instantiated() + return self.contract.approveForPool(pool_id, operator, transaction_config) + def balance_of(self, account: ChecksumAddress, id: int) -> Any: self.assert_contract_is_instantiated() return self.contract.balanceOf.call(account, id) @@ -156,6 +162,14 @@ class TerminusFacet: self.assert_contract_is_instantiated() return self.contract.poolBasePrice.call() + def pool_mint_batch( + self, id: int, to_addresses: List, amounts: List, transaction_config + ) -> Any: + self.assert_contract_is_instantiated() + return self.contract.poolMintBatch( + id, to_addresses, amounts, transaction_config + ) + def safe_batch_transfer_from( self, from_: ChecksumAddress, @@ -290,6 +304,18 @@ def handle_deploy(args: argparse.Namespace) -> None: print(result) +def handle_approve_for_pool(args: argparse.Namespace) -> None: + network.connect(args.network) + contract = TerminusFacet(args.address) + transaction_config = get_transaction_config(args) + result = contract.approve_for_pool( + pool_id=args.pool_id, + operator=args.operator, + transaction_config=transaction_config, + ) + print(result) + + def handle_balance_of(args: argparse.Namespace) -> None: network.connect(args.network) contract = TerminusFacet(args.address) @@ -396,6 +422,19 @@ def handle_pool_base_price(args: argparse.Namespace) -> None: print(result) +def handle_pool_mint_batch(args: argparse.Namespace) -> None: + network.connect(args.network) + contract = TerminusFacet(args.address) + transaction_config = get_transaction_config(args) + result = contract.pool_mint_batch( + id=args.id, + to_addresses=args.to_addresses, + amounts=args.amounts, + transaction_config=transaction_config, + ) + print(result) + + def handle_safe_batch_transfer_from(args: argparse.Namespace) -> None: network.connect(args.network) contract = TerminusFacet(args.address) @@ -540,6 +579,16 @@ def generate_cli() -> argparse.ArgumentParser: add_default_arguments(deploy_parser, True) deploy_parser.set_defaults(func=handle_deploy) + approve_for_pool_parser = subcommands.add_parser("approve-for-pool") + add_default_arguments(approve_for_pool_parser, True) + approve_for_pool_parser.add_argument( + "--pool-id", required=True, help="Type: uint256", type=int + ) + approve_for_pool_parser.add_argument( + "--operator", required=True, help="Type: address" + ) + approve_for_pool_parser.set_defaults(func=handle_approve_for_pool) + balance_of_parser = subcommands.add_parser("balance-of") add_default_arguments(balance_of_parser, False) balance_of_parser.add_argument("--account", required=True, help="Type: address") @@ -640,6 +689,19 @@ def generate_cli() -> argparse.ArgumentParser: add_default_arguments(pool_base_price_parser, False) pool_base_price_parser.set_defaults(func=handle_pool_base_price) + pool_mint_batch_parser = subcommands.add_parser("pool-mint-batch") + add_default_arguments(pool_mint_batch_parser, True) + pool_mint_batch_parser.add_argument( + "--id", required=True, help="Type: uint256", type=int + ) + pool_mint_batch_parser.add_argument( + "--to-addresses", required=True, help="Type: address[]", nargs="+" + ) + pool_mint_batch_parser.add_argument( + "--amounts", required=True, help="Type: uint256[]", nargs="+" + ) + pool_mint_batch_parser.set_defaults(func=handle_pool_mint_batch) + safe_batch_transfer_from_parser = subcommands.add_parser("safe-batch-transfer-from") add_default_arguments(safe_batch_transfer_from_parser, True) safe_batch_transfer_from_parser.add_argument( diff --git a/dao/test_terminus.py b/dao/test_terminus.py index cfa66bb..f49df06 100644 --- a/dao/test_terminus.py +++ b/dao/test_terminus.py @@ -1,3 +1,4 @@ +from typing import List import unittest from brownie import accounts @@ -196,6 +197,66 @@ class TestPoolOperations(TerminusTestCase): supply = self.diamond_terminus.terminus_pool_supply(pool_id) self.assertEqual(supply, 0) + def test_pool_mint_batch(self): + pool_id = self.diamond_terminus.total_pools() + target_accounts = [account.address for account in accounts[:5]] + target_amounts = [1 for _ in accounts[:5]] + num_accounts = len(accounts[:5]) + initial_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + initial_balances: List[int] = [] + for account in accounts[:5]: + initial_balances.append( + self.diamond_terminus.balance_of(account.address, pool_id) + ) + self.diamond_terminus.pool_mint_batch( + pool_id, target_accounts, target_amounts, {"from": accounts[1]} + ) + final_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + self.assertEqual(final_pool_supply, initial_pool_supply + num_accounts) + for i, account in enumerate(accounts[:5]): + final_balance = self.diamond_terminus.balance_of(account.address, pool_id) + self.assertEqual(final_balance, initial_balances[i] + 1) + + def test_pool_mint_batch_as_contract_controller_not_pool_controller(self): + pool_id = self.diamond_terminus.total_pools() + target_accounts = [account.address for account in accounts[:5]] + target_amounts = [1 for _ in accounts[:5]] + initial_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + initial_balances: List[int] = [] + for account in accounts[:5]: + initial_balances.append( + self.diamond_terminus.balance_of(account.address, pool_id) + ) + with self.assertRaises(Exception): + self.diamond_terminus.pool_mint_batch( + pool_id, target_accounts, target_amounts, {"from": accounts[0]} + ) + final_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + self.assertEqual(final_pool_supply, initial_pool_supply) + for i, account in enumerate(accounts[:5]): + final_balance = self.diamond_terminus.balance_of(account.address, pool_id) + self.assertEqual(final_balance, initial_balances[i]) + + def test_pool_mint_batch_as_unauthorized_third_party(self): + pool_id = self.diamond_terminus.total_pools() + target_accounts = [account.address for account in accounts[:5]] + target_amounts = [1 for _ in accounts[:5]] + initial_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + initial_balances: List[int] = [] + for account in accounts[:5]: + initial_balances.append( + self.diamond_terminus.balance_of(account.address, pool_id) + ) + with self.assertRaises(Exception): + self.diamond_terminus.pool_mint_batch( + pool_id, target_accounts, target_amounts, {"from": accounts[2]} + ) + final_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + self.assertEqual(final_pool_supply, initial_pool_supply) + for i, account in enumerate(accounts[:5]): + final_balance = self.diamond_terminus.balance_of(account.address, pool_id) + self.assertEqual(final_balance, initial_balances[i]) + def test_transfer(self): pool_id = self.diamond_terminus.total_pools() self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) @@ -287,6 +348,39 @@ class TestPoolOperations(TerminusTestCase): self.assertEqual(final_sender_balance, initial_sender_balance) self.assertEqual(final_receiver_balance, initial_receiver_balance) + def test_transfer_as_authorized_recipient(self): + pool_id = self.diamond_terminus.total_pools() + self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) + + initial_sender_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + initial_receiver_balance = self.diamond_terminus.balance_of( + accounts[3].address, pool_id + ) + + self.diamond_terminus.approve_for_pool( + pool_id, accounts[3].address, {"from": accounts[1]} + ) + self.diamond_terminus.safe_transfer_from( + accounts[2].address, + accounts[3].address, + pool_id, + 1, + b"", + {"from": accounts[3]}, + ) + + final_sender_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + final_receiver_balance = self.diamond_terminus.balance_of( + accounts[3].address, pool_id + ) + + self.assertEqual(final_sender_balance, initial_sender_balance - 1) + self.assertEqual(final_receiver_balance, initial_receiver_balance + 1) + def test_transfer_as_unauthorized_unrelated_party(self): pool_id = self.diamond_terminus.total_pools() self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) @@ -318,6 +412,39 @@ class TestPoolOperations(TerminusTestCase): self.assertEqual(final_sender_balance, initial_sender_balance) self.assertEqual(final_receiver_balance, initial_receiver_balance) + def test_transfer_as_authorized_unrelated_party(self): + pool_id = self.diamond_terminus.total_pools() + self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) + + initial_sender_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + initial_receiver_balance = self.diamond_terminus.balance_of( + accounts[3].address, pool_id + ) + + self.diamond_terminus.approve_for_pool( + pool_id, accounts[4].address, {"from": accounts[1]} + ) + self.diamond_terminus.safe_transfer_from( + accounts[2].address, + accounts[3].address, + pool_id, + 1, + b"", + {"from": accounts[4]}, + ) + + final_sender_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + final_receiver_balance = self.diamond_terminus.balance_of( + accounts[3].address, pool_id + ) + + self.assertEqual(final_sender_balance, initial_sender_balance - 1) + self.assertEqual(final_receiver_balance, initial_receiver_balance + 1) + def test_burn_fails_as_token_owner(self): pool_id = self.diamond_terminus.total_pools() self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) @@ -378,6 +505,29 @@ class TestPoolOperations(TerminusTestCase): self.assertEqual(final_pool_supply, initial_pool_supply) self.assertEqual(final_owner_balance, initial_owner_balance) + def test_burn_fails_as_authorized_third_party(self): + pool_id = self.diamond_terminus.total_pools() + self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) + + initial_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + initial_owner_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + self.diamond_terminus.approve_for_pool( + pool_id, accounts[3].address, {"from": accounts[1]} + ) + with self.assertRaises(Exception): + self.diamond_terminus.burn( + accounts[2].address, pool_id, 1, {"from": accounts[3]} + ) + + final_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + final_owner_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + self.assertEqual(final_pool_supply, initial_pool_supply) + self.assertEqual(final_owner_balance, initial_owner_balance) + class TestCreatePoolV1(TestPoolOperations): def setUp(self): @@ -455,6 +605,29 @@ class TestCreatePoolV1(TestPoolOperations): self.assertEqual(final_pool_supply, initial_pool_supply - 1) self.assertEqual(final_owner_balance, initial_owner_balance - 1) + def test_burnable_pool_burn_as_authorized_third_party(self): + self.diamond_terminus.create_pool_v1(10, True, True, {"from": accounts[1]}) + pool_id = self.diamond_terminus.total_pools() + self.diamond_terminus.mint(accounts[2], pool_id, 1, b"", {"from": accounts[1]}) + + initial_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + initial_owner_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + self.diamond_terminus.approve_for_pool( + pool_id, accounts[3].address, {"from": accounts[1]} + ) + self.diamond_terminus.burn( + accounts[2].address, pool_id, 1, {"from": accounts[3]} + ) + + final_pool_supply = self.diamond_terminus.terminus_pool_supply(pool_id) + final_owner_balance = self.diamond_terminus.balance_of( + accounts[2].address, pool_id + ) + self.assertEqual(final_pool_supply, initial_pool_supply - 1) + self.assertEqual(final_owner_balance, initial_owner_balance - 1) + def test_burnable_pool_burn_as_unauthorized_third_party(self): self.diamond_terminus.create_pool_v1(10, True, True, {"from": accounts[1]}) pool_id = self.diamond_terminus.total_pools()