ECO: Large Language Model Unlearning via Embedding-Corrupted Prompts

University of California, Santa Cruz
NeurIPS 2024
Teaser Image

Using embedding-corrupted prompts to maintain an unlearned state on the LLM subject to unlearning. We first employ a classifier to identify whether the incoming prompt falls within the scope of the unlearning target. We construct embedding-corrupted prompts by selectively corrupting dimensions within the tokens' embeddings. The corruption parameter is learned offline via zeroth order optimization. An unlearned state is imposed during inference and does not require any updates to the original model's weights.

Abstract

Large language models (LLMs) have advanced to encompass extensive knowledge across diverse domains. Yet controlling what a large language model should not know is important for ensuring alignment and thus safe use. However, accurately and efficiently unlearning knowledge from an LLM remains challenging due to the potential collateral damage caused by the fuzzy boundary between retention and forgetting, and the large computational requirements for optimization across state-of-the-art models with hundreds of billions of parameters. In this work, we present \textbf{Embedding-COrrupted (ECO) Prompts}, a lightweight unlearning framework for large language models to address both the challenges of knowledge entanglement and unlearning efficiency. Instead of relying on the LLM itself to unlearn, we enforce an unlearned state during inference by employing a prompt classifier to identify and safeguard prompts to forget. We learn corruptions added to prompt embeddings via zeroth order optimization toward the unlearning objective offline and corrupt prompts flagged by the classifier during inference. We find that these embedding-corrupted prompts not only lead to desirable outputs that satisfy the unlearning objective but also closely approximate the output from a model that has never been trained on the data intended for forgetting. Through extensive experiments on unlearning, we demonstrate the superiority of our method in achieving promising unlearning at \textit{nearly zero side effects} in general domains and domains closely related to the unlearned ones. Additionally, we highlight the scalability of our method to 100 LLMs, ranging from 0.5B to 236B parameters, incurring no additional cost as the number of parameters increases. We have made our code publicly available at link.

Contributions

  • Embedding-COrrupted (ECO) Prompts, a novel and lightweight LLM unlearning method that enforces an unlearned state over an intact LLM.
  • Instead of relying on unlearning objective optimization, carefully corrupted prompts lead to behavior that resembles that of a model which has never seen the data intended to be forgotten, across multiple tasks and metrics.
  • The superior performance of our method in both retaining and forgetting, incurring virtually zero side effects and no additional cost when scaling to larger models.
  • To the best of our knowledge, we are the first to demonstrate universally effective and efficient unlearning for 100 LLMs and up to 236B parameters.


Method Overview

Our method consists of two steps: 1) train a prompt classifier to predict if an incoming prompt falls within the scope of unlearning, and 2) corrupt the prompt in the embedding space if the classifier makes a positive prediction (i.e., should forget).

Enforcing Retaining and Forgetting via A Classifier

We first train a prompt classifier to explicitly identify if the prompt falls within the scope of unlearning. For any incoming prompt, \( \mathbf{x} \), the prompt classifier \( C \) takes in \( \mathbf{x} \) and returns \( p_C(f \mid \mathbf{x}) = 1 - p_C(r \mid \mathbf{x}) \), the probability of the prompt being in the scope of forgetting. Similar to any classifier prediction, if \( p_C(f \mid \mathbf{x}) > p_C(r \mid \mathbf{x}) \), we consider \( \mathbf{x} \) as containing the unlearning concept that our LLM is supposed to forget. Formally, given a positive prediction, \( p_C(f \mid \mathbf{x}) > p_C(r \mid \mathbf{x}) \), we replace the original input \( \mathbf{x} \) by a \( \tilde{\mathbf{x}} \). Otherwise, the original \( \mathbf{x} \) is passed to the LLM.

\[ \mathbf{x} = \begin{cases} \tilde{\mathbf{x}} & p_C(f \mid \mathbf{x}) > p_C(r \mid \mathbf{x}) \\ \mathbf{x} & \text{otherwise} \end{cases} \]
Additional simple thresholding or conformal prediction is employed to reduce false positive/negative rate.

Embedding-COrrupted Prompts

Instead of a modification of \( \mathbf{x} \) in the token space, we corrupt it in the embedding space. Let \( \mathbf{x} = \{x_1, x_2, \dots, x_{T}\} \) be a prompt of \( T \) tokens and \( \mathbf{e} = \{e_1, e_2, ..., e_T\} \) be the corresponding embedding vectors. Let \( \mathcal{E} \) be the space of the token embeddings. Each embedding vector is produced by an embedding function \( E: \mathcal{X} \rightarrow \mathbb{R}^d \). We also use the symbol \( \sigma \in \mathcal{S} \) (where \( \mathcal{S} \subset \mathbb{R} \)) to denote the strength of the corruption, which parameterizes the strength of the corruption function. Formally, for a single prompt \( \mathbf{x} \) mapped to the embeddings \( \mathbf{e} = E(\mathbf{x}) = \{e_1, e_2, ..., e_T\} \), a corruption function \( \texttt{Corrupt}: \mathcal{E} \times \mathcal{S} \rightarrow \mathcal{E} \), parameterized by \( \sigma \), produces the embedding-corrupted prompts

