Skip to content
This repository was archived by the owner on May 7, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 233 additions & 0 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@
_GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT,
)

_CLAUDE_3_SONNET_ENDPOINT = "claude-3-sonnet"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add all BQML-supported models?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Leaving tests to be added.

@shobsi shobsi Aug 20, 2024

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I thought you were going to add tests in this PR itself. Are you going to send another PR for this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. need to setup the connection for other regions.

_CLAUDE_3_HAIKU_ENDPOINT = "claude-3-haiku"
_CLAUDE_3_5_SONNET_ENDPOINT = "claude-3-5-sonnet"
_CLAUDE_3_OPUS_ENDPOINT = "claude-3-opus"
_CLAUDE_3_ENDPOINTS = (
_CLAUDE_3_SONNET_ENDPOINT,
_CLAUDE_3_HAIKU_ENDPOINT,
_CLAUDE_3_5_SONNET_ENDPOINT,
_CLAUDE_3_OPUS_ENDPOINT,
)


_ML_GENERATE_TEXT_STATUS = "ml_generate_text_status"
_ML_EMBED_TEXT_STATUS = "ml_embed_text_status"
Expand Down Expand Up @@ -1020,3 +1031,225 @@ def to_gbq(self, model_name: str, replace: bool = False) -> GeminiTextGenerator:

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)


@log_adapter.class_logger
class Claude3TextGenerator(base.BaseEstimator):
"""Claude3 text generator LLM model.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like "Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role" is an additional requirement https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#set-permissions, we should document this in the class docstring and after the release in the reference docs https://cloud.google.com/bigquery/docs/use-bigquery-dataframes#remote-models

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.


Go to Google Cloud Console -> Vertex AI -> Model Garden page to enabe the models before use. Must have the Consumer Procurement Entitlement Manager Identity and Access Management (IAM) role to enable the models.
https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-partner-models#grant-permissions

.. note::

This product or feature is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
Service Specific Terms(https://cloud.google.com/terms/service-terms#1). Pre-GA products and features are available "as is"
and might have limited support. For more information, see the launch stage descriptions
(https://cloud.google.com/products#product-launch-stages).


.. note::

The models only availabe in specific regions. Check https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#regions for details.

Args:
model_name (str, Default to "claude-3-sonnet"):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] "Defaults to ..."?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use "Default to" in all the APIs... Keeping it for now.

The model for natural language tasks. Possible values are "claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet" and "claude-3-opus".
"claude-3-sonnet" is Anthropic's dependable combination of skills and speed. It is engineered to be dependable for scaled AI deployments across a variety of use cases.
"claude-3-haiku" is Anthropic's fastest, most compact vision and text model for near-instant responses to simple queries, meant for seamless AI experiences mimicking human interactions.
"claude-3-5-sonnet" is Anthropic's most powerful AI model and maintains the speed and cost of Claude 3 Sonnet, which is a mid-tier model.
"claude-3-opus" is Anthropic's second-most powerful AI model, with strong performance on highly complex tasks.
https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#available-claude-models
Default to "claude-3-sonnet".
session (bigframes.Session or None):
BQ session to create the model. If None, use the global default session.
connection_name (str or None):
Connection to connect with remote service. str of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>.
If None, use default connection in session context. BigQuery DataFrame will try to create the connection and attach
permission if the connection isn't fully set up.
"""

def __init__(
self,
*,
model_name: Literal[
"claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"
] = "claude-3-sonnet",
session: Optional[bigframes.Session] = None,
connection_name: Optional[str] = None,
):
self.model_name = model_name
self.session = session or bpd.get_global_session()
self._bq_connection_manager = self.session.bqconnectionmanager

connection_name = connection_name or self.session._bq_connection
self.connection_name = clients.resolve_full_bq_connection_name(
connection_name,
default_project=self.session._project,
default_location=self.session._location,
)

self._bqml_model_factory = globals.bqml_model_factory()
self._bqml_model: core.BqmlModel = self._create_bqml_model()

def _create_bqml_model(self):
# Parse and create connection if needed.
if not self.connection_name:
raise ValueError(
"Must provide connection_name, either in constructor or through session options."
)

if self._bq_connection_manager:
connection_name_parts = self.connection_name.split(".")
if len(connection_name_parts) != 3:
raise ValueError(
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
)
self._bq_connection_manager.create_bq_connection(
project_id=connection_name_parts[0],
location=connection_name_parts[1],
connection_id=connection_name_parts[2],
iam_role="aiplatform.user",
)

if self.model_name not in _CLAUDE_3_ENDPOINTS:
raise ValueError(
f"Model name {self.model_name} is not supported. We only support {', '.join(_CLAUDE_3_ENDPOINTS)}."
)

