Hello I'm using the library on a TPU-VM (v3-8). The software version is tpu-vm-base.
Traceback (most recent call last):
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 611, in connect
sock = self.retry.call_with_retry(
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/retry.py", line 46, in call_with_retry
return do()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 612, in <lambda>
lambda: self._connect(), lambda error: self.disconnect(error)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 677, in _connect
raise err
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 665, in _connect
sock.connect(socket_address)
ConnectionRefusedError: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/examples/gptj_serve.py", line 98, in <module>
inference_server = InferenceServerMP(
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 21, in __init__
self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 79, in initalize_server
build_method(Config.init_message, r, Q)(self)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 44, in call_method
request_id = int(r.incr('request_id_counter'))
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/commands/core.py", line 1831, in incrby
return self.execute_command("INCRBY", name, amount)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/client.py", line 1235, in execute_command
conn = self.connection or pool.get_connection(command_name, **options)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 1387, in get_connection
connection.connect()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 617, in connect
raise ConnectionError(self._error_message(e))
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/examples/gptj_serve.py", line 98, in <module>
inference_server = InferenceServerMP(
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 21, in __init__
self.Q = initalize_server(self, super().__getattribute__('r'), cache_cls, args, kwargs)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 79, in initalize_server
build_method(Config.init_message, r, Q)(self)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 44, in call_method
request_id = int(r.incr('request_id_counter'))
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/commands/core.py", line 1831, in incrby
return self.execute_command("INCRBY", name, amount)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/client.py", line 1235, in execute_command
conn = self.connection or pool.get_connection(command_name, **options)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 1387, in get_connection
connection.connect()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/site-packages/redis/connection.py", line 617, in connect
raise ConnectionError(self._error_message(e))
redis.exceptions.ConnectionError: Error 111 connecting to localhost:6379. Connection refused.
using mesh shape: (1, 8)
full mesh: [[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)
TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)
TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)
TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)
TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)
TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)
TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]]
current process index 0, in position [0, 0] of [1, 1]
tcmalloc: large alloc 24203542528 bytes == 0x560314c42000 @ 0x7f615c34c680 0x7f615c36d824 0x560289f3e53b 0x560289f7f0ba 0x56028a055a58 0x560289fb148d 0x560289e8b328 0x56028a06b66d 0x560289fb1825 0x560289f0f2da 0x560289fa6fe3 0x560289fa8709 0x560289f0e73d 0x560289fa7be4 0x560289f0e088 0x560289fa6fe3 0x560289fa7d24 0x560289f0e73d 0x560289fa6fe3 0x560289fa7d24 0x560289f92a2e 0x560289f9c429 0x560289f676ab 0x560289f56359 0x560289fe7e7a 0x560289fa7be4 0x560289f5630a 0x560289fe7e7a 0x560289fa7be4 0x560289f0f2da 0x560289fa6fe3
unmatches keys: set()
Process Process-2:
Traceback (most recent call last):
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 802, in _callmethod
conn = self._tls.connection
AttributeError: 'ForkAwareLocal' object has no attribute 'connection'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 60, in server_process
request_id, method, args, kwargs = Q.get()
File "<string>", line 2, in get
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 806, in _callmethod
self._connect()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/managers.py", line 793, in _connect
conn = self._Client(self._token.address, authkey=self._authkey)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/connection.py", line 507, in Client
c = SocketClient(address)
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/connection.py", line 635, in SocketClient
s.connect(address)
ConnectionRefusedError: [Errno 111] Connection refused
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/home/qj213/anaconda3/envs/JaxSeq/lib/python3.9/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/qj213/JAXSeq/src/utils/serve_queue.py", line 73, in server_process
raise Exception
Exception