openfree commited on
Commit
ea36a1d
·
verified ·
1 Parent(s): dcfd1fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -177
app.py CHANGED
@@ -1,6 +1,6 @@
1
  """
2
- DNA-Diffusion Gradio Application - Fixed for Hugging Face Spaces
3
- Properly implements @spaces.GPU decorator for ZeroGPU environment
4
  """
5
 
6
  import gradio as gr
@@ -33,7 +33,7 @@ except ImportError:
33
  return func
34
  return decorator
35
 
36
- # Genetic code table
37
  CODON_TABLE = {
38
  'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
39
  'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
@@ -53,28 +53,21 @@ CODON_TABLE = {
53
  'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
54
  }
55
 
56
- # Mock DNA model for demonstration
57
  class DNAModel:
58
  """Mock DNA generation model"""
59
  def generate(self, cell_type: str, num_sequences: int = 1, length: int = 200, guidance_scale: float = 1.0):
60
- """Generate DNA sequences with cell-type specific patterns"""
61
  sequences = []
62
 
63
  for _ in range(num_sequences):
64
- # Base generation
65
  sequence = ''.join(random.choice(['A', 'T', 'C', 'G']) for _ in range(length))
66
 
67
- # Add cell-type specific patterns
68
  if cell_type == "K562":
69
- # Add GC-rich regions for K562
70
  for i in range(0, length-8, 50):
71
  sequence = sequence[:i] + 'GCGCGCGC' + sequence[i+8:]
72
  elif cell_type == "GM12878":
73
- # Add AT-rich regions for GM12878
74
  for i in range(10, length-8, 60):
75
  sequence = sequence[:i] + 'ATATATAT' + sequence[i+8:]
76
  elif cell_type == "HepG2":
77
- # Add mixed patterns for HepG2
78
  for i in range(20, length-12, 70):
79
  sequence = sequence[:i] + 'GCGATCGATCGC' + sequence[i+12:]
80
 
@@ -82,13 +75,9 @@ class DNAModel:
82
 
83
  return sequences[0] if num_sequences == 1 else sequences
84
 
85
- # Initialize model
86
  model = DNAModel()
87
 
88
- # Main application class
89
  class DNADiffusionApp:
90
- """Main application with GPU-accelerated functions"""
91
-
92
  def __init__(self):
93
  self.current_sequence = ""
94
  self.current_analysis = {}
@@ -96,17 +85,12 @@ class DNADiffusionApp:
96
 
97
  @spaces.GPU(duration=60)
98
  def generate_with_gpu(self, cell_type: str, guidance_scale: float = 1.0, use_enhanced: bool = True):
99
- """GPU-accelerated sequence generation"""
100
  logger.info(f"Generating sequence on GPU for cell type: {cell_type}")
101
 
102
  try:
103
- # Simulate GPU computation time
104
- time.sleep(2) # Simulated GPU processing
105
-
106
- # Generate sequence
107
  sequence = model.generate(cell_type, length=200, guidance_scale=guidance_scale)
108
 
109
- # If enhanced mode, do additional processing
110
  if use_enhanced:
111
  sequence = self.enhance_sequence(sequence, cell_type)
112
 
@@ -120,16 +104,13 @@ class DNADiffusionApp:
120
  raise
121
 
122
  def enhance_sequence(self, sequence: str, cell_type: str) -> str:
123
- """Enhance sequence with additional patterns"""
124
- # Add enhancer sequences based on cell type
125
  enhancers = {
126
- "K562": "GGGACTTTCC", # NF-κB binding site
127
- "GM12878": "TGACGTCA", # CREB binding site
128
- "HepG2": "TGTTGGTGG" # HNF4 binding site
129
  }
130
 
131
  if cell_type in enhancers:
132
- # Insert enhancer at a reasonable position
133
  pos = len(sequence) // 4
134
  enhancer = enhancers[cell_type]
135
  sequence = sequence[:pos] + enhancer + sequence[pos+len(enhancer):]
@@ -137,21 +118,17 @@ class DNADiffusionApp:
137
  return sequence
138
 
139
  def analyze_sequence(self, sequence: str) -> Dict[str, Any]:
140
- """Analyze DNA sequence properties"""
141
  if not sequence:
142
  return {}
143
 
144
- # GC content
145
  gc_count = sequence.count('G') + sequence.count('C')
146
  gc_content = (gc_count / len(sequence)) * 100
147
 
148
- # Melting temperature
149
  if len(sequence) < 14:
150
  tm = 4 * (sequence.count('G') + sequence.count('C')) + 2 * (sequence.count('A') + sequence.count('T'))
151
  else:
152
  tm = 81.5 + 0.41 * gc_content - 675 / len(sequence)
153
 
154
- # Find restriction sites
155
  restriction_sites = {}
156
  enzymes = {
157
  'EcoRI': 'GAATTC',
@@ -169,7 +146,6 @@ class DNADiffusionApp:
169
  if positions:
170
  restriction_sites[enzyme] = positions
171
 
172
- # Translate to protein
173
  protein = self.translate_to_protein(sequence)
174
 
175
  return {
@@ -182,53 +158,311 @@ class DNADiffusionApp:
182
  }
183
 
184
  def translate_to_protein(self, dna_sequence: str) -> str:
185
- """Translate DNA to protein sequence"""
186
  protein = []
187
 
188
- # Find start codon
189
  start_pos = dna_sequence.find('ATG')
190
  if start_pos == -1:
191
  start_pos = 0
192
 
193
- # Translate from start position
194
  for i in range(start_pos, len(dna_sequence) - 2, 3):
195
  codon = dna_sequence[i:i+3]
196
  if len(codon) == 3:
197
  amino_acid = CODON_TABLE.get(codon, 'X')
198
- if amino_acid == '*': # Stop codon
199
  break
200
  protein.append(amino_acid)
201
 
202
  return ''.join(protein)
203
-
204
- @spaces.GPU(duration=30)
205
- def batch_generate(self, cell_types: List[str], count: int = 5):
206
- """Generate multiple sequences with GPU acceleration"""
207
- logger.info(f"Batch generating {count} sequences")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- results = []
210
- for i in range(count):
211
- cell_type = cell_types[i % len(cell_types)]
212
- sequence = model.generate(cell_type)
213
- analysis = self.analyze_sequence(sequence)
214
-
215
- results.append({
216
- 'id': f'SEQ_{i+1:03d}',
217
- 'cell_type': cell_type,
218
- 'sequence': sequence,
219
- 'analysis': analysis
220
- })
221
 
222
- return results
 
 
 
 
223
 
224
- # Create app instance
225
- app = DNADiffusionApp()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- # Create the Gradio interface
228
  def create_demo():
229
- """Create the Gradio interface"""
230
 
231
- # Custom CSS
232
  css = """
233
  .gradio-container {
234
  font-family: 'Arial', sans-serif;
@@ -239,27 +473,28 @@ def create_demo():
239
  padding: 10px;
240
  border-radius: 5px;
241
  }
 
 
 
 
242
  """
243
 
244
- with gr.Blocks(css=css, title="DNA-Diffusion") as demo:
245
- # Header
246
  gr.Markdown(
247
  """
248
- # 🧬 DNA-Diffusion: AI-Powered Sequence Generation
249
 
250
- Generate cell-type specific DNA sequences using advanced AI models.
251
- This app uses GPU acceleration for optimal performance.
252
  """
253
  )
254
 
255
- # GPU status indicator
256
  gpu_status = gr.Markdown(
257
  f"🖥️ **GPU Status**: {'✅ Available' if SPACES_AVAILABLE else '❌ Not Available (CPU mode)'}"
258
  )
259
 
260
  with gr.Tabs():
261
- # Tab 1: Single Sequence Generation
262
- with gr.TabItem("🎯 Single Generation"):
263
  with gr.Row():
264
  with gr.Column(scale=1):
265
  cell_type = gr.Dropdown(
@@ -301,18 +536,27 @@ def create_demo():
301
  analysis_output = gr.JSON(
302
  label="Sequence Analysis"
303
  )
304
-
305
- # Examples
306
- gr.Examples(
307
- examples=[
308
- ["K562", 1.0, True],
309
- ["GM12878", 5.0, True],
310
- ["HepG2", 3.0, False]
311
- ],
312
- inputs=[cell_type, guidance_scale, enhanced_mode]
313
- )
314
 
315
- # Tab 2: Batch Generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  with gr.TabItem("📦 Batch Generation"):
317
  with gr.Row():
318
  with gr.Column():
@@ -340,28 +584,7 @@ def create_demo():
340
  headers=["ID", "Cell Type", "Length", "GC%", "Tm(°C)"],
341
  label="Batch Results"
342
  )
343
-
344
- download_btn = gr.Button("💾 Download Results")
345
-
346
- # Tab 3: Analysis Tools
347
- with gr.TabItem("🔬 Analysis Tools"):
348
- gr.Markdown("### Paste or generate a sequence to analyze")
349
-
350
- analysis_input = gr.Textbox(
351
- label="DNA Sequence",
352
- placeholder="Paste your DNA sequence here...",
353
- lines=3
354
- )
355
-
356
- analyze_btn = gr.Button("🔍 Analyze")
357
-
358
- with gr.Row():
359
- gc_plot = gr.Plot(label="GC Content Distribution")
360
- restriction_plot = gr.Plot(label="Restriction Sites")
361
-
362
- detailed_analysis = gr.JSON(label="Detailed Analysis")
363
 
364
- # Status bar
365
  status_text = gr.Textbox(
366
  label="Status",
367
  value="Ready to generate sequences...",
@@ -370,14 +593,10 @@ def create_demo():
370
 
371
  # Event handlers
372
  def generate_single(cell_type, guidance_scale, enhanced):
373
- """Handle single sequence generation"""
374
  try:
375
  status_text.value = "🔄 Generating sequence on GPU..."
376
 
377
- # Generate sequence using GPU
378
  sequence = app.generate_with_gpu(cell_type, guidance_scale, enhanced)
379
-
380
- # Analyze the sequence
381
  analysis = app.analyze_sequence(sequence)
382
 
383
  status_text.value = f"✅ Successfully generated sequence for {cell_type}"
@@ -388,76 +607,49 @@ def create_demo():
388
  logger.error(error_msg)
389
  return "", {}, error_msg
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def generate_batch(cell_types, count):
392
- """Handle batch generation"""
393
  if not cell_types:
394
  return None, "❌ Please select at least one cell type"
395
 
396
  try:
397
  status_text.value = "🔄 Generating batch on GPU..."
398
 
399
- # Generate batch
400
- results = app.batch_generate(cell_types, count)
401
-
402
- # Format for dataframe
403
- df_data = []
404
- for r in results:
405
- df_data.append([
406
- r['id'],
407
- r['cell_type'],
408
- r['analysis']['length'],
409
- r['analysis']['gc_content'],
410
- r['analysis']['melting_temp']
411
  ])
412
 
413
  status_text.value = f"✅ Generated {len(results)} sequences"
414
- return df_data, status_text.value
415
 
416
  except Exception as e:
417
  error_msg = f"❌ Error: {str(e)}"
418
  logger.error(error_msg)
419
  return None, error_msg
420
 
421
- def analyze_sequence(sequence):
422
- """Analyze a given sequence"""
423
- if not sequence:
424
- return None, None, {}
425
-
426
- analysis = app.analyze_sequence(sequence.upper())
427
-
428
- # Create GC content plot
429
- import matplotlib.pyplot as plt
430
-
431
- fig1, ax1 = plt.subplots(figsize=(6, 4))
432
- gc = analysis['gc_content']
433
- at = 100 - gc
434
- ax1.bar(['GC%', 'AT%'], [gc, at], color=['#4CAF50', '#2196F3'])
435
- ax1.set_ylabel('Percentage')
436
- ax1.set_title('Nucleotide Composition')
437
-
438
- # Create restriction site plot
439
- fig2, ax2 = plt.subplots(figsize=(8, 3))
440
- if analysis['restriction_sites']:
441
- y_pos = 0
442
- colors = plt.cm.tab10(range(len(analysis['restriction_sites'])))
443
-
444
- for (enzyme, positions), color in zip(analysis['restriction_sites'].items(), colors):
445
- for pos in positions:
446
- ax2.vlines(pos, y_pos - 0.4, y_pos + 0.4, color=color, linewidth=3)
447
- ax2.text(-10, y_pos, enzyme, ha='right', va='center')
448
- y_pos += 1
449
-
450
- ax2.set_xlim(-50, len(sequence))
451
- ax2.set_ylim(-0.5, len(analysis['restriction_sites']) - 0.5)
452
- ax2.set_xlabel('Position (bp)')
453
- ax2.set_title('Restriction Enzyme Sites')
454
- ax2.grid(axis='x', alpha=0.3)
455
- else:
456
- ax2.text(0.5, 0.5, 'No restriction sites found',
457
- ha='center', va='center', transform=ax2.transAxes)
458
-
459
- return fig1, fig2, analysis
460
-
461
  # Connect event handlers
462
  generate_btn.click(
463
  fn=generate_single,
@@ -465,50 +657,44 @@ def create_demo():
465
  outputs=[sequence_output, analysis_output, status_text]
466
  )
467
 
468
- batch_generate_btn.click(
469
- fn=generate_batch,
470
- inputs=[batch_cell_types, batch_count],
471
- outputs=[batch_output, status_text]
472
- )
473
-
474
- analyze_btn.click(
475
- fn=analyze_sequence,
476
- inputs=[analysis_input],
477
- outputs=[gc_plot, restriction_plot, detailed_analysis]
478
  )
479
 
480
- # Auto-analyze generated sequences
481
  sequence_output.change(
482
- fn=lambda seq: app.analyze_sequence(seq) if seq else {},
483
  inputs=[sequence_output],
484
- outputs=[analysis_output]
 
 
 
 
 
 
485
  )
486
 
487
  return demo
488
 
489
- # Main launch function
490
  if __name__ == "__main__":
491
- # Print startup info
492
  logger.info("=" * 50)
493
- logger.info("DNA-Diffusion App Starting...")
494
  logger.info(f"GPU Available: {SPACES_AVAILABLE}")
495
  logger.info(f"Environment: {'Hugging Face Spaces' if os.getenv('SPACE_ID') else 'Local'}")
496
  logger.info("=" * 50)
497
 
498
- # Create and launch the demo
499
  demo = create_demo()
500
 
501
- # Launch with appropriate settings
502
  if os.getenv("SPACE_ID"):
503
- # Running on Hugging Face Spaces
504
  demo.launch(
505
  server_name="0.0.0.0",
506
  server_port=7860,
507
- share=False, # Don't set share=True on Spaces
508
  show_error=True
509
  )
510
  else:
511
- # Local development
512
  demo.launch(
513
  share=True,
514
  show_error=True,
 
1
  """
2
+ DNA-Diffusion Gradio Application with Integrated 3D Viewer
3
+ Combines Gradio interface with HTML-based 3D molecular visualization
4
  """
5
 
6
  import gradio as gr
 
33
  return func
34
  return decorator
35
 
36
+ # DNA Model and genetic code table (same as before)
37
  CODON_TABLE = {
38
  'TTT': 'F', 'TTC': 'F', 'TTA': 'L', 'TTG': 'L',
39
  'TCT': 'S', 'TCC': 'S', 'TCA': 'S', 'TCG': 'S',
 
53
  'GGT': 'G', 'GGC': 'G', 'GGA': 'G', 'GGG': 'G'
54
  }
55
 
 
56
  class DNAModel:
57
  """Mock DNA generation model"""
58
  def generate(self, cell_type: str, num_sequences: int = 1, length: int = 200, guidance_scale: float = 1.0):
 
59
  sequences = []
60
 
61
  for _ in range(num_sequences):
 
62
  sequence = ''.join(random.choice(['A', 'T', 'C', 'G']) for _ in range(length))
63
 
 
64
  if cell_type == "K562":
 
65
  for i in range(0, length-8, 50):
66
  sequence = sequence[:i] + 'GCGCGCGC' + sequence[i+8:]
67
  elif cell_type == "GM12878":
 
68
  for i in range(10, length-8, 60):
69
  sequence = sequence[:i] + 'ATATATAT' + sequence[i+8:]
70
  elif cell_type == "HepG2":
 
71
  for i in range(20, length-12, 70):
72
  sequence = sequence[:i] + 'GCGATCGATCGC' + sequence[i+12:]
73
 
 
75
 
76
  return sequences[0] if num_sequences == 1 else sequences
77
 
 
78
  model = DNAModel()
79
 
 
80
  class DNADiffusionApp:
 
 
81
  def __init__(self):
82
  self.current_sequence = ""
83
  self.current_analysis = {}
 
85
 
86
  @spaces.GPU(duration=60)
87
  def generate_with_gpu(self, cell_type: str, guidance_scale: float = 1.0, use_enhanced: bool = True):
 
88
  logger.info(f"Generating sequence on GPU for cell type: {cell_type}")
89
 
90
  try:
91
+ time.sleep(2)
 
 
 
92
  sequence = model.generate(cell_type, length=200, guidance_scale=guidance_scale)
93
 
 
94
  if use_enhanced:
95
  sequence = self.enhance_sequence(sequence, cell_type)
96
 
 
104
  raise
105
 
106
  def enhance_sequence(self, sequence: str, cell_type: str) -> str:
 
 
107
  enhancers = {
108
+ "K562": "GGGACTTTCC",
109
+ "GM12878": "TGACGTCA",
110
+ "HepG2": "TGTTGGTGG"
111
  }
112
 
113
  if cell_type in enhancers:
 
114
  pos = len(sequence) // 4
115
  enhancer = enhancers[cell_type]
116
  sequence = sequence[:pos] + enhancer + sequence[pos+len(enhancer):]
 
118
  return sequence
119
 
120
  def analyze_sequence(self, sequence: str) -> Dict[str, Any]:
 
121
  if not sequence:
122
  return {}
123
 
 
124
  gc_count = sequence.count('G') + sequence.count('C')
125
  gc_content = (gc_count / len(sequence)) * 100
126
 
 
127
  if len(sequence) < 14:
128
  tm = 4 * (sequence.count('G') + sequence.count('C')) + 2 * (sequence.count('A') + sequence.count('T'))
129
  else:
130
  tm = 81.5 + 0.41 * gc_content - 675 / len(sequence)
131
 
 
132
  restriction_sites = {}
133
  enzymes = {
134
  'EcoRI': 'GAATTC',
 
146
  if positions:
147
  restriction_sites[enzyme] = positions
148
 
 
149
  protein = self.translate_to_protein(sequence)
150
 
151
  return {
 
158
  }
159
 
160
  def translate_to_protein(self, dna_sequence: str) -> str:
 
161
  protein = []
162
 
 
163
  start_pos = dna_sequence.find('ATG')
164
  if start_pos == -1:
165
  start_pos = 0
166
 
 
167
  for i in range(start_pos, len(dna_sequence) - 2, 3):
168
  codon = dna_sequence[i:i+3]
169
  if len(codon) == 3:
170
  amino_acid = CODON_TABLE.get(codon, 'X')
171
+ if amino_acid == '*':
172
  break
173
  protein.append(amino_acid)
174
 
175
  return ''.join(protein)
176
+
177
+ app = DNADiffusionApp()
178
+
179
+ # HTML for 3D Viewer
180
+ HTML_3D_VIEWER = """
181
+ <!DOCTYPE html>
182
+ <html>
183
+ <head>
184
+ <meta charset="UTF-8">
185
+ <style>
186
+ body {
187
+ margin: 0;
188
+ padding: 0;
189
+ background: #000;
190
+ font-family: Arial, sans-serif;
191
+ color: #fff;
192
+ height: 100vh;
193
+ overflow: hidden;
194
+ }
195
+ #viewer-container {
196
+ width: 100%;
197
+ height: 100%;
198
+ position: relative;
199
+ }
200
+ #canvas3d {
201
+ width: 100%;
202
+ height: 100%;
203
+ }
204
+ .controls-panel {
205
+ position: absolute;
206
+ top: 20px;
207
+ right: 20px;
208
+ background: rgba(0,0,0,0.8);
209
+ padding: 20px;
210
+ border-radius: 10px;
211
+ border: 2px solid #00ff88;
212
+ max-width: 300px;
213
+ }
214
+ .controls-panel h3 {
215
+ color: #00ff88;
216
+ margin-top: 0;
217
+ }
218
+ .control-btn {
219
+ background: #00ff88;
220
+ color: #000;
221
+ border: none;
222
+ padding: 8px 16px;
223
+ margin: 5px;
224
+ border-radius: 5px;
225
+ cursor: pointer;
226
+ font-weight: bold;
227
+ }
228
+ .control-btn:hover {
229
+ background: #00cc66;
230
+ }
231
+ .info-display {
232
+ position: absolute;
233
+ bottom: 20px;
234
+ left: 20px;
235
+ background: rgba(0,0,0,0.8);
236
+ padding: 15px;
237
+ border-radius: 8px;
238
+ border: 1px solid #0088ff;
239
+ }
240
+ #sequence-display {
241
+ font-family: monospace;
242
+ color: #00ff88;
243
+ word-break: break-all;
244
+ margin-top: 10px;
245
+ }
246
+ </style>
247
+ </head>
248
+ <body>
249
+ <div id="viewer-container">
250
+ <canvas id="canvas3d"></canvas>
251
 
252
+ <div class="controls-panel">
253
+ <h3>3D View Controls</h3>
254
+ <button class="control-btn" onclick="setViewMode('cartoon')">Cartoon</button>
255
+ <button class="control-btn" onclick="setViewMode('stick')">Stick</button>
256
+ <button class="control-btn" onclick="setViewMode('sphere')">Sphere</button>
257
+ <button class="control-btn" onclick="toggleRotation()">Toggle Rotation</button>
258
+ <button class="control-btn" onclick="resetView()">Reset View</button>
259
+ </div>
 
 
 
 
260
 
261
+ <div class="info-display">
262
+ <strong>Current Sequence:</strong>
263
+ <div id="sequence-display">No sequence loaded</div>
264
+ </div>
265
+ </div>
266
 
267
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/three.js/r128/three.min.js"></script>
268
+ <script>
269
+ let scene, camera, renderer;
270
+ let moleculeGroup;
271
+ let currentSequence = '';
272
+ let autoRotate = true;
273
+ let viewMode = 'cartoon';
274
+
275
+ function init() {
276
+ scene = new THREE.Scene();
277
+ scene.background = new THREE.Color(0x000000);
278
+
279
+ camera = new THREE.PerspectiveCamera(75, window.innerWidth / window.innerHeight, 0.1, 1000);
280
+ camera.position.z = 50;
281
+
282
+ renderer = new THREE.WebGLRenderer({ canvas: document.getElementById('canvas3d'), antialias: true });
283
+ renderer.setSize(window.innerWidth, window.innerHeight);
284
+
285
+ const ambientLight = new THREE.AmbientLight(0x404040, 1.5);
286
+ scene.add(ambientLight);
287
+
288
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 1);
289
+ directionalLight.position.set(50, 50, 50);
290
+ scene.add(directionalLight);
291
+
292
+ moleculeGroup = new THREE.Group();
293
+ scene.add(moleculeGroup);
294
+
295
+ animate();
296
+
297
+ // Listen for sequence updates from parent
298
+ window.addEventListener('message', function(e) {
299
+ if (e.data.type === 'updateSequence') {
300
+ updateSequence(e.data.sequence);
301
+ }
302
+ });
303
+ }
304
+
305
+ function updateSequence(sequence) {
306
+ currentSequence = sequence;
307
+ document.getElementById('sequence-display').textContent = sequence;
308
+ generateDNAStructure(sequence);
309
+ }
310
+
311
+ function generateDNAStructure(sequence) {
312
+ moleculeGroup.clear();
313
+
314
+ const radius = 10;
315
+ const rise = 3.4;
316
+ const basesPerTurn = 10;
317
+ const anglePerBase = (2 * Math.PI) / basesPerTurn;
318
+
319
+ // Create double helix
320
+ const curve1Points = [];
321
+ const curve2Points = [];
322
+
323
+ for (let i = 0; i < sequence.length; i++) {
324
+ const angle = i * anglePerBase;
325
+ const height = i * rise / basesPerTurn;
326
+
327
+ curve1Points.push(new THREE.Vector3(
328
+ radius * Math.cos(angle),
329
+ height,
330
+ radius * Math.sin(angle)
331
+ ));
332
+
333
+ curve2Points.push(new THREE.Vector3(
334
+ radius * Math.cos(angle + Math.PI),
335
+ height,
336
+ radius * Math.sin(angle + Math.PI)
337
+ ));
338
+ }
339
+
340
+ // Create backbone curves
341
+ const curve1 = new THREE.CatmullRomCurve3(curve1Points);
342
+ const curve2 = new THREE.CatmullRomCurve3(curve2Points);
343
+
344
+ const tubeGeometry1 = new THREE.TubeGeometry(curve1, 100, 0.5, 8, false);
345
+ const tubeGeometry2 = new THREE.TubeGeometry(curve2, 100, 0.5, 8, false);
346
+
347
+ const material1 = new THREE.MeshPhongMaterial({ color: 0xff0000 });
348
+ const material2 = new THREE.MeshPhongMaterial({ color: 0x0000ff });
349
+
350
+ moleculeGroup.add(new THREE.Mesh(tubeGeometry1, material1));
351
+ moleculeGroup.add(new THREE.Mesh(tubeGeometry2, material2));
352
+
353
+ // Add base pairs
354
+ for (let i = 0; i < Math.min(sequence.length, 50); i++) {
355
+ const p1 = curve1Points[i];
356
+ const p2 = curve2Points[i];
357
+
358
+ const direction = new THREE.Vector3().subVectors(p2, p1);
359
+ const distance = direction.length();
360
+ direction.normalize();
361
+
362
+ const geometry = new THREE.CylinderGeometry(0.3, 0.3, distance, 8);
363
+ const material = new THREE.MeshPhongMaterial({
364
+ color: getBaseColor(sequence[i])
365
+ });
366
+
367
+ const cylinder = new THREE.Mesh(geometry, material);
368
+ cylinder.position.copy(p1).add(direction.multiplyScalar(distance / 2));
369
+ cylinder.quaternion.setFromUnitVectors(new THREE.Vector3(0, 1, 0), direction);
370
+
371
+ moleculeGroup.add(cylinder);
372
+ }
373
+
374
+ // Center the molecule
375
+ const box = new THREE.Box3().setFromObject(moleculeGroup);
376
+ const center = box.getCenter(new THREE.Vector3());
377
+ moleculeGroup.position.sub(center);
378
+ }
379
+
380
+ function getBaseColor(base) {
381
+ const colors = {
382
+ 'A': 0xff0000,
383
+ 'T': 0x00ff00,
384
+ 'G': 0x0000ff,
385
+ 'C': 0xffff00
386
+ };
387
+ return colors[base] || 0xffffff;
388
+ }
389
+
390
+ function setViewMode(mode) {
391
+ viewMode = mode;
392
+ if (currentSequence) {
393
+ generateDNAStructure(currentSequence);
394
+ }
395
+ }
396
+
397
+ function toggleRotation() {
398
+ autoRotate = !autoRotate;
399
+ }
400
+
401
+ function resetView() {
402
+ camera.position.set(0, 0, 50);
403
+ moleculeGroup.rotation.set(0, 0, 0);
404
+ }
405
+
406
+ function animate() {
407
+ requestAnimationFrame(animate);
408
+
409
+ if (autoRotate) {
410
+ moleculeGroup.rotation.y += 0.01;
411
+ }
412
+
413
+ renderer.render(scene, camera);
414
+ }
415
+
416
+ // Mouse controls
417
+ let isDragging = false;
418
+ let previousMousePosition = { x: 0, y: 0 };
419
+
420
+ document.addEventListener('mousedown', function(e) {
421
+ isDragging = true;
422
+ });
423
+
424
+ document.addEventListener('mouseup', function(e) {
425
+ isDragging = false;
426
+ });
427
+
428
+ document.addEventListener('mousemove', function(e) {
429
+ if (isDragging) {
430
+ const deltaMove = {
431
+ x: e.clientX - previousMousePosition.x,
432
+ y: e.clientY - previousMousePosition.y
433
+ };
434
+
435
+ moleculeGroup.rotation.y += deltaMove.x * 0.01;
436
+ moleculeGroup.rotation.x += deltaMove.y * 0.01;
437
+ }
438
+
439
+ previousMousePosition = {
440
+ x: e.clientX,
441
+ y: e.clientY
442
+ };
443
+ });
444
+
445
+ document.addEventListener('wheel', function(e) {
446
+ camera.position.z += e.deltaY * 0.1;
447
+ camera.position.z = Math.max(10, Math.min(200, camera.position.z));
448
+ });
449
+
450
+ window.addEventListener('resize', function() {
451
+ camera.aspect = window.innerWidth / window.innerHeight;
452
+ camera.updateProjectionMatrix();
453
+ renderer.setSize(window.innerWidth, window.innerHeight);
454
+ });
455
+
456
+ // Initialize
457
+ init();
458
+ </script>
459
+ </body>
460
+ </html>
461
+ """
462
 
 
463
  def create_demo():
464
+ """Create the Gradio interface with integrated 3D viewer"""
465
 
 
466
  css = """
467
  .gradio-container {
468
  font-family: 'Arial', sans-serif;
 
473
  padding: 10px;
474
  border-radius: 5px;
475
  }
476
+ iframe {
477
+ border: none;
478
+ border-radius: 10px;
479
+ }
480
  """
481
 
482
+ with gr.Blocks(css=css, title="DNA-Diffusion with 3D Viewer") as demo:
 
483
  gr.Markdown(
484
  """
485
+ # 🧬 DNA-Diffusion: AI-Powered Sequence Generation with 3D Visualization
486
 
487
+ Generate cell-type specific DNA sequences and visualize them in 3D!
 
488
  """
489
  )
490
 
 
491
  gpu_status = gr.Markdown(
492
  f"🖥️ **GPU Status**: {'✅ Available' if SPACES_AVAILABLE else '❌ Not Available (CPU mode)'}"
493
  )
494
 
495
  with gr.Tabs():
496
+ # Tab 1: Sequence Generation
497
+ with gr.TabItem("🎯 Generate Sequence"):
498
  with gr.Row():
499
  with gr.Column(scale=1):
500
  cell_type = gr.Dropdown(
 
536
  analysis_output = gr.JSON(
537
  label="Sequence Analysis"
538
  )
 
 
 
 
 
 
 
 
 
 
539
 
540
+ # Tab 2: 3D Visualization
541
+ with gr.TabItem("🔬 3D Structure"):
542
+ with gr.Row():
543
+ with gr.Column():
544
+ gr.Markdown("### 3D DNA Structure Visualization")
545
+ gr.Markdown("The 3D viewer shows the double helix structure of your generated DNA sequence.")
546
+
547
+ # HTML component for 3D viewer
548
+ viewer_html = gr.HTML(
549
+ value=f'<iframe src="data:text/html;charset=utf-8,{HTML_3D_VIEWER.replace("#", "%23")}" width="100%" height="600px"></iframe>',
550
+ label="3D Molecular Viewer"
551
+ )
552
+
553
+ # Button to update 3D view
554
+ update_3d_btn = gr.Button(
555
+ "🔄 Update 3D View with Current Sequence",
556
+ variant="secondary"
557
+ )
558
+
559
+ # Tab 3: Batch Generation
560
  with gr.TabItem("📦 Batch Generation"):
561
  with gr.Row():
562
  with gr.Column():
 
584
  headers=["ID", "Cell Type", "Length", "GC%", "Tm(°C)"],
585
  label="Batch Results"
586
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
 
 
588
  status_text = gr.Textbox(
589
  label="Status",
590
  value="Ready to generate sequences...",
 
593
 
594
  # Event handlers
595
  def generate_single(cell_type, guidance_scale, enhanced):
 
596
  try:
597
  status_text.value = "🔄 Generating sequence on GPU..."
598
 
 
599
  sequence = app.generate_with_gpu(cell_type, guidance_scale, enhanced)
 
 
600
  analysis = app.analyze_sequence(sequence)
601
 
602
  status_text.value = f"✅ Successfully generated sequence for {cell_type}"
 
607
  logger.error(error_msg)
608
  return "", {}, error_msg
609
 
610
+ def update_3d_viewer(sequence):
611
+ if not sequence:
612
+ return gr.HTML.update()
613
+
614
+ # Create HTML with embedded sequence data
615
+ html_with_sequence = HTML_3D_VIEWER.replace(
616
+ "window.addEventListener('message'",
617
+ f"updateSequence('{sequence}');\n window.addEventListener('message'"
618
+ )
619
+
620
+ return gr.HTML.update(
621
+ value=f'<iframe src="data:text/html;charset=utf-8,{html_with_sequence.replace("#", "%23")}" width="100%" height="600px"></iframe>'
622
+ )
623
+
624
  def generate_batch(cell_types, count):
 
625
  if not cell_types:
626
  return None, "❌ Please select at least one cell type"
627
 
628
  try:
629
  status_text.value = "🔄 Generating batch on GPU..."
630
 
631
+ results = []
632
+ for i in range(count):
633
+ cell_type = cell_types[i % len(cell_types)]
634
+ sequence = app.generate_with_gpu(cell_type)
635
+ analysis = app.analyze_sequence(sequence)
636
+
637
+ results.append([
638
+ f'SEQ_{i+1:03d}',
639
+ cell_type,
640
+ analysis['length'],
641
+ analysis['gc_content'],
642
+ analysis['melting_temp']
643
  ])
644
 
645
  status_text.value = f"✅ Generated {len(results)} sequences"
646
+ return results, status_text.value
647
 
648
  except Exception as e:
649
  error_msg = f"❌ Error: {str(e)}"
650
  logger.error(error_msg)
651
  return None, error_msg
652
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  # Connect event handlers
654
  generate_btn.click(
655
  fn=generate_single,
 
657
  outputs=[sequence_output, analysis_output, status_text]
658
  )
659
 
660
+ update_3d_btn.click(
661
+ fn=update_3d_viewer,
662
+ inputs=[sequence_output],
663
+ outputs=[viewer_html]
 
 
 
 
 
 
664
  )
665
 
666
+ # Auto-update 3D viewer when sequence is generated
667
  sequence_output.change(
668
+ fn=update_3d_viewer,
669
  inputs=[sequence_output],
670
+ outputs=[viewer_html]
671
+ )
672
+
673
+ batch_generate_btn.click(
674
+ fn=generate_batch,
675
+ inputs=[batch_cell_types, batch_count],
676
+ outputs=[batch_output, status_text]
677
  )
678
 
679
  return demo
680
 
 
681
  if __name__ == "__main__":
 
682
  logger.info("=" * 50)
683
+ logger.info("DNA-Diffusion App with 3D Viewer Starting...")
684
  logger.info(f"GPU Available: {SPACES_AVAILABLE}")
685
  logger.info(f"Environment: {'Hugging Face Spaces' if os.getenv('SPACE_ID') else 'Local'}")
686
  logger.info("=" * 50)
687
 
 
688
  demo = create_demo()
689
 
 
690
  if os.getenv("SPACE_ID"):
 
691
  demo.launch(
692
  server_name="0.0.0.0",
693
  server_port=7860,
694
+ share=False,
695
  show_error=True
696
  )
697
  else:
 
698
  demo.launch(
699
  share=True,
700
  show_error=True,