options = {
"endpoint": self.model_name,
}

return self._bqml_model_factory.create_remote_model(
session=self.session, connection_name=self.connection_name, options=options
)

@classmethod
def _from_bq(
cls, session: bigframes.Session, bq_model: bigquery.Model
) -> Claude3TextGenerator:
assert bq_model.model_type == "MODEL_TYPE_UNSPECIFIED"
assert "remoteModelInfo" in bq_model._properties
assert "endpoint" in bq_model._properties["remoteModelInfo"]
assert "connection" in bq_model._properties["remoteModelInfo"]

# Parse the remote model endpoint
bqml_endpoint = bq_model._properties["remoteModelInfo"]["endpoint"]
model_connection = bq_model._properties["remoteModelInfo"]["connection"]
model_endpoint = bqml_endpoint.split("/")[-1]

kwargs = utils.retrieve_params_from_bq_model(
cls, bq_model, _BQML_PARAMS_MAPPING
)

model = cls(
**kwargs,
session=session,
model_name=model_endpoint,
connection_name=model_connection,
)
model._bqml_model = core.BqmlModel(session, bq_model)
return model

@property
def _bqml_options(self) -> dict:
"""The model options as they will be set for BQML"""
options = {
"data_split_method": "NO_SPLIT",
}
return options

def predict(
self,
X: Union[bpd.DataFrame, bpd.Series],
*,
max_output_tokens: int = 128,
top_k: int = 40,
top_p: float = 0.95,
) -> bpd.DataFrame:
"""Predict the result from input DataFrame.

Args:
X (bigframes.dataframe.DataFrame or bigframes.series.Series):
Input DataFrame or Series, which contains only one column of prompts.
Prompts can include preamble, questions, suggestions, instructions, or examples.

max_output_tokens (int, default 128):
Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses.
A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.
Default 128. Possible values are in the range [1, 4096].

top_k (int, default 40):
Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens
in the model's vocabulary (also called greedy decoding), while a top-k of 3 means that the next token is selected from among the 3 most probable tokens (using temperature).
For each token selection step, the top K tokens with the highest probabilities are sampled. Then tokens are further filtered based on topP with the final token selected using temperature sampling.
Specify a lower value for less random responses and a higher value for more random responses.
Default 40. Possible values [1, 40].

top_p (float, default 0.95)::
Top-p changes how the model selects tokens for output. Tokens are selected from most K (see topK parameter) probable to least until the sum of their probabilities equals the top-p value.
For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-p value is 0.5, then the model will select either A or B as the next token (using temperature)
and not consider C at all.
Specify a lower value for less random responses and a higher value for more random responses.
Default 0.95. Possible values [0.0, 1.0].


Returns:
bigframes.dataframe.DataFrame: DataFrame of shape (n_samples, n_input_columns + n_prediction_columns). Returns predicted values.
"""

# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
if max_output_tokens not in range(1, 4097):
raise ValueError(
f"max_output_token must be [1, 4096], but is {max_output_tokens}."
)

if top_k not in range(1, 41):
raise ValueError(f"top_k must be [1, 40], but is {top_k}.")

if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")

(X,) = utils.convert_to_dataframe(X)

if len(X.columns) != 1:
raise ValueError(
f"Only support one column as input. {constants.FEEDBACK_LINK}"
)

# BQML identified the column by name
col_label = cast(blocks.Label, X.columns[0])
X = X.rename(columns={col_label: "prompt"})

options = {
"max_output_tokens": max_output_tokens,
"top_k": top_k,
"top_p": top_p,
"flatten_json_output": True,
}

df = self._bqml_model.generate_text(X, options)

if (df[_ML_GENERATE_TEXT_STATUS] != "").any():
warnings.warn(
f"Some predictions failed. Check column {_ML_GENERATE_TEXT_STATUS} for detailed status. You may want to filter the failed rows and retry.",
RuntimeWarning,
)

return df

def to_gbq(self, model_name: str, replace: bool = False) -> Claude3TextGenerator:
"""Save the model to BigQuery.

Args:
model_name (str):
The name of the model.
replace (bool, default False):
Determine whether to replace if the model already exists. Default to False.

Returns:
Claude3TextGenerator: Saved model."""

