#!/usr/bin/env python3

import os, zlib

class Extractor:
    def extract(self, romid):
        self.romid = romid

        filename = 'pd.%s.z64' % romid

        fd = open(filename, 'rb')
        self.rom = fd.read()
        fd.close()

        self.gamedata = self.decompress(self.rom[self.val('gamedata'):])
        self.extract_all()

    def extract_all(self):
        self.extract_audio()
        self.extract_files()
        self.extract_fonts()
        self.extract_gamedata()
        self.extract_textures()
        self.extract_ucodes()
        self.extract_mpconfigs()
        self.extract_mpstrings()

    #
    # Audio
    #

    def extract_audio(self):
        sfxctl = self.val('sfxctl')
        sfxtbl = sfxctl + 0x2fb80
        musicctl = sfxtbl + 0x4c2160
        musictbl = musicctl + 0xa060
        seqtbl = musictbl + 0x17c070
        self.write('audio/sfx.ctl', self.rom[sfxctl:sfxtbl])
        self.write('audio/sfx.tbl', self.rom[sfxtbl:musicctl])
        self.write('audio/music.ctl', self.rom[musicctl:musictbl])
        self.write('audio/music.tbl', self.rom[musictbl:seqtbl])

        seqtbllen = 0x563b0 if self.romid == 'ntsc-beta' else 0x563a0

        sequences = self.rom[seqtbl:seqtbl+seqtbllen]
        self.write('audio/sequences.bin', sequences)

        # Extract sequences
        count = int.from_bytes(sequences[0:2], 'big')
        i = 0
        while i < count:
            sequence = self.extract_sequence(sequences, i)
            self.write('audio/sequences/%03d.seq' % i, sequence)
            i += 1

    def extract_sequence(self, sequences, index):
        pos = 4 + index * 8
        offset = int.from_bytes(sequences[pos:pos+4], 'big')
        return self.decompress(sequences[offset:])

    #
    # Files
    #

    def extract_files(self):
        offsets = self.get_file_offsets()
        names = self.get_file_names(offsets[len(offsets) - 1])

        for (index, offset) in enumerate(offsets):
            if index == 0:
                continue
            try:
                endoffset = offsets[index + 1]
            except:
                return
            content = self.rom[offset:endoffset]
            name = names[index]

            # If the content is zipped then the last byte might be padding. So
            # unzip it to see.
            unzipped = None
            if content[0:2] == b'\x11\x73':
                (unzipped, unused) = self.decompressandgetunused(content)
                if len(unused):
                    content = content[:-len(unused)]

            self.write('files/%s' % name, content)

            unzippedname = name[:-1] if name.endswith('Z') else name

            if name.endswith('tilesZ'):
                self.write('files/bgdata/%s.bin' % unzippedname, unzipped)

            if name.startswith('C'):
                self.write('files/chrs/%s.bin' % unzippedname, unzipped)

            if name.startswith('G'):
                self.write('files/guns/%s.bin' % unzippedname, unzipped)

            if name.startswith('L'):
                self.write('files/lang/%s.bin' % unzippedname, unzipped)

            if name.startswith('P') and len(content):
                self.write('files/props/%s.bin' % unzippedname, unzipped)

            if name.startswith('U'):
                self.write('files/setup/%s.bin' % unzippedname, unzipped)

    def get_file_offsets(self):
        i = self.val('files')
        offsets = []
        while True:
            offset = int.from_bytes(self.gamedata[i:i+4], 'big')
            if offset == 0 and len(offsets):
                return offsets
            offsets.append(offset)
            i += 4

    def get_file_names(self, tableaddr):
        i = tableaddr
        names = []
        while True:
            offset = int.from_bytes(self.rom[i:i+4], 'big')
            if offset == 0 and len(names):
                return names
            names.append(self.read_name(tableaddr + offset))
            i += 4

    def read_name(self, address):
        nullpos = self.rom[address:].index(0)
        return str(self.rom[address:address + nullpos], 'utf-8')

    #
    # MpConfigs
    #

    def extract_mpconfigs(self):
        addr = self.val('mpconfigs')
        self.write('ucode/mpconfigs.bin', self.rom[addr:addr+0x68*44])

    def extract_mpstrings(self):
        self.extract_mpstrings_lang(0, 'E')
        self.extract_mpstrings_lang(1, 'J')
        self.extract_mpstrings_lang(2, 'P')
        self.extract_mpstrings_lang(3, 'G')
        self.extract_mpstrings_lang(4, 'F')
        self.extract_mpstrings_lang(5, 'S')
        self.extract_mpstrings_lang(6, 'I')

    def extract_mpstrings_lang(self, index, lang):
        addr = self.val('mpconfigs') + 0x68 * 44 + 0x3700 * index
        self.write('ucode/mpstrings%s.bin' % lang, self.rom[addr:addr+0x3700])

    #
    # Fonts
    #

    def extract_fonts(self):
        # Not implemented
        pass

    #
    # Game data
    #

    def extract_gamedata(self):
        self.write('ucode/gamedata.bin', self.gamedata)

    #
    # Textures
    #

    def extract_textures(self):
        base = self.val('textures')
        datalen = 0x294960 if self.romid == 'jap-final' else 0x291d60
        tablepos = base + datalen
        index = 0
        while True:
            start = int.from_bytes(self.rom[tablepos+1:tablepos+4], 'big')
            end = int.from_bytes(self.rom[tablepos+9:tablepos+12], 'big')
            if int.from_bytes(self.rom[tablepos+12:tablepos+16], 'big') != 0:
                return
            texturedata = self.rom[base+start:base+end]
            self.write('textures/%04x.bin' % index, texturedata)
            index += 1
            tablepos += 8

    #
    # Ucodes
    #

    def extract_ucodes(self):
        self.write('ucode/rspboot.bin', self.rom[0x40:0x1000])
        self.write('ucode/boot.bin', self.rom[0x1000:0x3050])
        self.write('ucode/lib.bin', self.decompress(self.rom[0x3050:]))
        self.write('ucode/inflate.bin', self.rom[0x4e850:0x4fc40])
        self.extract_ucode_game()

    def extract_ucode_game(self):
        binary = bytes()
        start = i = self.val('game')

        while True:
            romoffset = start + int.from_bytes(self.rom[i:i+4], 'big') + 2
            peek = int.from_bytes(self.rom[romoffset:romoffset+2], 'big')
            if peek == 0:
                break
            part = self.decompress(self.rom[romoffset:romoffset+0x1000])
            binary += part
            if len(part) != 0x1000:
                break
            i += 4

        self.write('ucode/game.bin', binary)

    #
    # Misc functions
    #

    def decompress(self, buffer):
        header = int.from_bytes(buffer[0:2], 'big')
        assert(header == 0x1173)
        return zlib.decompress(buffer[5:], wbits=-15)

    def decompressandgetunused(self, buffer):
        header = int.from_bytes(buffer[0:2], 'big')
        assert(header == 0x1173)
        obj = zlib.decompressobj(wbits=-15)
        bin = obj.decompress(buffer[5:])
        return (bin, obj.unused_data)

    def write(self, filename, data):
        filename = 'extracted/%s/%s' % (self.romid, filename)

        dirname = os.path.dirname(filename)
        if not os.path.exists(dirname):
            os.makedirs(dirname)

        fd = open(filename, 'wb')
        fd.write(data)
        fd.close()

    def val(self, name):
        return self.vals[self.romid][name]

    vals = {
        'ntsc-final': {
            'game': 0x4fc40,
            'files': 0x28080,
            'gamedata': 0x39850,
            'mpconfigs': 0x7d0a40,
            'sfxctl': 0x80a250,
            'textures': 0x01d65f40,
        },
        'ntsc-1.0': {
            'game': 0x4fc40,
            'files': 0x28080,
            'gamedata': 0x39850,
            'mpconfigs': 0x7d0a40,
            'sfxctl': 0x80a250,
            'textures': 0x01d65f40,
        },
        'ntsc-beta': {
            'game': 0x43c40,
            'files': 0x29160,
            'gamedata': 0x30850,
            'mpconfigs': 0x785130,
            'sfxctl': 0x7be940,
            'textures': 0x01d12fe0,
        },
        'pal-final': {
            'game': 0x4fc40,
            'files': 0x28910,
            'gamedata': 0x39850,
            'mpconfigs': 0x7bc240,
            'sfxctl': 0x7f87e0,
            'textures': 0x01d5ca20,
        },
        'pal-beta': {
            'game': 0x4fc40,
            'files': 0x29b90,
            'gamedata': 0x39850,
            'mpconfigs': 0x7bc240,
            'sfxctl': 0x7f87e0,
            'textures': 0x01d5bb50,
        },
        'jap-final': {
            'game': 0x4fc40,
            'files': 0x28800,
            'gamedata': 0x39850,
            'mpconfigs': 0x7c00d0,
            'sfxctl': 0x7fc670,
            'textures': 0x01d61f90,
        },
    }

extractor = Extractor()
extractor.extract(os.environ['ROMID'])

