Source code for grader_service.main

# Copyright (c) 2022, TU Wien
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import inspect
import logging
import os
import secrets
import shutil
import signal
import subprocess
import sys

import tornado
import uvloop as uvloop
from jupyterhub.log import log_request
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from tornado.httpserver import HTTPServer
from traitlets import (
    Bool,
    Dict,
    Enum,
    HasTraits,
    Instance,
    Int,
    List,
    TraitError,
    Type,
    Unicode,
    config,
    default,
    observe,
    validate,
)
from traitlets import log as traitlets_log

from grader_service import __version__
from grader_service.auth.auth import Authenticator

# run __init__.py to register handlers
from grader_service.auth.dummy import DummyAuthenticator
from grader_service.autograding.celery.app import CeleryApp
from grader_service.handlers.base_handler import RequestHandlerConfig
from grader_service.handlers.static import CacheControlStaticFilesHandler
from grader_service.oauth2 import handlers as oauth_handlers
from grader_service.oauth2.provider import make_provider
from grader_service.orm import Lecture, Role, User
from grader_service.orm.base import DeleteState
from grader_service.orm.lecture import LectureState
from grader_service.orm.takepart import Scope
from grader_service.plugins.lti import LTISyncGrades
from grader_service.registry import HandlerPathRegistry
from grader_service.server import GraderServer
from grader_service.utils import url_path_join


def get_session_maker(url) -> scoped_session:
    engine = create_engine(url)
    return scoped_session(sessionmaker(bind=engine))