new_model = self._bqml_model.copy(model_name, replace)
return new_model.session.read_gbq_model(model_name)
5 changes: 5 additions & 0 deletions bigframes/ml/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
llm._GEMINI_PRO_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
llm._GEMINI_1P5_PRO_FLASH_PREVIEW_ENDPOINT: llm.GeminiTextGenerator,
llm._CLAUDE_3_HAIKU_ENDPOINT: llm.Claude3TextGenerator,
llm._CLAUDE_3_SONNET_ENDPOINT: llm.Claude3TextGenerator,
llm._CLAUDE_3_5_SONNET_ENDPOINT: llm.Claude3TextGenerator,
llm._CLAUDE_3_OPUS_ENDPOINT: llm.Claude3TextGenerator,
llm._TEXT_EMBEDDING_004_ENDPOINT: llm.TextEmbeddingGenerator,
llm._TEXT_MULTILINGUAL_EMBEDDING_002_ENDPOINT: llm.TextEmbeddingGenerator,
}
Expand All @@ -86,6 +90,7 @@ def from_bq(
imported.XGBoostModel,
llm.PaLM2TextGenerator,
llm.PaLM2TextEmbeddingGenerator,
llm.Claude3TextGenerator,
llm.TextEmbeddingGenerator,
pipeline.Pipeline,
compose.ColumnTransformer,
Expand Down
2 changes: 2 additions & 0 deletions docs/templates/toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
uid: bigframes.ml.llm.PaLM2TextGenerator
- name: PaLM2TextEmbeddingGenerator
uid: bigframes.ml.llm.PaLM2TextEmbeddingGenerator
- name: Claude3TextGenerator
uid: bigframes.ml.llm.Claude3TextGenerator
name: llm
- items:
- name: metrics
Expand Down
10 changes: 10 additions & 0 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,16 @@ def session() -> Generator[bigframes.Session, None, None]:
session.close() # close generated session at cleanup time


@pytest.fixture(scope="session")
def session_us_east5() -> Generator[bigframes.Session, None, None]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we use it anywhere?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to remove in this PR, but will use in PR of adding tests.

context = bigframes.BigQueryOptions(
location="us-east5",
)
session = bigframes.Session(context=context)
yield session
session.close() # close generated session at cleanup time


@pytest.fixture(scope="session")
def session_load() -> Generator[bigframes.Session, None, None]:
context = bigframes.BigQueryOptions(location="US", project="bigframes-load-testing")
Expand Down
65 changes: 63 additions & 2 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from tests.system import utils


def test_create_text_generator_model(
def test_create_load_text_generator_model(
palm2_text_generator_model, dataset_id, bq_connection
):
# Model creation doesn't return error
Expand All @@ -34,7 +34,7 @@ def test_create_text_generator_model(
assert reloaded_model.connection_name == bq_connection


def test_create_text_generator_32k_model(
def test_create_load_text_generator_32k_model(
palm2_text_generator_32k_model, dataset_id, bq_connection
):
# Model creation doesn't return error
Expand Down Expand Up @@ -405,6 +405,67 @@ def test_gemini_text_generator_predict_with_params_success(
assert all(series.str.len() > 20)


# TODO(garrettwu): add tests for claude3.5 sonnet and claude3 opus as they are only available in other regions.
@pytest.mark.parametrize(
"model_name",
("claude-3-sonnet", "claude-3-haiku"),
)
def test_claude3_text_generator_create_load(
dataset_id, model_name, session, bq_connection
):
claude3_text_generator_model = llm.Claude3TextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
assert claude3_text_generator_model is not None
assert claude3_text_generator_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = claude3_text_generator_model.to_gbq(
f"{dataset_id}.temp_text_model", replace=True
)
assert f"{dataset_id}.temp_text_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.connection_name == bq_connection
assert reloaded_model.model_name == model_name


@pytest.mark.parametrize(
"model_name",
("claude-3-sonnet", "claude-3-haiku"),
)
@pytest.mark.flaky(retries=2)
def test_claude3_text_generator_predict_default_params_success(
llm_text_df, model_name, session, bq_connection
):
claude3_text_generator_model = llm.Claude3TextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
df = claude3_text_generator_model.predict(llm_text_df).to_pandas()
assert df.shape == (3, 3)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)


@pytest.mark.parametrize(
"model_name",
("claude-3-sonnet", "claude-3-haiku"),
)
@pytest.mark.flaky(retries=2)
def test_claude3_text_generator_predict_with_params_success(
llm_text_df, model_name, session, bq_connection
):
claude3_text_generator_model = llm.Claude3TextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
df = claude3_text_generator_model.predict(
llm_text_df, max_output_tokens=100, top_k=20, top_p=0.5
).to_pandas()
assert df.shape == (3, 3)
assert "ml_generate_text_llm_result" in df.columns
series = df["ml_generate_text_llm_result"]
assert all(series.str.len() > 20)


@pytest.mark.flaky(retries=2)
def test_llm_palm_score(llm_fine_tune_df_default_index):
model = llm.PaLM2TextGenerator(model_name="text-bison")
Expand Down