-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathactivation_patching.py
More file actions
41 lines (31 loc) · 1.63 KB
/
Copy pathactivation_patching.py
File metadata and controls
41 lines (31 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
def activation_patching(
source_prompt,
target_prompt,
source_patch,
target_patch,
model
):
with torch.no_grad():
with model.session(remote=not model.dispatched):
with model.trace() as tracer:
source_activations = list()
with tracer.invoke(source_prompt)as invoker:
for i, layer in enumerate(model.model.layers):
if isinstance(layer.output, tuple):
source_activations.append(layer.output[0][0, source_patch, :].clone().detach())
else:
source_activations.append(layer.output[0, source_patch, :].clone().detach())
source_logits = model.lm_head.output[0][-1].detach().cpu().save()
with model.trace() as tracer:
with tracer.invoke(target_prompt):
target_logits_clean = model.lm_head.output[0][-1].detach().cpu().save()
target_logits_patched = dict().save()
for i in range(len(model.model.layers)):
with tracer.invoke(target_prompt):
if isinstance(model.model.layers[i].output, tuple):
model.model.layers[i].output[0][0, target_patch, :] = source_activations[i]
else:
model.model.layers[i].output[0, target_patch, :] = source_activations[i]
target_logits_patched[str(i)] = model.lm_head.output[0][-1].detach().cpu()
return source_logits, target_logits_clean, target_logits_patched