\[ \tilde{\mathbf{e}} = \texttt{Corrupt}(\mathbf{e}; \sigma) = \{\tilde{e}_1, \tilde{e}_2, \dots, \tilde{e}_T\}. \]

Let \( \tilde{h}: \mathcal{E} \times \Theta \rightarrow \mathcal{Y} \) be the function \( h \) but taking the input embeddings instead of input tokens (i.e. \( h \) with the input embedding layer detached), our objective is to pick a good \( \sigma^* \) such that the following modified unlearning objective is satisfied:

\[ \frac{\mathbb{E}\left[m_i \left(\tilde{h}\left(\texttt{Corrupt}(\mathbf{e}; \sigma^*); \theta_o \right)\right)\right]}{\hat{v}_r} \approx 1, \forall m_i \in \mathcal{M}. \label{eq:unlearning_objective_modified} \]

Here, \( \hat{v}_r \) is used to approximate the true \( \mathbb{E}[m_i(\tilde{h}(\mathbf{e}; \theta_r))] \) as the retained model is not available. \(\mathcal{M}\) represents a set of metrics relevant to unlearning.

Optimizing Toward An Optimal Corruption Strength

We aim to learn a \(\sigma^*\) such that the metric gap in between the unlearned model and the retained model is minimized.

\[ d(\tilde{\mathbf{e}}, \theta_{o}, \hat{v}_r, \mathcal{M}) = \frac{1}{|\mathcal{M}|} \sum_{i} \Big| \underbrace{m_i(\tilde{h}(\tilde{\mathbf{e}}; \theta_{o}))}_{\text{unlearned metric value}} - \underbrace{\hat{v}_r}_{\text{surrogate retain metric value}} \Big| \]
\[ \sigma^* = \arg \min_{\sigma} d\left(\texttt{Corrupt}(\mathbf{e}; \sigma), \theta_{o}, \hat{v}_r, \mathcal{M}\right) \]
Finally, an optimal \(\sigma^*\) is obtained by zeroth order optimization via finite difference approximation.


Unlearning Fictitious Authors

tofu
Model utility versus forget quality (p-value) on three different forget set sizes of the TOFU dataset after unlearning. We show two models, Phi-1.5 (top) and Llama-2-7B-Chat (bottom). For GA, GD, KL, PO, and the prompting baseline, the forget quality are either too small or comes at the cost of substantial decrease of model utility. Negative preference optimization (NPO) variants achieve a great balance in some cases, but the trade-off on model utility is still non-trivial. ECO-RN (random noise) and ECO-ZO (zero-out) achieve almost identical distribution to the retained model while having no sacrifice in model utility.

Unlearning (Hazardous) Knowledge

wmdp_table
Multiple-choice accuracy of five LLMs on the WMDP benchmark (forget) and the full MMLU (retain) after unlearning. ECO achieves accuracy close to random guessing on all subsets of the WMDP benchmark (as desired), and has zero decrease in accuracy on MMLU. Other baselines either struggle to forget or incur substantial decrease in MMLU.
mmlu
Multiple-choice accuracy of Zephyr-7B after unlearning, on three MMLU subsets and the corresponding retain sets. ECO achieves both perfect retaining and unlearning on all subsets.
wmdp_all_models
The number of parameters of the model subject to unlearning versus the average performance on WMDP benchmark and MMLU subsets on 100 LLMs. This figure is a visualization of the forget set accuracy.

Unlearning Copyrighted Contnet

copyrighted_content
Comparison of our method and the baseline methods to the retained model on two copyrighted content unlearning tasks. The results are obtained from unlearning OLMo-7B models fine-tuned on the relevant corpus. ECO consistently maintains high similarity to the retained model (in average similarity gap (ASG)) and generates meaningful and diverse outputs (reflected by perplexity (PPL) and unique token ratio), while having no performance loss on utility.

BibTeX


        @article{liu2024eco,
          title={Large Language Model Unlearning via Embedding-Corrupted Prompts},
          author={Liu, Chris Yuhao and Wang, Yaxuan and Flanigan, Jeffrey and Liu, Yang},
          journal={arXiv preprint arXiv:2406.07933},
          year={2024}
        }