avans06 commited on
Commit
314497f
·
1 Parent(s): 31305bd

Optimize Gallery UI and enable Ctrl+V image pasting

Browse files

- Remove custom upload/delete buttons in favor of native gr.Gallery features.
- Add support for pasting images (Ctrl+V) directly into the gallery.
- Refactor code to remove unused functions and simplify the layout.

- Improve console logging: Group execution times by image to reduce redundancy and improve readability.
- Optimize ZIP handling: Conditionally create face-related ZIPs only when face restoration is enabled. Added logic to remove empty ZIP files (<= 22 bytes) before returning results.
- Fix filename bug: Reset self.modelInUse at the start of inference to prevent model name suffix accumulation across multiple runs.

Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +162 -116
  3. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 📈
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.44.1
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.50.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
app.py CHANGED
@@ -795,6 +795,8 @@ class Upscale:
795
  if not gallery or (not face_restoration and not upscale_model):
796
  raise ValueError("Invalid parameter setting")
797
 
 
 
798
  gallery_len = len(gallery)
799
  print(face_restoration, upscale_model, scale, f"gallery length: {gallery_len}")
800
 
@@ -826,14 +828,20 @@ class Upscale:
826
  files = []
827
  # Create zip files for each output type
828
  unique_id = str(int(time.time())) # Use timestamp for uniqueness
 
 
829
  zip_cropf_path = f"output/{unique_id}_cropped_faces{self.modelInUse}.zip"
830
- zipf_cropf = zipfile.ZipFile(zip_cropf_path, 'w', zipfile.ZIP_DEFLATED)
831
  zip_restoref_path = f"output/{unique_id}_restored_faces{self.modelInUse}.zip"
832
- zipf_restoref = zipfile.ZipFile(zip_restoref_path, 'w', zipfile.ZIP_DEFLATED)
833
  zip_cmp_path = f"output/{unique_id}_cmp{self.modelInUse}.zip"
834
- zipf_cmp = zipfile.ZipFile(zip_cmp_path, 'w', zipfile.ZIP_DEFLATED)
835
  zip_restore_path = f"output/{unique_id}_restored_images{self.modelInUse}.zip"
836
- zipf_restore = zipfile.ZipFile(zip_restore_path, 'w', zipfile.ZIP_DEFLATED)
 
 
 
 
 
 
 
837
 
838
  is_auto_split_upscale = True
839
  # Dictionary to track counters for each filename
@@ -869,7 +877,7 @@ class Upscale:
869
  bg_upsample_img, _ = auto_split_upscale(img_cv2, self.realesrganer.enhance, self.scale) if is_auto_split_upscale else self.realesrganer.enhance(img_cv2, outscale=self.scale)
870
  current_progress += progressRatio/progressTotal;
871
  progress(current_progress, desc=f"Image {gallery_idx:02d}: Background upscaling...")
872
- timer.checkpoint(f"Image {gallery_idx:02d}: Background upscale section")
873
 
874
  if face_restoration and self.face_enhancer:
875
  cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img_cv2, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
@@ -879,45 +887,60 @@ class Upscale:
879
  # save cropped face
880
  save_crop_path = f"output/{basename}_{idx:02d}_cropped_faces{self.modelInUse}.png"
881
  self.imwriteUTF8(save_crop_path, cropped_face)
882
- zipf_cropf.write(save_crop_path, arcname=os.path.basename(save_crop_path))
 
883
  # save restored face
884
  save_restore_path = f"output/{basename}_{idx:02d}_restored_faces{self.modelInUse}.png"
885
  self.imwriteUTF8(save_restore_path, restored_face)
886
- zipf_restoref.write(save_restore_path, arcname=os.path.basename(save_restore_path))
 
887
  # save comparison image
888
  save_cmp_path = f"output/{basename}_{idx:02d}_cmp{self.modelInUse}.png"
889
  cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
890
  self.imwriteUTF8(save_cmp_path, cmp_img)
891
- zipf_cmp.write(save_cmp_path, arcname=os.path.basename(save_cmp_path))
 
892
 
893
  files.append(save_crop_path)
894
  files.append(save_restore_path)
895
  files.append(save_cmp_path)
896
  current_progress += progressRatio/progressTotal;
