Skip to content

Commit

Permalink
temporary commit where an issue exists data_id getting garbage collec…
Browse files Browse the repository at this point in the history
…ted in main thread
  • Loading branch information
arunjose696 committed Mar 30, 2023
1 parent d39bba5 commit 92abdcc
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 90 deletions.
26 changes: 14 additions & 12 deletions unidist/core/backends/mpi/core/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
garbage_collector,
)
from unidist.core.backends.mpi.core.controller.common import push_data, RoundRobin

from .api import workQueue

class ActorMethod:
"""
Expand Down Expand Up @@ -40,9 +40,13 @@ def __call__(self, *args, num_returns=1, **kwargs):
unwrapped_args = [common.unwrap_data_ids(arg) for arg in args]
unwrapped_kwargs = {k: common.unwrap_data_ids(v) for k, v in kwargs.items()}

push_data(self._actor._owner_rank, unwrapped_args)
push_data(self._actor._owner_rank, unwrapped_kwargs)

#push_data(self._actor._owner_rank, unwrapped_args)
#push_data(self._actor._owner_rank, unwrapped_kwargs)
from unidist.core.backends.common.data_id import DataID
breakpoint()
workQueue.put([1,(self._actor._owner_rank, [DataID("dadada")])])
workQueue.put([1,(self._actor._owner_rank, unwrapped_args)])
workQueue.put([1,(self._actor._owner_rank, unwrapped_kwargs)])
operation_type = common.Operation.ACTOR_EXECUTE
operation_data = {
"task": self._method_name,
Expand All @@ -51,14 +55,12 @@ def __call__(self, *args, num_returns=1, **kwargs):
"output": common.master_data_ids_to_base(output_id),
"handler": self._actor._handler_id.base_data_id(),
}
async_operations = AsyncOperations.get_instance()
h_list, _ = communication.isend_complex_operation(
communication.MPIState.get_instance().comm,
operation_type,
operation_data,
self._actor._owner_rank,
)
async_operations.extend(h_list)

workQueue.put([2,(communication.MPIState.get_instance().comm,
operation_type,
operation_data,
self._actor._owner_rank,)])

return output_id


Expand Down
124 changes: 46 additions & 78 deletions unidist/core/backends/mpi/core/controller/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from datetime import datetime
try:
import mpi4py
from mpi4py import futures
except ImportError:
raise ImportError(
"Missing dependency 'mpi4py'. Use pip or conda to install it."
Expand Down Expand Up @@ -69,31 +70,22 @@ def run(self):
print ("Exiting " + self.name)


class ThreadSafeDict(dict) :
def __init__(self, * p_arg, ** n_arg) :
dict.__init__(self, * p_arg, ** n_arg)
self._lock = threading.Lock()

def __enter__(self) :
self._lock.acquire()
return self

def __exit__(self, type, value, traceback) :
self._lock.release()
getRequests=ThreadSafeDict()
def process_data(threadName, q):
global getRequests
while not exitFlag:
queueLock.acquire()
#print("queue size is {} , time ={}".format(workQueue.qsize(),datetime.fromtimestamp(time.time())))
if not workQueue.empty():
flag, data = q.get()
flag, future, data = q.get()
print("order========================================================== {}".format(flag))

queueLock.release()
if flag==1:
dest_rank, value = data
push_data(dest_rank, value)
print("pushed data=========Data id======================================= {}".format(value))
if flag==2:
comm, operation_type, operation_data, dest_rank = data
async_operations = AsyncOperations.get_instance()
Expand All @@ -104,9 +96,18 @@ def process_data(threadName, q):
dest_rank,
)
async_operations.extend(h_list)
if flag == 3:
if flag == 3:
data_id = data
print("order=========Data id======================================= {}".format(data_id))
getQueue.put(request_worker_data(data_id))
if flag == 4:
function, args = data
function(*args)
if flag == 5:
function, args = data
result = function(*args)
future.set_result(result)


#print ("%s processing %s" % (threadName, value))
else:
Expand All @@ -116,43 +117,11 @@ def process_data(threadName, q):

threadList = ["Thread-1"]
queueLock = threading.Lock()
workQueue = queue.Queue(10)
getQueue = queue.Queue(10)
workQueue = queue.Queue(0)
getQueue = queue.Queue(0)
threads = []



# Notify threads it's time to exit


# Wait for all threads to complete
# for t in threads:
# t.join()
# print ("Exiting Main Thread")

def test(comm):

output_id=common.MasterDataID("data_id_dada", garbage_collector)
task = lambda : 1
operation_type = 1
operation_data = {
"task": task,
"args": [],
"kwargs": {},
"output":output_id,
}
async_operations = AsyncOperations.get_instance()
h_list, _ = communication.isend_complex_operation(
comm,
operation_type,
operation_data,
2,
)
async_operations.extend(h_list)




def init():
"""
Initialize MPI processes.
Expand All @@ -179,9 +148,8 @@ def init():

comm = MPI.COMM_WORLD
rank = comm.Get_rank()

parent_comm = MPI.Comm.Get_parent()
if not threads:
if not threads and rank == 0 and parent_comm == MPI.COMM_NULL:
for tName in threadList:
thread = myThread(1, tName, workQueue)
thread.start()
Expand All @@ -191,19 +159,6 @@ def init():
if rank == 0 and parent_comm == MPI.COMM_NULL:
# Create new threads



# Fill the queue


# Wait for queue to empty



# Wait for all threads to complete



if IsMpiSpawnWorkers.get():
nprocs_to_spawn = CpuCount.get() + 1 # +1 for monitor process
args = ["-c"]
Expand Down Expand Up @@ -355,8 +310,12 @@ def put(data):
unidist.core.backends.mpi.core.common.MasterDataID
An ID of an object in object storage.
"""
data_id = object_store.generate_data_id(garbage_collector)
object_store.put(data_id, data)
#data_id = object_store.generate_data_id(garbage_collector)
data_id = futures.Future()
workQueue.put([5,data_id, [ object_store.generate_data_id,[garbage_collector]] ])
data_id = data_id.result()
workQueue.put([4, None, [object_store.put,[data_id, data]] ])
#object_store.put(data_id, data)

logger.debug("PUT {} id".format(data_id._id))

Expand All @@ -379,25 +338,28 @@ def get(data_ids):
"""

def get_impl(data_id):
global getRequests
global getRequests,workQueue
if object_store.contains(data_id):
value = object_store.get(data_id)
else:
queueLock.acquire()
print(workQueue.qsize())
workQueue.put([3,(data_id)])
print("size {} data_id={}".format(workQueue.qsize(), data_id))
workQueue.put([3, None, (data_id)])

queueLock.release()


#raise ValueError(getQueue.qsize())
while True:
queueLock.acquire()

if not getQueue.empty():
value = getQueue.get()
queueLock.release()
break
queueLock.release()

print("got {}".format( data_id))
print("value {}".format( value))



Expand Down Expand Up @@ -518,10 +480,13 @@ def submit(task, *args, num_returns=1, **kwargs):

dest_rank = RoundRobin.get_instance().schedule_rank()

output_ids = object_store.generate_output_data_id(
dest_rank, garbage_collector, num_returns
)

# output_ids = object_store.generate_output_data_id(
# dest_rank, garbage_collector, num_returns
# )
global workQueue
output_ids = futures.Future()
workQueue.put([5,output_ids, [ object_store.generate_output_data_id,[dest_rank, garbage_collector, num_returns]] ])
output_ids = output_ids.result()
logger.debug("REMOTE OPERATION")
logger.debug(
"REMOTE args to {} rank: {}".format(
Expand All @@ -539,16 +504,19 @@ def submit(task, *args, num_returns=1, **kwargs):


# Fill the queue
from unidist.core.backends.common.data_id import DataID
# if DataID("rank_0_id_11") in unwrapped_args or DataID("rank_0_id_11") in unwrapped_kwargs:
# breakpoint()


global workQueue
queueLock.acquire()
print(workQueue.qsize())
workQueue.put([1,(dest_rank, unwrapped_args)])
workQueue.put([1,None, (dest_rank, unwrapped_args)])

print(unwrapped_args, time.time())
# push_data(dest_rank, unwrapped_args)
workQueue.put([1, None, (dest_rank, unwrapped_args)])

print(workQueue.qsize())

# push_data(dest_rank, unwrapped_args)
push_data(dest_rank, unwrapped_kwargs)

operation_type = common.Operation.EXECUTE
operation_data = {
Expand All @@ -557,7 +525,7 @@ def submit(task, *args, num_returns=1, **kwargs):
"kwargs": unwrapped_kwargs,
"output": common.master_data_ids_to_base(output_ids),
}
workQueue.put([2,(communication.MPIState.get_instance().comm,
workQueue.put([2, None, (communication.MPIState.get_instance().comm,
operation_type,
operation_data,
dest_rank,)])
Expand Down

0 comments on commit 92abdcc

Please sign in to comment.