wbrooks commited on
Commit
bccb1fa
·
1 Parent(s): b23d55c

making sure all the correct arguments are there in the function calls

Browse files
Files changed (1) hide show
  1. src/tfidf_search.py +3 -3
src/tfidf_search.py CHANGED
@@ -55,7 +55,7 @@ def query_worker(query: str, rownames: list[str], fasttext_model: fasttext.FastT
55
 
56
 
57
 
58
- def query_factory(rownames: list[str], dtm_svd: NDArray[np.float64], dtm_svd_mat: NDArray[np.float64], idf: NDArray[np.float64], vocab_norm: NDArray[np.float64], concentration: float = 10) -> Callable[[str], pl.DataFrame]:
59
  """
60
  Create a function that will compare query text to the documents in the corpus.
61
 
@@ -73,7 +73,7 @@ def query_factory(rownames: list[str], dtm_svd: NDArray[np.float64], dtm_svd_mat
73
  Returns:
74
  polars.DataFrame: Results sorted so that the best matches (according to column `score-tfidf`) are listed first.
75
  """
76
- return query_worker(query, rownames, dtm_svd, dtm_svd_mat, vocab_norm, concentration)
77
 
78
  return do_query
79
 
@@ -120,4 +120,4 @@ def create_tfidf_search_function(dtm_df_path: str, vectorizer_path: str, model_n
120
  dtm_svd = TruncatedSVD(n_components=300)
121
  X_svd = dtm_svd.fit_transform(doc_term_mat)
122
 
123
- return query_factory(rownames = filenames, dtm_svd = dtm_svd, dtm_svd_mat = X_svd, vocab_norm=vocab_norm, idf = my_idf, concentration = 30)
 
55
 
56
 
57
 
58
+ def query_factory(rownames: list[str], fasttext_model: fasttext.FastText._FastText, idf: NDArray[np.float64], dtm_svd: NDArray[np.float64], dtm_svd_mat: NDArray[np.float64], vocab_norm: NDArray[np.float64], concentration: float = 10) -> Callable[[str], pl.DataFrame]:
59
  """
60
  Create a function that will compare query text to the documents in the corpus.
61
 
 
73
  Returns:
74
  polars.DataFrame: Results sorted so that the best matches (according to column `score-tfidf`) are listed first.
75
  """
76
+ return query_worker(query, rownames, fasttext_model, idf, dtm_svd, dtm_svd_mat, vocab_norm, concentration)
77
 
78
  return do_query
79
 
 
120
  dtm_svd = TruncatedSVD(n_components=300)
121
  X_svd = dtm_svd.fit_transform(doc_term_mat)
122
 
123
+ return query_factory(rownames = filenames, fasttext_model = fasttext_model, idf = my_idf, dtm_svd = dtm_svd, dtm_svd_mat = X_svd, vocab_norm=vocab_norm, concentration = 30)