Source code for compliance.runners

# -*- mode:python; coding:utf-8 -*-
# Copyright (c) 2020 IBM Corp. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compliance automation flow management module."""

import inspect
import json
import os
import re
import sys
import time
import unittest
from collections import defaultdict

from compliance.check import ComplianceCheck
from compliance.config import get_config
from compliance.controls import ControlDescriptor
from compliance.fetch import ComplianceFetcher
from compliance.fix import Fixer
from compliance.locker import Locker
from compliance.notify import get_notifiers
from compliance.report import ReportBuilder
from compliance.utils.exceptions import LockerPushError
from compliance.utils.path import (
    CHECK_PREFIX, FETCH_PREFIX, get_toplevel_dirpath, load_evidences_modules
)


class _BaseRunner(object):
    """Base class for fetcher and check processing."""

    def __init__(self, opts, extra_opts):
        self.opts = opts
        self.extra_opts = extra_opts
        self.load_errors = set()

    def __enter__(self):
        self.init_config()

    def __exit__(self, typ, val, traceback):
        pass

    def init_config(self):
        """Initialize the framework configuration."""
        self._load_compliance_config()
        self._init_dirs()
        ComplianceFetcher.config = self.config
        ComplianceCheck.config = self.config

    def init_locker(self, ttl_tolerance=0):
        """
        Initialize the framework locker.

        :param ttl_tolerance: Evidence TTL tolerance in seconds
        """
        self.locker = self._create_the_locker(ttl_tolerance)
        self.locker.init()
        self.locker.logger_init_msgs()
        for path in self.opts.force:
            self.locker.forced_evidence.append(path)
        ComplianceFetcher.locker = self.locker
        ComplianceCheck.locker = self.locker

    def get_test_candidates(self, suite):
        """
        Provide the test cases from a test suite.

        :param suite: a TestSuite object

        :returns: generator iterable of test cases
        """
        for suite_test in suite:
            if unittest.suite._isnotsuite(suite_test):
                yield suite_test
                continue
            for test in self.get_test_candidates(suite_test):
                yield test

    def _load_compliance_config(self):
        creds_path = os.path.expanduser(self.opts.creds_path)
        if not os.path.isfile(creds_path):
            raise ValueError(f'{creds_path} does not exist.')
        self.config = get_config()
        self.config.creds_path = creds_path
        self.config.load(self.opts.compliance_config)

    def _init_dirs(self):
        self.dirs = set()
        for p in self.extra_opts:
            if os.path.isdir(p):
                self.dirs.add(get_toplevel_dirpath(os.path.abspath(p)))
            elif os.path.isfile(p):
                dirpath = get_toplevel_dirpath(os.path.abspath(p))
                if dirpath is None:
                    continue
                self.dirs.add(dirpath)
            else:
                self.dirs.add('.')

        if not self.dirs:
            self.dirs.add(os.path.abspath('.'))

        at_least_one_valid_dir = any(
            (get_toplevel_dirpath(d) is not None for d in self.dirs)
        )
        if not at_least_one_valid_dir:
            raise ValueError(
                'None of the paths provided are valid directories.  '
                'Please provide at least one directory containing a '
                '"controls.json" file.'
            )
        for d in self.dirs:
            load_evidences_modules(d)

    def _create_the_locker(self, ttl_tolerance):
        dirname = self.config.get('locker.dirname')
        mode = self.opts.evidence

        gitconfig = self.config.get('locker.gitconfig')
        if mode == 'local':
            return Locker(
                name=dirname, ttl_tolerance=ttl_tolerance, gitconfig=gitconfig
            )
        repo_url = self.config.get('locker.repo_url')
        if repo_url is None:
            raise ValueError(f'Evidence mode "{mode}" requires a URL.')
        return Locker(
            name=dirname,
            repo_url=repo_url,
            creds=self.config.creds,
            do_push=True if mode == 'full-remote' else False,
            ttl_tolerance=ttl_tolerance,
            gitconfig=gitconfig
        )


