Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Logprobs in MLC Batch Serving #82

Conversation

zxybazh
Copy link

@zxybazh zxybazh commented Nov 22, 2023

This PR enables logprobs option in mlc server following vllm's example and openai's api.

Example query:

{
  "model": "codellama-7b-fp16",
  "messages": [
    {
      "role": "user",
      "content": "Implement merge sort in python"
    }
  ],
  "logprobs": true,
  "top_logprobs": 3,
  "stream": false,
  "stop": "\n",
  "max_tokens": 10,
  "temperature": 0
}

Example response:

{
  "id": "cmpl-c920b4ef876542369e700e15d6e9d88e",
  "object": "chat.completion",
  "created": 1706194000,
  "model": "codellama-7b-fp16",
  "choices": [
    {
      "index": 0,
      "message": {
        "role": "assistant",
        "content": "  Sure! Here is an implementation of the merge"
      },
      "logprobs": {
        "content": [
          {
            "token": "",
            "logprob": -3.933898824470816e-06,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "",
                "logprob": -3.933898824470816e-06,
                "bytes": null
              },
              {
                "token": "▁Sure",
                "logprob": -13.156253814697266,
                "bytes": null
              },
              {
                "token": "▁▁",
                "logprob": -13.171878814697266,
                "bytes": null
              }
            ]
          },
          {
            "token": "Sure",
            "logprob": -0.046100322157144547,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁Sure",
                "logprob": -0.046100322157144547,
                "bytes": null
              },
              {
                "token": "▁Here",
                "logprob": -3.171100378036499,
                "bytes": null
              },
              {
                "token": "▁Mer",
                "logprob": -6.45235013961792,
                "bytes": null
              }
            ]
          },
          {
            "token": "!",
            "logprob": -0.15342634916305542,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "!",
                "logprob": -0.15342634916305542,
                "bytes": null
              },
              {
                "token": ",",
                "logprob": -1.9503014087677002,
                "bytes": null
              },
              {
                "token": "?",
                "logprob": -13.551863670349121,
                "bytes": null
              }
            ]
          },
          {
            "token": "Here",
            "logprob": -0.00019905969384126365,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁Here",
                "logprob": -0.00019905969384126365,
                "bytes": null
              },
              {
                "token": "▁Mer",
                "logprob": -9.062699317932129,
                "bytes": null
              },
              {
                "token": "Here",
                "logprob": -10.062699317932129,
                "bytes": null
              }
            ]
          },
          {
            "token": "is",
            "logprob": -0.0050763762556016445,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁is",
                "logprob": -0.0050763762556016445,
                "bytes": null
              },
              {
                "token": "'",
                "logprob": -5.2863264083862305,
                "bytes": null
              },
              {
                "token": "▁are",
                "logprob": -13.06757640838623,
                "bytes": null
              }
            ]
          },
          {
            "token": "an",
            "logprob": -0.03826115280389786,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁an",
                "logprob": -0.03826115280389786,
                "bytes": null
              },
              {
                "token": "▁a",
                "logprob": -3.3038861751556396,
                "bytes": null
              },
              {
                "token": "▁the",
                "logprob": -7.3507609367370605,
                "bytes": null
              }
            ]
          },
          {
            "token": "implementation",
            "logprob": -0.5718472003936768,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁implementation",
                "logprob": -0.5718472003936768,
                "bytes": null
              },
              {
                "token": "▁example",
                "logprob": -0.8374722003936768,
                "bytes": null
              },
              {
                "token": "▁outline",
                "logprob": -7.298409461975098,
                "bytes": null
              }
            ]
          },
          {
            "token": "of",
            "logprob": -3.576278118089249e-07,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁of",
                "logprob": -3.576278118089249e-07,
                "bytes": null
              },
              {
                "token": "▁in",
                "logprob": -14.796875,
                "bytes": null
              },
              {
                "token": "of",
                "logprob": -17.359375,
                "bytes": null
              }
            ]
          },
          {
            "token": "the",
            "logprob": -0.01982644945383072,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁the",
                "logprob": -0.01982644945383072,
                "bytes": null
              },
              {
                "token": "▁merge",
                "logprob": -3.9729514122009277,
                "bytes": null
              },
              {
                "token": "▁Mer",
                "logprob": -7.660451412200928,
                "bytes": null
              }
            ]
          },
          {
            "token": "merge",
            "logprob": -0.026704560965299606,
            "bytes": null,
            "top_logprobs": [
              {
                "token": "▁merge",
                "logprob": -0.026704560965299606,
                "bytes": null
              },
              {
                "token": "▁Mer",
                "logprob": -3.9485795497894287,
                "bytes": null
              },
              {
                "token": "▁mer",
                "logprob": -5.214204788208008,
                "bytes": null
              }
            ]
          }
        ]
      },
      "finish_reason": "length"
    }
  ],
  "usage": {
    "prompt_tokens": 15,
    "total_tokens": 25,
    "completion_tokens": 10
  }
}