897
  progress(current_progress, desc=f"Image {gallery_idx:02d}: Face enhancement...")
898
- timer.checkpoint(f"Image {gallery_idx:02d}: Face enhancer section")
899
 
900
  restored_img = bg_upsample_img
901
- timer.report() # Report time for this image
 
902
 
903
  # Handle cases where image processing might still result in None
904
  if restored_img is None:
905
  print(f"Warning: Processing resulted in no image for '{img_path}'. Skipping output.")
906
  continue
 
 
 
907
 
908
  # Determine the file extension for the output image based on user preference and image properties.
909
  if save_as_png:
910
- # Force PNG output for the best quality, as requested by the user.
911
  final_extension = ".png"
912
  else:
913
  # Use original logic: PNG for images with an alpha channel (RGBA), otherwise use the original extension or default to jpg.
914
  final_extension = ".png" if img_mode == "RGBA" else (extension if extension else ".jpg")
 
915
  save_path = f"output/{basename}{self.modelInUse}{final_extension}"
 
 
916
  self.imwriteUTF8(save_path, restored_img)
917
- zipf_restore.write(save_path, arcname=os.path.basename(save_path))
 
918
 
 
919
  restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
920
  files.append(save_path)
 
 
 
 
921
  except RuntimeError as error:
922
  print(f"Runtime Error while processing image {gallery_idx} ({img_path or 'unknown path'}): {error}")
923
  print(traceback.format_exc())
@@ -929,11 +952,31 @@ class Upscale:
929
 
930
  progress(1, desc=f"Processing complete.")
931
  timer.report_all() # Print all recorded times for the whole batch
932
- # Close zip files
933
- zipf_cropf.close()
934
- zipf_restoref.close()
935
- zipf_cmp.close()
936
- zipf_restore.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
937
  except Exception as error:
938
  print(f"Global exception occurred: {error}")
939
  print(traceback.format_exc())
@@ -946,7 +989,7 @@ class Upscale:
946
  torch.cuda.empty_cache()
947
  gc.collect()
948
 
949
- return files, [zip_cropf_path, zip_restoref_path, zip_cmp_path, zip_restore_path] if face_restoration else [zip_restore_path]
950
 
951
 
952
  def find_max_numbers(self, state_dict, findkeys):
@@ -1012,36 +1055,38 @@ class Timer:
1012
  now = time.perf_counter()
1013
  self.checkpoints.append((label, now))
1014
 
1015
- def report(self, is_clear_checkpoints = True):
1016
  # Determine the max label width for alignment
 
 
 
 
 
 
 
 
 
 
 
1017
  max_label_length = max(len(label) for label, _ in self.checkpoints)
1018
 
1019
  prev_time = self.checkpoints[0][1]
1020
  for label, curr_time in self.checkpoints[1:]:
1021
  elapsed = curr_time - prev_time
1022
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
1023
  prev_time = curr_time
1024
 
1025
  if is_clear_checkpoints:
1026
  self.checkpoints.clear()
1027
- self.checkpoint() # Store checkpoints
1028
 
1029
  def report_all(self):
1030
  """Print all recorded checkpoints and total execution time with aligned formatting."""
1031
  print("\n> Execution Time Report:")
1032
-
1033
- # Determine the max label width for alignment
1034
- max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
1035
-
1036
- prev_time = self.start_time
1037
- for label, curr_time in self.checkpoints[1:]:
1038
- elapsed = curr_time - prev_time
1039
- print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
1040
- prev_time = curr_time
1041
-
1042
- total_time = self.checkpoints[-1][1] - self.start_time
1043
- print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
1044
-
1045
  self.checkpoints.clear()
1046
 
1047
  def restart(self):
