Merge pull request #62 from minrk/shutdown-state

adjustments to Spawner.stop
This commit is contained in:
Min RK
2014-10-14 11:56:24 -07:00
5 changed files with 125 additions and 75 deletions

View File

@@ -448,20 +448,25 @@ class JupyterHubApp(Application):
for user in db.query(orm.User): for user in db.query(orm.User):
if not user.state: if not user.state:
# without spawner state, server isn't valid
user.server = None
user_summaries.append(_user_summary(user)) user_summaries.append(_user_summary(user))
continue continue
self.log.debug("Loading state for %s from db", user.name) self.log.debug("Loading state for %s from db", user.name)
spawner = self.spawner_class.fromJSON(user.state, user=user, hub=self.hub, config=self.config) user.spawner = spawner = self.spawner_class(
user=user, hub=self.hub, config=self.config,
)
status = run_sync(spawner.poll) status = run_sync(spawner.poll)
if status is None: if status is None:
self.log.info("User %s still running", user.name) self.log.info("%s still running", user.name)
user.spawner = spawner
spawner.add_poll_callback(user_stopped, user) spawner.add_poll_callback(user_stopped, user)
spawner.start_polling() spawner.start_polling()
else: else:
self.log.warn("Failed to load state for %s, assuming server is not running.", user.name) # user not running. This is expected if server is None,
# not running, state is invalid # but indicates the user's server died while the Hub wasn't running
user.state = {} # if user.server is defined.
log = self.log.warn if user.server else self.log.debug
log("%s not running.", user.name)
user.server = None user.server = None
user_summaries.append(_user_summary(user)) user_summaries.append(_user_summary(user))
@@ -508,7 +513,8 @@ class JupyterHubApp(Application):
'--api-port', str(self.proxy.api_server.port), '--api-port', str(self.proxy.api_server.port),
'--default-target', self.hub.server.host, '--default-target', self.hub.server.host,
] ]
if self.log_level == logging.DEBUG: if False:
# if self.log_level == logging.DEBUG:
cmd.extend(['--log-level', 'debug']) cmd.extend(['--log-level', 'debug'])
if self.ssl_key: if self.ssl_key:
cmd.extend(['--ssl-key', self.ssl_key]) cmd.extend(['--ssl-key', self.ssl_key])

View File

@@ -306,13 +306,18 @@ class User(Base):
db.add(api_token) db.add(api_token)
db.commit() db.commit()
spawner = self.spawner = spawner_class( spawner = self.spawner = spawner_class(
config=config, config=config,
user=self, user=self,
hub=hub, hub=hub,
api_token=api_token.token, api_token=api_token.token,
) )
# we are starting a new server, make sure it doesn't restore state
spawner.clear_state()
yield spawner.start() yield spawner.start()
spawner.start_polling()
# store state # store state
self.state = spawner.get_state() self.state = spawner.get_state()
@@ -324,14 +329,19 @@ class User(Base):
@gen.coroutine @gen.coroutine
def stop(self): def stop(self):
"""Stop the user's spawner""" """Stop the user's spawner
and cleanup after it.
"""
if self.spawner is None: if self.spawner is None:
return return
self.spawner.stop_polling()
status = yield self.spawner.poll() status = yield self.spawner.poll()
if status is None: if status is None:
yield self.spawner.stop() yield self.spawner.stop()
self.state = {} self.spawner.clear_state()
self.spawner = None self.state = self.spawner.get_state()
self.last_activity = datetime.utcnow()
self.server = None self.server = None
inspect(self).session.commit() inspect(self).session.commit()

View File

