diff --git a/bigframes/ml/llm.py b/bigframes/ml/llm.py index 45634423c6..35bcf0a33c 100644 --- a/bigframes/ml/llm.py +++ b/bigframes/ml/llm.py @@ -61,6 +61,17 @@ _GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT, ) +_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet" +_CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku" +_CLAUDE_3_5_SONNET_ENDPOINT = "claude-3-5-sonnet" +_CLAUDE_3_OPUS_ENDPOINT = "claude-3-opus" +_CLAUDE_3_ENDPOINTS = ( + _CLAUDE_3_SONNET_ENDPOINT, + _CLAUDE_3_HAIKU_ENDPOINT, + _CLAUDE_3_5_SONNET_ENDPOINT, + _CLAUDE_3_OPUS_ENDPOINT, +) + _ML_GENERATE_TEXT_STATUS = "ml_generate_text_status" _ML_EMBED_TEXT_STATUS = "ml_embed_text_status" @@ -1020,3 +1031,225 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator: new_model = self._bqml_model.copy(model_name, replace) return new_model.session.read_gbq_model(model_name) + + +@log_adapter.class_logger +class Claude3TextGenerator(base.BaseEstimator): + """Claude3 text generator LLM model. + + Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models. + https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#grant-permissions + + .. note:: + + This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the + Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is" + and might have limited support. For more information, see the launch stage descriptions + (https://cloud.google.com/products#product-launch-stages). + + + .. note:: + + The models only availabe in specific regions. Check https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions for details. + + Args: + model_name (str, Default to "claude-3-sonnet"): + The model for natural language tasks. Possible values are "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet" and "claude-3-opus". + "claude-3-sonnet" is Anthropic's dependable combination of skills and speed. It is engineered to be dependable for scaled AI deployments across a variety of use cases. + "claude-3-haiku" is Anthropic's fastest, most compact vision and text model for near-instant responses to simple queries, meant for seamless AI experiences mimicking human interactions. + "claude-3-5-sonnet" is Anthropic's most powerful AI model and maintains the speed and cost of Claude 3 Sonnet, which is a mid-tier model. + "claude-3-opus" is Anthropic's second-most powerful AI model, with strong performance on highly complex tasks. + https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#available-claude-models + Default to "claude-3-sonnet". + session (bigframes.Session or None): + BQ session to create the model. If None, use the global default session. + connection_name (str or None): + Connection to connect with remote service. str of the format ... + If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach + permission if the connection isn't fully set up. + """ + + def __init__( + self, + *, + model_name: Literal[ + "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus" + ] = "claude-3-sonnet", + session: Optional[bigframes.Session] = None, + connection_name: Optional[str] = None, + ): + self.model_name = model_name + self.session = session or bpd.get_global_session() + self._bq_connection_manager = self.session.bqconnectionmanager + + connection_name = connection_name or self.session._bq_connection + self.connection_name = clients.resolve_full_bq_connection_name( + connection_name, + default_project=self.session._project, + default_location=self.session._location, + ) + + self._bqml_model_factory = globals.bqml_model_factory() + self._bqml_model: core.BqmlModel = self._create_bqml_model() + + def _create_bqml_model(self): + # Parse and create connection if needed. + if not self.connection_name: + raise ValueError( + "Must provide connection_name, either in constructor or through session options." + ) + + if self._bq_connection_manager: + connection_name_parts = self.connection_name.split(".") + if len(connection_name_parts) != 3: + raise ValueError( + f"connection_name must be of the format .., got {self.connection_name}." + ) + self._bq_connection_manager.create_bq_connection( + project_id=connection_name_parts[0], + location=connection_name_parts[1], + connection_id=connection_name_parts[2], + iam_role="aiplatform.user", + ) + + if self.model_name not in _CLAUDE_3_ENDPOINTS: + raise ValueError( + f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}." + ) + + options = { + "endpoint": self.model_name, + } + + return self._bqml_model_factory.create_remote_model( + session=self.session, connection_name=self.connection_name, options=options + ) + + @classmethod + def _from_bq( + cls, session: bigframes.Session, bq_model: bigquery.Model + ) -> Claude3TextGenerator: + assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED" + assert "remoteModelInfo" in bq_model._properties + assert "endpoint" in bq_model._properties["remoteModelInfo"] + assert "connection" in bq_model._properties["remoteModelInfo"] + + # Parse the remote model endpoint + bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"] + model_connection = bq_model._properties["remoteModelInfo"]["connection"] + model_endpoint = bqml_endpoint.split("/")[-1] + + kwargs = utils.retrieve_params_from_bq_model( + cls, bq_model, _BQML_PARAMS_MAPPING + ) + + model = cls( + **kwargs, + session=session, + model_name=model_endpoint, + connection_name=model_connection, + ) + model._bqml_model = core.BqmlModel(session, bq_model) + return model + + @property + def _bqml_options(self) -> dict: + """The model options as they will be set for BQML""" + options = { + "data_split_method": "NO_SPLIT", + } + return options + + def predict( + self, + X: Union[bpd.DataFrame, bpd.Series], + *, + max_output_tokens: int = 128, + top_k: int = 40, + top_p: float = 0.95, + ) -> bpd.DataFrame: + """Predict the result from input DataFrame. + + Args: + X (bigframes.dataframe.DataFrame or bigframes.series.Series): + Input DataFrame or Series, which contains only one column of prompts. + Prompts can include preamble, questions, suggestions, instructions, or examples. + + max_output_tokens (int, default 128): + Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses. + A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words. + Default 128. Possible values are in the range [1, 4096]. + + top_k (int, default 40): + Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens + in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature). + For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling. + Specify a lower value for less random responses and a higher value for more random responses. + Default 40. Possible values [1, 40]. + + top_p (float, default 0.95):: + Top-p changes how the model selects tokens for output. Tokens are selected from most K (see topK parameter) probable to least until the sum of their probabilities equals the top-p value. + For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-p value is 0.5, then the model will select either A or B as the next token (using temperature) + and not consider C at all. + Specify a lower value for less random responses and a higher value for more random responses. + Default 0.95. Possible values [0.0, 1.0]. + + + Returns: + bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values. + """ + + # Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models + if max_output_tokens not in range(1, 4097): + raise ValueError( + f"max_output_token must be [1, 4096], but is {max_output_tokens}." + ) + + if top_k not in range(1, 41): + raise ValueError(f"top_k must be [1, 40], but is {top_k}.") + + if top_p < 0.0 or top_p > 1.0: + raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.") + + (X,) = utils.convert_to_dataframe(X) + + if len(X.columns) != 1: + raise ValueError( + f"Only support one column as input. {constants.FEEDBACK_LINK}" + ) + + # BQML identified the column by name + col_label = cast(blocks.Label, X.columns[0]) + X = X.rename(columns={col_label: "prompt"}) + + options = { + "max_output_tokens": max_output_tokens, + "top_k": top_k, + "top_p": top_p, + "flatten_json_output": True, + } + + df = self._bqml_model.generate_text(X, options) + + if (df[_ML_GENERATE_TEXT_STATUS] != "").any(): + warnings.warn( + f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.", + RuntimeWarning, + ) + + return df + + def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator: + """Save the model to BigQuery. + + Args: + model_name (str): + The name of the model. + replace (bool, default False): + Determine whether to replace if the model already exists. Default to False. + + Returns: + Claude3TextGenerator: Saved model.""" + + new_model = self._bqml_model.copy(model_name, replace) + return new_model.session.read_gbq_model(model_name) diff --git a/bigframes/ml/loader.py b/bigframes/ml/loader.py index bd01342152..7d75f4c65a 100644 --- a/bigframes/ml/loader.py +++ b/bigframes/ml/loader.py @@ -63,6 +63,10 @@ llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator, + llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator, + llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator, + llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator, + llm._CLAUDE_3_OPUS_ENDPOINT: llm.Claude3TextGenerator, llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator, llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator, } @@ -86,6 +90,7 @@ def from_bq( imported.XGBoostModel, llm.PaLM2TextGenerator, llm.PaLM2TextEmbeddingGenerator, + llm.Claude3TextGenerator, llm.TextEmbeddingGenerator, pipeline.Pipeline, compose.ColumnTransformer, diff --git a/docs/templates/toc.yml b/docs/templates/toc.yml index 736ffba286..bab4ad9aac 100644 --- a/docs/templates/toc.yml +++ b/docs/templates/toc.yml @@ -157,6 +157,8 @@ uid: bigframes.ml.llm.PaLM2TextGenerator - name: PaLM2TextEmbeddingGenerator uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator + - name: Claude3TextGenerator + uid: bigframes.ml.llm.Claude3TextGenerator name: llm - items: - name: metrics diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 83c8baac39..05ff80dc33 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -145,6 +145,16 @@ def session() -> Generator[bigframes.Session, None, None]: session.close() # close generated session at cleanup time +@pytest.fixture(scope="session") +def session_us_east5() -> Generator[bigframes.Session, None, None]: + context = bigframes.BigQueryOptions( + location="us-east5", + ) + session = bigframes.Session(context=context) + yield session + session.close() # close generated session at cleanup time + + @pytest.fixture(scope="session") def session_load() -> Generator[bigframes.Session, None, None]: context = bigframes.BigQueryOptions(location="US", project="bigframes-load-testing") diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index c2f62096d0..1647eb879f 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -18,7 +18,7 @@ from tests.system import utils -def test_create_text_generator_model( +def test_create_load_text_generator_model( palm2_text_generator_model, dataset_id, bq_connection ): # Model creation doesn't return error @@ -34,7 +34,7 @@ def test_create_text_generator_model( assert reloaded_model.connection_name == bq_connection -def test_create_text_generator_32k_model( +def test_create_load_text_generator_32k_model( palm2_text_generator_32k_model, dataset_id, bq_connection ): # Model creation doesn't return error @@ -405,6 +405,67 @@ def test_gemini_text_generator_predict_with_params_success( assert all(series.str.len() > 20) +# TODO(garrettwu): add tests for claude3.5 sonnet and claude3 opus as they are only available in other regions. +@pytest.mark.parametrize( + "model_name", + ("claude-3-sonnet", "claude-3-haiku"), +) +def test_claude3_text_generator_create_load( + dataset_id, model_name, session, bq_connection +): + claude3_text_generator_model = llm.Claude3TextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + assert claude3_text_generator_model is not None + assert claude3_text_generator_model._bqml_model is not None + + # save, load to ensure configuration was kept + reloaded_model = claude3_text_generator_model.to_gbq( + f"{dataset_id}.temp_text_model", replace=True + ) + assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name + assert reloaded_model.connection_name == bq_connection + assert reloaded_model.model_name == model_name + + +@pytest.mark.parametrize( + "model_name", + ("claude-3-sonnet", "claude-3-haiku"), +) +@pytest.mark.flaky(retries=2) +def test_claude3_text_generator_predict_default_params_success( + llm_text_df, model_name, session, bq_connection +): + claude3_text_generator_model = llm.Claude3TextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + df = claude3_text_generator_model.predict(llm_text_df).to_pandas() + assert df.shape == (3, 3) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) + + +@pytest.mark.parametrize( + "model_name", + ("claude-3-sonnet", "claude-3-haiku"), +) +@pytest.mark.flaky(retries=2) +def test_claude3_text_generator_predict_with_params_success( + llm_text_df, model_name, session, bq_connection +): + claude3_text_generator_model = llm.Claude3TextGenerator( + model_name=model_name, connection_name=bq_connection, session=session + ) + df = claude3_text_generator_model.predict( + llm_text_df, max_output_tokens=100, top_k=20, top_p=0.5 + ).to_pandas() + assert df.shape == (3, 3) + assert "ml_generate_text_llm_result" in df.columns + series = df["ml_generate_text_llm_result"] + assert all(series.str.len() > 20) + + @pytest.mark.flaky(retries=2) def test_llm_palm_score(llm_fine_tune_df_default_index): model = llm.PaLM2TextGenerator(model_name="text-bison")