diff --git a/src/main/java/core/packetproxy/DuplexManager.java b/src/main/java/core/packetproxy/DuplexManager.java index dfb721bf..55019dd6 100644 --- a/src/main/java/core/packetproxy/DuplexManager.java +++ b/src/main/java/core/packetproxy/DuplexManager.java @@ -59,6 +59,10 @@ public Duplex getDuplex(int hash) { return duplex_list.get(hash); } + public void removeDuplex(int hash) { + duplex_list.remove(hash); + } + public boolean has(int hash) { return (duplex_list.get(hash) == null) ? false : true; } diff --git a/src/main/java/core/packetproxy/ProxyUDPForward.java b/src/main/java/core/packetproxy/ProxyUDPForward.java index 57e1619f..52b0bd4c 100644 --- a/src/main/java/core/packetproxy/ProxyUDPForward.java +++ b/src/main/java/core/packetproxy/ProxyUDPForward.java @@ -19,6 +19,9 @@ import static packetproxy.util.Logging.log; import java.net.InetSocketAddress; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; import packetproxy.common.Endpoint; import packetproxy.common.UDPServerSocket; import packetproxy.common.UDPSocketEndpoint; @@ -26,8 +29,23 @@ public class ProxyUDPForward extends Proxy { + private static final int MAX_ACTIVE_CONNECTIONS = 256; private ListenPort listen_info; private UDPServerSocket listen_socket; + private final Map activeConnections = new LinkedHashMap<>(); + private volatile boolean closed = false; + + private static class ActiveConnection { + private final DuplexAsync duplex; + private final UDPSocketEndpoint serverEndpoint; + private final int duplexHash; + + ActiveConnection(DuplexAsync duplex, UDPSocketEndpoint serverEndpoint, int duplexHash) { + this.duplex = duplex; + this.serverEndpoint = serverEndpoint; + this.duplexHash = duplexHash; + } + } public ProxyUDPForward(ListenPort listen_info) throws Exception { this.listen_info = listen_info; @@ -38,26 +56,102 @@ public ProxyUDPForward(ListenPort listen_info) throws Exception { public void run() { try { - while (true) { + while (!closed) { + try { + + Endpoint client_endpoint = listen_socket.accept(); + log("accept"); + + InetSocketAddress clientAddr = client_endpoint.getAddress(); + InetSocketAddress serverAddr = listen_info.getServer().getAddress(); + UDPSocketEndpoint server_endpoint = new UDPSocketEndpoint(serverAddr); - Endpoint client_endpoint = listen_socket.accept(); - log("accept"); + DuplexAsync duplex = DuplexFactory.createDuplexAsync(client_endpoint, server_endpoint, + listen_info.getServer().getEncoder()); + duplex.start(); + int duplexHash = DuplexManager.getInstance().registerDuplex(duplex); - InetSocketAddress serverAddr = listen_info.getServer().getAddress(); - UDPSocketEndpoint server_endpoint = new UDPSocketEndpoint(serverAddr); + closeConnectionIfExists(clientAddr); + activeConnections.put(clientAddr, new ActiveConnection(duplex, server_endpoint, duplexHash)); + evictIfOverLimit(); + } catch (Exception e) { - DuplexAsync duplex = DuplexFactory.createDuplexAsync(client_endpoint, server_endpoint, - listen_info.getServer().getEncoder()); - duplex.start(); - DuplexManager.getInstance().registerDuplex(duplex); + if (!closed) { + errWithStackTrace(e); + } + } } } catch (Exception e) { errWithStackTrace(e); + } finally { + closeAllConnections(); } } public void close() throws Exception { + closed = true; + closeAllConnections(); listen_socket.close(); } + + private void evictIfOverLimit() { + while (activeConnections.size() > MAX_ACTIVE_CONNECTIONS) { + Iterator> i = activeConnections.entrySet().iterator(); + if (!i.hasNext()) { + + return; + } + Map.Entry oldest = i.next(); + i.remove(); + closeConnection(oldest.getKey(), oldest.getValue()); + } + } + + private void closeConnectionIfExists(InetSocketAddress clientAddr) { + ActiveConnection oldConnection = activeConnections.remove(clientAddr); + if (oldConnection != null) { + + closeConnection(clientAddr, oldConnection); + } + } + + private void closeConnection(InetSocketAddress clientAddr, ActiveConnection connection) { + try { + + connection.duplex.close(); + } catch (Exception e) { + + errWithStackTrace(e); + } + try { + + connection.serverEndpoint.close(); + } catch (Exception e) { + + errWithStackTrace(e); + } + try { + + DuplexManager.getInstance().removeDuplex(connection.duplexHash); + } catch (Exception e) { + + errWithStackTrace(e); + } + try { + + listen_socket.removeConnection(clientAddr); + } catch (Exception e) { + + errWithStackTrace(e); + } + } + + private void closeAllConnections() { + for (Map.Entry entry : activeConnections.entrySet()) { + + closeConnection(entry.getKey(), entry.getValue()); + } + activeConnections.clear(); + } } diff --git a/src/main/java/core/packetproxy/common/UDPConn.java b/src/main/java/core/packetproxy/common/UDPConn.java index 5fd97bd8..2fe12efc 100644 --- a/src/main/java/core/packetproxy/common/UDPConn.java +++ b/src/main/java/core/packetproxy/common/UDPConn.java @@ -23,16 +23,23 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Future; import org.apache.commons.io.output.ByteArrayOutputStream; public class UDPConn { private PipeEndpoint pipe; private InetSocketAddress addr; + private final ExecutorService receiveExecutor; + private Future recvTaskFuture; + private volatile boolean closed; public UDPConn(InetSocketAddress addr) throws Exception { this.addr = addr; this.pipe = new PipeEndpoint(addr); + this.receiveExecutor = Executors.newSingleThreadExecutor(); + this.recvTaskFuture = null; + this.closed = false; } public void put(byte[] data, int offset, int length) throws Exception { @@ -49,24 +56,57 @@ public void put(byte[] data) throws Exception { } public void getAutomatically(final BlockingQueue queue) throws Exception { - ExecutorService executor = Executors.newSingleThreadExecutor(); + if (closed) { + + throw new IllegalStateException("UDPConn is already closed"); + } Callable recvTask = new Callable() { public Void call() throws Exception { - while (true) { + while (!closed) { InputStream is = pipe.getRawEndpoint().getInputStream(); byte[] buf = new byte[4096]; int len = is.read(buf); + if (len < 0) { + + return null; + } DatagramPacket recvPacket = new DatagramPacket(buf, len, addr); queue.put(recvPacket); } + return null; } }; - executor.submit(recvTask); + recvTaskFuture = receiveExecutor.submit(recvTask); } public Endpoint getEndpoint() throws Exception { return pipe.getProxyRawEndpoint(); } + + public void close() throws Exception { + if (closed) { + + return; + } + closed = true; + if (recvTaskFuture != null) { + + recvTaskFuture.cancel(true); + } + closeQuietly(pipe.getRawEndpoint().getInputStream()); + closeQuietly(pipe.getRawEndpoint().getOutputStream()); + closeQuietly(pipe.getProxyRawEndpoint().getInputStream()); + closeQuietly(pipe.getProxyRawEndpoint().getOutputStream()); + receiveExecutor.shutdownNow(); + } + + private void closeQuietly(AutoCloseable closeable) { + try { + + closeable.close(); + } catch (Exception ignored) { + } + } } diff --git a/src/main/java/core/packetproxy/common/UDPConnManager.java b/src/main/java/core/packetproxy/common/UDPConnManager.java index b479e505..c7d3f7cc 100644 --- a/src/main/java/core/packetproxy/common/UDPConnManager.java +++ b/src/main/java/core/packetproxy/common/UDPConnManager.java @@ -35,18 +35,29 @@ public UDPConnManager() { } public Endpoint accept() throws Exception { - InetSocketAddress addr = acceptedQueue.take(); - return connList.get(addr).getEndpoint(); + while (true) { + InetSocketAddress addr = acceptedQueue.take(); + synchronized (this) { + UDPConn conn = connList.get(addr); + if (conn != null) { + + return conn.getEndpoint(); + } + } + } } public void put(DatagramPacket packet) throws Exception { InetSocketAddress addr = new InetSocketAddress(packet.getAddress(), packet.getPort()); - UDPConn conn = this.query(addr); - if (conn == null) { + UDPConn conn; + synchronized (this) { + conn = this.query(addr); + if (conn == null) { - conn = this.create(addr); - conn.getAutomatically(recvQueue); - acceptedQueue.put(addr); + conn = this.create(addr); + conn.getAutomatically(recvQueue); + acceptedQueue.put(addr); + } } conn.put(packet.getData(), 0, packet.getLength()); } @@ -64,4 +75,20 @@ private UDPConn create(InetSocketAddress key) throws Exception { connList.put(key, conn); return conn; } + + public synchronized void remove(InetSocketAddress key) throws Exception { + UDPConn conn = connList.remove(key); + if (conn != null) { + + conn.close(); + } + } + + public synchronized void closeAll() throws Exception { + for (UDPConn conn : connList.values()) { + + conn.close(); + } + connList.clear(); + } } diff --git a/src/main/java/core/packetproxy/common/UDPServerSocket.java b/src/main/java/core/packetproxy/common/UDPServerSocket.java index a76a3cd3..2dc428c1 100644 --- a/src/main/java/core/packetproxy/common/UDPServerSocket.java +++ b/src/main/java/core/packetproxy/common/UDPServerSocket.java @@ -17,6 +17,7 @@ import java.net.DatagramPacket; import java.net.DatagramSocket; +import java.net.InetSocketAddress; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -34,12 +35,17 @@ public UDPServerSocket(int port) throws Exception { public void close() throws Exception { socket.close(); + connManager.closeAll(); } public Endpoint accept() throws Exception { return connManager.accept(); } + public void removeConnection(InetSocketAddress addr) throws Exception { + connManager.remove(addr); + } + private void createRecvLoop() throws Exception { ExecutorService executor = Executors.newFixedThreadPool(2); Callable recvTask = new Callable() { diff --git a/src/main/java/core/packetproxy/common/UDPSocketEndpoint.java b/src/main/java/core/packetproxy/common/UDPSocketEndpoint.java index 63bab73b..abc8aa2e 100644 --- a/src/main/java/core/packetproxy/common/UDPSocketEndpoint.java +++ b/src/main/java/core/packetproxy/common/UDPSocketEndpoint.java @@ -30,12 +30,16 @@ public class UDPSocketEndpoint implements Endpoint { private InetSocketAddress serverAddr; private PipeEndpoint pipe; private static int BUFSIZE = 4096; + private final ExecutorService executor; + private volatile boolean closed; public UDPSocketEndpoint(InetSocketAddress addr) throws Exception { socket = new DatagramSocket(); socket.connect(addr); serverAddr = addr; pipe = new PipeEndpoint(addr); + executor = Executors.newFixedThreadPool(2); + closed = false; loop(); } @@ -55,24 +59,28 @@ public OutputStream getOutputStream() throws Exception { } private void loop() { - ExecutorService executor = Executors.newFixedThreadPool(2); Callable sendTask = new Callable() { public Void call() throws Exception { - while (true) { + while (!closed) { InputStream is = pipe.getRawEndpoint().getInputStream(); byte[] input_data = new byte[BUFSIZE]; int len = is.read(input_data); + if (len < 0) { + + return null; + } DatagramPacket sendPacket = new DatagramPacket(input_data, 0, len, serverAddr); socket.send(sendPacket); } + return null; } }; Callable recvTask = new Callable() { public Void call() throws Exception { - while (true) { + while (!closed) { byte[] buf = new byte[BUFSIZE]; DatagramPacket recvPacket = new DatagramPacket(buf, BUFSIZE); @@ -81,12 +89,43 @@ public Void call() throws Exception { os.write(recvPacket.getData(), 0, recvPacket.getLength()); os.flush(); } + return null; } }; executor.submit(sendTask); executor.submit(recvTask); } + public void close() { + if (closed) { + + return; + } + closed = true; + socket.close(); + executor.shutdownNow(); + try { + + pipe.getRawEndpoint().getInputStream().close(); + } catch (Exception ignored) { + } + try { + + pipe.getRawEndpoint().getOutputStream().close(); + } catch (Exception ignored) { + } + try { + + pipe.getProxyRawEndpoint().getInputStream().close(); + } catch (Exception ignored) { + } + try { + + pipe.getProxyRawEndpoint().getOutputStream().close(); + } catch (Exception ignored) { + } + } + @Override public int getLocalPort() { return socket.getLocalPort(); diff --git a/src/test/java/packetproxy/common/UDPConnTest.java b/src/test/java/packetproxy/common/UDPConnTest.java new file mode 100644 index 00000000..8fec432e --- /dev/null +++ b/src/test/java/packetproxy/common/UDPConnTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2026 DeNA Co., Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package packetproxy.common; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import java.net.DatagramPacket; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +public class UDPConnTest { + + @Test + public void testCloseStopsAutomaticallyCreatedThread() throws Exception { + var existingThreads = new HashSet<>(Thread.getAllStackTraces().keySet()); + var conn = new UDPConn(new InetSocketAddress("127.0.0.1", 20000)); + var queue = new LinkedBlockingQueue(); + try { + + conn.getAutomatically(queue); + var payload = "request".getBytes(StandardCharsets.US_ASCII); + var endpointOutput = conn.getEndpoint().getOutputStream(); + endpointOutput.write(payload); + endpointOutput.flush(); + + var packet = queue.poll(1, TimeUnit.SECONDS); + assertNotNull(packet); + assertEquals("request", new String(packet.getData(), 0, packet.getLength(), StandardCharsets.US_ASCII)); + + var connThreads = waitForNewThreads(existingThreads); + assertFalse(connThreads.isEmpty()); + + conn.close(); + joinThreads(connThreads); + assertFalse(hasAliveThreads(connThreads)); + } finally { + + conn.close(); + } + } + + private Set waitForNewThreads(Set existingThreads) throws Exception { + long deadlineMillis = System.currentTimeMillis() + 1000; + Set newThreads = Set.of(); + while (System.currentTimeMillis() < deadlineMillis) { + + newThreads = getNewThreads(existingThreads); + if (!newThreads.isEmpty()) { + + return newThreads; + } + Thread.sleep(10); + } + return newThreads; + } + + private Set getNewThreads(Set existingThreads) { + var currentThreads = Thread.getAllStackTraces().keySet(); + var newThreads = new HashSet(); + for (var thread : currentThreads) { + + if (!existingThreads.contains(thread)) { + newThreads.add(thread); + } + } + return newThreads; + } + + private void joinThreads(Set threads) throws Exception { + for (var thread : threads) { + + thread.join(1000); + } + } + + private boolean hasAliveThreads(Set threads) { + for (var thread : threads) { + + if (thread.isAlive()) { + + return true; + } + } + return false; + } +}