Spaces:
Running
Running
making sure all the correct arguments are there in the function calls
Browse files- src/tfidf_search.py +3 -3
src/tfidf_search.py
CHANGED
|
@@ -55,7 +55,7 @@ def query_worker(query: str, rownames: list[str], fasttext_model: fasttext.FastT
|
|
| 55 |
|
| 56 |
|
| 57 |
|
| 58 |
-
def query_factory(rownames: list[str],
|
| 59 |
"""
|
| 60 |
Create a function that will compare query text to the documents in the corpus.
|
| 61 |
|
|
@@ -73,7 +73,7 @@ def query_factory(rownames: list[str], dtm_svd: NDArray[np.float64], dtm_svd_mat
|
|
| 73 |
Returns:
|
| 74 |
polars.DataFrame: Results sorted so that the best matches (according to column `score-tfidf`) are listed first.
|
| 75 |
"""
|
| 76 |
-
return query_worker(query, rownames, dtm_svd, dtm_svd_mat, vocab_norm, concentration)
|
| 77 |
|
| 78 |
return do_query
|
| 79 |
|
|
@@ -120,4 +120,4 @@ def create_tfidf_search_function(dtm_df_path: str, vectorizer_path: str, model_n
|
|
| 120 |
dtm_svd = TruncatedSVD(n_components=300)
|
| 121 |
X_svd = dtm_svd.fit_transform(doc_term_mat)
|
| 122 |
|
| 123 |
-
return query_factory(rownames = filenames, dtm_svd = dtm_svd, dtm_svd_mat = X_svd, vocab_norm=vocab_norm,
|
|
|
|
| 55 |
|
| 56 |
|
| 57 |
|
| 58 |
+
def query_factory(rownames: list[str], fasttext_model: fasttext.FastText._FastText, idf: NDArray[np.float64], dtm_svd: NDArray[np.float64], dtm_svd_mat: NDArray[np.float64], vocab_norm: NDArray[np.float64], concentration: float = 10) -> Callable[[str], pl.DataFrame]:
|
| 59 |
"""
|
| 60 |
Create a function that will compare query text to the documents in the corpus.
|
| 61 |
|
|
|
|
| 73 |
Returns:
|
| 74 |
polars.DataFrame: Results sorted so that the best matches (according to column `score-tfidf`) are listed first.
|
| 75 |
"""
|
| 76 |
+
return query_worker(query, rownames, fasttext_model, idf, dtm_svd, dtm_svd_mat, vocab_norm, concentration)
|
| 77 |
|
| 78 |
return do_query
|
| 79 |
|
|
|
|
| 120 |
dtm_svd = TruncatedSVD(n_components=300)
|
| 121 |
X_svd = dtm_svd.fit_transform(doc_term_mat)
|
| 122 |
|
| 123 |
+
return query_factory(rownames = filenames, fasttext_model = fasttext_model, idf = my_idf, dtm_svd = dtm_svd, dtm_svd_mat = X_svd, vocab_norm=vocab_norm, concentration = 30)
|