diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/BookKeeperFactory.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/BookKeeperFactory.java index b81c5e9d..084da4aa 100644 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/BookKeeperFactory.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/BookKeeperFactory.java @@ -72,7 +72,7 @@ public RetryingPooledBookkeeperClient createBookKeeperClient(String host, Config } else { Poolable obj; - obj = pool.borrowObject(host, conf); + obj = pool.borrowObject(host); RetryingPooledBookkeeperClient retryingBookkeeperClient = new RetryingPooledBookkeeperClient(obj, host, conf); return retryingBookkeeperClient; } diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/CacheConfig.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/CacheConfig.java index 7aed3430..1c4e2907 100644 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/CacheConfig.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/CacheConfig.java @@ -12,11 +12,8 @@ */ package com.qubole.rubix.spi; -import com.google.common.collect.ImmutableList; import org.apache.hadoop.conf.Configuration; -import java.util.List; - import static com.qubole.rubix.spi.utils.DataSizeUnits.MEGABYTES; /** @@ -52,6 +49,7 @@ public class CacheConfig private static final String KEY_POOL_MIN_SIZE = "rubix.pool.size.min"; private static final String KEY_POOL_DELTA_SIZE = "rubix.pool.delta.size"; private static final String KEY_POOL_MAX_WAIT_TIMEOUT = "rubix.pool.wait.timeout"; + private static final String KEY_POOL_SCAVENGER_INTERVAL = "rubix.pool.scavenger.interval"; private static final String KEY_DATA_CACHE_EXPIRY_AFTER_WRITE = "rubix.cache.expiration.after-write"; private static final String KEY_DATA_CACHE_DIR_PREFIX = "rubix.cache.dirprefix.list"; private static final String KEY_DATA_CACHE_DIR_SUFFIX = "rubix.cache.dirsuffix"; @@ -310,7 +308,7 @@ public static int getTransportPoolMaxWait(Configuration conf) public static int getScavengeInterval(Configuration conf) { - return conf.getInt(KEY_POOL_MAX_WAIT_TIMEOUT, DEFAULT_SCAVENGE_INTERVAL); + return conf.getInt(KEY_POOL_SCAVENGER_INTERVAL, DEFAULT_SCAVENGE_INTERVAL); } public static int get(Configuration conf) diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/DataTransferClientFactory.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/DataTransferClientFactory.java index d77b55b1..c1da557a 100644 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/DataTransferClientFactory.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/DataTransferClientFactory.java @@ -20,7 +20,6 @@ import org.apache.hadoop.conf.Configuration; import java.io.Closeable; -import java.io.IOException; import java.nio.channels.SocketChannel; import java.util.concurrent.atomic.AtomicBoolean; @@ -53,7 +52,7 @@ public static DataTransferClient getClient(String host, Configuration conf) } } } - Poolable socketChannelPoolable = pool.borrowObject(host, conf); + Poolable socketChannelPoolable = pool.borrowObject(host); return new DataTransferClient(socketChannelPoolable); } diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/RetryingPooledThriftClient.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/RetryingPooledThriftClient.java index df2a045b..e55cac88 100644 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/RetryingPooledThriftClient.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/RetryingPooledThriftClient.java @@ -12,6 +12,7 @@ */ package com.qubole.rubix.spi; +import com.google.common.annotations.VisibleForTesting; import com.qubole.rubix.spi.fop.ObjectPool; import com.qubole.rubix.spi.fop.Poolable; import org.apache.commons.logging.Log; @@ -76,7 +77,7 @@ protected V retryConnection(Callable callable) // unset transportPoolable so that close() doesnt return it again to pool if borrowObject hits an exception transportPoolable = null; - transportPoolable = objectPool.borrowObject(host, conf); + transportPoolable = objectPool.borrowObject(host); updateClient(transportPoolable); } } @@ -84,6 +85,12 @@ protected V retryConnection(Callable callable) throw new TException(); } + @VisibleForTesting + public Poolable getTransportPoolable() + { + return transportPoolable; + } + @Override public void close() { diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectFactory.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectFactory.java index 11abb208..a1455ebe 100755 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectFactory.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectFactory.java @@ -21,7 +21,8 @@ */ public interface ObjectFactory { - T create(String host, int socketTimeout, int connectTimeout); + T create(String host, int socketTimeout, int connectTimeout) + throws Exception; void destroy(T t); diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPool.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPool.java index d40788dd..85158d4b 100755 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPool.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPool.java @@ -19,14 +19,12 @@ import com.google.common.util.concurrent.AbstractScheduledService; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.hadoop.conf.Configuration; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; import static java.lang.Thread.currentThread; @@ -67,7 +65,7 @@ protected BlockingQueue> createBlockingQueue(PoolConfig poolConfig) return new ArrayBlockingQueue<>(poolConfig.getMaxSize()); } - public Poolable borrowObject(String host, Configuration conf) + public Poolable borrowObject(String host) { if (!hostToPoolMap.containsKey(host)) { synchronized (hostToPoolMap) { @@ -77,22 +75,21 @@ public Poolable borrowObject(String host, Configuration conf) } } log.debug(this.name + " : Borrowing object for partition: " + host); - for (int i = 0; i < 3; i++) { // try at most three times - Poolable result = getObject(false, host); - if (factory.validate(result.getObject())) { - return result; - } - else { - this.hostToPoolMap.get(host).decreaseObject(result); - } + Poolable result = getObject(host); + if (result == null) { + throw new RuntimeException("Unable to find a free object from connection pool: " + this.name); + } + else if (!factory.validate(result.getObject())) { + this.hostToPoolMap.get(host).decreaseObject(result); + throw new RuntimeException("Cannot find a valid object from connection pool: " + this.name); } - throw new RuntimeException("Cannot find a valid object"); + return result; } - private Poolable getObject(boolean blocking, String host) + private Poolable getObject(String host) { ObjectPoolPartition subPool = this.hostToPoolMap.get(host); - return subPool.getObject(blocking); + return subPool.getObject(); } public void returnObject(Poolable obj) @@ -105,7 +102,7 @@ public int getSize() { int size = 0; for (ObjectPoolPartition subPool : hostToPoolMap.values()) { - size += subPool.getTotalCount(); + size += subPool.getAliveObjectCount(); } return size; } diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPoolPartition.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPoolPartition.java index 267b914e..edf50315 100755 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPoolPartition.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/ObjectPoolPartition.java @@ -20,8 +20,10 @@ import org.apache.commons.logging.LogFactory; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Semaphore; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import static com.google.common.base.Preconditions.checkState; @@ -35,10 +37,11 @@ public class ObjectPoolPartition private final PoolConfig config; private final BlockingQueue> objectQueue; private final ObjectFactory objectFactory; - private int totalCount; private final String host; private final int socketTimeout; private final int connectTimeout; + private final Semaphore takeSemaphore; + private final AtomicInteger aliveObjectCount; public ObjectPoolPartition(ObjectPool pool, PoolConfig config, ObjectFactory objectFactory, BlockingQueue> queue, String host, String name) @@ -50,115 +53,41 @@ public ObjectPoolPartition(ObjectPool pool, PoolConfig config, this.host = host; this.socketTimeout = config.getSocketTimeoutMilliseconds(); this.connectTimeout = config.getConnectTimeoutMilliseconds(); - this.totalCount = 0; + this.aliveObjectCount = new AtomicInteger(); this.log = new CustomLogger(name, host); - for (int i = 0; i < config.getMinSize(); i++) { - T object = objectFactory.create(host, socketTimeout, connectTimeout); - if (object != null) { - objectQueue.add(new Poolable<>(object, pool, host)); - totalCount++; - } - } - } - - public void returnObject(Poolable object) - { - if (!objectFactory.validate(object.getObject())) { - log.debug(String.format("Invalid object...removing: %s ", object)); - decreaseObject(object); - // Compensate for the removed object. Needed to prevent endless wait when in parallel a borrowObject is called - increaseObjects(1, false); - return; - } - - log.debug(String.format("Returning object: %s to queue. Queue size: %d", object, objectQueue.size())); - if (!objectQueue.offer(object)) { - log.warn("Created more objects than configured. Created=" + totalCount + " QueueSize=" + objectQueue.size()); - decreaseObject(object); - } - } - - public Poolable getObject(boolean blocking) - { - if (objectQueue.size() == 0) { - // increase objects and return one, it will return null if pool reaches max size or if object creation fails - Poolable object = increaseObjects(this.config.getDelta(), true); - - if (object != null) { - return object; - } - - if (totalCount == 0) { - // Could not create objects, this is mostly due to connection timeouts hence no point blocking as there is not other producer of sockets - throw new RuntimeException("Could not add connections to pool"); - } - // else wait for a connection to get free - } - - Poolable freeObject; + this.takeSemaphore = new Semaphore(config.getMaxSize(), true); try { - if (blocking) { - freeObject = objectQueue.take(); - } - else { - freeObject = objectQueue.poll(config.getMaxWaitMilliseconds(), TimeUnit.MILLISECONDS); - if (freeObject == null) { - throw new RuntimeException("Cannot get a free object from the pool"); - } + for (int i = 0; i < config.getMinSize(); i++) { + T object = objectFactory.create(host, socketTimeout, connectTimeout); + objectQueue.add(new Poolable<>(object, pool, host)); + aliveObjectCount.incrementAndGet(); } } - catch (InterruptedException e) { - throw new RuntimeException(e); // will never happen + catch (Exception e) { + // skipping logging the exception as factories are already logging. } - - freeObject.setLastAccessTs(System.currentTimeMillis()); - return freeObject; } - private synchronized Poolable increaseObjects(int delta, boolean returnObject) + public void returnObject(Poolable object) { - int oldCount = totalCount; - if (delta + totalCount > config.getMaxSize()) { - delta = config.getMaxSize() - totalCount; - } - - Poolable objectToReturn = null; try { - for (int i = 0; i < delta; i++) { - T object = objectFactory.create(host, socketTimeout, connectTimeout); - if (object != null) { - // Do not put the first object on queue - // it will be returned to the caller to ensure it's request is satisfied first if object is requested - Poolable poolable = new Poolable<>(object, pool, host); - if (objectToReturn == null && returnObject) { - objectToReturn = poolable; - } - else { - objectQueue.put(poolable); - } - totalCount++; - } + if (!objectFactory.validate(object.getObject())) { + log.debug(String.format("Invalid object...removing: %s ", object)); + decreaseObject(object); + return; } - if (delta > 0 && (totalCount - oldCount) == 0) { - log.warn(String.format("Could not increase pool size. Pool state: totalCount=%d queueSize=%d delta=%d", totalCount, objectQueue.size(), delta)); - } - else { - log.debug(String.format("Increased pool size by %d, to new size: %d, current queue size: %d, delta: %d", - totalCount - oldCount, totalCount, objectQueue.size(), delta)); + log.debug(String.format("Returning object: %s to queue. Queue size: %d", object, objectQueue.size())); + if (!objectQueue.offer(object)) { + String errorLog = "Created more objects than configured. Created=" + aliveObjectCount + " QueueSize=" + objectQueue.size(); + log.warn(errorLog); + decreaseObject(object); + throw new RuntimeException(errorLog); } } - catch (Exception e) { - log.warn(String.format("Unable to increase pool size. Pool state: totalCount=%d queueSize=%d delta=%d", totalCount, objectQueue.size(), delta), e); - // objectToReturn is not on the queue hence untracked, clean it up before forwarding exception - if (objectToReturn != null) { - objectFactory.destroy(objectToReturn.getObject()); - objectToReturn.destroy(); - } - throw new RuntimeException(e); + finally { + takeSemaphore.release(); } - - return objectToReturn; } public boolean decreaseObject(Poolable obj) @@ -167,27 +96,72 @@ public boolean decreaseObject(Poolable obj) checkState(obj.getHost().equals(this.host), "Call to free object of wrong partition, current partition=%s requested partition = %s", this.host, obj.getHost()); - objectRemoved(); log.debug("Decreasing pool size object: " + obj); objectFactory.destroy(obj.getObject()); + aliveObjectCount.decrementAndGet(); obj.destroy(); return true; } - private synchronized void objectRemoved() + public Poolable getObject() { - totalCount--; + Poolable object; + try { + if (!takeSemaphore.tryAcquire(config.getMaxWaitMilliseconds(), TimeUnit.MILLISECONDS)) { + // Not able to acquire semaphore in the given timeout, return null + return null; + } + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return null; + } + + try { + object = tryGetObject(); + object.setLastAccessTs(System.currentTimeMillis()); + } + catch (Exception e) { + takeSemaphore.release(); + throw new RuntimeException("Cannot get a free object from the pool", e); + } + return object; + } + + private Poolable tryGetObject() throws Exception + { + Poolable poolable = objectQueue.poll(); + if (poolable == null) + { + try { + T object = objectFactory.create(host, socketTimeout, connectTimeout); + poolable = new Poolable<>(object, pool, host); + aliveObjectCount.incrementAndGet(); + log.debug(String.format("Added a connection, Pool state: totalCount: %s, queueSize: %d", aliveObjectCount, + objectQueue.size())); + } + catch (Exception e) { + log.warn(String.format("Unable create a connection. Pool state: totalCount=%s queueSize=%d", aliveObjectCount, + objectQueue.size()), e); + if (poolable != null) { + objectFactory.destroy(poolable.getObject()); + poolable.destroy(); + } + throw e; + } + } + return poolable; } - public synchronized int getTotalCount() + public int getAliveObjectCount() { - return totalCount; + return aliveObjectCount.get(); } // set the scavenge interval carefully public void scavenge() throws InterruptedException { - int delta = this.totalCount - config.getMinSize(); + int delta = this.aliveObjectCount.get() - config.getMinSize(); if (delta <= 0) { log.debug("Scavenge for delta <= 0, Skipping !!!"); return; @@ -225,7 +199,7 @@ public void scavenge() throws InterruptedException public synchronized int shutdown() { int removed = 0; - while (this.totalCount > 0) { + while (this.aliveObjectCount.get() > 0) { Poolable obj = objectQueue.poll(); if (obj != null) { decreaseObject(obj); diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketChannelObjectFactory.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketChannelObjectFactory.java index 17cc427f..cb94e32a 100644 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketChannelObjectFactory.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketChannelObjectFactory.java @@ -37,9 +37,10 @@ public SocketChannelObjectFactory(int port) @Override public SocketChannel create(String host, int socketTimeout, int connectTimeout) + throws IOException { SocketAddress sad = new InetSocketAddress(host, this.port); - SocketChannel socket = null; + SocketChannel socket; try { socket = SocketChannel.open(); socket.socket().setSoTimeout(socketTimeout); @@ -49,6 +50,7 @@ public SocketChannel create(String host, int socketTimeout, int connectTimeout) } catch (IOException e) { log.warn(LDS_POOL + " : Unable to open connection to host " + host, e); + throw e; } return socket; } diff --git a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketObjectFactory.java b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketObjectFactory.java index 609f7673..4c69c82e 100644 --- a/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketObjectFactory.java +++ b/rubix-spi/src/main/java/com/qubole/rubix/spi/fop/SocketObjectFactory.java @@ -36,16 +36,17 @@ public SocketObjectFactory(int port) @Override public TSocket create(String host, int socketTimeout, int connectTimeout) + throws TTransportException { log.debug(BKS_POOL + " : Opening connection to host: " + host); - TSocket socket = null; + TSocket socket; try { socket = new TSocket(host, port, socketTimeout, connectTimeout); socket.open(); } catch (TTransportException e) { - socket = null; log.warn("Unable to open connection to host " + host, e); + throw e; } return socket; } diff --git a/rubix-spi/src/test/java/com/qubole/rubix/spi/TestBookKeeperFactory.java b/rubix-spi/src/test/java/com/qubole/rubix/spi/TestBookKeeperFactory.java index 368c968d..6de0673f 100644 --- a/rubix-spi/src/test/java/com/qubole/rubix/spi/TestBookKeeperFactory.java +++ b/rubix-spi/src/test/java/com/qubole/rubix/spi/TestBookKeeperFactory.java @@ -12,6 +12,7 @@ */ package com.qubole.rubix.spi; +import com.qubole.rubix.spi.fop.Poolable; import com.qubole.rubix.spi.thrift.BookKeeperService; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -21,6 +22,7 @@ import org.apache.thrift.server.TThreadPoolServer; import org.apache.thrift.transport.TServerSocket; import org.apache.thrift.transport.TServerTransport; +import org.apache.thrift.transport.TTransport; import org.apache.thrift.transport.TTransportException; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -30,7 +32,9 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; public class TestBookKeeperFactory { @@ -146,6 +150,43 @@ public void testCreateBookKeeperClient_startDelay_unableToConnect_socketTimeout( client.isBookKeeperAlive(); // should throw expected exception due to socket timeout } + @Test + public void testConnectionPoolSemaphoreLogic() throws TException, InterruptedException + { + final int connectTimeout = 500; + final int socketTimeout = 500; + + // Create a connection pool of size = 1 + conf.setInt("rubix.pool.size.max", 1); + server = startMockServer(true, NO_DELAY, NO_DELAY); + + RetryingPooledBookkeeperClient bookKeeperClient = createTestBookKeeperClient(socketTimeout, connectTimeout, 3); + assertTrue(bookKeeperClient.isBookKeeperAlive(), "Unable to connect to bookkeeper"); + + Poolable transportPoolable = bookKeeperClient.getTransportPoolable(); + + try { + bookKeeperFactory.createBookKeeperClient("localhost", conf); + } + catch (Exception e) { + assertEquals(e.getMessage(), "Unable to find a free object from connection pool: bks-pool"); + + // close the client which should have added back the free connection the pool. + bookKeeperClient.close(); + + bookKeeperClient = bookKeeperFactory.createBookKeeperClient("localhost", conf); + assertTrue(bookKeeperClient.isBookKeeperAlive(), "Unable to connect to bookkeeper"); + + // Verify that the pool return the same connection instead of creating the new one. + assertEquals(transportPoolable.getObject(), bookKeeperClient.getTransportPoolable().getObject(), "Same connection should be reused from the pool"); + return; + } + finally { + stopMockServer(); + } + fail("Expected exception to be thrown while creating bookkeeper client"); + } + private MockBookKeeperServer startMockServer(boolean waitForStart, int startDelay, int aliveCallDelay) throws InterruptedException { MockBookKeeperServer server = new MockBookKeeperServer(startDelay, aliveCallDelay); diff --git a/rubix-spi/src/test/java/com/qubole/rubix/spi/client/TestPoolingClient.java b/rubix-spi/src/test/java/com/qubole/rubix/spi/client/TestPoolingClient.java index 29b798c8..7d5acc32 100644 --- a/rubix-spi/src/test/java/com/qubole/rubix/spi/client/TestPoolingClient.java +++ b/rubix-spi/src/test/java/com/qubole/rubix/spi/client/TestPoolingClient.java @@ -114,7 +114,7 @@ private RetryingPooledThriftTestClient getClient(int retries) retries, conf, "localhost", - pool.borrowObject("localhost", conf)); + pool.borrowObject("localhost")); } private static void startServerAsync(final Configuration conf)