Ready for review CC @sunggg @masahi

@zxybazh zxybazh force-pushed the feature/2023-11-22/enable-mlc-server-logprobs branch from 56f5b41 to 0994bd8 Compare November 22, 2023 20:41
@zxybazh zxybazh marked this pull request as ready for review November 22, 2023 23:16
@masahi
Copy link
Member

masahi commented Nov 22, 2023

OpenAI API doesn't specify how logprobs should be calculated and returned. I think it's better to wait until their new logprob API is released.

Besides, logprob-related logic in vllm look very complicated and they are scattered across their codebase. Does this change implement the same logic?

@zxybazh
Copy link
Author

zxybazh commented Nov 23, 2023

I also find vllm's implementation quite complicated so I only referred to their example as linked in the PR descrption. Given OpenAI's api is not revealed, I don't have strong opinion on whether to integrate it right now. It's more of a use case based decision so I would like to see how other folks think 👀.

@masahi
Copy link
Member

masahi commented Nov 23, 2023

I think we can merge tentative logprob support now as long as

  • We pay no perf cost when logprob is not requested
  • Logprob calculation logic and return format are easy to modify upon the new OpenAI API release

@sunggg
Copy link
Member

sunggg commented Nov 23, 2023

Although OpenAI spec does not reveal logprob yet, in my understanding, we need this for @vvchernov for his accuracy testing work. Let's get his feedback and incorporate with his PR #69.

@sunggg
Copy link
Member

sunggg commented Nov 23, 2023

@zxybazh, seems like you might need to rebase. See the conflicts

Conflicting files
serve/mlc_serve/engine/async_connector.py
serve/mlc_serve/engine/base.py
serve/mlc_serve/engine/staging_engine_worker.py

@zxybazh zxybazh force-pushed the feature/2023-11-22/enable-mlc-server-logprobs branch from 3be4501 to 5957ae8 Compare November 23, 2023 05:43
@vvchernov
Copy link

vvchernov commented Nov 27, 2023

Hello @sunggg and @zxybazh! Thank you that called me! Good job! I think I could spend on more time for it due to I'm not so familiar with this code. I've added some suggestions to fix. But now I want to ask and discuss some things related to logprobs and why they are needed to us to clarify understanding from both sides.

@vvchernov
Copy link

But before it I've marked in your example that the first top token is an empty string(""). I've met the same when test logprobs on mlc-llm side and avoid it by removing all processing related to system prompt. I'm not sure that it is correct behavior. Where do you get the example? Is it a result of measurements after the feature was added?

@vvchernov
Copy link

Logprobs is a powerful tool of LLM and it is used in many scenarios:

  • Different logits mechanics like beam search
  • LangChain uses it in retrieval approach (see FLARE)
  • Data scientists use logprobs to define artificial text in hybrid one (paper)
  • Accuracy benchmark based on loglikelihood approach which utilizes them
  • and so on

Due to this I'm not sure that "logprob will not be frequently asked" in the closest future.
I agree that it will be good to separate pipeline with and without logprobs, but I do not see well implementation of it excluding to build two topologies for the same model with/without logprob calculation that is not ideal.
I'm not sure it should be wrapped to debug mode due to it can quickly become prod solution.

Brief resume about loglikelyhood approach: Request consists of context and continuation strings. The context is question or part of sentence. Usually there are four continuations (answers on the question or the context continuation): one of them is assumed as correct. We force model to continue the context along the continuation tokens (not generate its own continuation) and sum of their logprobs (it is always negative). After that the maximum from four sums is found and if it is sum of the correct continuation the answer of model is assumed correct.
I use pre-fill step to get logits for one pass wherein I set context tokens + continuation tokens - the last token in the model as input. We still should remember the last token index to calculate logprob for it from generated logits by model.
I've implemented logprobs calculation in separated pass on mlc-llm side. I've used textsynth API instead of OpenAI one. It gets context and continuation strings and return logprob sum. I use it due to it is convenient for loglikelihood accuracy measurement. If we decide to stop on OpenAI API for logprob (just now deprecated), I switch on it, but for it I need some clarifying of the following moments:

  • Are logprobs here calculated for one token when it is generated on decode step or prefill step also generates logprobs for all input tokens?
  • What should I insert as input? How should model generation be stopped? I need only logprobs for specified text, I do not need continuation generated over the text
  • I'm aware about the last token. If I cut it from context + continuation and model generate new one the logprob for specified token can be not in top5 of logprobs.

