1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
// Copyright 2021 The Matrix.org Foundation C.I.C.
// Copyright 2021 Damir Jelić
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

mod key;

use aes::{
    cipher::{
        block_padding::{Pkcs7, UnpadError},
        BlockDecryptMut, BlockEncryptMut, KeyIvInit,
    },
    Aes256,
};
use hmac::{digest::MacError, Hmac, Mac as MacT};
use key::CipherKeys;
use sha2::Sha256;
use thiserror::Error;

type Aes256CbcEnc = cbc::Encryptor<Aes256>;
type Aes256CbcDec = cbc::Decryptor<Aes256>;
type HmacSha256 = Hmac<Sha256>;

pub struct Mac([u8; 32]);

impl Mac {
    pub const TRUNCATED_LEN: usize = 8;

    pub fn truncate(&self) -> [u8; Self::TRUNCATED_LEN] {
        let mut truncated = [0u8; Self::TRUNCATED_LEN];
        truncated.copy_from_slice(&self.0[0..Self::TRUNCATED_LEN]);

        truncated
    }
}

#[derive(Debug, Error)]
pub enum DecryptionError {
    #[error("Failed decrypting, invalid padding")]
    InvalidPadding(#[from] UnpadError),
    #[error("The MAC of the ciphertext didn't pass validation {0}")]
    Mac(#[from] MacError),
    #[allow(dead_code)]
    #[error("The ciphertext didn't contain a valid MAC")]
    MacMissing,
}

pub struct Cipher {
    keys: CipherKeys,
}

impl Cipher {
    pub fn new(key: &[u8; 32]) -> Self {
        let keys = CipherKeys::new(key);

        Self { keys }
    }

    pub fn new_megolm(&key: &[u8; 128]) -> Self {
        let keys = CipherKeys::new_megolm(&key);

        Self { keys }
    }

    #[cfg(feature = "libolm-compat")]
    pub fn new_pickle(key: &[u8]) -> Self {
        let keys = CipherKeys::new_pickle(key);

        Self { keys }
    }

    fn get_hmac(&self) -> HmacSha256 {
        // We don't use HmacSha256::new() here because it expects a 64-byte
        // large HMAC key while the Olm spec defines a 32-byte one instead.
        //
        // https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md#version-1
        HmacSha256::new_from_slice(self.keys.mac_key()).expect("Invalid HMAC key size")
    }

    pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
        let cipher = Aes256CbcEnc::new(self.keys.aes_key(), self.keys.iv());
        cipher.encrypt_padded_vec_mut::<Pkcs7>(plaintext)
    }

    pub fn mac(&self, message: &[u8]) -> Mac {
        let mut hmac = self.get_hmac();
        hmac.update(message);

        let mac_bytes = hmac.finalize().into_bytes();

        let mut mac = [0u8; 32];
        mac.copy_from_slice(&mac_bytes);

        Mac(mac)
    }

    pub fn decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, UnpadError> {
        let cipher = Aes256CbcDec::new(self.keys.aes_key(), self.keys.iv());
        cipher.decrypt_padded_vec_mut::<Pkcs7>(ciphertext)
    }

    pub fn decrypt_pickle(&self, ciphertext: &[u8]) -> Result<Vec<u8>, DecryptionError> {
        if ciphertext.len() < Mac::TRUNCATED_LEN + 1 {
            Err(DecryptionError::MacMissing)
        } else {
            let (ciphertext, mac) = ciphertext.split_at(ciphertext.len() - Mac::TRUNCATED_LEN);
            self.verify_mac(ciphertext, mac)?;

            Ok(self.decrypt(ciphertext)?)
        }
    }

    pub fn encrypt_pickle(&self, plaintext: &[u8]) -> Vec<u8> {
        let mut ciphertext = self.encrypt(plaintext);
        let mac = self.mac(&ciphertext);

        ciphertext.extend(mac.truncate());

        ciphertext
    }

    pub fn verify_mac(&self, message: &[u8], tag: &[u8]) -> Result<(), MacError> {
        let mut hmac = self.get_hmac();

        hmac.update(message);
        hmac.verify_truncated_left(tag)
    }
}