AB498 commited on
Commit
8a48ded
·
1 Parent(s): 1a8240b
Files changed (2) hide show
  1. README.md +10 -1
  2. app.py +41 -21
README.md CHANGED
@@ -14,4 +14,13 @@ license: mit
14
  short_description: 'codebert-base-mlm: a fill-in-middle/masked language model'
15
  ---
16
 
17
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
14
  short_description: 'codebert-base-mlm: a fill-in-middle/masked language model'
15
  ---
16
 
17
+ A CodeBERT Masked Language Model demo using [Gradio](https://gradio.app) and [Transformers](https://huggingface.co/docs/transformers). This app predicts masked tokens in code snippets.
18
+
19
+ ## Usage
20
+
21
+ Enter code with `<mask>` tokens where you want predictions:
22
+ - `def <mask>(x, y): return x + y`
23
+ - `import <mask>`
24
+ - `for i in <mask>(10):`
25
+
26
+ The model will suggest the most likely tokens to fill in the mask.
app.py CHANGED
@@ -7,43 +7,65 @@ model_name = "microsoft/codebert-base-mlm"
7
  tokenizer = RobertaTokenizer.from_pretrained(model_name)
8
  model = RobertaForMaskedLM.from_pretrained(model_name)
9
 
10
- def predict_masked_code(code_with_mask, top_k=5):
11
  """
12
  Predict the masked token in code.
13
  Use <mask> to indicate where to predict.
 
 
 
 
 
 
 
14
  """
15
  try:
16
  # Replace <mask> with the tokenizer's mask token
17
- code_with_mask = code_with_mask.replace("<mask>", tokenizer.mask_token)
18
 
19
  # Tokenize input
20
- inputs = tokenizer(code_with_mask, return_tensors="pt")
21
 
22
  # Find the position of the mask token
23
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
24
 
25
  if len(mask_token_index) == 0:
26
- return "Error: No <mask> token found in the input. Please include <mask> where you want predictions."
 
 
 
27
 
28
  # Get predictions
29
  with torch.no_grad():
30
  outputs = model(**inputs)
31
- predictions = outputs.logits
32
 
33
  # Get top-k predictions for the mask token
34
- mask_token_logits = predictions[0, mask_token_index, :]
35
- top_tokens = torch.topk(mask_token_logits, top_k, dim=1)
36
 
37
- results = []
38
- for i, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist())):
39
  predicted_token = tokenizer.decode([token_id])
40
- filled_code = code_with_mask.replace(tokenizer.mask_token, predicted_token)
41
- results.append(f"{i+1}. {predicted_token} (score: {score:.2f})\n Code: {filled_code}")
 
 
 
 
 
 
42
 
43
- return "\n\n".join(results)
 
 
 
44
 
45
  except Exception as e:
46
- return f"Error: {str(e)}"
 
 
 
47
 
48
  # Create Gradio interface
49
  with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
@@ -69,7 +91,7 @@ with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
69
  lines=5,
70
  value="def <mask>(x, y):\n return x + y"
71
  )
72
- top_k_slider = gr.Slider(
73
  minimum=1,
74
  maximum=10,
75
  value=5,
@@ -79,10 +101,8 @@ with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
79
  predict_btn = gr.Button("Predict", variant="primary")
80
 
81
  with gr.Column():
82
- output = gr.Textbox(
83
- label="Predictions",
84
- lines=15,
85
- interactive=False
86
  )
87
 
88
  # Examples
@@ -95,12 +115,12 @@ with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
95
  ["if x <mask> 0:", 5],
96
  ["class <mask>:", 5],
97
  ],
98
- inputs=[code_input, top_k_slider],
99
  )
100
 
101
  predict_btn.click(
102
- fn=predict_masked_code,
103
- inputs=[code_input, top_k_slider],
104
  outputs=output
105
  )
106
 
 
7
  tokenizer = RobertaTokenizer.from_pretrained(model_name)
8
  model = RobertaForMaskedLM.from_pretrained(model_name)
9
 
10
+ def predict(code, num_predictions=5):
11
  """
12
  Predict the masked token in code.
13
  Use <mask> to indicate where to predict.
14
+
15
+ Args:
16
+ code: Code snippet with <mask> token
17
+ num_predictions: Number of top predictions to return
18
+
19
+ Returns:
20
+ JSON object with predictions
21
  """
22
  try:
23
  # Replace <mask> with the tokenizer's mask token
24
+ code_input = code.replace("<mask>", tokenizer.mask_token)
25
 
26
  # Tokenize input
27
+ inputs = tokenizer(code_input, return_tensors="pt")
28
 
29
  # Find the position of the mask token
30
  mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
31
 
32
  if len(mask_token_index) == 0:
33
+ return {
34
+ "error": "No <mask> token found in the input. Please include <mask> where you want predictions.",
35
+ "predictions": []
36
+ }
37
 
38
  # Get predictions
39
  with torch.no_grad():
40
  outputs = model(**inputs)
41
+ logits = outputs.logits
42
 
43
  # Get top-k predictions for the mask token
44
+ mask_token_logits = logits[0, mask_token_index, :]
45
+ top_tokens = torch.topk(mask_token_logits, num_predictions, dim=1)
46
 
47
+ predictions = []
48
+ for rank, (token_id, score) in enumerate(zip(top_tokens.indices[0].tolist(), top_tokens.values[0].tolist()), 1):
49
  predicted_token = tokenizer.decode([token_id])
50
+ completed_code = code_input.replace(tokenizer.mask_token, predicted_token)
51
+
52
+ predictions.append({
53
+ "rank": rank,
54
+ "token": predicted_token,
55
+ "score": round(float(score), 4),
56
+ "completed_code": completed_code
57
+ })
58
 
59
+ return {
60
+ "original_code": code,
61
+ "predictions": predictions
62
+ }
63
 
64
  except Exception as e:
65
+ return {
66
+ "error": str(e),
67
+ "predictions": []
68
+ }
69
 
70
  # Create Gradio interface
71
  with gr.Blocks(title="CodeBERT Masked Language Model") as demo:
 
91
  lines=5,
92
  value="def <mask>(x, y):\n return x + y"
93
  )
94
+ num_predictions_slider = gr.Slider(
95
  minimum=1,
96
  maximum=10,
97
  value=5,
 
101
  predict_btn = gr.Button("Predict", variant="primary")
102
 
103
  with gr.Column():
104
+ output = gr.JSON(
105
+ label="Predictions"
 
 
106
  )
107
 
108
  # Examples
 
115
  ["if x <mask> 0:", 5],
116
  ["class <mask>:", 5],
117
  ],
118
+ inputs=[code_input, num_predictions_slider],
119
  )
120
 
121
  predict_btn.click(
122
+ fn=predict,
123
+ inputs=[code_input, num_predictions_slider],
124
  outputs=output
125
  )
126