diff --git a/.circleci/config.yml b/.circleci/config.yml index 6de0ba2..8f2cca3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -63,6 +63,14 @@ jobs: path: ~/verified-vyper-contracts - run: ./test.sh ecdsa + biginttest: + working_directory: ~/verified-vyper-contracts/tests + docker: + - image: python:3.6.7 + steps: + - checkout: + path: ~/verified-vyper-contracts + - run: ./test.sh bigint workflows: version: 2 @@ -117,3 +125,10 @@ workflows: - /ecdsa\/.*/ - /all\/.*/ - master + - biginttest: + filters: + branches: + only: + - /bigint\/.*/ + - /all\/.*/ + - master diff --git a/contracts/bigint/BigInt.vy b/contracts/bigint/BigInt.vy new file mode 100644 index 0000000..7960b7e --- /dev/null +++ b/contracts/bigint/BigInt.vy @@ -0,0 +1,211 @@ +# @dev RSA Accumulator +# @author Ryuya Nakamura (@nrryuya) +# Based on The Matter team's work: +# https://github.com/matterinc/RSAAccumulator/blob/master/contracts/RSAAccumulator.sol + +### CONSTANTS ### +M_LIST_LENGTH: constant(int128) = 8 +M_BYTE_COUNT: constant(int128) = 32 * M_LIST_LENGTH +# For now, the same lengths are used for the simplicity of impelementation. +BASE_BYTE_COUNT: constant(int128) = M_BYTE_COUNT +E_BYTE_COUNT: constant(int128) = M_BYTE_COUNT +# Lenth in bytes32 representation +M_BYTE_COUNT_BYTES32: constant(bytes32) = convert(M_BYTE_COUNT, bytes32) +BASE_BYTE_COUNT_BYTES32: constant(bytes32) = convert(BASE_BYTE_COUNT, bytes32) +E_BYTE_COUNT_BYTES32: constant(bytes32) = convert(BASE_BYTE_COUNT, bytes32) + +PRECOMPILED_BIGMODEXP: constant(address) = 0x0000000000000000000000000000000000000005 +BIGMODEXP_RES_SIZE: constant(int128) = 32 * 3 + M_BYTE_COUNT + BASE_BYTE_COUNT + E_BYTE_COUNT + +### BIG INTEGER ARITHMETIC FUNCTIONS ### + +@private +def _bigModExp(_base: uint256[M_LIST_LENGTH], _e: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + # convert UInt256 list to bytes (inlined for code size reduction) + tmp: bytes32[M_LIST_LENGTH] + for i in range(M_LIST_LENGTH): + tmp[i] = convert(_base[i], bytes32) + base: bytes[M_BYTE_COUNT] = concat(tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7]) + + # convert UInt256 list to bytes (inlined for code size reduction) + for i in range(M_LIST_LENGTH): + tmp[i] = convert(_e[i], bytes32) + exponent: bytes[M_BYTE_COUNT] = concat(tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7]) + + # convert UInt256 list to bytes (inlined for code size reduction) + for i in range(M_LIST_LENGTH): + tmp[i] = convert(_m[i], bytes32) + modulus: bytes[M_BYTE_COUNT] = concat(tmp[0], tmp[1], tmp[2], tmp[3], tmp[4], tmp[5], tmp[6], tmp[7]) + + # ref. https://eips.ethereum.org/EIPS/eip-198 + data: bytes[BIGMODEXP_RES_SIZE] = concat( + BASE_BYTE_COUNT_BYTES32, E_BYTE_COUNT_BYTES32, M_BYTE_COUNT_BYTES32, base, exponent, modulus) + # NOTE: raw_call doesn't support static call for now. + res: bytes[M_BYTE_COUNT] = raw_call(PRECOMPILED_BIGMODEXP, data, outsize=256, gas=2000) + + # convert bytes array to UInt256 list (inlined for code size reduction) + out: uint256[M_LIST_LENGTH] + for i in range(M_LIST_LENGTH): + out[i] = convert(extract32(res, i * 32, type=bytes32), uint256) + return out + + +@private +@constant +def _wrappingSub(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + """ + Assumes _a > _b, otherwise returns _a - _b + 2 ** (256 * M_LIST_LENGTH)(finishes with borrow = True) + Assumes _a - _b < _m, otherwise the output is larger or equal to _m + """ + borrow: bool = False + limb: uint256 = 0 + o: uint256[M_LIST_LENGTH] + for i in range(M_LIST_LENGTH): + j: int128 = M_LIST_LENGTH - 1 - i + limb = _a[j] + if borrow: + if limb == 0: + borrow = True + o[j] = MAX_UINT256 - _b[j] + continue + else: + limb -= 1 + if limb < _b[j]: + borrow = True + # 2 ** 256 - diff + o[j] = MAX_UINT256 - (_b[j] - limb) + 1 + else: + borrow = False + o[j] = limb - _b[j] + return o + + +@private +@constant +def _wrappingAdd(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + """ + Assumes _a + _b < _m, otherwise the output is larger or equal to _m + """ + carry: bool = False + limb: uint256 = 0 + subaddition: uint256 = 0 + o: uint256[M_LIST_LENGTH] + for i in range(M_LIST_LENGTH): + j: int128 = M_LIST_LENGTH - 1 - i + limb = _a[j] + if carry: + if limb == MAX_UINT256: # NOTE: The original seems wrong here. + carry = True + o[j] = _b[j] + continue + else: + limb += 1 + if limb > MAX_UINT256 - _b[j]: + carry = True + o[j] = limb - (MAX_UINT256 - _b[j] + 1) + else: + carry = False + o[j] = limb + _b[j] + return o + + +@private +@constant +def _modularSub(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + """ + Assumes _a - _b < _m, otherwise the output is larger or equal to _m + Assumes _b - _a < _m, otherwise returns _m - (_b - _a) + 2 ** (256 * M_LIST_LENGTH) when _a < _b + """ + # Comparison (inlined for code size reduction) + comparison: int128 + for i in range(M_LIST_LENGTH): + if _a[i] > _b[i]: + comparison = 1 + elif _a[i] < _b[i]: + comparison = -1 + else: + comparison = 0 + + if comparison == 0: # _a = _b + o: uint256[M_LIST_LENGTH] + return o + elif comparison == 1: # _a > _b + return self._wrappingSub(_a, _b) + else: # _a < _b + tmp: uint256[M_LIST_LENGTH] = self._wrappingSub(_b, _a) + return self._wrappingSub(_m, tmp) + + +@private +@constant +def _modularAdd(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + """ + Assumes _a <= _m, otherwise space = _m - _a + 2 ** (256 * M_LIST_LENGTH) + Assumes _a + _b <= 2 * _m? (otherwise _b - space >= _m) when space < _b + """ + # See how much "space" has left before an overflow + space: uint256[M_LIST_LENGTH] = self._wrappingSub(_m, _a) + + # Comparison (inlined for code size reduction) + comparison: int128 + for i in range(M_LIST_LENGTH): + if space[i] > _b[i]: + comparison = 1 + elif space[i] < _b[i]: + comparison = -1 + else: + comparison = 0 + + if comparison == 0: # space = _b + o: uint256[M_LIST_LENGTH] + return o + elif comparison == 1: # space > _b + return self._wrappingAdd(_a, _b) + else: # space < _b + return self._wrappingSub(_b, space) + + +### PUBLIC FUNCTIONS ### + +@public +def modularExp(_base: uint256[M_LIST_LENGTH], _e: uint256, _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + e: uint256[M_LIST_LENGTH] + e[M_LIST_LENGTH - 1] = _e + return self._bigModExp(_base, e, _m) + + +@public +def modularExpVariableLength(_base: uint256[M_LIST_LENGTH], _e: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + return self._bigModExp(_base, _e, _m) + + +# 4ab = (a + b)**2 - (a - b)**2 +@public +def modularMul4(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + two: uint256[M_LIST_LENGTH] + two[M_LIST_LENGTH - 1] = 2 + aPlusB: uint256[M_LIST_LENGTH] = self._bigModExp(self._modularAdd(_a, _b, _m), two, _m) + aMinusB: uint256[M_LIST_LENGTH] = self._bigModExp(self._modularSub(_a, _b, _m), two, _m) + return self._modularSub(aPlusB, aMinusB, _m) + + +# 4a = (a + a) + (a + a) +@public +@constant +def modularMulBy4(_a: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + t: uint256[M_LIST_LENGTH] = self._modularAdd(_a, _a, _m) + return self._modularAdd(t, t, _m) + + +# NOTE: modularAdd and modularSub are commented out here due to the code size issue (EIP170). +# When you use them, remove other public functions instead to reduce the code size. +# @public +# @constant +# def modularAdd(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: +# return self._modularAdd(_a, _b, _m) + + +# @public +# @constant +# def modularSub(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: +# return self._modularSub(_a, _b, _m) diff --git a/tests/bigint/__init__.py b/tests/bigint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/bigint/test_bigint.py b/tests/bigint/test_bigint.py new file mode 100644 index 0000000..1586bd0 --- /dev/null +++ b/tests/bigint/test_bigint.py @@ -0,0 +1,73 @@ +import pytest + +# RSA-2048 (https://en.wikipedia.org/wiki/RSA_numbers#RSA-2048) +M = 25195908475657893494027183240048398571429282126204032027777137836043662020707595556264018525880784406918290641249515082189298559149176184502808489120072844992687392807287776735971418347270261896375014971824691165077613379859095700097330459748808428401797429100642458691817195118746121515172654632282216869987549182422433637259085141865462043576798423387184774447920739934236584823824281198163815010674810451660377306056201619676256133844143603833904414952634432190114657544454178424020924616515723350778707749817125772467962926386356373289912154831438167899885040445364023527381951378636564391212010397122822120720357 +M_LIST_LENGTH = 8 + + +def int_to_list(inp): + """ + e.g. int_to_list(2**256) = [0, 0, 0, 0, 0, 0, 1, 0] + """ + hex_str = format(inp, '0512x') + return [int(hex_str[64 * i: 64 * (i + 1)], 16) for i in range(M_LIST_LENGTH)] + + +def list_to_int(inp): + out = 0 + for i in range(M_LIST_LENGTH): + out += 2 ** (256 * i) * inp[M_LIST_LENGTH - 1 - i] + return out + + +M_LIST = int_to_list(M) + + +@pytest.fixture +def c(get_contract, w3): + with open('../contracts/bigint/BigInt.vy') as f: + code = f.read() + c = get_contract(code) + return c + + +@pytest.fixture +def c2(get_contract, w3): + """ + BigInt.vy with modularAdd and modularSub + """ + with open('../contracts/bigint/BigInt.vy') as f: + code = f.read() + + EXP_CODE = """@public +def modularExp(_base: uint256[M_LIST_LENGTH], _e: uint256, _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + e: uint256[M_LIST_LENGTH] + e[M_LIST_LENGTH - 1] = _e + return self._bigModExp(_base, e, _m) + + +@public +def modularExpVariableLength(_base: uint256[M_LIST_LENGTH], _e: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + return self._bigModExp(_base, _e, _m)""" + + ADD_AND_SUB_CODE = """@public +@constant +def modularAdd(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + return self._modularAdd(_a, _b, _m) + + +@public +@constant +def modularSub(_a: uint256[M_LIST_LENGTH], _b: uint256[M_LIST_LENGTH], _m: uint256[M_LIST_LENGTH]) -> uint256[M_LIST_LENGTH]: + return self._modularSub(_a, _b, _m)""" + code = code.replace(EXP_CODE, ADD_AND_SUB_CODE) + c = get_contract(code) + return c + + +# def test_modularSub(c2): +# assert list_to_int(c2.modularSub(int_to_list(1), int_to_list(1), M_LIST)) == 0 + + +# def test_modularAdd(c2): +# assert list_to_int(c2.modularAdd(int_to_list(1), int_to_list(1), M_LIST)) == 1 + 1