Skip to content

Commit f456ffc

Browse files
committed
Add parameterised tests for UInt256
Fix bugs found by tests Signed-off-by: Simon Dudley <[email protected]>
1 parent af9125a commit f456ffc

File tree

2 files changed

+349
-6
lines changed

2 files changed

+349
-6
lines changed

evm/src/main/java/org/hyperledger/besu/evm/UInt256.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,11 @@ public int hashCode() {
403403
* @return Shifted UInt256 value.
404404
*/
405405
public UInt256 shiftLeft(final int shift) {
406-
if (shift >= length * 32) return ZERO;
406+
if (shift >= 256) return ZERO;
407407
if (shift < 0) return shiftRight(-shift);
408408
if (shift == 0 || isZero()) return this;
409409
int nDiffBits = shift - numberOfLeadingZeros(this.limbs, this.length);
410-
int size = this.length + (nDiffBits + 31) / 32;
410+
int size = Math.min(this.length + (nDiffBits + 31) / 32, N_LIMBS);
411411
int[] shifted = new int[size];
412412
shiftLeftInto(shifted, this.limbs, shift);
413413
return new UInt256(shifted, size);
@@ -576,7 +576,7 @@ private static void shiftLeftInto(final int[] result, final int[] x, final int s
576576
int limbShift = shift / 32;
577577
int bitShift = shift % 32;
578578
int nLimbs = Math.min(x.length, result.length - limbShift);
579-
if (shift > 32 * nLimbs) return;
579+
if (limbShift >= result.length) return;
580580
if (bitShift == 0) {
581581
System.arraycopy(x, 0, result, limbShift, nLimbs);
582582
return;
@@ -588,15 +588,15 @@ private static void shiftLeftInto(final int[] result, final int[] x, final int s
588588
result[j] = (x[i] << bitShift) | carry;
589589
carry = x[i] >>> (32 - bitShift);
590590
}
591-
if (carry != 0) result[j] = carry; // last carry
591+
if (carry != 0 && j < result.length) result[j] = carry; // last carry
592592
}
593593

594594
private static void shiftRightInto(final int[] result, final int[] x, final int shift) {
595595
int limbShift = shift / 32;
596596
int bitShift = shift % 32;
597597
int nLimbs = Math.min(x.length - limbShift, result.length);
598598

599-
if (shift > 32 * nLimbs) return;
599+
if (limbShift >= x.length) return;
600600
if (bitShift == 0) {
601601
System.arraycopy(x, limbShift, result, 0, nLimbs);
602602
return;
@@ -686,7 +686,7 @@ private static int[] knuthRemainder(final int[] dividend, final int[] modulus) {
686686
int n = modulus.length - limbShift;
687687
if (n == 0) return new int[0];
688688
if (n == 1) {
689-
if (dividend.length == 1) return (new int[] {dividend[0] % modulus[0]});
689+
if (dividend.length == 1) return (new int[] {Integer.remainderUnsigned(dividend[0], modulus[0])});
690690
long d = modulus[0] & MASK_L;
691691
long rem = 0;
692692
// Process from most significant limb downwards
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
/*
2+
* Copyright contributors to Besu.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
* the License. You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
* specific language governing permissions and limitations under the License.
12+
*
13+
* SPDX-License-Identifier: Apache-2.0
14+
*/
15+
package org.hyperledger.besu.evm;
16+
17+
import static org.assertj.core.api.Assertions.assertThat;
18+
19+
import java.math.BigInteger;
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
import java.util.Random;
23+
import java.util.stream.Stream;
24+
25+
import org.junit.jupiter.params.ParameterizedTest;
26+
import org.junit.jupiter.params.provider.Arguments;
27+
import org.junit.jupiter.params.provider.MethodSource;
28+
29+
public class UInt256ParameterisedTest {
30+
31+
// Test constants
32+
private static final BigInteger TWO_TO_64 = BigInteger.TWO.pow(64);
33+
private static final BigInteger TWO_TO_128 = BigInteger.TWO.pow(128);
34+
private static final BigInteger TWO_TO_192 = BigInteger.TWO.pow(192);
35+
private static final BigInteger TWO_TO_256 = BigInteger.TWO.pow(256);
36+
private static final BigInteger UINT256_MAX = TWO_TO_256.subtract(BigInteger.ONE);
37+
38+
private static final int RANDOM_TEST_COUNT = 6;
39+
40+
// region Test Data Providers
41+
42+
/**
43+
* Provides unary test cases (single BigInteger values).
44+
*/
45+
static Stream<BigInteger> provideUnaryTestCases() {
46+
List<BigInteger> cases = new ArrayList<>();
47+
48+
// Basic values
49+
cases.add(BigInteger.ZERO);
50+
cases.add(BigInteger.ONE);
51+
cases.add(BigInteger.TWO);
52+
cases.add(BigInteger.valueOf(3));
53+
54+
// Boundary values
55+
cases.add(BigInteger.valueOf(Short.MAX_VALUE));
56+
cases.add(BigInteger.valueOf(0xFFFF - 1)); // UnsignedShort.MAX_VALUE - 1
57+
cases.add(BigInteger.valueOf(0xFFFF)); // UnsignedShort.MAX_VALUE
58+
cases.add(BigInteger.valueOf(0xFFFF + 1)); // UnsignedShort.MAX_VALUE + 1
59+
cases.add(BigInteger.valueOf(Integer.MAX_VALUE));
60+
cases.add(BigInteger.valueOf(0xFFFFFFFFL - 1)); // UnsignedInteger.MAX_VALUE - 1
61+
cases.add(BigInteger.valueOf(0xFFFFFFFFL)); // UnsignedInteger.MAX_VALUE
62+
cases.add(BigInteger.valueOf(0xFFFFFFFFL + 1)); // UnsignedInteger.MAX_VALUE + 1
63+
cases.add(BigInteger.valueOf(Long.MAX_VALUE));
64+
65+
// Large values
66+
cases.add(new BigInteger("FFFFFFFFFFFFFFFE", 16)); // UnsignedLong.MAX_VALUE - 1
67+
cases.add(new BigInteger("FFFFFFFFFFFFFFFF", 16)); // UnsignedLong.MAX_VALUE
68+
cases.add(new BigInteger("080000000000000008000000000000001", 16));
69+
cases.add(TWO_TO_64);
70+
cases.add(TWO_TO_128);
71+
cases.add(TWO_TO_192);
72+
cases.add(TWO_TO_128.subtract(BigInteger.ONE)); // UInt128Max
73+
cases.add(TWO_TO_192.subtract(BigInteger.ONE)); // UInt192Max
74+
cases.add(UINT256_MAX);
75+
76+
// Add random values
77+
cases.addAll(generateRandomUnsigned(RANDOM_TEST_COUNT));
78+
79+
return cases.stream();
80+
}
81+
82+
/**
83+
* Provides binary test cases (pairs of BigInteger values).
84+
*/
85+
static Stream<Arguments> provideBinaryTestCases() {
86+
List<BigInteger> unary = provideUnaryTestCases().toList();
87+
List<Arguments> binary = new ArrayList<>();
88+
89+
for (BigInteger a : unary) {
90+
for (BigInteger b : unary) {
91+
binary.add(Arguments.of(a, b));
92+
}
93+
}
94+
95+
return binary.stream();
96+
}
97+
98+
/**
99+
* Provides ternary test cases (triples of BigInteger values).
100+
*/
101+
static Stream<Arguments> provideTernaryTestCases() {
102+
List<Arguments> binary = provideBinaryTestCases().toList();
103+
List<BigInteger> unary = provideUnaryTestCases().toList();
104+
List<Arguments> ternary = new ArrayList<>();
105+
106+
for (Arguments binArgs : binary) {
107+
BigInteger a = (BigInteger) binArgs.get()[0];
108+
BigInteger b = (BigInteger) binArgs.get()[1];
109+
for (BigInteger c : unary) {
110+
ternary.add(Arguments.of(a, b, c));
111+
}
112+
}
113+
114+
return ternary.stream();
115+
}
116+
117+
/**
118+
* Provides shift test cases (BigInteger value and int shift amount).
119+
*/
120+
static Stream<Arguments> provideShiftTestCases() {
121+
List<BigInteger> unary = provideUnaryTestCases().toList();
122+
List<Arguments> shifts = new ArrayList<>();
123+
124+
for (BigInteger value : unary) {
125+
for (int shift = 0; shift <= 256; shift++) {
126+
shifts.add(Arguments.of(value, shift));
127+
}
128+
}
129+
130+
return shifts.stream();
131+
}
132+
133+
// endregion
134+
135+
// region Helper Methods
136+
137+
/**
138+
* Converts BigInteger to UInt256, wrapping to 256-bit range.
139+
*/
140+
private static UInt256 toUInt256(final BigInteger value) {
141+
BigInteger wrapped = value.mod(TWO_TO_256);
142+
return fromBigInteger(wrapped);
143+
}
144+
145+
/**
146+
* Create UInt256 from BigInteger.
147+
*
148+
* @param value BigInteger value to convert (must be non-negative and <= 2^256-1)
149+
* @return UInt256 representation of the BigInteger value.
150+
* @throws IllegalArgumentException if value is negative or exceeds 256 bits
151+
*/
152+
private static UInt256 fromBigInteger(final java.math.BigInteger value) {
153+
if (value.signum() < 0) {
154+
throw new IllegalArgumentException("UInt256 cannot represent negative values");
155+
}
156+
if (value.bitLength() > 256) {
157+
throw new IllegalArgumentException("Value exceeds 256 bits");
158+
}
159+
if (value.equals(java.math.BigInteger.ZERO)) return UInt256.ZERO;
160+
161+
byte[] bytes = value.toByteArray();
162+
// Remove sign byte if present
163+
int offset = 0;
164+
if (bytes.length > 32 || (bytes.length > 0 && bytes[0] == 0)) {
165+
offset = bytes.length - 32;
166+
if (offset < 0) {
167+
// Need to pad with zeros
168+
byte[] padded = new byte[32];
169+
System.arraycopy(bytes, 0, padded, 32 - bytes.length, bytes.length);
170+
return UInt256.fromBytesBE(padded);
171+
}
172+
}
173+
174+
return UInt256.fromBytesBE(java.util.Arrays.copyOfRange(bytes, offset, bytes.length));
175+
}
176+
177+
private static BigInteger toBigInteger(final UInt256 value) {
178+
if (value.isZero()) return java.math.BigInteger.ZERO;
179+
byte[] bytes = value.toBytesBE();
180+
return new java.math.BigInteger(1, bytes);
181+
}
182+
183+
/**
184+
* Generates random unsigned 256-bit BigInteger values.
185+
*/
186+
private static List<BigInteger> generateRandomUnsigned(final int count) {
187+
List<BigInteger> randoms = new ArrayList<>();
188+
Random rand = new Random(12345);
189+
byte[] data = new byte[32];
190+
191+
for (int i = 0; i < count; i++) {
192+
rand.nextBytes(data);
193+
data[data.length - 1] &= 0x7F; // Clear sign bit to ensure positive
194+
randoms.add(new BigInteger(1, data));
195+
}
196+
197+
return randoms;
198+
}
199+
200+
/**
201+
* Wraps result to 256-bit range.
202+
*/
203+
private static BigInteger wrap256(final BigInteger value) {
204+
return value.mod(TWO_TO_256);
205+
}
206+
207+
// endregion
208+
209+
// region Arithmetic Operations Tests
210+
211+
@ParameterizedTest
212+
@MethodSource("provideBinaryTestCases")
213+
void testAdd(final BigInteger a, final BigInteger b) {
214+
BigInteger expected = wrap256(a.add(b));
215+
216+
UInt256 uint256a = toUInt256(a);
217+
UInt256 uint256b = toUInt256(b);
218+
UInt256 result = uint256a.add(uint256b);
219+
220+
assertThat(toBigInteger(result)).isEqualTo(expected);
221+
}
222+
223+
@ParameterizedTest
224+
@MethodSource("provideBinaryTestCases")
225+
void testMul(final BigInteger a, final BigInteger b) {
226+
BigInteger expected = wrap256(a.multiply(b));
227+
228+
UInt256 uint256a = toUInt256(a);
229+
UInt256 uint256b = toUInt256(b);
230+
UInt256 result = uint256a.mul(uint256b);
231+
232+
assertThat(toBigInteger(result)).isEqualTo(expected);
233+
}
234+
235+
@ParameterizedTest
236+
@MethodSource("provideBinaryTestCases")
237+
void testMod(final BigInteger a, final BigInteger b) {
238+
if (b.equals(BigInteger.ZERO)) {
239+
return; // Skip division by zero
240+
}
241+
242+
BigInteger expected = wrap256(a.mod(b));
243+
244+
UInt256 uint256a = toUInt256(a);
245+
UInt256 uint256b = toUInt256(b);
246+
UInt256 result = uint256a.mod(uint256b);
247+
248+
assertThat(toBigInteger(result)).isEqualTo(expected);
249+
}
250+
251+
@ParameterizedTest
252+
@MethodSource("provideTernaryTestCases")
253+
void testAddMod(final BigInteger a, final BigInteger b, final BigInteger m) {
254+
if (m.equals(BigInteger.ZERO)) {
255+
return; // Skip division by zero
256+
}
257+
258+
BigInteger expected = a.add(b).mod(m);
259+
expected = wrap256(expected);
260+
261+
UInt256 uint256a = toUInt256(a);
262+
UInt256 uint256b = toUInt256(b);
263+
UInt256 uint256m = toUInt256(m);
264+
265+
UInt256 result = uint256a.addMod(uint256b, uint256m);
266+
assertThat(toBigInteger(result)).isEqualTo(expected);
267+
}
268+
269+
@ParameterizedTest
270+
@MethodSource("provideTernaryTestCases")
271+
void testMulMod(final BigInteger a, final BigInteger b, final BigInteger m) {
272+
if (m.equals(BigInteger.ZERO)) {
273+
return; // Skip division by zero
274+
}
275+
276+
BigInteger expected = a.multiply(b).mod(m);
277+
expected = wrap256(expected);
278+
279+
UInt256 uint256a = toUInt256(a);
280+
UInt256 uint256b = toUInt256(b);
281+
UInt256 uint256m = toUInt256(m);
282+
283+
UInt256 result = uint256a.mulMod(uint256b, uint256m);
284+
assertThat(toBigInteger(result)).isEqualTo(expected);
285+
}
286+
287+
// endregion
288+
289+
// region Bitwise Operations Tests
290+
291+
@ParameterizedTest
292+
@MethodSource("provideShiftTestCases")
293+
void testLeftShift(final BigInteger value, final int shift) {
294+
BigInteger expected = wrap256(value.shiftLeft(shift));
295+
296+
UInt256 uint256 = toUInt256(value);
297+
UInt256 result = uint256.shiftLeft(shift);
298+
299+
assertThat(toBigInteger(result)).isEqualTo(expected);
300+
}
301+
302+
@ParameterizedTest
303+
@MethodSource("provideShiftTestCases")
304+
void testRightShift(final BigInteger value, final int shift) {
305+
BigInteger expected = wrap256(value.shiftRight(shift));
306+
307+
UInt256 uint256 = toUInt256(value);
308+
UInt256 result = uint256.shiftRight(shift);
309+
310+
assertThat(toBigInteger(result)).isEqualTo(expected);
311+
}
312+
313+
// endregion
314+
315+
// region Comparison Tests
316+
317+
@ParameterizedTest
318+
@MethodSource("provideBinaryTestCases")
319+
void testComparisons(final BigInteger a, final BigInteger b) {
320+
UInt256 uint256a = toUInt256(a);
321+
UInt256 uint256b = toUInt256(b);
322+
323+
assertThat(UInt256.compare(uint256a, uint256b) < 0).isEqualTo(a.compareTo(b) < 0);
324+
assertThat(UInt256.compare(uint256a, uint256b) <= 0).isEqualTo(a.compareTo(b) <= 0);
325+
assertThat(UInt256.compare(uint256a, uint256b) > 0).isEqualTo(a.compareTo(b) > 0);
326+
assertThat(UInt256.compare(uint256a, uint256b) >= 0).isEqualTo(a.compareTo(b) >= 0);
327+
assertThat(uint256a.equals(uint256b)).isEqualTo(a.equals(b));
328+
}
329+
330+
// endregion
331+
332+
// region Conversion Tests
333+
334+
@ParameterizedTest
335+
@MethodSource("provideUnaryTestCases")
336+
void testToBigIntegerAndBack(final BigInteger value) {
337+
UInt256 uint256 = toUInt256(value);
338+
BigInteger result = toBigInteger(uint256);
339+
assertThat(result).isEqualTo(value);
340+
}
341+
342+
// endregion
343+
}

0 commit comments

Comments
 (0)