@@ -1080,73 +1125,6 @@ def limit_gallery(gallery):
1080
  """
1081
  return gallery[:input_images_limit] if input_images_limit > 0 and gallery else gallery
1082
 
1083
- def append_gallery(gallery: list, image: str):
1084
- """
1085
- Append a single image to the gallery while respecting input_images_limit.
1086
-
1087
- Parameters:
1088
- - gallery (list): Existing list of images. If None, initializes an empty list.
1089
- - image (str): The image to be added. If None or empty, no action is taken.
1090
-
1091
- Returns:
1092
- - list: Updated gallery.
1093
- """
1094
- if gallery is None:
1095
- gallery = []
1096
- if not image:
1097
- return gallery, None
1098
-
1099
- if input_images_limit == -1 or len(gallery) < input_images_limit:
1100
- gallery.append(image)
1101
-
1102
- return gallery, None
1103
-
1104
-
1105
- def extend_gallery(gallery: list, images):
1106
- """
1107
- Extend the gallery with new images while respecting the input_images_limit.
1108
-
1109
- Parameters:
1110
- - gallery (list): Existing list of images. If None, initializes an empty list.
1111
- - images (list): New images to be added. If None, defaults to an empty list.
1112
-
1113
- Returns:
1114
- - list: Updated gallery with the new images added.
1115
- """
1116
- if gallery is None:
1117
- gallery = []
1118
- if not images:
1119
- return gallery
1120
-
1121
- # Add new images to the gallery
1122
- gallery.extend(images)
1123
-
1124
- # Trim gallery to the specified limit, if applicable
1125
- if input_images_limit > 0:
1126
- gallery = gallery[:input_images_limit]
1127
-
1128
- return gallery
1129
-
1130
- def remove_image_from_gallery(gallery: list, selected_image: str):
1131
- """
1132
- Removes a selected image from the gallery if it exists.
1133
-
1134
- Args:
1135
- gallery (list): The current list of images in the gallery.
1136
- selected_image (str): The image to be removed, represented as a string
1137
- that needs to be parsed into a tuple.
1138
-
1139
- Returns:
1140
- list: The updated gallery after removing the selected image.
1141
- """
1142
- if not gallery or not selected_image:
1143
- return gallery
1144
-
1145
- selected_image = ast.literal_eval(selected_image) # Use ast.literal_eval to parse text into a tuple in remove_image_from_gallery.
1146
- # Remove the selected image from the gallery
1147
- if selected_image in gallery:
1148
- gallery.remove(selected_image)
1149
- return gallery
1150
 
1151
  def main():
1152
  if torch.cuda.is_available():
@@ -1159,6 +1137,7 @@ def main():
1159
  # https://github.com/CompVis/stable-diffusion/issues/69#issuecomment-1260722801
1160
  torch.backends.cudnn.enabled = True
1161
  torch.backends.cudnn.benchmark = True
 
1162
  # Ensure the target directory exists
1163
  os.makedirs('output', exist_ok=True)
1164
 
@@ -1182,6 +1161,74 @@ def main():
1182
  text-align: left !important;
1183
  width: 55.5% !important;
1184
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
  """
1186
 
1187
  upscale = Upscale()
@@ -1209,12 +1256,14 @@ def main():
1209
  with gr.Row():
1210
  with gr.Column(variant="panel"):
1211
  submit = gr.Button(value="Submit", variant="primary", size="lg")
1212
- # Create an Image component for uploading images
1213
- input_image = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", format="png", height=150)
1214
- with gr.Row():
1215
- upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
1216
- remove_button = gr.Button("Remove Selected Image", size="sm")
1217
- input_gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images" + (f"(The online environment image limit is {input_images_limit})" if input_images_limit > 0 else ""))
 
 
1218
  face_model = gr.Dropdown([None]+list(face_models.keys()), type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model.")
1219
  upscale_model = gr.Dropdown([None]+list(typed_upscale_models.keys()), type="value", value='SRVGG, realesr-general-x4v3.pth', label='UpScale version')
1220
  upscale_scale = gr.Number(label="Rescaling factor", value=4)
@@ -1225,17 +1274,11 @@ def main():
1225
  # Add a checkbox to always save the output as a PNG file for the best quality.
1226
  save_as_png = gr.Checkbox(label="Always save output as PNG", value=True, info="If enabled, all output images will be saved in PNG format to ensure the best quality. If disabled, the format will be determined automatically (PNG for images with transparency, otherwise JPG).")
1227
 
1228
- # Define the event listener to add the uploaded image to the gallery
1229
- input_image.change(append_gallery, inputs=[input_gallery, input_image], outputs=[input_gallery, input_image])
1230
- # When the upload button is clicked, add the new images to the gallery
1231
- upload_button.upload(extend_gallery, inputs=[input_gallery, upload_button], outputs=input_gallery)
1232
  # Event to update the selected image when an image is clicked in the gallery
1233
  selected_image = gr.Textbox(label="Selected Image", visible=False)
1234
  input_gallery.select(get_selection_from_gallery, inputs=None, outputs=selected_image)
1235
  # Trigger update when gallery changes
1236
  input_gallery.change(limit_gallery, input_gallery, input_gallery)
1237
- # Event to remove a selected image from the gallery
1238
- remove_button.click(remove_image_from_gallery, inputs=[input_gallery, selected_image], outputs=input_gallery)
1239
 
1240
  with gr.Row():
1241
  clear = gr.ClearButton(
@@ -1257,8 +1300,8 @@ def main():
1257
  # Generate output array
1258
  output_arr = []
1259
  for file_name in example_list:
1260
- output_arr.append([file_name,])
1261
- gr.Examples(output_arr, inputs=[input_image,], examples_per_page=20)
1262
  with gr.Row(variant="panel"):
1263
  # Convert to Markdown table
1264
  header = "| Face Model Name | Info | Download URL |\n|------------|------|--------------|"
@@ -1288,9 +1331,12 @@ def main():
1288
  ],
1289
  outputs=[gallerys, outputs],
1290
  )
 
 
 
