Skip to content

Commit

Permalink
Add count_tokens function (#93)
Browse files Browse the repository at this point in the history
* Add count_tokens function

* Add support to sabia-2 family

* Fix medium/small tokenizers

* Update readme

* Update README.md

Co-authored-by: Rodrigo Nogueira <[email protected]>

---------

Co-authored-by: Rodrigo Nogueira <[email protected]>
  • Loading branch information
hugoabonizio and rodrigo-f-nogueira committed Jul 30, 2024
1 parent 8ae3189 commit 3841dc9
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 9 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,21 +209,19 @@ Para tarefas com apenas uma resposta correta, como no exemplo acima, é recomend
Para tarefas de geração de textos diversos ou longos, é recomendado usar `do_sample=True` e `temperature=0.7`. Quanto maior a temperatura, mais diversos serão os textos gerados, mas há maior chance de o modelo "alucinar" e gerar textos sem sentido. Quanto menor a temperatura, a resposta é mais conservadora, mas corre o risco de gerar textos repetidos.

## Como saber o número de tokens que serão cobrados?
Para saber de antemão o quanto suas requisições irão custar, use os tokenizadores dos modelos MariTalk, disponíveis na HuggingFace, para saber o número de tokens em um dado prompt.
Para saber de antemão o quanto suas requisições irão custar, use a função `count_tokens` para saber o número de tokens em um dado prompt.

Exemplo de uso:
```python
import transformers
tokenizer = transformers.AutoTokenizer.from_pretrained("maritaca-ai/sabia-2-tokenizer-medium")
from maritalk import count_tokens

prompt = "Com quantos paus se faz uma canoa?"

tokens = tokenizer.encode(prompt)
total_tokens = count_tokens(prompt, model="sabia-3")

print(f'O prompt "{prompt}" contém {len(tokens)} tokens.')
print(f'O prompt "{prompt}" contém {total_tokens} tokens.')
```

Note que os tokenizadores da Sabiá-2 Small e Medium são diferentes.

# Web Chat

Expand Down
3 changes: 2 additions & 1 deletion maritalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@

from .resources.api import MariTalk
from .resources.local import MariTalkLocal
from .tokenizer import count_tokens

__all__ = ["MariTalk", "MariTalkLocal"]
__all__ = ["MariTalk", "MariTalkLocal", "count_tokens"]
98 changes: 98 additions & 0 deletions maritalk/tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import base64
import tiktoken
from typing import Union, List
from transformers import AutoTokenizer, PreTrainedTokenizerFast
from .tokenizer_model import data as tokenizer_data

_encoder = None
_tokenizer_small = None
_tokenizer_medium = None


def _get_encoder() -> tiktoken.Encoding:
global _encoder
if _encoder is None:
mergeable_ranks = {
base64.b64decode(token): int(rank)
for token, rank in (
line.split() for line in tokenizer_data.splitlines() if line
)
}
_encoder = tiktoken.Encoding(
"sabia-3",
pat_str=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
mergeable_ranks=mergeable_ranks,
special_tokens={},
)
return _encoder


def _get_tokenizer(version: str = "medium") -> PreTrainedTokenizerFast:
global _tokenizer_small, _tokenizer_medium
if version == "small":
if _tokenizer_small is None:
_tokenizer_small = AutoTokenizer.from_pretrained(
"maritaca-ai/sabia-2-tokenizer-small"
)
return _tokenizer_small
elif version == "medium":
if _tokenizer_medium is None:
_tokenizer_medium = AutoTokenizer.from_pretrained(
"maritaca-ai/sabia-2-tokenizer-medium"
)
return _tokenizer_medium
else:
raise ValueError("Version must be 'small' or 'medium'")


def count_tokens(
text: Union[str, List[str]],
model: str = "sabia-3",
) -> Union[int, List[int]]:
"""
Counts the number of tokens in the given string or list of strings.
Args:
text (Union[str, List[str]]): The input text or a list of texts to be tokenized.
Returns:
Union[int, List[int]]: The number of tokens in the input text if a single string is provided,
or a list of token counts for each string in the input list if a list of strings is provided.
Examples:
>>> count_tokens("Olá, mundo!")
5
>>> count_tokens(["Olá, mundo!", "Como vai você?"])
[5, 4]
>>> count_tokens(["Olá, mundo!", "Como vai você?"], model='sabia-2-small')
[6, 7]
"""

if model.startswith("sabia-3"):
encoder = _get_encoder()
encode = encoder.encode
encode_batch = encoder.encode_batch
elif model.startswith("sabia-2-small"):
tokenizer = _get_tokenizer("small")
encode = tokenizer.encode
encode_batch = lambda texts: tokenizer(texts)["input_ids"]
elif model.startswith("sabia-2-medium"):
tokenizer = _get_tokenizer("medium")
encode = tokenizer.encode
encode_batch = lambda texts: tokenizer(texts)["input_ids"]
else:
raise ValueError(
"Model must be one of the following: sabia-3, sabia-2-medium, sabia-2-small"
)

if isinstance(text, str):
return len(encode(text))
elif isinstance(text, list):
return [len(ids) for ids in encode_batch(text)]
else:
raise TypeError("Input must be either a string or a list of strings")


__all__ = ["count_tokens"]
4 changes: 4 additions & 0 deletions maritalk/tokenizer_model.py

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ build-backend = "setuptools.build_meta"

[project]
name = "maritalk"
version = "0.2.5"
version = "0.2.6"
authors = [
{ name="Maritaca AI", email="info@maritaca.ai" },
{ name="Maritaca AI", email="suporte@maritaca.ai" },
]
description = "Client library for the MariTalk API"
readme = "README.md"
Expand All @@ -21,6 +21,8 @@ dependencies = [
"requests",
"tqdm",
"httpx",
"tiktoken>=0.7,<0.8",
"transformers"
]

[tool.setuptools.packages]
Expand Down

0 comments on commit 3841dc9

Please sign in to comment.