[docs]class FetchMode(_BaseRunner): """The fetcher process flow.""" def __enter__(self): """Initialize fetcher mode processing.""" super(FetchMode, self).__enter__() self.init_locker(self.config.get('locker.ttl_tolerance', 0)) return self def __exit__(self, typ, val, traceback): """Handle post fetcher test execution processing.""" super(FetchMode, self).__exit__(typ, val, traceback) # make sure that all added evidence are committed self.locker.checkin() # Only push if fetchers are run separately from checks, # otherwise push occurs after check processing is complete. if not self.opts.check: try: self.locker.push() except LockerPushError as lpe: self.locker.logger.error(str(lpe))
[docs] def get_fetchers(self): """Provide all compliance framework fetcher classes.""" fetchers = set() for loc in self.dirs: tl = unittest.TestLoader() tl.testMethodPrefix = FETCH_PREFIX candidates = self.get_test_candidates( tl.discover(loc, f'{FETCH_PREFIX}*.py') ) for candidate in candidates: if issubclass(candidate.__class__, ComplianceFetcher): fetchers.add(candidate.__class__) for load_err in tl.errors: try: locate = re.search( '^Failed to import test module: (.+?)\n.*?', load_err ) if locate.group(1).split('.')[-1].startswith(FETCH_PREFIX): self.load_errors.add(load_err) except AttributeError: pass if not (self.opts.include or self.opts.exclude): return fetchers include = {f'{f.__module__}.{f.__name__}' for f in fetchers} if self.opts.include: include = set(json.loads(open(self.opts.include).read())) exclude = set() if self.opts.exclude: exclude = set(json.loads(open(self.opts.exclude).read())) include -= exclude return filter( lambda f: f'{f.__module__}.{f.__name__}' in include, fetchers )
[docs] def run_fetchers(self, reruns=None): """ Execute fetchers. :param reruns: A list of fetchers in dot notation to rerun :returns: Success (True) if no errors other than dependency unavailable """ loader = unittest.TestLoader() loader.testMethodPrefix = FETCH_PREFIX fetchers = unittest.TestSuite() if reruns is None: fetcher_overrides = [ fo for fo in self.extra_opts if not os.path.isdir(fo) ] if fetcher_overrides: fetchers.addTests(loader.loadTestsFromNames(fetcher_overrides)) else: for fetcher in self.get_fetchers(): fetchers.addTests(loader.loadTestsFromTestCase(fetcher)) else: self.config.dependency_rerun = True self.locker.reset_depenency_rerun() fetchers.addTests(loader.loadTestsFromNames(reruns)) runner = unittest.TextTestRunner( verbosity=self.opts.verbose, resultclass=ComplianceBaseResult ) return all( ( 'DependencyUnavailableError' in tb.split('Traceback')[-1] for (_, tb) in runner.run(fetchers).errors ) )
[docs]class CheckMode(_BaseRunner): """The check process flow.""" def __init__(self, opts, extra_opts): """ Construct and initialize the check mode context manager. :param opts: arguments provided from the command line. :param extra_opts: additional arguments provided from the command line. """ super(CheckMode, self).__init__(opts, extra_opts) self.accreds = [a.strip() for a in opts.check.split(',')] # Backward compatibility to support ghe_issues option self.notifiers = [ n.strip().replace('ghe_issues', 'gh_issues') for n in opts.notify.split(',') ] self.push_error = False def __enter__(self): """Initialize check mode processing.""" super(CheckMode, self).__enter__() self.init_locker() return self def __exit__(self, typ, val, traceback): """Handle post check test execution processing.""" super(CheckMode, self).__exit__(typ, val, traceback) try: self.build_reports() # When in full-remote mode, fixers only run if push was successful self.fix_failures() except LockerPushError as lpe: self.locker.logger.error(str(lpe)) self.push_error = True self.run_notifiers()
[docs] def init_config(self): """Initialize the framework configuration for check execution.""" super(CheckMode, self).init_config() self.results = None self.controls = ControlDescriptor(self.dirs)
[docs] def get_checks(self): """Provide the appropriate compliance framework check classes.""" checks = set() tests_found = set() for loc in self.dirs: tl = unittest.TestLoader() tl.testMethodPrefix = CHECK_PREFIX candidates = self.get_test_candidates( tl.discover(loc, f'{CHECK_PREFIX}*.py') ) for test in [c.__class__ for c in candidates]: path = f'{test.__module__}.{test.__name__}' tests_found.add(path) in_accred_grouping = self.controls.is_test_included( path, self.accreds ) if issubclass(test, ComplianceCheck) and in_accred_grouping: test.tests = [ method for method in dir(test) if ( method.startswith(CHECK_PREFIX) and ( inspect.ismethod(getattr(test, method)) or inspect.isfunction(getattr(test, method)) ) ) ] checks.add(test) for load_err in tl.errors: try: locate = re.search( '^Failed to import test module: (.+?)\n.*?', load_err ) for accred in self.accreds: for check in self.controls.accred_checks[accred]: if check.startswith(locate.group(1)): self.load_errors.add( f'Unable to load {check}\n\n{load_err}' ) tests_found.add(check) except AttributeError: pass expected_checks = set() for accred, checks_in_accred in self.controls.accred_checks.items(): if accred in self.accreds: expected_checks.update(checks_in_accred) for check_not_found in expected_checks - tests_found: self.load_errors.add( ( f'Unable to load {check_not_found}\n\n' f'The check {check_not_found} was not found. ' 'Please validate that the path provided is correct.' ) ) return checks
[docs] def run_checks(self): """ Execute checks. :returns: Success (True) if no errors encountered """ loader = unittest.TestLoader() loader.testMethodPrefix = CHECK_PREFIX checks = unittest.TestSuite() for check in self.get_checks(): checks.addTests(loader.loadTestsFromTestCase(check)) runner = unittest.TextTestRunner( verbosity=self.opts.verbose, resultclass=ComplianceCheckResult ) check_run = runner.run(checks) self.results = check_run.results return False if check_run.errors else True
[docs] def fix_failures(self): """Fix failures if fixer methods are included in checks.""" if self.opts.fix != 'off': fixer = Fixer(self.results, dry_run=(self.opts.fix == 'dry-run')) fixer.fix()
[docs] def build_reports(self): """Generate reports based on check results.""" builder = ReportBuilder(self.locker, self.results, self.controls) builder.build()
[docs] def run_notifiers(self): """Execute all requested notifiers.""" sys.stdout.flush() sys.stderr.flush() notifiers = get_notifiers() for notifier_name in self.notifiers: notifier_args = [self.results, self.controls] if notifier_name == 'locker': notifier_args.append(self.locker) notifier = notifiers[notifier_name]( *notifier_args, push_error=self.push_error ) notifier.notify()
[docs]class ComplianceBaseResult(unittest.TextTestResult): """Base Compliance result class."""
[docs] def startTest(self, test): # noqa: N802 """Start test timer for each test.""" super(ComplianceBaseResult, self).startTest(test) if self.showAll: self.start_time = time.perf_counter()
[docs] def stopTest(self, test): # noqa: N802 """Report on execution time at the end of each test.""" super(ComplianceBaseResult, self).stopTest(test) if self.showAll: time_taken = time.perf_counter() - self.start_time self.stream.write(f'{self.getDescription(test)} - ran in: ') self.stream.writeln(f'{time_taken:.3f}s') self.stream.flush()
[docs]class ComplianceCheckResult(ComplianceBaseResult): """Compliance check result class.""" def __init__(self, *args, **kwargs): """Construct and initialize the compliance check result.""" super(ComplianceCheckResult, self).__init__(*args, **kwargs) self.results = defaultdict(dict)
[docs] def addSuccess(self, test): # noqa: N802 """ Add test successes and warnings to check results. :param test: a ``unittest.TestCase`` object. """ super(ComplianceCheckResult, self).addSuccess(test) self.record(test, 'pass' if test.warnings_count() == 0 else 'warn')
[docs] def addError(self, test, err): # noqa: N802 """ Add test errors to check results. :param test: a ``unittest.TestCase`` object. :param err: a tuple of the form returned by sys.exc_info() """ super(ComplianceCheckResult, self).addError(test, err) self.record(test, 'error')
[docs] def addFailure(self, test, err): # noqa: N802 """ Add test failures to check results. :param test: a ``unittest.TestCase`` object. :param err: a tuple of the form returned by sys.exc_info() """ super(ComplianceCheckResult, self).addFailure(test, err) self.record(test, 'fail')
[docs] def record(self, test, status): """ Populate the results as expected by downstream reports and notifiers. :param test: a ``unittest.TestCase`` object. :param status: a string status (`pass`, `warn`, `fail`, or `error`) """ self.results[test.id()] = { 'status': status, 'timestamp': time.time(), 'test': ComplianceTestWrapper(test) }
[docs]class ComplianceTestWrapper(object): """ TestCase wrapper class. Wraps the TestCase test object as an attribute to ensure backwards compatibility with checks, reporting, and notifiers. """ def __init__(self, test): """Construct and initialize the test case wrapper object.""" self.test = test