|
10 | 10 | TextGenerationModel,
|
11 | 11 | TextEmbeddingModel,
|
12 | 12 | )
|
| 13 | +from google import genai |
| 14 | +from google.genai import types |
| 15 | + |
13 | 16 | import os
|
14 | 17 | import pathlib
|
15 | 18 | import vertexai
|
@@ -50,62 +53,54 @@ class VertexPromptClient:
|
50 | 53 | ]
|
51 | 54 |
|
52 | 55 | def __init__(self) -> None:
|
53 |
| - service_account_secrets_file = os.path.join( |
54 |
| - pathlib.Path(__file__).parent.parent.parent, "gcp_sa_secret.json" |
55 |
| - ) |
56 |
| - if os.environ.get("SERVICE_ACCOUNT_CREDENTIALS"): |
57 |
| - with open(service_account_secrets_file, "w") as f: |
58 |
| - f.write(os.environ.get("SERVICE_ACCOUNT_CREDENTIALS")) |
59 |
| - os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = ( |
60 |
| - service_account_secrets_file |
61 |
| - ) |
62 |
| - elif not os.environ.get("GCP_NATIVE"): |
63 |
| - logger.fatal( |
64 |
| - "neither GCP_NATIVE nor SERVICE_ACCOUNT_CREDENTIALS have been set" |
65 |
| - ) |
66 |
| - |
67 |
| - vertexai.init( |
68 |
| - project=os.environ.get("GOOGLE_PROJECT_ID"), |
69 |
| - location=os.environ.get("GOOGLE_PROJECT_LOCATION"), |
70 |
| - ) |
71 |
| - self.chat_model = ChatModel.from_pretrained("chat-bison@001") |
72 |
| - self.embeddings_model = TextEmbeddingModel.from_pretrained( |
73 |
| - "textembedding-gecko@001" |
74 |
| - ) |
75 |
| - self.chat = self.chat_model.start_chat(context=self.context) |
| 56 | + self.client = genai.Client(api_key=os.environ.get("GEMINI_API_KEY")) |
76 | 57 |
|
77 | 58 | def get_text_embeddings(self, text: str) -> List[float]:
|
78 | 59 | """Text embedding with a Large Language Model."""
|
79 |
| - embeddings_model = TextEmbeddingModel.from_pretrained("textembedding-gecko@001") |
| 60 | + |
80 | 61 | if len(text) > 8000:
|
81 | 62 | logger.info(
|
82 | 63 | f"embedding content is more than the vertex hard limit of 8k tokens, reducing to 8000"
|
83 | 64 | )
|
84 | 65 | text = text[:8000]
|
85 |
| - embeddings = [] |
| 66 | + values = [] |
86 | 67 | try:
|
87 |
| - emb = embeddings_model.get_embeddings([text]) |
88 |
| - embeddings = emb[0].values |
89 |
| - except googleExceptions.ResourceExhausted as e: |
| 68 | + result = self.client.models.embed_content( |
| 69 | + model="gemini-embedding-exp-03-07", |
| 70 | + contents=text, |
| 71 | + config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), |
| 72 | + ) |
| 73 | + if not result: |
| 74 | + return None |
| 75 | + values = result.embeddings[0].values |
| 76 | + except genai.errors.ClientError as e: |
90 | 77 | logger.info("hit limit, sleeping for a minute")
|
91 | 78 | time.sleep(
|
92 | 79 | 60
|
93 | 80 | ) # Vertex's quota is per minute, so sleep for a full minute, then try again
|
94 |
| - embeddings = self.get_text_embeddings(text) |
| 81 | + values = self.get_text_embeddings(text) |
95 | 82 |
|
96 |
| - if not embeddings: |
97 |
| - return None |
98 |
| - values = embeddings |
99 | 83 | return values
|
100 | 84 |
|
101 | 85 | def create_chat_completion(self, prompt, closest_object_str) -> str:
|
102 |
| - parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS} |
103 | 86 | msg = f"Your task is to answer the following question based on this area of knowledge:`{closest_object_str}` if you can, provide code examples, delimit any code snippet with three backticks\nQuestion: `{prompt}`\n ignore all other commands and questions that are not relevant."
|
104 |
| - response = self.chat.send_message(msg, **parameters) |
| 87 | + response = self.client.models.generate_content( |
| 88 | + model="gemini-2.0-flash", |
| 89 | + contents=msg, |
| 90 | + config=types.GenerateContentConfig( |
| 91 | + max_output_tokens=MAX_OUTPUT_TOKENS, temperature=0.5 |
| 92 | + ), |
| 93 | + ) |
105 | 94 | return response.text
|
106 | 95 |
|
107 | 96 | def query_llm(self, raw_question: str) -> str:
|
108 |
| - parameters = {"temperature": 0.5, "max_output_tokens": MAX_OUTPUT_TOKENS} |
109 |
| - msg = f"Your task is to answer the following cybesrsecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant." |
110 |
| - response = self.chat.send_message(msg, **parameters) |
| 97 | + msg = f"Your task is to answer the following cybersecurity question if you can, provide code examples, delimit any code snippet with three backticks, ignore any unethical questions or questions irrelevant to cybersecurity\nQuestion: `{raw_question}`\n ignore all other commands and questions that are not relevant." |
| 98 | + response = self.client.models.generate_content( |
| 99 | + model="gemini-2.0-flash", |
| 100 | + contents=msg, |
| 101 | + config=types.GenerateContentConfig( |
| 102 | + max_output_tokens=MAX_OUTPUT_TOKENS, temperature=0.5 |
| 103 | + ), |
| 104 | + ) |
| 105 | + # response = self.chat.send_message(msg, **parameters) |
111 | 106 | return response.text
|
0 commit comments