Spaces:
Sleeping
Sleeping
| import time | |
| import requests | |
| import json | |
| from volcenginesdkarkruntime import Ark | |
| from util.config_util import read_config as config | |
| from util.config_util import load_json | |
| from util import logger | |
| import volcenginesdkcore | |
| import volcenginesdkark | |
| from volcenginesdkcore.rest import ApiException | |
| from util.logger_util import log_decorate | |
| class DouBaoService: | |
| def __init__(self, model_name): | |
| config = load_json('./conf/config.json') | |
| self.conf = config[f"{model_name}ModelInfo"] | |
| self.client = self.init_client() | |
| self._complete_args = {} | |
| def init_client(self): | |
| base_url = self.conf["BASE_URL"] | |
| ak = self.conf["ACCESS_KEY"] | |
| sk = self.conf["SECRET_KEY"] | |
| # api_key = self.conf["API_KEY"] | |
| client = Ark(ak=ak, sk=sk, base_url=base_url) | |
| # client = Ark(ak=api_key, base_url=base_url) | |
| return client | |
| def get_api_key(self): | |
| configuration = volcenginesdkcore.Configuration() | |
| configuration.ak = self.conf["ACCESS_KEY"] | |
| configuration.sk = self.conf["SECRET_KEY"] | |
| configuration.region = "cn-beijing" | |
| endpoint_id = self.conf["ENDPOINT_ID"] | |
| volcenginesdkcore.Configuration.set_default(configuration) | |
| # use global default configuration | |
| api_instance = volcenginesdkark.ARKApi() | |
| get_api_key_request = volcenginesdkark.GetApiKeyRequest( | |
| duration_seconds=30 * 24 * 3600, | |
| resource_type="endpoint", | |
| resource_ids=[ | |
| endpoint_id | |
| ], | |
| ) | |
| try: | |
| resp = api_instance.get_api_key(get_api_key_request) | |
| return resp.api_key | |
| except ApiException as e: | |
| logger.error(f"Exception when calling api: {e}") | |
| def set_complete_args(self, temperature=None, top_p=None, max_token=None): | |
| if temperature is not None: | |
| self._complete_args["temperature"] = temperature | |
| if top_p is not None: | |
| self._complete_args["top_p"] = top_p | |
| if max_token is not None: | |
| self._complete_args["max_tokens"] = max_token | |
| def form_user_role(self, content): | |
| return {"role": "user", "content": content} | |
| def form_sys_role(self, content): | |
| return {"role": "system", "content": content} | |
| def form_assistant_role(self, content): | |
| return {"role": "assistant", "content": content} | |
| def complete_args(self): | |
| return {"temperature": 0.01, "top_p": 0.7} | |
| def chat_complete(self, messages): | |
| endpoint_id = self.conf["ENDPOINT_ID"] | |
| completion = self.client.chat.completions.create( | |
| model=endpoint_id, | |
| messages=messages, | |
| **self.complete_args | |
| ) | |
| logger.info(f"complete doubao task, id: {completion.id}") | |
| return completion.choices[0].message.content | |
| def prd_to_keypoint(self, prd_content): | |
| role_desc = {"role": "system", "content": PRD2KP_SYS} | |
| messages = [ | |
| role_desc, | |
| {"role": "user", "content": prd_content} | |
| ] | |
| return self.chat_complete(messages) | |
| def prd_to_cases(self, prd_content, case_language="Chinese"): | |
| role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]} | |
| messages = [ | |
| role_desc, | |
| {"role": "user", "content": prd_content} | |
| ] | |
| return self.chat_complete(messages) | |
| def keypoint_to_case(self, key_points): | |
| role_desc = {"role": "system", "content": KP2CASE_SYS} | |
| messages = [ | |
| role_desc, | |
| {"role": "user", "content": key_points} | |
| ] | |
| return self.chat_complete(messages) | |
| def case_merge_together(self, case_suits): | |
| role_desc = {"role": "system", "content": CASE_AGG_SYS} | |
| content_case_suits = "" | |
| for i, case_suit in enumerate(case_suits): | |
| case_suit_expr = json.dumps(case_suit, indent=4, ensure_ascii=False) | |
| content_case_suits += f"来自初级测试工程师{i + 1}的测试用例:\n```json\n{case_suit_expr}\n```\n" | |
| messages = [ | |
| role_desc, | |
| {"role": "user", "content": content_case_suits} | |
| ] | |
| completion = self.chat_complete(messages) | |
| return completion | |
| def cycle_more_case(self, prd_content, case_language="Chinese"): | |
| role_desc = {"role": "system", "content": PRD_CASE_SYS[case_language]} | |
| messages = [ | |
| role_desc, | |
| {"role": "user", "content": PRD_CASE_1[case_language] + prd_content + "\n" + PRD_CASE_2[case_language]} | |
| ] | |
| result = [] | |
| for sys in MORE_CASE_PROMPT[case_language]: | |
| if sys: | |
| messages.append({"role": "user", "content": sys}) | |
| reply = self.chat_complete(messages) | |
| result.append(reply) | |
| messages.append({"role": "assistant", "content": reply}) | |
| time.sleep(10) | |
| return result | |
| if __name__ == "__main__": | |
| cli = DouBaoService("DouBao128Pro") | |
| # print(cli.get_api_key()) | |
| # prd_content = requests.get("https://tosv.byted.org/obj/music-qa-bucket/xmind-test/de3ebc67410c43603034e21bfefa76a0.md").text | |
| # aa = cli.cycle_more_case(prd_content, "English") | |
| # print(aa) | |
| print(cli.chat_complete(messages=[ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": "Introduce LLM shortly."}, | |
| ])) | |