@@ -14,7 +14,7 @@ from tornado.ioloop import IOLoop, PeriodicCallback
from IPython.config import LoggingConfigurable from IPython.config import LoggingConfigurable
from IPython.utils.traitlets import ( from IPython.utils.traitlets import (
Any, Bool, Dict, Enum, Instance, Integer, List, Unicode, Any, Bool, Dict, Enum, Instance, Integer, Float, List, Unicode,
) )
from .utils import random_port from .utils import random_port
@@ -80,15 +80,10 @@ class Spawner(LoggingConfigurable):
help="""The command used for starting notebooks.""" help="""The command used for starting notebooks."""
) )
@classmethod def __init__(self, **kwargs):
def fromJSON(cls, state, **kwargs): super(Spawner, self).__init__(**kwargs)
"""Create a new instance, and load its JSON state if self.user.state:
self.load_state(self.user.state)
state will be a dict, loaded from JSON in the database.
"""
inst = cls(**kwargs)
inst.load_state(state)
return inst
def load_state(self, state): def load_state(self, state):
"""load state from the database """load state from the database
@@ -96,18 +91,21 @@ class Spawner(LoggingConfigurable):
This is the extensible part of state This is the extensible part of state
Override in a subclass if there is state to load. Override in a subclass if there is state to load.
Should call `super`.
See Also See Also
-------- --------
get_state get_state, clear_state
""" """
pass if 'api_token' in state:
self.api_token = state['api_token']
def get_state(self): def get_state(self):
"""store the state necessary for load_state """store the state necessary for load_state
A black box of extra state for custom spawners A black box of extra state for custom spawners.
Should call `super`.
Returns Returns
------- -------
@@ -115,7 +113,19 @@ class Spawner(LoggingConfigurable):
state: dict state: dict
a JSONable dict of state a JSONable dict of state
""" """
return dict(api_token=self.api_token) state = {}
if self.api_token:
state['api_token'] = self.api_token
return state
def clear_state(self):
"""clear any state that should be cleared when the process stops
State that should be preserved across server instances should not be cleared.
Subclasses should call super, to ensure that state is properly cleared.
"""
self.api_token = ''
def get_args(self): def get_args(self):
"""Return the arguments to be passed after self.cmd""" """Return the arguments to be passed after self.cmd"""
@@ -201,6 +211,18 @@ class Spawner(LoggingConfigurable):
for callback in self._callbacks: for callback in self._callbacks:
add_callback(callback) add_callback(callback)
death_interval = Float(0.1)
@gen.coroutine
def wait_for_death(self, timeout=10):
"""wait for the process to die, up to timeout seconds"""
loop = IOLoop.current()
for i in range(int(timeout / self.death_interval)):
status = yield self.poll()
if status is not None:
break
else:
yield gen.Task(loop.add_timeout, loop.time() + self.death_interval)
def set_user_setuid(username): def set_user_setuid(username):
"""return a preexec_fn for setting the user (via setuid) of a spawned process""" """return a preexec_fn for setting the user (via setuid) of a spawned process"""
@@ -251,7 +273,7 @@ class LocalProcessSpawner(Spawner):
) )
proc = Instance(Popen) proc = Instance(Popen)
pid = Integer() pid = Integer(0)
sudo_args = List(['-n'], config=True, sudo_args = List(['-n'], config=True,
help="""arguments to be passed to sudo (in addition to -u [username]) help="""arguments to be passed to sudo (in addition to -u [username])
@@ -277,14 +299,23 @@ class LocalProcessSpawner(Spawner):
raise ValueError("This should be impossible") raise ValueError("This should be impossible")
def load_state(self, state): def load_state(self, state):
"""load pid from state"""
super(LocalProcessSpawner, self).load_state(state) super(LocalProcessSpawner, self).load_state(state)
if 'pid' in state:
self.pid = state['pid'] self.pid = state['pid']
def get_state(self): def get_state(self):
"""add pid to state"""
state = super(LocalProcessSpawner, self).get_state() state = super(LocalProcessSpawner, self).get_state()
if self.pid:
state['pid'] = self.pid state['pid'] = self.pid
return state return state
def clear_state(self):
"""clear pid state"""
super(LocalProcessSpawner, self).clear_state()
self.pid = 0
def sudo_cmd(self, user): def sudo_cmd(self, user):
return ['sudo', '-u', user.name] + self.sudo_args return ['sudo', '-u', user.name] + self.sudo_args
@@ -311,39 +342,48 @@ class LocalProcessSpawner(Spawner):
preexec_fn=self.make_preexec_fn(self.user.name), preexec_fn=self.make_preexec_fn(self.user.name),
) )
self.pid = self.proc.pid self.pid = self.proc.pid
self.start_polling()
@gen.coroutine @gen.coroutine
def poll(self): def poll(self):
"""Poll the process""" """Poll the process"""
# if we started the process, poll with Popen # if we started the process, poll with Popen
if self.proc is not None: if self.proc is not None:
raise gen.Return(self.proc.poll()) status = self.proc.poll()
if status is not None:
# clear state if the process is done
self.clear_state()
raise gen.Return(status)
# if we resumed from stored state, # if we resumed from stored state,
# we don't have the Popen handle anymore # we don't have the Popen handle anymore, so rely on self.pid
if not self.pid:
# no pid, not running
self.clear_state()
raise gen.Return(0)
# send signal 0 to check if PID exists
# this doesn't work on Windows, but that's okay because we don't support Windows. # this doesn't work on Windows, but that's okay because we don't support Windows.
try: alive = self._signal(0)
os.kill(self.pid, 0) if not alive:
except OSError as e: self.clear_state()
if e.errno == errno.ESRCH:
# no such process, return exitcode == 0, since we don't know the exit status
raise gen.Return(0) raise gen.Return(0)
else: else:
# None indicates the process is running
raise gen.Return(None) raise gen.Return(None)
@gen.coroutine def _signal(self, sig):
def _wait_for_death(self, timeout=10): """send a signal, and ignore ERSCH because it just means it already died
"""wait for the process to die, up to timeout seconds"""
for i in range(int(timeout * 10)): returns bool for whether the process existed to receive the signal.
status = yield self.poll() """
if status is not None: try:
break os.kill(self.pid, sig)
except OSError as e:
if e.errno == errno.ESRCH:
return False # process is gone
else: else:
loop = IOLoop.current() raise
yield gen.Task(loop.add_timeout, loop.time() + 0.1) return True # process exists
@gen.coroutine @gen.coroutine
def stop(self, now=False): def stop(self, now=False):
@@ -351,39 +391,29 @@ class LocalProcessSpawner(Spawner):
if `now`, skip waiting for clean shutdown if `now`, skip waiting for clean shutdown
""" """
self.stop_polling()
if not now: if not now:
# SIGINT to request clean shutdown status = yield self.poll()
self.log.debug("Interrupting %i", self.pid) if status is not None:
try:
os.kill(self.pid, signal.SIGINT)
except OSError as e:
if e.errno == errno.ESRCH:
return return
self.log.debug("Interrupting %i", self.pid)
yield self._wait_for_death(self.INTERRUPT_TIMEOUT) self._signal(signal.SIGINT)
yield self.wait_for_death(self.INTERRUPT_TIMEOUT)
# clean shutdown failed, use TERM # clean shutdown failed, use TERM
status = yield self.poll() status = yield self.poll()
if status is None: if status is not None:
self.log.debug("Terminating %i", self.pid)
try:
os.kill(self.pid, signal.SIGTERM)
except OSError as e:
if e.errno == errno.ESRCH:
return return
yield self._wait_for_death(self.TERM_TIMEOUT) self.log.debug("Terminating %i", self.pid)
self._signal(signal.SIGTERM)
yield self.wait_for_death(self.TERM_TIMEOUT)
# TERM failed, use KILL # TERM failed, use KILL
status = yield self.poll() status = yield self.poll()
if status is None: if status is not None:
self.log.debug("Killing %i", self.pid)
try:
os.kill(self.pid, signal.SIGKILL)
except OSError as e:
if e.errno == errno.ESRCH:
return return
yield self._wait_for_death(self.KILL_TIMEOUT) self.log.debug("Killing %i", self.pid)
self._signal(signal.SIGKILL)
yield self.wait_for_death(self.KILL_TIMEOUT)
status = yield self.poll() status = yield self.poll()
if status is None: if status is None:

