ColettoG commited on
Commit
2e069a9
·
1 Parent(s): 0eaa0c0

add: better tool call for new gemini

Browse files
Files changed (1) hide show
  1. src/agents/crypto_data/tools.py +19 -1
src/agents/crypto_data/tools.py CHANGED
@@ -1,6 +1,7 @@
1
  import logging
2
  import requests
3
  import json
 
4
  from sklearn.feature_extraction.text import TfidfVectorizer
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  from src.agents.crypto_data.config import Config
@@ -385,6 +386,18 @@ def get_coin_market_cap_tool(coin_name: str) -> str:
385
  return Config.API_ERROR_MESSAGE
386
 
387
 
 
 
 
 
 
 
 
 
 
 
 
 
388
  def get_tools() -> list[Tool]:
389
  """
390
  Build and return the list of LangChain Tools for use in an agent.
@@ -396,6 +409,7 @@ def get_tools() -> list[Tool]:
396
  Tool(
397
  name="get_coin_price",
398
  func=get_coin_price_tool,
 
399
  description=(
400
  "Use this to get the current USD price of a cryptocurrency. "
401
  "Input should be the coin name (e.g. 'bitcoin')."
@@ -404,6 +418,7 @@ def get_tools() -> list[Tool]:
404
  Tool(
405
  name="get_nft_floor_price",
406
  func=get_nft_floor_price_tool,
 
407
  description=(
408
  "Fetch the floor price of an NFT collection in USD. "
409
  "Input should be the NFT name or slug."
@@ -412,6 +427,7 @@ def get_tools() -> list[Tool]:
412
  Tool(
413
  name="get_protocol_tvl",
414
  func=get_protocol_total_value_locked_tool,
 
415
  description=(
416
  "Returns the Total Value Locked (TVL) of a DeFi protocol. "
417
  "Input is the protocol name."
@@ -420,6 +436,7 @@ def get_tools() -> list[Tool]:
420
  Tool(
421
  name="get_fully_diluted_valuation",
422
  func=get_fully_diluted_valuation_tool,
 
423
  description=(
424
  "Get a coin's fully diluted valuation in USD. "
425
  "Input the coin's name."
@@ -428,9 +445,10 @@ def get_tools() -> list[Tool]:
428
  Tool(
429
  name="get_market_cap",
430
  func=get_coin_market_cap_tool,
 
431
  description=(
432
  "Retrieve the market capitalization of a coin in USD. "
433
  "Input is the coin's name."
434
  ),
435
  ),
436
- ]
 
1
  import logging
2
  import requests
3
  import json
4
+ from pydantic import BaseModel, Field
5
  from sklearn.feature_extraction.text import TfidfVectorizer
6
  from sklearn.metrics.pairwise import cosine_similarity
7
  from src.agents.crypto_data.config import Config
 
386
  return Config.API_ERROR_MESSAGE
387
 
388
 
389
+ class _CoinNameArgs(BaseModel):
390
+ coin_name: str = Field(..., description="Name of the cryptocurrency to look up.")
391
+
392
+
393
+ class _NFTNameArgs(BaseModel):
394
+ nft_name: str = Field(..., description="Name or slug of the NFT collection.")
395
+
396
+
397
+ class _ProtocolNameArgs(BaseModel):
398
+ protocol_name: str = Field(..., description="Name of the DeFi protocol.")
399
+
400
+
401
  def get_tools() -> list[Tool]:
402
  """
403
  Build and return the list of LangChain Tools for use in an agent.
 
409
  Tool(
410
  name="get_coin_price",
411
  func=get_coin_price_tool,
412
+ args_schema=_CoinNameArgs,
413
  description=(
414
  "Use this to get the current USD price of a cryptocurrency. "
415
  "Input should be the coin name (e.g. 'bitcoin')."
 
418
  Tool(
419
  name="get_nft_floor_price",
420
  func=get_nft_floor_price_tool,
421
+ args_schema=_NFTNameArgs,
422
  description=(
423
  "Fetch the floor price of an NFT collection in USD. "
424
  "Input should be the NFT name or slug."
 
427
  Tool(
428
  name="get_protocol_tvl",
429
  func=get_protocol_total_value_locked_tool,
430
+ args_schema=_ProtocolNameArgs,
431
  description=(
432
  "Returns the Total Value Locked (TVL) of a DeFi protocol. "
433
  "Input is the protocol name."
 
436
  Tool(
437
  name="get_fully_diluted_valuation",
438
  func=get_fully_diluted_valuation_tool,
439
+ args_schema=_CoinNameArgs,
440
  description=(
441
  "Get a coin's fully diluted valuation in USD. "
442
  "Input the coin's name."
 
445
  Tool(
446
  name="get_market_cap",
447
  func=get_coin_market_cap_tool,
448
+ args_schema=_CoinNameArgs,
449
  description=(
450
  "Retrieve the market capitalization of a coin in USD. "
451
  "Input is the coin's name."
452
  ),
453
  ),
454
+ ]