@masahi @zxybazh @sunggg any thoughts?

cc @binarybana

@zxybazh zxybazh force-pushed the feature/2023-11-22/enable-mlc-server-logprobs branch 2 times, most recently from d62d756 to e232862 Compare November 27, 2023 21:12
@zxybazh
Copy link
Author

zxybazh commented Nov 28, 2023

Thanks @vvchernov for the detailed response! I agree Logprobs could be very useful. From implementation perspective, I think it possible to implement less intrusive logprob generation as sometimes they may not be required. The challeging part to me is when there're different logprob requests (some without logprob, some with different logprob topK numbers) in the same batch, it would be a bit hard to handle without losing performance. In that case we may need to separate the requests and build two topologies as you said.

For your first comment on the "empty string", I'm not sure yet how this was generated but it is likely containing some escape characters which the current decoder cannot directly decode which made it look like an empty string. The example is generated by running a llama-7b-chat-hf-q0f16 model via mlc serve and query via the curl request. Can you please elaborate more on "result of measurements after the feature was added" I'm not sure if this case fall into that category but happy to discuss more.

Are logprobs here calculated for one token when it is generated on decode step or prefill step also generates logprobs for all input tokens?

Logprobs generated here are for each token during decode step, the definition of logprob could vary as OpenAI doesn't have a clear rule for that right now. We can try generate logprob during prefill step if that's what we need.

What should I insert as input? How should model generation be stopped? I need only logprobs for specified text, I do not need continuation generated over the text

There're many sampling parameters to control when model generation should stop, the most frequently used one is length, a.k.a., max token number. For now the input is a prompt just like our regular use case in completion.

I'm aware about the last token. If I cut it from context + continuation and model generate new one the logprob for specified token can be not in top5 of logprobs.

Right now I'm producing logprobs for decode steps only, would you please share an example where the new one is not included in the top5 of logprobs? If that's the case, we can still output the logprob of selected token (the specified token).

@vvchernov
Copy link

Hello @zxybazh! Thank you for quick response!

I agree Logprobs could be very useful. From implementation perspective, I think it possible to implement less intrusive logprob generation as sometimes they may not be required. The challeging part to me is when there're different logprob requests (some without logprob, some with different logprob topK numbers) in the same batch, it would be a bit hard to handle without losing performance. In that case we may need to separate the requests and build two topologies as you said.

My general point was not that Logprobs are valuable tool (of course, it is), but that may be it can become basic need of a client. Therefore I called @binarybana to discuss it with us and share with his point of view. In both cases the selection of with/without logprobs request and using two topologies for processing looks like solution.

@vvchernov
Copy link

For your first comment on the "empty string", I'm not sure yet how this was generated but it is likely containing some escape characters which the current decoder cannot directly decode which made it look like an empty string. The example is generated by running a llama-7b-chat-hf-q0f16 model via mlc serve and query via the curl request. Can you please elaborate more on "result of measurements after the feature was added" I'm not sure if this case fall into that category but happy to discuss more.

A little bit more details of my implementation and test of loglikelihood calculation on mlc-llm side. I compare results between original HF llama2-7b and llama-7b-chat-hf-q0f16 from mlc-llm on specified samples. The enough big gap was observed constantly. There were two reasons: 1. system prompt. Moreover when I set empty prompt it was still some gap. It still added some tokens before (like [INST], [/INST]). It disappeared when I commented SystemPromptProcessing method (I used mlc-chat pipeline for initial tests). 2. I observed that tokenizer.decode(context + continuation) != tokenizer.decode(context) + tokenizer.decode(continuation). I always observed incorrect first token for continuation which is empty string in encoded state. As result I cut it and the gap become minimal.
Python tokenizer does not have such problem from point 2. And it looks like that can be seen on your example. Due to this I'm aware that we get the same results for logprobs.

@vvchernov
Copy link

We can try generate logprob during prefill step if that's what we need.

For me just now it is what we need only. Decode step is not needed for loglikelihood calculation.

@sunggg
Copy link
Member

sunggg commented Nov 28, 2023

Hi, guys. Thank you for the fruitful discussion!
So, it seems like there are several variants/implementations for logprob features. Based on my read, it seems like we have

What are the differences among these options? My general take is why not following the OpenAI spec given that it is a standard. OpenAI is working on bringing logprob in their legacy completion api to their new chat API (see related discussion). Meantime, do you guys think we can follow the logprob in the completion api (see link)?

