remote-block-device-backup/remote-block-device-backup-.../src/aes128_xts128_stream.rs

141 lines
4.6 KiB
Rust

use aes::cipher::generic_array::GenericArray;
use aes::Aes128;
use aes::NewBlockCipher;
use std::any::type_name;
use std::fmt;
use std::io;
use std::net::Shutdown;
use std::net::TcpStream;
use xts_mode::{get_tweak_default, Xts128};
#[derive(new, Debug)]
pub struct Aes128Xts128Stream<'a> {
pub stream: &'a mut TcpStream,
pub series_r: Aes128Xts128StreamSeries,
pub series_w: Aes128Xts128StreamSeries,
}
impl Aes128Xts128Stream<'_> {
pub fn shutdown(&mut self, how: Shutdown) -> io::Result<()> {
self.stream.shutdown(how)?;
Ok(())
}
}
impl io::Read for Aes128Xts128Stream<'_> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// return Ok(self.stream.read(buf)?);
let r = self.stream.read(buf)?;
buf.iter_mut()
.zip(self.series_r.generate(r).iter())
.for_each(|(x1, x2)| *x1 ^= *x2);
Ok(r)
}
}
impl io::Write for Aes128Xts128Stream<'_> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// return Ok(self.stream.write(buf)?);
let buf_owned: &mut [u8] = &mut buf.to_owned();
let bl = buf_owned.len();
buf_owned
.iter_mut()
.zip(self.series_w.generate(bl).iter())
.for_each(|(x1, x2)| *x1 ^= *x2);
let r = self.stream.write(buf_owned)?;
self.series_w.rewind(bl - r);
Ok(r)
}
fn flush(&mut self) -> io::Result<()> {
self.stream.flush()
}
}
impl<'a> From<(&'a mut TcpStream, [u8; 32])> for Aes128Xts128Stream<'a> {
fn from((stream, key): (&'a mut TcpStream, [u8; 32])) -> Self {
Self::new(stream, key.into(), key.into())
}
}
#[derive(new)]
pub struct Aes128Xts128StreamSeries {
pub xts: Xts128<Aes128>,
pub remainder: Vec<u8>,
pub sector_count: u128,
}
impl fmt::Debug for Aes128Xts128StreamSeries {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct(type_name::<Self>())
.field("remainder", &self.remainder)
.field("sector_count", &self.sector_count)
.finish_non_exhaustive()?;
Ok(())
}
}
impl Aes128Xts128StreamSeries {
fn generate(&mut self, size: usize) -> Vec<u8> {
let mut pattern: Vec<u8> = Vec::with_capacity(size);
let mut remaining_size = size - pattern.len();
while remaining_size > 0 {
if self.remainder.len() > 0 {
if self.remainder.len() <= remaining_size {
pattern.append(&mut self.remainder);
self.remainder = vec![];
} else {
let (needed, excess) = self.remainder.split_at(remaining_size);
pattern.append(&mut needed.to_owned());
self.remainder = excess.to_owned();
}
} else if remaining_size >= 16 {
while remaining_size >= 16 {
pattern.append(&mut self.generate_next_sector().to_vec());
remaining_size -= 16;
}
} else {
self.remainder = self.generate_next_sector().to_vec();
}
remaining_size = size - pattern.len();
}
pattern
}
fn generate_next_sector(&mut self) -> [u8; 16] {
let mut sector = [0u8; 16];
self.xts
.encrypt_sector(&mut sector, get_tweak_default(self.sector_count));
self.sector_count += 1;
sector
}
fn calculate_current_byte(&self) -> usize {
((self.sector_count as usize + 1) * 16) - self.remainder.len()
}
fn rewind(&mut self, by: usize) {
if by > 0 {
let current_position = self.calculate_current_byte();
let new_position = current_position - by;
let new_sector = new_position / 16;
let new_consumed = new_position % 16;
self.sector_count = new_sector as u128;
self.generate(new_consumed);
if new_position != self.calculate_current_byte() {
panic!("Programming error while rewinding; contact developer.");
}
}
}
}
impl From<[u8; 32]> for Aes128Xts128StreamSeries {
fn from(key: [u8; 32]) -> Self {
let mut cipher_1_array = [0u8; 16];
let mut cipher_2_array = [0u8; 16];
cipher_1_array.copy_from_slice(&key[..16]);
cipher_2_array.copy_from_slice(&key[16..]);
let cipher_1 = Aes128::new(&GenericArray::from(cipher_1_array));
let cipher_2 = Aes128::new(&GenericArray::from(cipher_2_array));
let xts = Xts128::<Aes128>::new(cipher_1, cipher_2);
Self::new(xts, vec![], 0)
}
}