Skip to content

[WIP] Add support for Together.ai provider #956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[WIP] Add support for Together.ai provider
  • Loading branch information
nihit committed Mar 5, 2025
commit c16041fb6a4d2975201dfc5bf9ee1d985599d70c
2 changes: 2 additions & 0 deletions src/autolabel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@

logger = logging.getLogger(__name__)

from autolabel.models.anthropic import AnthropicLLM
from autolabel.models.cohere import CohereLLM
from autolabel.models.google import GoogleLLM
from autolabel.models.hf_pipeline import HFPipelineLLM
from autolabel.models.hf_pipeline_vision import HFPipelineMultimodal
from autolabel.models.mistral import MistralLLM
from autolabel.models.openai import OpenAILLM
from autolabel.models.openai_vision import OpenAIVisionLLM
from autolabel.models.vllm import VLLMModel
from autolabel.models.azure_openai import AzureOpenAILLM
from autolabel.models.together import TogetherLLM

Check failure on line 23 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (I001)

src/autolabel/models/__init__.py:13:1: I001 Import block is un-sorted or un-formatted

MODEL_REGISTRY = {
ModelProvider.OPENAI: OpenAILLM,
Expand All @@ -32,30 +33,31 @@
ModelProvider.GOOGLE: GoogleLLM,
ModelProvider.VLLM: VLLMModel,
ModelProvider.AZURE_OPENAI: AzureOpenAILLM,
ModelProvider.TOGETHER: TogetherLLM,
}


def register_model(name, model_cls):

Check failure on line 40 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ANN201)

src/autolabel/models/__init__.py:40:5: ANN201 Missing return type annotation for public function `register_model`

Check failure on line 40 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ANN001)

src/autolabel/models/__init__.py:40:20: ANN001 Missing type annotation for function argument `name`

Check failure on line 40 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (ANN001)

src/autolabel/models/__init__.py:40:26: ANN001 Missing type annotation for function argument `model_cls`
"""Register Model class"""
MODEL_REGISTRY[name] = model_cls


class ModelFactory:

"""The ModelFactory class is used to create a BaseModel object from the given AutoLabelConfig configuration."""

Check failure on line 47 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

src/autolabel/models/__init__.py:47:89: E501 Line too long (115 > 88)

@staticmethod
def from_config(

Check failure on line 50 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D417)

src/autolabel/models/__init__.py:50:9: D417 Missing argument description in the docstring for `from_config`: `tokenizer`
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: AutoTokenizer = None,
) -> BaseModel:
"""
Returns a BaseModel object configured with the settings found in the provided AutolabelConfig.

Check failure on line 56 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

src/autolabel/models/__init__.py:56:89: E501 Line too long (102 > 88)

Args:
config: AutolabelConfig object containing project settings
cache: cache allows for saving results in between labeling runs for future use

Check failure on line 60 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (E501)

src/autolabel/models/__init__.py:60:89: E501 Line too long (90 > 88)
Returns:
model: a fully configured BaseModel object

Expand All @@ -66,17 +68,17 @@
model_obj = model_cls(config=config, cache=cache, tokenizer=tokenizer)
# The below ensures that users should based off of the BaseModel
# when creating/registering custom models.
assert isinstance(

Check failure on line 71 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (S101)

src/autolabel/models/__init__.py:71:13: S101 Use of `assert` detected
model_obj,
BaseModel,
), f"{model_obj} should inherit from autolabel.models.BaseModel"
except KeyError as e:
# We should never get here as the config should have already
# been validated by the pydantic model.
logger.error(
f"{config.provider()} is not in the list of supported providers: \
{list(ModelProvider.__members__.keys())}",
)

Check failure on line 81 in src/autolabel/models/__init__.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (TRY400)

src/autolabel/models/__init__.py:78:13: TRY400 Use `logging.exception` instead of `logging.error`
raise e

return model_obj
136 changes: 136 additions & 0 deletions src/autolabel/models/together.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
import os
from time import time
from typing import Dict, List, Optional

from langchain.schema import Generation
from transformers import AutoTokenizer

from autolabel.cache import BaseCache
from autolabel.configs import AutolabelConfig
from autolabel.models import BaseModel
from autolabel.schema import ErrorType, LabelingError, RefuelLLMResult

logger = logging.getLogger(__name__)


class TogetherLLM(BaseModel):
DEFAULT_MODEL = "meta-llama/Llama-3.3-70B-Instruct-Turbo"
DEFAULT_PARAMS = {
"max_tokens": 1000,
"temperature": 0.0,
"request_timeout": 30,
"logprobs": 1
}

# Reference: https://www.together.ai/pricing
COST_PER_PROMPT_TOKEN = {
"meta-llama/Llama-3.3-70B-Instruct-Turbo": (0.88 / 1_000_000)
}
COST_PER_COMPLETION_TOKEN = {
"meta-llama/Llama-3.3-70B-Instruct-Turbo": (0.88 / 1_000_000)
}

def __init__(
self,
config: AutolabelConfig,
cache: BaseCache = None,
tokenizer: Optional[AutoTokenizer] = None,
) -> None:
super().__init__(config, cache, tokenizer)

try:
from together import Together
except ImportError:
raise ImportError(
"together is required to use the Together.ai LLM. Please install it with the following command: pip install 'refuel-autolabel[together]'"
)

if os.getenv("TOGETHER_API_KEY") is None:
raise ValueError("TOGETHER_API_KEY environment variable not set")

# populate model name
self.model_name = config.model_name() or self.DEFAULT_MODEL
# populate model params
model_params = config.model_params()
self.model_params = {**self.DEFAULT_PARAMS, **model_params}

self.client = Together()

def _label(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult:
generations = []
errors = []
latencies = []

for prompt in prompts:
try:
start_time = time()

# Format the prompt based on the model
messages = [{"role": "user", "content": prompt}]

# Call the Together.ai API
response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
**self.model_params,
)

end_time = time()

# Extract the generated text
generated_text = response.choices[0].message.content

# TODO: logprobs

generations.append([Generation(text=generated_text)])
errors.append(None)
latencies.append(end_time - start_time)
except Exception as e:
logger.error(f"Error generating label: {e}")
generations.append([Generation(text="")])
errors.append(
LabelingError(
error_type=ErrorType.LLM_PROVIDER_ERROR,
error_message=str(e),
)
)
latencies.append(0)

return RefuelLLMResult(
generations=generations,
errors=errors,
latencies=latencies,
)

async def _alabel(self, prompts: List[str], output_schema: Dict) -> RefuelLLMResult:
# For now, we'll use the synchronous implementation
return self._label(prompts, output_schema)

def get_cost(self, prompt: str, label: Optional[str] = "") -> float:
num_prompt_tokens = self.get_num_tokens(prompt)

if label:
num_completion_tokens = self.get_num_tokens(label)
else:
num_completion_tokens = self.model_params["max_tokens"]

# Get the cost per token for this model
cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN.get(
self.model_name
)
cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN.get(
self.model_name
)

return (num_prompt_tokens * cost_per_prompt_token) + (
num_completion_tokens * cost_per_completion_token
)

def returns_token_probs(self) -> bool:
return False

def get_num_tokens(self, prompt: str) -> int:
if not prompt:
return 0
return len(prompt) // 4
1 change: 1 addition & 0 deletions src/autolabel/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class ModelProvider(str, Enum):
TGI = "tgi"
VLLM = "vllm"
AZURE_OPENAI = "azure_openai"
TOGETHER = "together"


class TaskType(str, Enum):
Expand Down