We can try generate logprob during prefill step if that's what we need.

For me just now it is what we need only. Decode step is not needed for loglikelihood calculation.

If we only need it for prefill, I think performance penalty might be marginal. @vvchernov, it seems like prefill suffices your current need, but do you think you may need logprob for the decoding steps in the future? I'm asking because I think what @zxybazh implemented is based on vllm and vllm community probably have their reason for their implementation.

Also, I believe we can do more efficient discussion if we have performance data. So for the following three options below, can we clarify what we would need eventually so that @zxybazh can benchmark the performance implication?

  • S1. prefill only
  • S2. decode only
  • S3. both prefill and decode

@vvchernov
Copy link

Hello @sunggg!

  1. Currently we need to calculate logprobs for loglikelihood approach to evaluate models on popular tasks like MMLU, BigBench, HellaSwag and so on to check our models (including quantized) with public leaderboard. This task is narrow when logprob calculation in deprecated OpenAI style.
  2. Nevertheless as I said logprob is powerful tool and soon it may become very popular. Moreover this may be a reason to sacrifice performance for the sake of logprob calculation. But just now we have time to think how do it better
  3. I thought a little bit about my closest work and strategies related to logprob calculation. I think our solutions are independent enough and I plan implemented logprob in textsynth API here. Simultaneously I hope OpenAI API logprob calculation also will be developed and possibly in the future when we will see new OpenAI API our solutions are joined.

@vvchernov
Copy link

Hello @masahi @zxybazh @sunggg! After discussion outside and analyzing openai-style implementation I've prepare resume of possible options I see to implement loglikelihood approach support over here.
I remind what we need for loglikelihood:

  1. Input is context and continuation strings
  2. The continuation can be a far from LLM output if the context is input only.
  3. Output should be the sum of logprob for the continuation tokens.

What is the problem with current PR? If the context string is used as input, in each decode step (1) the logprob of the corresponding continuation token should be output, it can be not from top5 logprobs (2) the next token should be the corresponding continuation token, not one that the LLM "thinks" it should be. And (3) the process should be stopped when the continuation tokens are ended.

I saw how did lm-evaluation-harness use deprecated OpenAI API for loglikelihood task evaluation.

  1. In the request context + continuation are used as input prompt, temperature is 0 to have determenistic answer, max_tokens=0 to stop process immediately after prefill step (link):
response = oa_completion(
    engine=self.engine,
    prompt=inps,
    echo=True,
    max_tokens=0,
    temperature=0.0,
    logprobs=10,
)

One more note: the completion is used, thus it resets cache each request.
2. We get all logprobs (for all input tokens), but sum only tail corresponding the continuation (link):

logprobs = response["logprobs"]["token_logprobs"]
continuation_logprobs = sum(logprobs[ctxlen:])

So the main part of this text: prefill step should be used for loglikelihood approach. The weak place here is all logprobs should be output. I see two options to do this: (1) two topologies of the same model are used. One return the last set of logits only without performance penalty (current state). Another one return all logits which are processed on CPU as needed. But Jason do not prefer this scenario (2) one topology is used. We use conditions and key to calculate all logits if needed. I see slight performance penalty due to logits projection before output (e.g. for llama2) and if I do not mistake, GPU calculates both branch of condition.

One more thing: if we plan to use speculative decoding scenario we need to do the same, namely to get set of last logprobs for given continuation.

@zxybazh my plan now is start from your branch and add to it some fixes and support of all logits on prefill step. I will pull my changes to your branch or we will decide how to move it here.

P.S. Jason said me that @tqchen knows how to calculate perplexity without performance penalty. Tianqi if you have other understanding that I explain here, please, share it with us.

cc @binarybana

@masahi
Copy link
Member

masahi commented Dec 18, 2023

The new openai API for logprob is out
https://platform.openai.com/docs/api-reference/chat/create#chat-create-logprobs

@vvchernov
Copy link

Hello @masahi @zxybazh @sunggg! I've upstream OpenAI API in this PR. The latter updates the current branch. I suggest you to review it, merge to branch of @zxybazh and merge this PR.
Due to OpenAI supports logprobs for new (generated) tokens only it can not be directly used for loglikelihood calculations. I decided to implement it in parallel as PoC and after successful test we can continue to discuss how it can be integrated with minimal performance reduction.

cc @binarybana

@zxybazh
Copy link
Author

zxybazh commented Dec 19, 2023

Thanks @vvchernov for the quick update after OpenAI's new api went out. I've merged his PR into mine and will do a rebase to sync with current branch.


