diff --git a/Cargo.toml b/Cargo.toml index a181925..6f06387 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,10 @@ autoexamples = true [dependencies] blake3 = { version = "1", features = ["traits-preview"] } +crossbeam-channel = "0.5" +crossbeam-utils = "0.8" flate2 = { version = "1", default-features = false, features = ["zlib-ng"] } +num_cpus = "1" serde = { version = "1", features = ["derive"] } serde_json = "1" sqpack = { git = "https://git.anna.lgbt/ascclemens/sqpack-rs", features = ["read", "write"] } @@ -17,11 +20,13 @@ thiserror = "1" zip = { version = "0.6", default-features = false, features = ["deflate-zlib"] } [dev-dependencies] +anyhow = "1" blake3 = { version = "1", features = ["traits-preview"] } criterion = "0.4" data-encoding = "2" sha3 = "0.10" tempfile = "3" +zip = { version = "0.6", default-features = false, features = ["deflate-zlib"] } [[bench]] name = "extract" diff --git a/examples/deduplicate.rs b/examples/deduplicate.rs new file mode 100644 index 0000000..182094e --- /dev/null +++ b/examples/deduplicate.rs @@ -0,0 +1,87 @@ +use std::fs::File; +use std::io::{Seek, Write}; +use std::path::Path; + +use zip::{CompressionMethod, ZipWriter}; +use zip::write::FileOptions; + +use ttmp::model::ManifestKind; +use ttmp::mpd_encoder::{FileInfo, MpdEncoder}; +use ttmp::ttmp_extractor::TtmpExtractor; + +fn main() -> anyhow::Result<()> { + let path = std::env::args().skip(1).next().unwrap(); + let file = File::open(&path)?; + let extractor = TtmpExtractor::new(file)?; + + let files = extractor.all_files_sorted(); + let mut zip = extractor.zip().borrow_mut(); + let mut data = zip.by_name("TTMPD.mpd")?; + + let mpd = tempfile::tempfile()?; + let mut encoder = MpdEncoder::new(mpd, extractor.manifest().clone(), None); + let mut staging = tempfile::tempfile()?; + + let mut last_offset = None; + let mut last_hash: Option> = None; + + for file in files { + let info = FileInfo { + group: file.group.map(ToOwned::to_owned), + option: file.option.map(ToOwned::to_owned), + game_path: file.file.full_path.clone(), + }; + + // handle deduped ttmps + if Some(file.file.mod_offset) == last_offset { + if let Some(hash) = &last_hash { + encoder.add_file_info(hash, info); + } + + continue; + } + + last_offset = Some(file.file.mod_offset); + + staging.rewind()?; + staging.set_len(0)?; + + TtmpExtractor::extract_one_into(&file, &mut data, &mut staging)?; + let size = staging.metadata()?.len() as usize; + staging.rewind()?; + + last_hash = if info.game_path.ends_with(".mdl") { + encoder.add_model_file(info, size, &mut staging)? + } else if info.game_path.ends_with(".tex") || info.game_path.ends_with(".atex") { + encoder.add_texture_file(info, size, &mut staging)? + } else { + encoder.add_standard_file(info, size, &mut staging)? + }.into(); + } + + let (manifest, mut mpd) = encoder.finalize()?; + mpd.rewind()?; + + let path = Path::new(&path); + let extension = path.extension().and_then(|s| s.to_str()).unwrap_or("ttmp2"); + let new_path = path.with_extension(format!("deduplicated.{}", extension)); + + let new_file = File::create(new_path)?; + let mut zip = ZipWriter::new(new_file); + + zip.start_file("TTMPL.mpl", FileOptions::default().compression_method(CompressionMethod::Deflated))?; + match manifest { + ManifestKind::V1(mods) => for mod_ in mods { + serde_json::to_writer(&mut zip, &mod_)?; + zip.write_all(b"\n")?; + } + ManifestKind::V2(pack) => serde_json::to_writer(&mut zip, &pack)?, + } + + zip.start_file("TTMPD.mpd", FileOptions::default().compression_method(CompressionMethod::Stored))?; + std::io::copy(&mut mpd, &mut zip)?; + + zip.finish()?; + + Ok(()) +} diff --git a/src/mpd_encoder.rs b/src/mpd_encoder.rs index d5f67e9..b5ee09e 100644 --- a/src/mpd_encoder.rs +++ b/src/mpd_encoder.rs @@ -1,9 +1,11 @@ use std::collections::{HashMap, HashSet}; use std::fs::File; -use std::io::{BufWriter, Read, Seek, SeekFrom, Write}; +use std::io::{BufWriter, Cursor, Read, Seek, SeekFrom, Write}; +use std::sync::{Arc, Condvar, Mutex}; use blake3::Hasher as Blake3; use blake3::traits::digest::Digest; +use crossbeam_channel::{Receiver, Sender}; use flate2::Compression; use flate2::write::DeflateEncoder; use sqpack::{DatBlockHeader, DatStdFileBlockInfos, FileKind, LodBlock, ModelBlock, SqPackFileInfo, SqPackFileInfoHeader}; @@ -18,7 +20,11 @@ const ALIGN: usize = 128; pub struct MpdEncoder { pub manifest: ManifestKind, pub writer: BufWriter, + pub compression_level: u32, hashes: HashMap, HashInfo>, + pub to_pool: Sender<(usize, Vec)>, + pub from_pool: Receiver, usize), std::io::Error>>, + pub current_chunk: Arc<(Condvar, Mutex)>, } #[derive(Hash, Eq, PartialEq)] @@ -47,11 +53,65 @@ impl HashInfo { impl MpdEncoder { const BLOCK_SIZE: usize = 16_000; - pub fn new(writer: File, manifest: ManifestKind) -> Self { + pub fn new(writer: File, manifest: ManifestKind, num_threads: impl Into>) -> Self { + Self::with_compression_level(writer, manifest, num_threads, 9) + } + + pub fn with_compression_level(writer: File, manifest: ManifestKind, num_threads: impl Into>, compression_level: u32) -> Self { + let num_threads = num_threads.into().unwrap_or_else(num_cpus::get); + + let (to_pool_tx, to_pool_rx) = crossbeam_channel::bounded(0); + let (from_pool_tx, from_pool_rx) = crossbeam_channel::bounded(0); + let current_chunk = Arc::new((Condvar::new(), Mutex::new(0))); + + for _ in 0..num_threads { + let to_pool_rx = to_pool_rx.clone(); + let from_pool_tx = from_pool_tx.clone(); + let current_chunk = Arc::clone(¤t_chunk); + std::thread::spawn(move || { + loop { + // println!("[{i}] waiting"); + // receive a chunk of data + let (idx, data): (usize, Vec) = match to_pool_rx.recv() { + Ok(data) => data, + Err(_) => break, + }; + // println!("[{i}] got a chunk!"); + + // compress it in memory + let cursor = Cursor::new(Vec::with_capacity(data.len())); + let mut encoder = DeflateEncoder::new(cursor, Compression::new(compression_level)); + let res = encoder.write_all(&data) + .and_then(|_| encoder.finish()) + .map(|cursor| (cursor.into_inner(), data.len())); + + // println!("[{i}] locking"); + let (cvar, lock) = &*current_chunk; + let mut current = lock.lock().unwrap(); + while *current != idx { + // println!("[{i}] {} != {idx}", *current); + current = cvar.wait(current).unwrap(); + } + + // drop the lock, lest we deadlock while blocking on send + drop(current); + + // send back the result containing compressed data + // println!("[{i}] sending"); + from_pool_tx.send(res).ok(); + // println!("[{i}] sent"); + } + }); + } + Self { manifest, writer: BufWriter::new(writer), + compression_level, hashes: Default::default(), + to_pool: to_pool_tx, + from_pool: from_pool_rx, + current_chunk, } } @@ -118,7 +178,7 @@ impl MpdEncoder { }); } - pub fn add_texture_file(&mut self, file_info: FileInfo, size: usize, mut data: impl Read) -> Result> { + pub fn add_texture_file(&mut self, file_info: FileInfo, size: usize, mut data: impl Read + Send) -> Result> { #[derive(binrw::BinRead)] #[br(little)] struct RawTextureHeader { @@ -278,7 +338,7 @@ impl MpdEncoder { ALIGN - (size % ALIGN) } - pub fn add_model_file(&mut self, file_info: FileInfo, size: usize, mut data: impl Read) -> Result> { + pub fn add_model_file(&mut self, file_info: FileInfo, size: usize, mut data: impl Read + Send) -> Result> { #[derive(binrw::BinRead)] #[br(little)] struct RawModelHeader { @@ -463,7 +523,7 @@ impl MpdEncoder { Ok(hash.to_vec()) } - fn write_lod(&mut self, lod: usize, lod_count: u8, offsets: &[u32], sizes: &[u32], mut data: impl Read, hasher: &mut impl Digest) -> Result> { + fn write_lod(&mut self, lod: usize, lod_count: u8, offsets: &[u32], sizes: &[u32], mut data: impl Read + Send, hasher: &mut (impl Digest + Send)) -> Result> { // only write out the lods we have if lod_count == 0 || lod > lod_count as usize - 1 { return Ok(Default::default()); @@ -494,7 +554,7 @@ impl MpdEncoder { num_blocks } - pub fn add_standard_file(&mut self, file_info: FileInfo, size: usize, data: impl Read) -> Result> { + pub fn add_standard_file(&mut self, file_info: FileInfo, size: usize, data: impl Read + Send) -> Result> { // store position before doing anything let pos = self.writer.stream_position().map_err(Error::Io)?; @@ -572,87 +632,193 @@ impl MpdEncoder { Ok(bytes_to_pad) } - fn write_blocks(&mut self, mut data: impl Read, hasher: &mut impl Digest) -> Result> { + fn write_blocks(&mut self, mut data: impl Read + Send, hasher: &mut (impl Digest + Send)) -> Result> { let mut total_written = 0; let mut infos = Vec::new(); - // read 16kb chunks and compress them - let mut buf = [0; Self::BLOCK_SIZE]; - let mut buf_idx: usize = 0; - 'outer: loop { - // read up to 16kb from the data stream - loop { - let size = data.read(&mut buf[buf_idx..]).map_err(Error::Io)?; - if size == 0 { - // end of file - if buf_idx == 0 { - break 'outer; + // in order to make encoding faster, we have a threadpool waiting to do + // compression jobs + + // we'll read out 16kb chunks here, then send them to threadpool to be + // compressed + + // the threadpool will send them back, and we'll write out the results + // in order + + // set the current chunk back to 0 + { + // println!("[main] locking 1"); + *self.current_chunk.1.lock().unwrap() = 0; + // println!("[main] done 1"); + } + let (finished_tx, finished_rx) = crossbeam_channel::bounded(0); + + // spawn a thread to read the data and send it + // println!("[main] spawning scoped thread"); + let infos = crossbeam_utils::thread::scope(|s| { + let handle = s.spawn(|_| { + // println!("[scoped] started"); + // read 16kb chunks and compress them + let mut chunk_idx = 0; + let mut buf = [0; Self::BLOCK_SIZE]; + let mut buf_idx: usize = 0; + 'outer: loop { + // read up to 16kb from the data stream + loop { + let size = data.read(&mut buf[buf_idx..]).map_err(Error::Io).unwrap(); + if size == 0 { + // end of file + if buf_idx == 0 { + break 'outer; + } + + break; + } + + buf_idx += size; } + // update hasher + hasher.update(&buf[..buf_idx]); + + // send the data to be compressed + // // println!("[scoped] sending data to pool"); + self.to_pool.send((chunk_idx, buf[..buf_idx].to_vec())).ok(); + // println!("[scoped] sent"); + chunk_idx += 1; + buf_idx = 0; + } + + // println!("[scoped] sending finished msg"); + finished_tx.send(chunk_idx).ok(); + // println!("[scoped] done"); + }); + + let mut num_chunks = None; + + // receive the compressed chunks + loop { + if let Some(chunks) = num_chunks { + // println!("[main] locking 2"); + if *self.current_chunk.1.lock().unwrap() >= chunks { + // println!("[main] done 2 (break)"); + break; + } + + // println!("[main] done 2"); + } + + // println!("[main] notifying all"); + let (cvar, _) = &*self.current_chunk; + cvar.notify_all(); + + // println!("[main] waiting for data"); + let (data, uncompressed_size) = crossbeam_channel::select! { + recv(self.from_pool) -> x => x, + recv(finished_rx) -> chunks => { + let chunks = chunks.unwrap(); + num_chunks = Some(chunks); + + continue; + } + }.unwrap().map_err(Error::Io)?; + + // println!("[main] got data"); + + // println!("[main] a"); + let offset = total_written; + // get position before chunk + let before_header = self.writer.stream_position().map_err(Error::Io)?; + + + // println!("[main] b"); + // make space for chunk header + self.writer.write_all(&vec![0; std::mem::size_of::()]).map_err(Error::Io)?; + total_written += std::mem::size_of::() as u64; + + // println!("[main] c"); + // write compressed chunk to writer + self.writer.write_all(&data).map_err(Error::Io)?; + + // println!("[main] d"); + // calculate the size of compressed data + let after_data = self.writer.stream_position().map_err(Error::Io)?; + let mut compressed_size = after_data - before_header; + total_written += compressed_size; + + // println!("[main] e"); + // seek back to before header + self.writer.seek(SeekFrom::Start(before_header)).map_err(Error::Io)?; + + // println!("[main] f"); + // write chunk header + let header = DatBlockHeader { + size: std::mem::size_of::() as u32, + uncompressed_size: uncompressed_size as u32, + compressed_size: (compressed_size - std::mem::size_of::() as u64) as u32, + _unk_0: 0, + }; + self.writer.write_le(&header).map_err(Error::BinRwWrite)?; + + // println!("[main] g"); + // seek past chunk + self.writer.seek(SeekFrom::Start(after_data)).map_err(Error::Io)?; + + // println!("[main] h"); + // pad to 128 bytes + let padded = { + let current_pos = self.writer.stream_position().map_err(Error::Io)? as usize; + let bytes_to_pad = 128 - (current_pos % 128); + if bytes_to_pad > 0 { + let zeroes = std::iter::repeat(0) + .take(bytes_to_pad) + .collect::>(); + + // write padding bytes + self.writer.write_all(&zeroes).map_err(Error::Io)?; + } + + bytes_to_pad + }; + + // add padding bytes to the compressed size because that's just + // how sqpack do + compressed_size += padded as u64; + // total_written += padded as u64; + + infos.push(DatStdFileBlockInfos { + offset: offset as u32, + uncompressed_size: uncompressed_size as u16, + compressed_size: compressed_size as u16, + }); + + // println!("[main] i"); + // end of file was reached + if uncompressed_size < Self::BLOCK_SIZE { + // println!("[main] i (break)"); break; } - buf_idx += size; + // println!("[main] j"); + // println!("[main] locking 3"); + let (_, lock) = &*self.current_chunk; + *lock.lock().unwrap() += 1; + // println!("[main] done 3"); } - // update hasher - hasher.update(&buf[..buf_idx]); + // at this point, we no longer care about receiving the finished + // message, so drop the receiver. the scoped thread can hang waiting + // to send the message if this isn't done + drop(finished_rx); - let offset = total_written; - // get position before chunk - let before_header = self.writer.stream_position().map_err(Error::Io)?; + // println!("[main] waiting on scoped handle"); + handle.join().unwrap(); + // println!("[main] joined"); - // make space for chunk header - self.writer.write_all(&vec![0; std::mem::size_of::()]).map_err(Error::Io)?; - total_written += std::mem::size_of::() as u64; + Ok(infos) + }).unwrap(); - // write compressed chunk to writer - let mut encoder = DeflateEncoder::new(&mut self.writer, Compression::best()); - encoder.write_all(&buf[..buf_idx]).map_err(Error::Io)?; - encoder.finish().map_err(Error::Io)?; - - // calculate the size of compressed data - let after_data = self.writer.stream_position().map_err(Error::Io)?; - let mut compressed_size = after_data - before_header; - total_written += compressed_size; - - // seek back to before header - self.writer.seek(SeekFrom::Start(before_header)).map_err(Error::Io)?; - - // write chunk header - let header = DatBlockHeader { - size: std::mem::size_of::() as u32, - uncompressed_size: buf_idx as u32, - compressed_size: (compressed_size - std::mem::size_of::() as u64) as u32, - _unk_0: 0, - }; - self.writer.write_le(&header).map_err(Error::BinRwWrite)?; - - // seek past chunk - self.writer.seek(SeekFrom::Start(after_data)).map_err(Error::Io)?; - - // pad to 128 bytes - let padded = self.align_to(ALIGN)?; - - // add padding bytes to the compressed size because that's just - // how sqpack do - compressed_size += padded as u64; - // total_written += padded as u64; - - infos.push(DatStdFileBlockInfos { - offset: offset as u32, - uncompressed_size: buf_idx as u16, - compressed_size: compressed_size as u16, - }); - - // end of file was reached - if buf_idx < Self::BLOCK_SIZE { - break; - } - - buf_idx = 0; - } - - Ok(infos) + // println!("[main] returnin'"); + infos } }