-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
65 lines (39 loc) · 1.8 KB
/
Copy pathutils.py
File metadata and controls
65 lines (39 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from nnsight import LanguageModel
import nnsight
from rich.table import Table
from rich.console import Console
import torch
def load_model(model_name, remote=True):
if model_name == "openai-community/gpt2-xl":
model = LanguageModel(model_name, rename={"transformer": "model", "h": "layers"}, device_map="auto", dispatch=True)
else:
if not nnsight.is_model_running(model_name):
print(nnsight.ndif_status())
raise ValueError(f"Model {model_name} is not currently available on NDIF")
model = LanguageModel(model_name, device_map="auto", dispatch=not remote)
print(f"Loading model {model_name} | {'Remote' if remote else 'Local'} Execution Mode")
return model
def show_token_positions(str_tokens, title):
table = Table(title=title, show_header=False)
table.add_row(*str_tokens)
table.add_row(*[str(i) for i in range(len(str_tokens))])
table.rows[0].style = "bold"
console = Console()
console.print(table)
return str_tokens
def show_patch_pattern(patch_position, str_tokens, title):
table = Table(title=title, show_header=False)
# table.add_row(*[f"[on purple]{t}[/on purple]" if i == patch_position else t for i, t in enumerate(str_tokens)])
table.add_row(*str_tokens)
table.add_row(*[str(i) for i in range(len(str_tokens))])
table.rows[0].style = "bold"
table.columns[patch_position].style = "on purple"
console = Console()
console.print(table)
def tokenize_prompt(prompt, model, show=False, title=None):
str_tokens = [model.tokenizer.decode(token) for token in model.tokenizer.encode(prompt)]
if show:
show_token_positions(str_tokens, title)
return str_tokens
def get_token_id(token_string, model):
return model.tokenizer.encode(token_string, add_special_tokens=False)[0]