1291
 
1292
  demo.queue(default_concurrency_limit=1)
1293
- demo.launch(inbrowser=True)
1294
 
1295
 
1296
  if __name__ == "__main__":
 
795
  if not gallery or (not face_restoration and not upscale_model):
796
  raise ValueError("Invalid parameter setting")
797
 
798
+ self.modelInUse = ""
799
+
800
  gallery_len = len(gallery)
801
  print(face_restoration, upscale_model, scale, f"gallery length: {gallery_len}")
802
 
 
828
  files = []
829
  # Create zip files for each output type
830
  unique_id = str(int(time.time())) # Use timestamp for uniqueness
831
+
832
+ # Define zip file paths
833
  zip_cropf_path = f"output/{unique_id}_cropped_faces{self.modelInUse}.zip"
 
834
  zip_restoref_path = f"output/{unique_id}_restored_faces{self.modelInUse}.zip"
 
835
  zip_cmp_path = f"output/{unique_id}_cmp{self.modelInUse}.zip"
 
836
  zip_restore_path = f"output/{unique_id}_restored_images{self.modelInUse}.zip"
837
+
838
+ # Initialize Zip Objects conditionally
839
+ # Only create face-related zips if face restoration is actually enabled
840
+ zipf_cropf = zipfile.ZipFile(zip_cropf_path, 'w', zipfile.ZIP_DEFLATED) if face_restoration else None
841
+ zipf_restoref = zipfile.ZipFile(zip_restoref_path, 'w', zipfile.ZIP_DEFLATED) if face_restoration else None
842
+ zipf_cmp = zipfile.ZipFile(zip_cmp_path, 'w', zipfile.ZIP_DEFLATED) if face_restoration else None
843
+ # Always attempt to create the main restoration zip
844
+ zipf_restore = zipfile.ZipFile(zip_restore_path, 'w', zipfile.ZIP_DEFLATED)
845
 
846
  is_auto_split_upscale = True
847
  # Dictionary to track counters for each filename
 
877
  bg_upsample_img, _ = auto_split_upscale(img_cv2, self.realesrganer.enhance, self.scale) if is_auto_split_upscale else self.realesrganer.enhance(img_cv2, outscale=self.scale)
878
  current_progress += progressRatio/progressTotal;
879
  progress(current_progress, desc=f"Image {gallery_idx:02d}: Background upscaling...")
880
+ timer.checkpoint("Background upscale")
881
 
882
  if face_restoration and self.face_enhancer:
883
  cropped_faces, restored_aligned, bg_upsample_img = self.face_enhancer.enhance(img_cv2, has_aligned=False, only_center_face=face_detection_only_center, paste_back=True, bg_upsample_img=bg_upsample_img, eye_dist_threshold=face_detection_threshold)
 
887
  # save cropped face
888
  save_crop_path = f"output/{basename}_{idx:02d}_cropped_faces{self.modelInUse}.png"
889
  self.imwriteUTF8(save_crop_path, cropped_face)
890
+ if zipf_cropf:
891
+ zipf_cropf.write(save_crop_path, arcname=os.path.basename(save_crop_path))
892
  # save restored face
893
  save_restore_path = f"output/{basename}_{idx:02d}_restored_faces{self.modelInUse}.png"
894
  self.imwriteUTF8(save_restore_path, restored_face)
895
+ if zipf_restoref:
896
+ zipf_restoref.write(save_restore_path, arcname=os.path.basename(save_restore_path))
897
  # save comparison image
898
  save_cmp_path = f"output/{basename}_{idx:02d}_cmp{self.modelInUse}.png"
899
  cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
900
  self.imwriteUTF8(save_cmp_path, cmp_img)
901
+ if zipf_cmp:
902
+ zipf_cmp.write(save_cmp_path, arcname=os.path.basename(save_cmp_path))
903
 
904
  files.append(save_crop_path)
905
  files.append(save_restore_path)
906
  files.append(save_cmp_path)
907
  current_progress += progressRatio/progressTotal;
908
  progress(current_progress, desc=f"Image {gallery_idx:02d}: Face enhancement...")
909
+ timer.checkpoint("Face enhancement")
910
 
911
  restored_img = bg_upsample_img
912
+ # Report time for this image with a Header
913
+ timer.report(header=f"[Image {gallery_idx:02d} Stats]")
914
 
915
  # Handle cases where image processing might still result in None
916
  if restored_img is None:
917
  print(f"Warning: Processing resulted in no image for '{img_path}'. Skipping output.")
918
  continue
919
+
920
+ # Record the timestamp before I/O starts
921
+ timer.checkpoint("I/O preparation")
922
 
923
  # Determine the file extension for the output image based on user preference and image properties.
924
  if save_as_png:
 
925
  final_extension = ".png"
926
  else:
927
  # Use original logic: PNG for images with an alpha channel (RGBA), otherwise use the original extension or default to jpg.
928
  final_extension = ".png" if img_mode == "RGBA" else (extension if extension else ".jpg")
929
+
930
  save_path = f"output/{basename}{self.modelInUse}{final_extension}"
931
+
932
+ # Execute saving
933
  self.imwriteUTF8(save_path, restored_img)
934
+ if zipf_restore:
935
+ zipf_restore.write(save_path, arcname=os.path.basename(save_path))
936
 
937
+ # Color conversion
938
  restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
939
  files.append(save_path)
940
+
941
+ # --- Record saving end time and report ---
942
+ timer.checkpoint("File Saving (IO/CPU)")
943
+ timer.report(header=f"[Image {gallery_idx:02d} I/O]")
944
  except RuntimeError as error:
945
  print(f"Runtime Error while processing image {gallery_idx} ({img_path or 'unknown path'}): {error}")
946
  print(traceback.format_exc())
 
952
 
953
  progress(1, desc=f"Processing complete.")
954
  timer.report_all() # Print all recorded times for the whole batch
955
+
956
+ # Close zip files and clean up empty ones
957
+ final_zip_paths = []
958
+ zips_to_process = [
959
+ (zipf_cropf, zip_cropf_path),
960
+ (zipf_restoref, zip_restoref_path),
961
+ (zipf_cmp, zip_cmp_path),
962
+ (zipf_restore, zip_restore_path)
963
+ ]
964
+
965
+ for zf, path in zips_to_process:
966
+ if zf:
967
+ zf.close()
968
+ # Check if the zip file actually contains files.
969
+ # A standard empty zip file is 22 bytes (End of central directory record).
970
+ # If it's empty, we delete it and don't return it.
971
+ if os.path.exists(path):
972
+ if os.path.getsize(path) > 22:
973
+ final_zip_paths.append(path)
974
+ else:
975
+ try:
976
+ os.remove(path)
977
+ except OSError:
978
+ pass
979
+
980
  except Exception as error:
981
  print(f"Global exception occurred: {error}")
982
  print(traceback.format_exc())
 
989
  torch.cuda.empty_cache()
990
  gc.collect()
991
 
992
+ return files, final_zip_paths
993
 
994
 
995
  def find_max_numbers(self, state_dict, findkeys):
 
1055
  now = time.perf_counter()
1056
  self.checkpoints.append((label, now))
1057
 
1058
+ def report(self, header=None, is_clear_checkpoints=True):
1059
  # Determine the max label width for alignment
1060
+ # If there are no checkpoints to report, skip
1061
+ if len(self.checkpoints) <= 1:
1062
+ return
1063
+
1064
+ # Print header if provided (e.g., "[Image 00]")
1065
+ if header:
1066
+ print(header)
1067
+ indent = " " # Indent detailed logs if header exists
1068
+ else:
1069
+ indent = ""
1070
+
1071
  max_label_length = max(len(label) for label, _ in self.checkpoints)
1072
 
1073
  prev_time = self.checkpoints[0][1]
1074
  for label, curr_time in self.checkpoints[1:]:
1075
  elapsed = curr_time - prev_time
1076
+ print(f"{indent}{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
1077
  prev_time = curr_time
1078
 
1079
  if is_clear_checkpoints:
1080
  self.checkpoints.clear()
1081
+ self.checkpoint("Loop Start/Reset") # Reset start point
1082
 
1083
  def report_all(self):
1084
  """Print all recorded checkpoints and total execution time with aligned formatting."""
1085
  print("\n> Execution Time Report:")
1086
+ # Use current time (perf_counter) as the end point, instead of the last checkpoint
1087
+ end_time = time.perf_counter()
1088
+ total_time = end_time - self.start_time
1089
+ print(f"Total Execution Time: {total_time:.3f} seconds\n")
 
 
 
 
 
 
 
 
 
1090
  self.checkpoints.clear()
1091
 
1092
  def restart(self):
 
1125
  """
1126
  return gallery[:input_images_limit] if input_images_limit > 0 and gallery else gallery
1127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1128
 
1129
  def main():
1130
  if torch.cuda.is_available():
 
1137
  # https://github.com/CompVis/stable-diffusion/issues/69#issuecomment-1260722801
1138
  torch.backends.cudnn.enabled = True
1139
  torch.backends.cudnn.benchmark = True
1140
+ print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
1141
  # Ensure the target directory exists
1142
  os.makedirs('output', exist_ok=True)
1143
 
 
1161
  text-align: left !important;
1162
  width: 55.5% !important;
1163
  }
1164
+
1165
+ /* Make the Dropdown options display more compactly */
1166
+ .tag-dropdown span.svelte-1f354aw {
1167
+ font-family: monospace;
1168
+ }
1169
+ /* Add hover effect to Gallery to indicate it is an interactive area */
1170
+ #input_gallery:hover {
1171
+ border-color: var(--color-accent) !important;
1172
+ box-shadow: 0 0 10px rgba(0,0,0,0.1);
1173
+ }
1174
+ """
1175
+
1176
+ # JavaScript to handle Ctrl+V paste for MULTIPLE files ONLY when hovering over the gallery
1177
+ paste_js = """
1178
+ function initPaste() {
1179
+ document.addEventListener('paste', function(e) {
1180
+ // 1. First find the Gallery component
1181
+ const gallery = document.getElementById('input_gallery');
1182
+ if (!gallery) return;
1183
+
1184
+ // 2. Check if mouse is hovering over the Gallery
1185
+ // If mouse is not over the gallery, ignore this paste event
1186
+ if (!gallery.matches(':hover')) {
1187
+ return;
1188
+ }
1189
+
1190
+ const clipboardData = e.clipboardData || e.originalEvent.clipboardData;
1191
+ if (!clipboardData) return;
1192
+
1193
+ const items = clipboardData.items;
1194
+ const files = [];
1195
+
1196
+ // 3. Check clipboard content
1197
+ for (let i = 0; i < items.length; i++) {
1198
+ if (items[i].kind === 'file' && items[i].type.startsWith('image/')) {
1199
+ files.push(items[i].getAsFile());
1200
+ }
1201
+ }
1202
+
1203
+ // 4. Check file list (Copied files from OS)
1204
+ if (files.length === 0 && clipboardData.files.length > 0) {
1205
+ for (let i = 0; i < clipboardData.files.length; i++) {
1206
+ if (clipboardData.files[i].type.startsWith('image/')) {
1207
+ files.push(clipboardData.files[i]);
1208
+ }
1209
+ }
1210
+ }
1211
+
1212
+ if (files.length === 0) return;
1213
+
1214
+ // 5. Execute upload logic
1215
+ // Find input inside the gallery component
1216
+ const uploadInput = gallery.querySelector('input[type="file"]');
1217
+
1218
+ if (uploadInput) {
1219
+ e.preventDefault();
1220
+ e.stopPropagation();
1221
+
1222
+ const dataTransfer = new DataTransfer();
1223
+ files.forEach(file => dataTransfer.items.add(file));
1224
+
1225
+ uploadInput.files = dataTransfer.files;
1226
+
1227
+ // Trigger Gradio update
1228
+ uploadInput.dispatchEvent(new Event('change', { bubbles: true }));
1229
+ }
1230
+ });
1231
+ }
1232
  """
1233
 
1234
  upscale = Upscale()
 
1256
  with gr.Row():
1257
  with gr.Column(variant="panel"):
1258
  submit = gr.Button(value="Submit", variant="primary", size="lg")
1259
+ input_gallery = gr.Gallery(columns=5, rows=5, interactive=True, height=500, label="Gallery that displaying a grid of images" + (f"(The online environment image limit is {input_images_limit})" if input_images_limit > 0 else ""), elem_id="input_gallery")
1260
+ gr.Markdown(
1261
+ """
1262
+ <div style="text-align: right; font-size: 0.9em; color: gray;">
1263
+ 💡 Tip: Hover over the gallery and press <b>Ctrl+V</b> to paste images.
1264
+ </div>
1265
+ """
1266
+ )
1267
  face_model = gr.Dropdown([None]+list(face_models.keys()), type="value", value='GFPGANv1.4.pth', label='Face Restoration version', info="Face Restoration and RealESR can be freely combined in different ways, or one can be set to \"None\" to use only the other model. Face Restoration is primarily used for face restoration in real-life images, while RealESR serves as a background restoration model.")
1268
  upscale_model = gr.Dropdown([None]+list(typed_upscale_models.keys()), type="value", value='SRVGG, realesr-general-x4v3.pth', label='UpScale version')
1269
  upscale_scale = gr.Number(label="Rescaling factor", value=4)
 
1274
  # Add a checkbox to always save the output as a PNG file for the best quality.
1275
  save_as_png = gr.Checkbox(label="Always save output as PNG", value=True, info="If enabled, all output images will be saved in PNG format to ensure the best quality. If disabled, the format will be determined automatically (PNG for images with transparency, otherwise JPG).")
1276
 
 
 
 
 
1277
  # Event to update the selected image when an image is clicked in the gallery
1278
  selected_image = gr.Textbox(label="Selected Image", visible=False)
1279
  input_gallery.select(get_selection_from_gallery, inputs=None, outputs=selected_image)
1280
  # Trigger update when gallery changes
1281
  input_gallery.change(limit_gallery, input_gallery, input_gallery)
 
 
1282
 
1283
  with gr.Row():
1284
  clear = gr.ClearButton(
 
1300
  # Generate output array
1301
  output_arr = []
1302
  for file_name in example_list:
1303
+ output_arr.append([[file_name],])
1304
+ gr.Examples(output_arr, inputs=[input_gallery,], examples_per_page=20)
1305
  with gr.Row(variant="panel"):
1306
  # Convert to Markdown table
1307
  header = "| Face Model Name | Info | Download URL |\n|------------|------|--------------|"
 
1331
  ],
1332
  outputs=[gallerys, outputs],
1333
  )
1334
+
1335
+ # Load the JavaScript
1336
+ demo.load(None, None, None, js=paste_js)
1337
 
1338
  demo.queue(default_concurrency_limit=1)
1339
+ demo.launch(inbrowser=True) # , css = css
1340
 
1341
 
1342
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  --extra-index-url https://download.pytorch.org/whl/cu128
2
 
3
- gradio
4
 
5
  basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib
 
1
  --extra-index-url https://download.pytorch.org/whl/cu128
2
 
3
+ gradio==5.50.0
4
 
5
  basicsr @ git+https://github.com/avan06/BasicSR
6
  facexlib @ git+https://github.com/avan06/facexlib