#!/usr/bin/env python3

import os
import sys

import common
import lkmc.import_path
import path_properties
import signal
import thread_pool

class Main(common.TestCliFunction):
    def __init__(self, *args, **kwargs):
        if not 'description' in kwargs:
            kwargs['description'] = '''\
Test userland executables in user mode, or baremetal executables in full system
depending on the value of the --mode option. See also:

* https://cirosantilli.com/linux-kernel-module-cheat#user-mode-tests
* https://cirosantilli.com/linux-kernel-module-cheat#baremetal-tests
* https://cirosantilli.com/linux-kernel-module-cheat#userland-setup-getting-started-natively
'''
        if not 'defaults' in kwargs:
            kwargs['defaults'] = {}
        if not 'mode' in kwargs['defaults']:
            kwargs['defaults']['mode'] = 'userland'
        super().__init__(*args, **kwargs)
        self.add_argument(
            'tests',
            nargs='*',
            help='''\
If given, run only the given tests. Otherwise, run all tests.
'''
        )

    def setup_one(self):
        self.env['tests'] = self.resolve_targets(
            [
                self.env['baremetal_source_dir'],
                self.env['userland_source_dir']
            ],
            self.env['tests']
        )

    def timed_main(self):
        run_args = self.get_common_args()
        rootdir_abs_len = len(self.env['root_dir'])
        with thread_pool.ThreadPool(
            self.run_test,
            handle_output=self.handle_output_function,
            nthreads=self.env['nproc'],
            thread_id_arg='thread_id',
            submit_raise_exit=self.env['quit_on_fail'],
        ) as my_thread_pool:
            for test in self.env['tests']:
                for path, in_dirnames, in_filenames in self.sh.walk(test):
                    path_abs = os.path.abspath(path)
                    dirpath_relative_root = path_abs[rootdir_abs_len + 1:]
                    for in_filename in in_filenames:
                        if os.path.splitext(in_filename)[1] in self.env['build_in_exts']:
                            path_relative_root = os.path.join(dirpath_relative_root, in_filename)
                            my_path_properties = path_properties.get(path_relative_root)
                            if my_path_properties.should_be_tested(self.env):
                                cur_run_args = run_args.copy()
                                cur_run_args.update({
                                    self.env['mode']: [os.path.relpath(
                                        os.path.join(path_abs, in_filename),
                                        os.getcwd()
                                    )],
                                })
                                cur_run_args.update(my_path_properties['test_run_args'])
                                if my_path_properties['test_stdin_data'] is not None:
                                    cur_run_args['stdin_file'] = os.path.join(
                                        path_abs,
                                        'test_data',
                                        my_path_properties['test_stdin_data'] + '.i'
                                    )
                                run_test_args = {
                                    'expected_exit_status': my_path_properties['exit_status'],
                                    'run_args': cur_run_args,
                                    'run_obj': lkmc.import_path.import_path_main('run'),
                                    'test_id': '{} {}'.format(self.env['mode'], path_relative_root),
                                }
                                if (my_path_properties['qemu_unimplemented_instruction'] and
                                    self.env['emulator'] == 'qemu' and
                                    self.env['mode'] == 'userland'
                                ):
                                    run_test_args['expected_exit_status'] = -signal.Signals.SIGILL.value
                                my_signal = my_path_properties['signal_received']
                                if my_signal is not None:
                                    if self.env['mode'] == 'baremetal':
                                        run_test_args['expected_exit_status'] = 128 + my_signal.value
                                    elif self.env['mode'] == 'userland':
                                        # Python subprocess reports signals differently from Bash's 128 + signal rule.
                                        run_test_args['expected_exit_status'] = -my_signal.value
                                my_thread_pool.submit(run_test_args)
        return self._handle_thread_pool_errors(my_thread_pool)

if __name__ == '__main__':
    Main().cli()