View File

@@ -4,6 +4,7 @@
# Distributed under the terms of the Modified BSD License. # Distributed under the terms of the Modified BSD License.
import getpass import getpass
import logging
from pytest import fixture from pytest import fixture
from tornado import ioloop from tornado import ioloop
@@ -45,7 +46,7 @@ def io_loop():
@fixture(scope='module') @fixture(scope='module')
def app(request): def app(request):
app = MockHubApp() app = MockHubApp.instance(log_level=logging.DEBUG)
app.start([]) app.start([])
request.addfinalizer(app.stop) request.addfinalizer(app.stop)
return app return app

View File

@@ -155,6 +155,7 @@ def test_spawn(app, io_loop):
user = add_user(db, name=name) user = add_user(db, name=name)
r = api_request(app, 'users', name, 'server', method='post') r = api_request(app, 'users', name, 'server', method='post')
assert r.status_code == 201 assert r.status_code == 201
assert 'pid' in user.state
assert user.spawner is not None assert user.spawner is not None
status = io_loop.run_sync(user.spawner.poll) status = io_loop.run_sync(user.spawner.poll)
assert status is None assert status is None
@@ -173,5 +174,7 @@ def test_spawn(app, io_loop):
r = api_request(app, 'users', name, 'server', method='delete') r = api_request(app, 'users', name, 'server', method='delete')
assert r.status_code == 204 assert r.status_code == 204
assert user.spawner is None assert 'pid' not in user.state
status = io_loop.run_sync(user.spawner.poll)
assert status == 0