SmartHeal commited on
Commit
93d7a1f
·
verified ·
1 Parent(s): 9f0e612

Update src/ai_processor.py

Browse files
Files changed (1) hide show
  1. src/ai_processor.py +3 -50
src/ai_processor.py CHANGED
@@ -255,57 +255,10 @@ def load_yolo_model():
255
  model = YOLO(YOLO_MODEL_PATH)
256
  return model
257
 
258
- def load_segmentation_model(path: Optional[str] = None):
259
- """
260
- Robust loader for legacy .h5 models across TF/Keras versions.
261
- Uses global SEG_MODEL_PATH by default.
262
- """
263
- import ast
264
  import tensorflow as tf
265
- tf.config.set_visible_devices([], "GPU")
266
- model_path = path or SEG_MODEL_PATH
267
-
268
- # Attempt 1: tf.keras with safe_mode=False
269
- try:
270
- m = tf.keras.models.load_model(model_path, compile=False, safe_mode=False)
271
- logging.info("✅ Segmentation model loaded (tf.keras, safe_mode=False).")
272
- return m
273
- except Exception as e1:
274
- logging.warning(f"tf.keras load (safe_mode=False) failed: {e1}")
275
-
276
- # Attempt 2: patched InputLayer (drop legacy args; coerce string shapes)
277
- try:
278
- from tensorflow.keras.layers import InputLayer as _KInputLayer
279
- def _InputLayerPatched(*args, **kwargs):
280
- kwargs.pop("batch_shape", None)
281
- kwargs.pop("batch_input_shape", None)
282
- if "shape" in kwargs and isinstance(kwargs["shape"], str):
283
- try:
284
- kwargs["shape"] = tuple(ast.literal_eval(kwargs["shape"]))
285
- except Exception:
286
- kwargs.pop("shape", None)
287
- return _KInputLayer(**kwargs)
288
- m = tf.keras.models.load_model(
289
- model_path,
290
- compile=False,
291
- custom_objects={"InputLayer": _InputLayerPatched},
292
- safe_mode=False,
293
- )
294
- logging.info("✅ Segmentation model loaded (patched InputLayer).")
295
- return m
296
- except Exception as e2:
297
- logging.warning(f"Patched InputLayer load failed: {e2}")
298
-
299
- # Attempt 3: keras 2 shim (tf_keras) if present
300
- try:
301
- import tf_keras
302
- m = tf_keras.models.load_model(model_path, compile=False)
303
- logging.info("✅ Segmentation model loaded (tf_keras compat).")
304
- return m
305
- except Exception as e3:
306
- logging.warning(f"tf_keras load failed or not installed: {e3}")
307
-
308
- raise RuntimeError("Segmentation model could not be loaded; please convert/resave the model.")
309
 
310
  def load_classification_pipeline():
311
  pipe = _import_hf_cls()
 
255
  model = YOLO(YOLO_MODEL_PATH)
256
  return model
257
 
258
+ def load_segmentation_model():
 
 
 
 
 
259
  import tensorflow as tf
260
+ load_model = _import_tf_loader()
261
+ return load_model(SEG_MODEL_PATH, compile=False, custom_objects={'InputLayer': tf.keras.layers.InputLayer})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  def load_classification_pipeline():
264
  pipe = _import_hf_cls()