wbrooks commited on
Commit
bd1c23b
·
1 Parent(s): edfee12

the factory has to actually return the function

Browse files
Files changed (1) hide show
  1. src/embeddings_search.py +7 -4
src/embeddings_search.py CHANGED
@@ -50,6 +50,7 @@ def sbert_query_factory(corpus_embeddings_df: pl.DataFrame, model: SentenceTrans
50
  Returns:
51
  Callable[[str], pl.DataFrame]: Function to compare the query string to the corpus and return results sorted by the cosine similarity.
52
  """
 
53
 
54
  def do_sbert_query(query: str) -> pl.DataFrame:
55
  """
@@ -61,8 +62,9 @@ def sbert_query_factory(corpus_embeddings_df: pl.DataFrame, model: SentenceTrans
61
  Returns:
62
  polars.DataFrame: Corpus documents ranked by their match to the query.
63
  """
64
- search_fun = sbert_query(query, corpus_embeddings_df, model)
65
- return search_fun
 
66
 
67
 
68
  def load_embeddings_dfs(embeddings_dir: str = "block-embeddings") -> pl.DataFrame:
@@ -125,12 +127,13 @@ def create_embeddings_search_function_from_embeddings_df(model_name: str, embedd
125
  Callable[[str], pl.DataFrame]: Function to compare the query string to the corpus and return results sorted by the cosine similarity.
126
 
127
  """
 
128
  # Instantiate the sentence-transformer model:
129
  sentence_model = SentenceTransformer(model_name).to(device = device)
130
-
131
  # import the embeddings CSVs
132
  block_embeddings_df = pl.read_parquet(embeddings_df_path)
133
-
134
  # call the factory to make the search function and return it
135
  return sbert_query_factory(corpus_embeddings_df = block_embeddings_df, model = sentence_model)
136
 
 
50
  Returns:
51
  Callable[[str], pl.DataFrame]: Function to compare the query string to the corpus and return results sorted by the cosine similarity.
52
  """
53
+ print("starting factory")
54
 
55
  def do_sbert_query(query: str) -> pl.DataFrame:
56
  """
 
62
  Returns:
63
  polars.DataFrame: Corpus documents ranked by their match to the query.
64
  """
65
+ return sbert_query(query, corpus_embeddings_df, model)
66
+
67
+ return do_sbert_query
68
 
69
 
70
  def load_embeddings_dfs(embeddings_dir: str = "block-embeddings") -> pl.DataFrame:
 
127
  Callable[[str], pl.DataFrame]: Function to compare the query string to the corpus and return results sorted by the cosine similarity.
128
 
129
  """
130
+ print("starting to build embeddings search")
131
  # Instantiate the sentence-transformer model:
132
  sentence_model = SentenceTransformer(model_name).to(device = device)
133
+ print("instantiated sentence-transformers model")
134
  # import the embeddings CSVs
135
  block_embeddings_df = pl.read_parquet(embeddings_df_path)
136
+ print("read the embeddings to a data frame")
137
  # call the factory to make the search function and return it
138
  return sbert_query_factory(corpus_embeddings_df = block_embeddings_df, model = sentence_model)
139