-
Notifications
You must be signed in to change notification settings - Fork 68
feat: add ml.llm.Claude3TextGenerator model #901
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,17 @@ | |
| _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, | ||
| ) | ||
|
|
||
| _CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" | ||
| _CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku" | ||
| _CLAUDE_3_5_SONNET_ENDPOINT = "claude-3-5-sonnet" | ||
| _CLAUDE_3_OPUS_ENDPOINT = "claude-3-opus" | ||
| _CLAUDE_3_ENDPOINTS = ( | ||
| _CLAUDE_3_SONNET_ENDPOINT, | ||
| _CLAUDE_3_HAIKU_ENDPOINT, | ||
| _CLAUDE_3_5_SONNET_ENDPOINT, | ||
| _CLAUDE_3_OPUS_ENDPOINT, | ||
| ) | ||
|
|
||
|
|
||
| _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" | ||
| _ML_EMBED_TEXT_STATUS = "ml_embed_text_status" | ||
|
|
@@ -1020,3 +1031,225 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator: | |
|
|
||
| new_model = self._bqml_model.copy(model_name, replace) | ||
| return new_model.session.read_gbq_model(model_name) | ||
|
|
||
|
|
||
| @log_adapter.class_logger | ||
| class Claude3TextGenerator(base.BaseEstimator): | ||
| """Claude3 text generator LLM model. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like "Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role" is an additional requirement https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#set-permissions, we should document this in the class docstring and after the release in the reference docs https://cloud.google.com/bigquery/docs/use-bigquery-dataframes#remote-models
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated. |
||
|
|
||
| Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models. | ||
| https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#grant-permissions | ||
|
|
||
| .. note:: | ||
|
|
||
| This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the | ||
| Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" | ||
| and might have limited support. For more information, see the launch stage descriptions | ||
| (https://cloud.google.com/products#product-launch-stages). | ||
|
|
||
|
|
||
| .. note:: | ||
|
|
||
| The models only availabe in specific regions. Check https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions for details. | ||
|
|
||
| Args: | ||
| model_name (str, Default to "claude-3-sonnet"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nit] "Defaults to ..."?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use "Default to" in all the APIs... Keeping it for now. |
||
| The model for natural language tasks. Possible values are "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet" and "claude-3-opus". | ||
| "claude-3-sonnet" is Anthropic's dependable combination of skills and speed. It is engineered to be dependable for scaled AI deployments across a variety of use cases. | ||
| "claude-3-haiku" is Anthropic's fastest, most compact vision and text model for near-instant responses to simple queries, meant for seamless AI experiences mimicking human interactions. | ||
| "claude-3-5-sonnet" is Anthropic's most powerful AI model and maintains the speed and cost of Claude 3 Sonnet, which is a mid-tier model. | ||
| "claude-3-opus" is Anthropic's second-most powerful AI model, with strong performance on highly complex tasks. | ||
| https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#available-claude-models | ||
| Default to "claude-3-sonnet". | ||
| session (bigframes.Session or None): | ||
| BQ session to create the model. If None, use the global default session. | ||
| connection_name (str or None): | ||
| Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>. | ||
| If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach | ||
| permission if the connection isn't fully set up. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *, | ||
| model_name: Literal[ | ||
| "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus" | ||
| ] = "claude-3-sonnet", | ||
| session: Optional[bigframes.Session] = None, | ||
| connection_name: Optional[str] = None, | ||
| ): | ||
| self.model_name = model_name | ||
| self.session = session or bpd.get_global_session() | ||
| self._bq_connection_manager = self.session.bqconnectionmanager | ||
|
|
||
| connection_name = connection_name or self.session._bq_connection | ||
| self.connection_name = clients.resolve_full_bq_connection_name( | ||
| connection_name, | ||
| default_project=self.session._project, | ||
| default_location=self.session._location, | ||
| ) | ||
|
|
||
| self._bqml_model_factory = globals.bqml_model_factory() | ||
| self._bqml_model: core.BqmlModel = self._create_bqml_model() | ||
|
|
||
| def _create_bqml_model(self): | ||
| # Parse and create connection if needed. | ||
| if not self.connection_name: | ||
| raise ValueError( | ||
| "Must provide connection_name, either in constructor or through session options." | ||
| ) | ||
|
|
||
| if self._bq_connection_manager: | ||
| connection_name_parts = self.connection_name.split(".") | ||
| if len(connection_name_parts) != 3: | ||
| raise ValueError( | ||
| f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}." | ||
| ) | ||
| self._bq_connection_manager.create_bq_connection( | ||
| project_id=connection_name_parts[0], | ||
| location=connection_name_parts[1], | ||
| connection_id=connection_name_parts[2], | ||
| iam_role="aiplatform.user", | ||
| ) | ||
|
|
||
| if self.model_name not in _CLAUDE_3_ENDPOINTS: | ||
| raise ValueError( | ||
| f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}." | ||
| ) | ||
|
|
||
| options = { | ||
| "endpoint": self.model_name, | ||
| } | ||
|
|
||
| return self._bqml_model_factory.create_remote_model( | ||
| session=self.session, connection_name=self.connection_name, options=options | ||
| ) | ||
|
|
||
| @classmethod | ||
| def _from_bq( | ||
| cls, session: bigframes.Session, bq_model: bigquery.Model | ||
| ) -> Claude3TextGenerator: | ||
| assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" | ||
| assert "remoteModelInfo" in bq_model._properties | ||
| assert "endpoint" in bq_model._properties["remoteModelInfo"] | ||
| assert "connection" in bq_model._properties["remoteModelInfo"] | ||
|
|
||
| # Parse the remote model endpoint | ||
| bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] | ||
| model_connection = bq_model._properties["remoteModelInfo"]["connection"] | ||
| model_endpoint = bqml_endpoint.split("/")[-1] | ||
|
|
||
| kwargs = utils.retrieve_params_from_bq_model( | ||
| cls, bq_model, _BQML_PARAMS_MAPPING | ||
| ) | ||
|
|
||
| model = cls( | ||
| **kwargs, | ||
| session=session, | ||
| model_name=model_endpoint, | ||
| connection_name=model_connection, | ||
| ) | ||
| model._bqml_model = core.BqmlModel(session, bq_model) | ||
| return model | ||
|
|
||
| @property | ||
| def _bqml_options(self) -> dict: | ||
| """The model options as they will be set for BQML""" | ||
| options = { | ||
| "data_split_method": "NO_SPLIT", | ||
| } | ||
| return options | ||
|
|
||
| def predict( | ||
| self, | ||
| X: Union[bpd.DataFrame, bpd.Series], | ||
| *, | ||
| max_output_tokens: int = 128, | ||
| top_k: int = 40, | ||
| top_p: float = 0.95, | ||
| ) -> bpd.DataFrame: | ||
| """Predict the result from input DataFrame. | ||
|
|
||
| Args: | ||
| X (bigframes.dataframe.DataFrame or bigframes.series.Series): | ||
| Input DataFrame or Series, which contains only one column of prompts. | ||
| Prompts can include preamble, questions, suggestions, instructions, or examples. | ||
|
|
||
| max_output_tokens (int, default 128): | ||
| Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses. | ||
| A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words. | ||
| Default 128. Possible values are in the range [1, 4096]. | ||
|
|
||
| top_k (int, default 40): | ||
| Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens | ||
| in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature). | ||
| For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling. | ||
| Specify a lower value for less random responses and a higher value for more random responses. | ||
| Default 40. Possible values [1, 40]. | ||
|
|
||
| top_p (float, default 0.95):: | ||
| Top-p changes how the model selects tokens for output. Tokens are selected from most K (see topK parameter) probable to least until the sum of their probabilities equals the top-p value. | ||
| For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-p value is 0.5, then the model will select either A or B as the next token (using temperature) | ||
| and not consider C at all. | ||
| Specify a lower value for less random responses and a higher value for more random responses. | ||
| Default 0.95. Possible values [0.0, 1.0]. | ||
|
|
||
|
|
||
| Returns: | ||
| bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. | ||
| """ | ||
|
|
||
| # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models | ||
| if max_output_tokens not in range(1, 4097): | ||
| raise ValueError( | ||
| f"max_output_token must be [1, 4096], but is {max_output_tokens}." | ||
| ) | ||
|
|
||
| if top_k not in range(1, 41): | ||
| raise ValueError(f"top_k must be [1, 40], but is {top_k}.") | ||
|
|
||
| if top_p < 0.0 or top_p > 1.0: | ||
| raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.") | ||
|
|
||
| (X,) = utils.convert_to_dataframe(X) | ||
|
|
||
| if len(X.columns) != 1: | ||
| raise ValueError( | ||
| f"Only support one column as input. {constants.FEEDBACK_LINK}" | ||
| ) | ||
|
|
||
| # BQML identified the column by name | ||
| col_label = cast(blocks.Label, X.columns[0]) | ||
| X = X.rename(columns={col_label: "prompt"}) | ||
|
|
||
| options = { | ||
| "max_output_tokens": max_output_tokens, | ||
| "top_k": top_k, | ||
| "top_p": top_p, | ||
| "flatten_json_output": True, | ||
| } | ||
|
|
||
| df = self._bqml_model.generate_text(X, options) | ||
|
|
||
| if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): | ||
| warnings.warn( | ||
| f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", | ||
| RuntimeWarning, | ||
| ) | ||
|
|
||
| return df | ||
|
|
||
| def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator: | ||
| """Save the model to BigQuery. | ||
|
|
||
| Args: | ||
| model_name (str): | ||
| The name of the model. | ||
| replace (bool, default False): | ||
| Determine whether to replace if the model already exists. Default to False. | ||
|
|
||
| Returns: | ||
| Claude3TextGenerator: Saved model.""" | ||
|
|
||
| new_model = self._bqml_model.copy(model_name, replace) | ||
| return new_model.session.read_gbq_model(model_name) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -145,6 +145,16 @@ def session() -> Generator[bigframes.Session, None, None]: | |
| session.close() # close generated session at cleanup time | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def session_us_east5() -> Generator[bigframes.Session, None, None]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we use it anywhere?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. forgot to remove in this PR, but will use in PR of adding tests. |
||
| context = bigframes.BigQueryOptions( | ||
| location="us-east5", | ||
| ) | ||
| session = bigframes.Session(context=context) | ||
| yield session | ||
| session.close() # close generated session at cleanup time | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| def session_load() -> Generator[bigframes.Session, None, None]: | ||
| context = bigframes.BigQueryOptions(location="US", project="bigframes-load-testing") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add all BQML-supported models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Leaving tests to be added.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I thought you were going to add tests in this PR itself. Are you going to send another PR for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. need to setup the connection for other regions.