top_greedy_logprob, top_greedy = torch.topk(logprobs, k=5, dim=-1, largest=True, sorted=True)
# Convert to numpy
res_greedy_logprob = res_greedy_logprob.cpu().numpy()
Copy link

@vvchernov vvchernov Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it is performance bottleneck place. Could we continue to use torch.tensor or do it async. Another scenario suggested by Masa is to calculate logprobs if need

@zxybazh zxybazh force-pushed the feature/2023-11-22/enable-mlc-server-logprobs branch from b64db01 to 9b053e8 Compare December 19, 2023 23:15
@vvchernov
Copy link

Hello @zxybazh! I've prepared more patch for your branch. It should fix problems with CI

@vvchernov
Copy link

I think description due to logprob response format was changed

@sunggg
Copy link
Member

sunggg commented Jan 31, 2024

Great write up, @zxybazh! To me, there is no obvious reason not to try detokenize_incrementally since it is already working solution for such context-based detokenization. Can we try this quickly? Again, there is a follow-up PR waiting for this so I'd like us to speed things up. Happy to help if needed.

@vvchernov
Copy link

Hello guys! Last benchmark measurements show very small performance reduction for case when logprobs not requested (-0.25%) and it is in deviation range from test to test. The performance of request with logprobs is strongly slower. I think to not mix request with/without logprobs the logical next step would be to separate them (e.g. create LogprobRequest), simultaneously it finally resolves issue for the case when logprobs not requested

cc @sunggg @masahi

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok merging now, but please follow-up with my latest comments.

for info in logprob_infos:
if info is not None:
check = True
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just return logprob_infos here. No need for check variable.

Copy link
Member

@masahi masahi Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually you don't need this function at all. You can fuse get_raw_logprob_infos and check_logprob_infos by returning None from the former after doing the same check done by the latter.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all output logprob_infos as None instead of list of Nones allowed me to decrease performance reduction from 2% to 0.25% for case when logprobs not requested. It was last mile in this task. Therefore I need such check. A little bit more details, looks like pydantic and dataclass are slower enough (of course, it is still of order of 100 ns) for init and filling when standard classes like dict and it starts to make sense for very intensive token generation by 7b models.
Unfortunately I can not simply embed check_logprob_infos to get_raw_logprob_infos (of course, it was good idea) due to the latter goes not along full list but part of it corresponded to greedy or random indices. Thus it does not work for mixed case when there are random and greedy requests processed together as was done for benchmark_throughput by default.

) -> Optional[RawLogprobsInfos]:
if logprob_infos is None or logprob_infos[i] is None:
return None
return [logprob_infos[i]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for this to be a method of this class. Move it to model_common.py.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I also can not do it. Earlier it was done so, but led to performance reduction. In common case we have two options logprob_infos is 1. None, case when logprobs not requested, such trick allows to avoid performance reduction for this case. 2. list of None or logprob_info.
When we create TextGenerationResult we need only one element from the list. It means we should check that it is a list (not None) and extract the element, it can not be transfer to model_common.py side.
My doubts are we potentially can expect one element of RawLogprobsInfo in TextGenerationResult instead of the list with one element, it will slightly simplified the code. But I've seen that generated_tokens is a list with some comments about speculative decoding. And I've decided that logprobs info should correspond to each token in the list (i.e. should be the list)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. This function doesn't touch self at all, but its implementation is valid only when used by this class. So it cannot be moved to model_common.py which is supposed to be a collection of general utilities.

But soon I'm adding a PyTorch-based implementation of the model, in a file pt_model.py. I don't want to repeat this class there. Assuming that get_logprob_infos is called only by the TVM or by the PT model, I think it is ok to put it inside model_common.py.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@masahi masahi merged commit 2b3fcf0 into octoml:batch-serving Jan 31, 2024
1 check passed
@vvchernov vvchernov deleted the feature/2023-11-22/enable-mlc-server-logprobs branch February 1, 2024 06:16
@vvchernov vvchernov restored the feature/2023-11-22/enable-mlc-server-logprobs branch February 1, 2024 06:16
@zxybazh zxybazh mentioned this pull request Feb 1, 2024
masahi pushed a commit that referenced this pull request Feb 1, 2024
logits[ind],
token_ids[ind],
top_logprobs,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check if this is correct. It seems like i and ind are switched to me.

logprob_infos[ind] = get_raw_logprob_info(
            logits[i],
            token_ids[i],
            top_logprobs,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've double checked. It is correct, moreover it would failed with index out of bounds on benchmark test if it was not correct

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants