#!/usr/bin/env python3
"""
main.py
Aimless Transition Ensemble Sampling and Analysis (ATESA)
This script handles the primary loop of building and submitting jobs in independent Threads, using the methods thereof
to execute various interfaced/abstracted commands.
"""
import sys
import os
import shutil
import pickle
import pytraj
import copy
import time
import glob
import psutil
import warnings
import itertools
from atesa import configure
from atesa import factory
from atesa import process
from atesa import interpret
from atesa import utilities
from atesa import information_error
from atesa import resampling
from multiprocess import Pool, Manager, get_context
[docs]class Thread(object):
"""
Object representing a series of simulations and containing the relevant information to define its current state.
Threads represent the level on which ATESA is parallelized. This flexible object is used for every type of job
performed by ATESA.
Parameters
----------
settings : argparse.Namespace
Settings namespace object
Returns
-------
None
"""
def __init__(self):
self.topology = '' # filename containing topology file
self.jobids = [] # list of jobids associated with the present step of this thread
self.terminated = False # boolean indicating whether the thread has reached a termination criterion
self.current_type = [] # list of job types for the present step of this thread
self.current_name = [] # list of job names corresponding to the job types
self.current_results = [] # results of each job, if applicable
self.name = '' # name of current step
self.suffix = 0 # index of current step
self.total_moves = 0 # running total of "moves" attributable to this thread
self.accept_moves = 0 # running total of "accepted" "moves", as defined by JobType.update_results
self.status = 'fresh thread' # tag for current status of a thread
self.skip_update = False # used by restart to resubmit jobs as they were rather than doing the next step
[docs] def process(self, running, settings):
return process.process(self, running, settings)
[docs] def interpret(self, allthreads, running, settings):
return interpret.interpret(self, allthreads, running, settings)
[docs] def gatekeeper(self, settings):
jobtype = factory.jobtype_factory(settings.job_type)
return jobtype.gatekeeper(self, settings)
[docs] def get_next_step(self, settings):
jobtype = factory.jobtype_factory(settings.job_type)
self.current_type, self.current_name = jobtype.get_next_step(self, settings)
return self.current_type, self.current_name
[docs] def get_batch_template(self, type, settings):
jobtype = factory.jobtype_factory(settings.job_type)
return jobtype.get_batch_template(self, type, settings)
[docs] def get_frame(self, traj, frame, settings):
mdengine = factory.mdengine_factory(settings.md_engine)
return mdengine.get_frame(traj, frame, settings)
[docs] def get_status(self, job_index, settings):
batchsystem = factory.batchsystem_factory(settings.batch_system)
return batchsystem.get_status(self.jobids[job_index], settings)
[docs] def cancel_job(self, job_index, settings):
batchsystem = factory.batchsystem_factory(settings.batch_system)
batchsystem.cancel_job(self.jobids[job_index], settings)
[docs]def init_threads(settings):
"""
Initialize all the Thread objects called for by the user input file.
In the case where settings.restart == True, this involves unpickling restart.pkl; otherwise, brand new objects are
produced in accordance with settings.job_type (aimless_shooting, committor_analysis, equilibrium_path_sampling, or
isee).
Parameters
----------
settings : argparse.Namespace
Settings namespace object
Returns
-------
allthreads : list
List of all Thread objects, including those for which no further tasks are scheduled.
"""
if settings.restart:
allthreads = pickle.load(open(settings.working_directory + '/restart.pkl', 'rb'))
for thread in allthreads:
if not thread.current_type == []:
thread.skip_update = True
if settings.restart_terminated_threads:
for thread in allthreads:
thread.terminated = False
if settings.job_type == 'aimless_shooting' and settings.information_error_checking:
if os.path.exists(settings.working_directory + '/info_err.out') and len(open(settings.working_directory + '/info_err.out', 'r').readlines()) > 0:
info_err_lines = open(settings.working_directory + '/info_err.out', 'r').readlines()
# Resample completely if there's been a change in the number of definitions of CVs, or in the settings
# for, information_error_max_dims or information_error_lmax_string
wrong_length = False
for data_length in [str(line.split()[0]) for line in info_err_lines]:
first_line = open(settings.working_directory + '/as_decorr_' + data_length + '.out', 'r').readlines()[0]
num_cvs = len(first_line.replace('A <- ', '').replace('B <- ', '').split())
if settings.include_qdot:
num_cvs = num_cvs / 2
if not num_cvs == len(settings.cvs):
wrong_length = True
if (settings.previous_cvs and not settings.previous_cvs == settings.cvs) or \
(not settings.previous_information_error_max_dims == settings.information_error_max_dims) or \
(not settings.previous_information_error_lmax_string == settings.information_error_lmax_string) or \
wrong_length:
utilities.resample(settings, partial=False)
information_error.main()
# Resample if info_err.out is improperly formatted (will not run if resample called above)
if False in [len(info_err_lines[i].split(' ')) == 2 for i in range(0, len(info_err_lines))]:
utilities.resample(settings, partial=True)
information_error.main()
# Resample if info_err.out is missing lines (will not run if resample called above)
len_data = len(open(settings.working_directory + '/as_raw.out', 'r').readlines())
last_info_err = info_err_lines[-1].split(' ')[0]
last_breakpoint = len_data - (len_data % settings.information_error_freq)
if (last_breakpoint > 0 and not int(last_info_err) == int(last_breakpoint)):
utilities.resample(settings, partial=True)
information_error.main()
if settings.job_type == 'equilibrium_path_sampling' and settings.eps_dynamic_seed: # handle dynamic seeding restart behavior
for thread in allthreads:
window_index = 0
for bounds in settings.eps_bounds:
if bounds == thread.history.bounds:
settings.eps_empty_windows[window_index] -= 1 # decrement empty window count in this window
if settings.eps_empty_windows[window_index] < 0: # minimum value 0
settings.eps_empty_windows[window_index] = 0
break
window_index += 1
return allthreads
# If not restart:
allthreads = []
jobtype = factory.jobtype_factory(settings.job_type)
# Set topology properly even if it's given as a path
og_prmtop = settings.topology
if '/' in settings.topology:
settings.topology = settings.topology[settings.topology.rindex('/') + 1:]
try:
shutil.copy(og_prmtop, settings.working_directory + '/' + settings.topology)
except shutil.SameFileError:
pass
for file in jobtype.get_initial_coordinates(settings):
if '/' in file:
file = file[file.rindex('/') + 1:] # drop path to file from filename
thread = Thread() # initialize the thread object
thread.topology = settings.topology
jobtype.update_history(thread, settings, **{'initialize': True, 'inpcrd': file}) # initialize thread.history
thread.name = file + '_' + str(thread.suffix)
allthreads.append(thread)
return allthreads
[docs]def handle_loop_exception(attempted_rescue, running, settings):
"""
Handle attempted rescue of main loop after encountering an exception, or cancellation of jobs if rescue fails.
Parameters
----------
attempted_rescue : bool
True if rescue has already been attempted and this function is being called again. Skips attempting rescue again
and simply cancels all running jobs.
running : list
List of Thread objects that are currently running. These are the threads that will be canceled if the ATESA run
cannot be rescued.
settings : argparse.Namespace
Settings namespace object
Returns
-------
None
"""
# class UnableToRescueException(Exception):
# """ Custom exception for closing out ATESA after an unsuccessful rescue attempt """
# pass
# # todo: finish implementing and then uncomment this
# if not attempted_rescue:
# print('Attempting to remove offending thread(s) and rescue the operation...')
# verify_outcome_str, verify_outcome_int = verify_threads.main('restart.pkl')
# print(verify_outcome_str)
#
# if verify_outcome_int == 1: # broken thread removed, continue with attempted rescue
# # First, set rescue_running equal to running with deleted threads removed
# remaining_threads = pickle.load(open('restart.pkl', 'rb'))
# rescue_running = [thread for thread in running if thread in remaining_threads]
#
# # If rescue_running == running, removed threads weren't running so rescue fails. Otherwise...
# if not rescue_running == running:
# # Then, cancel jobs belonging to deleted threads
# deleted_threads = list(set(running) - set(rescue_running))
# for thread in deleted_threads:
# try:
# for job_index in range(thread.jobids):
# thread.cancel_job(job_index, settings)
# except Exception as little_e:
# print('Encountered exception while attempting to cancel a job: ' + str(little_e) +
# '\nIgnoring and continuing...')
#
# # Finally, resubmit main() with rescue_running list
# main(settings, rescue_running=rescue_running)
# return None
# This code reached if return statement above is not
print('\nCancelling currently running batch jobs belonging to this process in order to '
'preserve resources.')
for thread in running:
try:
for job_index in range(len(thread.jobids)):
thread.cancel_job(job_index, settings)
except Exception as little_e:
print('\nEncountered exception while attempting to cancel a job: ' + str(little_e) +
'\nIgnoring and continuing...')
raise RuntimeError('Job cancellation complete, ATESA is now shutting down.')
[docs]def main(settings, rescue_running=[]):
"""
Perform the primary loop of building, submitting, monitoring, and analyzing jobs.
This function works via a loop of calls to thread.process and thread.interpret for each thread that hasn't
terminated, until either the global termination criterion is met or all the individual threads have completed.
Parameters
----------
settings : argparse.Namespace
Settings namespace object
rescue_running : list
List of threads passed in from handle_loop_exception, containing running threads. If given, setup is skipped and
the function proceeds directly to the main loop.
Returns
-------
exit_message : str
A message indicating the status of ATESA at the end of main
"""
if not rescue_running:
# Implement resample
if settings.job_type in ['aimless_shooting', 'committor_analysis', 'umbrella_sampling'] and settings.resample:
# Store settings object in the working directory for compatibility with analysis/utility scripts
if not settings.dont_dump:
temp_settings = copy.deepcopy(settings) # initialize temporary copy of settings to modify
temp_settings.__dict__.pop('env') # env attribute is not picklable
pickle.dump(temp_settings, open(settings.working_directory + '/settings.pkl', 'wb'), protocol=2)
# Run resampling
if settings.job_type == 'aimless_shooting':
utilities.resample(settings, partial=False, full_cvs=settings.full_cvs, only_full_cvs=settings.only_full_cvs)
if settings.information_error_checking: # update info_err.out if called for by settings
information_error.main()
elif settings.job_type == 'committor_analysis':
resampling.resample_committor_analysis(settings)
elif settings.job_type == 'umbrella_sampling':
resampling.resample_umbrella_sampling(settings)
return 'Resampling complete'
# Make working directory if it does not exist, handling overwrite and restart as needed
if os.path.exists(settings.working_directory):
if settings.overwrite and not settings.restart:
if os.path.exists(settings.working_directory + '/cvs.txt'): # a kludge to avoid removing cvs.txt
if os.path.exists('ATESA_TEMP_CVS.txt'):
raise RuntimeError('tried to create temporary file ATESA_TEMP_CVS.txt in directory: ' +
os.getcwd() + ', but it already exists. Please move, delete, or rename it.')
shutil.move(settings.working_directory + '/cvs.txt', 'ATESA_TEMP_CVS.txt')
shutil.rmtree(settings.working_directory)
os.mkdir(settings.working_directory)
if os.path.exists('ATESA_TEMP_CVS.txt'): # continuation of aforementioned kludge
shutil.move('ATESA_TEMP_CVS.txt', settings.working_directory + '/cvs.txt')
elif not settings.restart and glob.glob(settings.working_directory + '/*') == [settings.working_directory + '/cvs.txt']:
# Occurs when restart = False, overwrite = False, and auto_cvs is used
pass
elif not settings.restart:
raise RuntimeError('Working directory ' + settings.working_directory + ' already exists, but overwrite '
'= False and restart = False. Either change one of these two settings or choose a '
'different working directory.')
else:
if not settings.restart:
os.mkdir(settings.working_directory)
else:
raise RuntimeError('Working directory ' + settings.working_directory + ' does not yet exist, but '
'restart = True.')
# Store settings object in the working directory for compatibility with analysis/utility scripts
if os.path.exists(settings.working_directory + '/settings.pkl'): # for checking for need for resample later
previous_settings = pickle.load(open(settings.working_directory + '/settings.pkl', 'rb'))
settings.previous_cvs = previous_settings.cvs
try:
settings.previous_information_error_max_dims = previous_settings.information_error_max_dims
except AttributeError:
pass
try:
settings.previous_information_error_lmax_string = previous_settings.information_error_lmax_string
except AttributeError:
pass
if not settings.dont_dump:
temp_settings = copy.deepcopy(settings) # initialize temporary copy of settings to modify
temp_settings.__dict__.pop('env') # env attribute is not picklable (update: maybe no longer true, but doesn't matter)
pickle.dump(temp_settings, open(settings.working_directory + '/settings.pkl', 'wb'), protocol=2)
# Build or load threads
allthreads = init_threads(settings)
# Move runtime to working directory
os.chdir(settings.working_directory)
running = allthreads.copy() # to be pruned later by thread.process()
attempted_rescue = False # to keep track of general error handling below
else:
allthreads = pickle.load(open(settings.working_directory + '/restart.pkl', 'rb'))
running = rescue_running
attempted_rescue = True
# Initialize threads with first process step
try:
if not rescue_running: # if rescue_running, this step has already finished and we just want the while loop
for thread in allthreads:
running = thread.process(running, settings)
except Exception as e:
if settings.restart:
print('The following error occurred while attempting to initialize threads from restart.pkl. It may be '
'corrupted.')
#'If you haven\'t already done so, consider running verify_threads.py to remove corrupted threads from this file.'
raise e
try:
if settings.job_type == 'aimless_shooting' and len(os.sched_getaffinity(0)) > 1:
# Initialize Manager for shared data across processes; this is necessary because multiprocessing is being
# retrofitted to code designed for serial processing, but it works!
manager = Manager()
# Setup Managed allthreads list
managed_allthreads = []
for thread in allthreads:
thread_dict = thread.__dict__
thread_history_dict = thread.history.__dict__
managed_thread = Thread()
managed_thread.history = manager.Namespace()
managed_thread.__dict__.update(thread_dict)
managed_thread.history.__dict__.update(thread_history_dict)
managed_allthreads.append(managed_thread)
allthreads = manager.list(managed_allthreads)
# Setup Managed settings Namespace
settings_dict = settings.__dict__
managed_settings = manager.Namespace()
# Need to explicitly update every key because of how the Managed Namespace works.
# Calling exec is the best way to do this I could find. Updating managed_settings.__dict__ doesn't work.
for key in settings_dict.keys():
exec('managed_settings.' + key + ' = settings_dict[key]')
# Distribute processes among available core Pool
with get_context("spawn").Pool(len(os.sched_getaffinity(0))) as p:
p.starmap(main_loop, zip(itertools.repeat(managed_settings), itertools.repeat(allthreads), [[thread] for thread in allthreads]))
else:
main_loop(settings, allthreads, running)
except AttributeError: # os.sched_getaffinity raises AttributeError on non-UNIX systems.
main_loop(settings, allthreads, running)
## Deprecated thread pool
# pool = ThreadPool(len(allthreads))
# func = partial(main_loop, settings)
# results = pool.map(func, [[thread] for thread in allthreads])
jobtype = factory.jobtype_factory(settings.job_type)
jobtype.cleanup(settings)
return 'ATESA run exiting normally'
[docs]def main_loop(settings, allthreads, running):
termination_criterion = False
attempted_rescue = False
interpreted = []
# Begin main loop
# This whole thing is in a try-except block to handle cancellation of jobs when the code crashes in any way
try:
while (not termination_criterion) and running:
for thread in running:
if thread.gatekeeper(settings):
termination_criterion, running = thread.interpret(allthreads, running, settings)
if attempted_rescue == True:
interpreted.append(thread)
if termination_criterion:
for thread in running:
for job_index in range(len(thread.current_type)):
thread.cancel_job(job_index, settings)
running = []
if not settings.pid == -1: # finish up currently running resample_and_inferr, if appropriate
proc_status = 'running'
while proc_status == 'running':
try:
proc = psutil.Process(settings.pid).status()
if proc in [psutil.STATUS_RUNNING, psutil.STATUS_SLEEPING, psutil.STATUS_DISK_SLEEP]:
proc_status = 'running'
time.sleep(60) # wait 1 minute before checking again
else:
proc_status = 'not_running'
except (psutil.NoSuchProcess, ProcessLookupError):
proc_status = 'not_running'
break
running = thread.process(running, settings)
else: # todo: change this to only run if all running threads fail gatekeeper?
time.sleep(30) # to prevent too-frequent calls to batch system by thread.gatekeeper
if all([thread in interpreted for thread in running]):
attempted_rescue = False # every thread has passed at an interpret step, so rescue was successful!
except Exception as e:
print(str(e))
handle_loop_exception(attempted_rescue, running, settings)
[docs]def run_main():
# Obtain settings namespace, initialize threads, and move promptly into main.
try:
working_directory = sys.argv[2]
except IndexError:
working_directory = ''
try:
settings = configure.configure(sys.argv[1], working_directory)
except IndexError:
raise RuntimeError('No configuration file specified. See documentation at atesa.readthedocs.io for details.')
exit_message = main(settings)
print(exit_message)
if __name__ == "__main__":
run_main()