// Copyright (c) 2020 Can Boluk and contributors of the VTIL Project   
// All rights reserved.   
//    
// Redistribution and use in source and binary forms, with or without   
// modification, are permitted provided that the following conditions are met: 
//    
// 1. Redistributions of source code must retain the above copyright notice,   
//    this list of conditions and the following disclaimer.   
// 2. Redistributions in binary form must reproduce the above copyright   
//    notice, this list of conditions and the following disclaimer in the   
//    documentation and/or other materials provided with the distribution.   
// 3. Neither the name of VTIL Project nor the names of its contributors
//    may be used to endorse or promote products derived from this software 
//    without specific prior written permission.   
//    
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE   
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE  
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE   
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR   
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF   
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS   
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN   
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)   
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE  
// POSSIBILITY OF SUCH DAMAGE.        
//
#include "bblock_extension_pass.hpp"
#include <vtil/query>

namespace vtil::optimizer
{
	// Implement the pass.
	//
	size_t bblock_extension_pass::pass( basic_block* blk, bool xblock )
	{
		// Skip if local optimization or if already visited.
		//
		if ( !xblock || visit_list.contains( blk ) )
			return 0;
		visit_list.insert( blk );

		// While we can form an extended basic block:
		//
		size_t counter = 0;
		while ( blk->next.size() == 1 &&
				blk->next[ 0 ]->prev.size() == 1 &&
				blk->next[ 0 ] != blk &&
				blk->stream.back().base->is_branching_virt() )
		{
			// Pop the branching instruction.
			//
			blk->stream.pop_back();

			// For each instruction in the destination:
			//
			basic_block* blk_next = blk->next[ 0 ];
			for ( instruction& ins : *blk_next )
			{
				// For each temporary register used, shift by current maximum:
				//
				for ( operand& op : ins.operands )
					if ( op.is_register() && op.reg().is_local() )
						op.reg().local_id += blk->last_temporary_index;

				// If inherited stack instance:
				//
				if ( ins.sp_index == 0 )
				{
					// Shift stack offset by current offset.
					//
					ins.sp_offset += blk->sp_offset;

					// If memory operation:
					//
					if ( ins.base->accesses_memory() )
					{
						// If base is stack pointer, offset by current offset.
						//
						auto [base, offset] = ins.memory_location();
						if ( base.is_stack_pointer() )
							offset += blk->sp_offset;
					}
				}

				// Shift stack indexes by current maximum and move the instruction to the current block.
				//
				ins.sp_index += blk->sp_index;
				blk->stream.push_back( std::move( ins ) );
			}

			// Merge block states.
			//
			if ( blk_next->sp_index == 0 )
				blk->sp_offset += blk_next->sp_offset;
			else
				blk->sp_offset = blk_next->sp_offset;
			blk->sp_index += blk_next->sp_index;
			blk->last_temporary_index += blk_next->last_temporary_index;
			blk->next = blk_next->next;

			// Fix the .prev links.
			//
			for ( basic_block* dst : blk_next->next )
				for ( basic_block*& src : dst->prev )
					if ( src == blk_next )
						src = blk;

			// Acquire the routine lock.
			//
			std::lock_guard _g( blk->owner->mutex );

			// Enumerate both forwards and backwards caches.
			//
			for ( auto& cache : blk->owner->path_cache )
			{
				// Enumerate:
				// std::map<const basic_block*, std::map<const basic_block*, std::set<const basic_block*>>>
				//
				for ( auto it = cache.begin(); it != cache.end(); )
				{
					// If entry key references deleted block, erase it and continue.
					//
					if ( it->first == blk_next )
					{
						it = cache.erase( it );
						continue;
					}

					// Enumerate:
					// std::map<const basic_block*, std::set<const basic_block*>
					//
					for ( auto it2 = it->second.begin(); it2 != it->second.end(); )
					{
						// If entry key references deleted block, erase it and continue.
						//
						if ( it2->first == blk_next )
						{
							it2 = it->second.erase( it2 );
							continue;
						}

						// Remove any references from set.
						//
						it2->second.erase( blk_next );
						
						// Continue iteration.
						//
						it2++;
					}

					// Continue iteration.
					//
					it++;
				}
			}

			// Delete the target block and increment counter.
			//
			blk->owner->explored_blocks.erase( blk_next->entry_vip );
			delete blk_next;
			counter++;
		}

		// Recurse into destinations:
		//
		for ( auto* dst : blk->next )
			counter += pass( dst, true );
		return counter;
	}
	size_t bblock_extension_pass::xpass( routine* rtn )
	{
		// Invoke recursive extender.
		//
		return pass( rtn->entry_point, true );
	}
};