[docs] class GraderService(config.Application): name = "grader-service" version = __version__ description = """Starts the grader service, which can be used to create and distribute assignments, collect submissions and grade them. """ examples = """ generate default config file: grader-service --generate-config -f /etc/grader/grader_service_config.py spawn the grader service: grader-service -f /etc/grader/grader_service_config.py """ generate_config = Bool(False, help="Generate config file based on defaults.").tag(config=True) service_host = Unicode( os.getenv("GRADER_HOST", "0.0.0.0"), help="The host address of the service" ).tag(config=True) service_port = Int( int(os.getenv("GRADER_PORT", "4010")), help="The port the service runs on" ).tag(config=True) reuse_port = Bool( False, help="Whether to allow for the specified service port to be reused." ).tag(config=True) grader_service_dir = Unicode(os.getenv("GRADER_SERVICE_DIRECTORY"), allow_none=False).tag( config=True ) db_url = Unicode(allow_none=False).tag(config=True) oauth_provider = None @default("db_url") def _default_db_url(self): db_path = os.path.join(self.grader_service_dir, "grader.db") service_dir_url = f"sqlite:///{db_path}" return os.getenv("GRADER_DB_URL", service_dir_url) grader_cookie_secret = Unicode( default_value=os.getenv("GRADER_COOKIE_SECRET", secrets.token_hex(nbytes=32)), allow_none=False, ).tag(config=True) max_body_size = Int(104857600, help="Sets the max buffer size in bytes, default to 100mb").tag( config=True ) max_buffer_size = Int(104857600, help="Sets the max body size in bytes, default to 100mb").tag( config=True ) service_git_username = Unicode("grader-service", allow_none=False).tag(config=True) service_git_email = Unicode("", allow_none=False).tag(config=True) config_file = Unicode("grader_service_config.py", help="The config file to load").tag( config=True ) base_url_path = Unicode("/services/grader/", help="base url path", allow_none=False).tag( config=True ) authenticator_class = Type( default_value=DummyAuthenticator, klass=Authenticator, allow_none=False, config=True, help=""" The authenticator class to use for authentication. Default is DummyAuthenticator, which does not require a password by default. You can set this to your own authenticator class, which should inherit from Authenticator. """, ) authenticator = Instance(klass=Authenticator) # TODO make configurable oauth_token_expires_in = int(1 * 24 * 3600) load_roles = Dict( List(), help=""" Dict of `'<lecture-code>': List[{'members': List[str], 'role': str}]` entries to load at startup. Example:: c.GraderService.load_roles = { 'lecture1': [ { 'members': ['student1', 'student2'], 'role': 'student' }, { 'members': ['instructor1', 'instructor2'], 'role': 'instructor' } ], } """, ).tag(config=True) oauth_clients = List( Dict(), default_value=[], help=""" List of OAuth clients `[{'client_id': '<client_id>', 'client_secret': '<client_secret>', 'redirect_uri': '<redirect_uri>'}]` to register for the provider. Example:: c.GraderService.oauth_clients = [{ 'client_id': 'hub', 'client_secret': 'hub', 'redirect_uri': 'http://localhost:8080/hub/oauth_callback' }] """, ).tag(config=True) @default("authenticator") def _authenticator_default(self): return self.authenticator_class(parent=self) @validate("config_file") def _validate_config_file(self, proposal): if not os.path.isfile(proposal.value) and not self.generate_config: print( "ERROR: Failed to find specified config file: {}".format(proposal.value), file=sys.stderr, ) sys.exit(1) return proposal.value flags = { "debug": ( {"Application": {"log_level": logging.DEBUG}}, "Set log-level to debug, for the most verbose logging.", ), "show-config": ( {"Application": {"show_config": True}}, "Show the application's configuration (human-readable format)", ), "show-config-json": ( {"Application": {"show_config_json": True}}, "Show the application's configuration (json format)", ), "generate-config": ( {"GraderService": {"generate_config": True}}, "generate default config file", ), } aliases = { "log-level": "Application.log_level", "f": "GraderService.config_file", "config": "GraderService.config_file", } log_level = Enum( [ 0, 10, 20, 30, 40, 50, "CRITICAL", "FATAL", "ERROR", "WARNING", "WARN", "INFO", "DEBUG", "NOTSET", ], "INFO", ).tag(config=True) def setup_loggers(self, log_level: str): # pragma: no cover """Handles application, Tornado, and SQLAlchemy logging configuration.""" stream_handler = logging.StreamHandler root_logger = logging.getLogger() root_logger.setLevel(log_level) fmt = "%(color)s%(levelname)-8s %(asctime)s %(module)-13s |%(end_color)s %(message)s" formatter = tornado.log.LogFormatter(fmt=fmt, color=True, datefmt=None) for log in ("access", "application", "general"): logger = logging.getLogger("tornado.{}".format(log)) if len(logger.handlers) > 0: logger.removeHandler(logger.handlers[0]) logger.setLevel(log_level) handler = stream_handler(stream=sys.stdout) handler.setFormatter(formatter) logger.addHandler(handler) sql_logger = logging.getLogger("sqlalchemy") sql_logger.propagate = False sql_logger.setLevel("WARN") sql_handler = stream_handler(stream=sys.stdout) sql_handler.setLevel("WARN") sql_handler.setFormatter(formatter) sql_logger.addHandler(sql_handler) oauth_log = logging.getLogger("oauthlib") oauth_handler = stream_handler(stream=sys.stdout) oauth_handler.setFormatter(formatter) oauth_log.setLevel(log_level) oauth_log.addHandler(oauth_handler) traitlet_logger = traitlets_log.get_logger() traitlet_logger.removeHandler(traitlet_logger.handlers[0]) traitlet_logger.setLevel(log_level) traitlets_handler = stream_handler(stream=sys.stdout) traitlets_handler.setFormatter(formatter) traitlet_logger.addHandler(traitlets_handler) def write_config_file(self): self.log.info(f"Writing config file {os.path.abspath(self.config_file)}") config_file_dir = os.path.dirname(os.path.abspath(self.config_file)) if not os.path.isdir(config_file_dir): self.exit( f"The directory to write the config file has to exist. {config_file_dir} not found" ) if os.path.isfile(os.path.abspath(self.config_file)): self.exit( f"Config file {os.path.abspath(self.config_file)} \ already exists!" ) members = inspect.getmembers( sys.modules[__name__], lambda x: inspect.isclass(x) and issubclass(x, HasTraits) ) config_classes = [x[1] for x in members] config_text = self.generate_config_file(classes=config_classes) if isinstance(config_text, bytes): config_text = config_text.decode("utf8") print("Generating config: %s" % self.config_file) with open(self.config_file, mode="w") as f: f.write(config_text) def initialize(self, argv, *args, **kwargs): self.log.info("Starting Initialization...") self.log.info("Loading config file...") super().initialize(*args, **kwargs) self.parse_command_line(argv) self.load_config_file(self.config_file) self.setup_loggers(self.log_level) self.session_maker = get_session_maker(self.db_url) self.init_roles() # use uvloop instead of default asyncio loop # asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) self._start_future = asyncio.Future() if sys.version_info.major < 3 or sys.version_info.minor < 9: msg = "Grader Service needs Python version 3.9 or above to run!" raise RuntimeError(msg) if shutil.which("git") is None: msg = "No git executable found! Git is necessary to run Grader Service!" raise RuntimeError(msg) def set_config(self): """Pass config to singletons.""" RequestHandlerConfig.config = self.config LTISyncGrades.config = self.config CeleryApp.instance(config=self.config) async def cleanup(self): pass def init_oauth(self): engine = create_engine(self.db_url) session = sessionmaker(engine) self.oauth_provider = make_provider( session, url_prefix=url_path_join(self.base_url_path, "api/oauth2"), login_url=url_path_join(self.base_url_path, "login"), token_expires_in=self.oauth_token_expires_in, ) for client in self.oauth_clients: self.oauth_provider.add_client( client["client_id"], client["client_secret"], client["redirect_uri"], ["identify"] ) def init_roles(self): """Load predefined groups into the database""" with self.session_maker() as db: users_loaded = set() for lecture_code in self.load_roles.keys(): role_list = self.load_roles[lecture_code] for role_dict in role_list: role = role_dict.get("role") users = role_dict.get("members", []) lecture = db.query(Lecture).filter(Lecture.code == lecture_code).one_or_none() # create lecture if no lecture with that name exists yet # (code is set in create) if lecture is None: self.log.info(f"Adding new lecture with lecture_code {lecture_code}") lecture = Lecture() lecture.code = lecture_code lecture.name = lecture_code lecture.state = LectureState.active lecture.deleted = DeleteState.active db.add(lecture) db.commit() for username in users: user = db.query(User).filter(User.name == username).one_or_none() if user is None: self.log.info(f"Adding new user with username {username}") user = User() user.name = username user.display_name = username db.add(user) db.commit() # delete all roles of users the first time a new role is added for the user if user.name not in users_loaded: db.query(Role).filter(Role.username == user.name).delete() users_loaded.add(user.name) try: db.add(Role(username=user.name, lectid=lecture.id, role=Scope[role])) except KeyError: self.log.error(f"Invalid role name: {role}") raise ValueError(f"Invalid role name: {role}") db.commit() async def start(self): self.log.info(f"Config File: {os.path.abspath(self.config_file)}") if self.generate_config: self.write_config_file() self.exit(0) self.log.info("Starting Grader Service...") self.io_loop = tornado.ioloop.IOLoop.current() self._setup_environment() self.init_oauth() # pass config self.set_config() handlers = HandlerPathRegistry.handler_list(self.base_url_path) self.log.info(handlers) # Add the handlers of the authenticator auth_handlers = self.authenticator.get_handlers(self.base_url_path) handlers.extend(auth_handlers) self.log.info( f"Registered authentication handlers for {self.authenticator.__class__.__name__}: {[n for n, _ in auth_handlers]}" ) oauth_provider_handlers = oauth_handlers.get_oauth_default_handlers(self.base_url_path) handlers.extend(oauth_provider_handlers) self.log.info(f"Registered OAuth handlers: {[n for n, _ in oauth_provider_handlers]}") # start the webserver self.http_server: HTTPServer = HTTPServer( GraderServer( grader_service_dir=self.grader_service_dir, base_url=self.base_url_path, authenticator=self.authenticator, handlers=handlers, oauth_provider=self.oauth_provider, cookie_secret=self.grader_cookie_secret, # generate new cookie secret at startup config=self.config, session_maker=self.session_maker, parent=self, login_url=self.authenticator.login_url(self.base_url_path), logout_url=self.authenticator.logout_url(self.base_url_path), static_url_prefix=url_path_join(self.base_url_path, "/static/"), static_handler_class=CacheControlStaticFilesHandler, log_function=log_request, ), # ssl_options=ssl_context, max_buffer_size=self.max_buffer_size, max_body_size=self.max_body_size, xheaders=True, ) self.log.info(f"Service directory - {self.grader_service_dir}") self.http_server.listen( self.service_port, address=self.service_host, reuse_port=self.reuse_port ) for s in (signal.SIGTERM, signal.SIGINT): asyncio.get_event_loop().add_signal_handler( s, lambda: asyncio.ensure_future(self.shutdown_cancel_tasks(s)) ) self.log.info(f"Grader service running at {self.service_host}:{self.service_port}") # finish start self._start_future.set_result(None) def _setup_environment(self): if not os.path.exists(os.path.join(self.grader_service_dir, "git")): os.mkdir(os.path.join(self.grader_service_dir, "git")) # check if git config exits so that git commits don't fail if ( subprocess.run( ["git", "config", "init.defaultBranch"], check=False, capture_output=True ) .stdout.decode() .strip() != "main" ): raise RuntimeError("Git default branch has to be set to 'main'!") if ( subprocess.run(["git", "config", "user.name"], check=False, capture_output=True) .stdout.decode() .strip() == "" ): raise RuntimeError("Git user.name has to be set!") if ( subprocess.run(["git", "config", "user.email"], check=False, capture_output=True) .stdout.decode() .strip() == "" ): raise RuntimeError("Git user.email has to be set!") async def shutdown_cancel_tasks(self, sig): """Cancel all other tasks of the event loop and initiate cleanup""" self.log.critical("Received signal %s, initiating shutdown...", sig.name) tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] if tasks: self.log.debug("Cancelling pending tasks") [t.cancel() for t in tasks] try: await asyncio.wait(tasks) except asyncio.CancelledError: self.log.debug("Caught Task CancelledError. Ignoring") except StopAsyncIteration: msg = "Caught StopAsyncIteration Exception" self.log.error(msg, exc_info=True) tasks = [t for t in asyncio.all_tasks()] for t in tasks: self.log.debug("Task status: %s", t) await self.cleanup() asyncio.get_event_loop().stop() async def launch_instance_async(self, argv=None): try: self.initialize(argv) await self.start() except Exception as e: self.log.exception(e) self.exit(1) @classmethod def launch_instance(cls, argv=None): self = cls.instance() loop = tornado.ioloop.IOLoop.current() task = asyncio.ensure_future(self.launch_instance_async(argv)) try: loop.start() except KeyboardInterrupt: print("\nInterrupted") finally: if task.done(): # re-raise exceptions in launch_instance_async task.result() loop.stop() @validate("grader_service_dir") def _validate_service_dir(self, proposal): path: str = proposal["value"] if not os.path.isabs(path): raise TraitError("The path is not absolute") if not os.path.isdir(path): os.mkdir(path, mode=0o700) return path @observe("grader_service_dir") def _observe_service_dir(self, change): path = change["new"] git_path = os.path.join(path, "git") if not os.path.isdir(git_path): os.mkdir(git_path, mode=0o700)