Created by: klshuster
Patch Description
This patch adds the following to sequence generation:
Repetition Penalties
Inspired by https://beta.openai.com/docs/api-reference/engines/retrieve:
mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
Where:
mu[j] is the logit of the j-th token
c[j] is how often that token was sampled prior to the current position
float(c[j] > 0) is 1 if c[j] > 0 and 0 otherwise
alpha_frequency is the frequency penalty coefficient
alpha_presence is the presence penalty coefficient
These are implemented both as penalties for tokens within the current generation, and tokens within the incoming source context
Factual Nucleus
Factual nucleus is a decoding method that decays the nucleus sampling p
value over time according to a constant factor lambda_decay
. The p
value resets when encountering a full-stop, and there is a lower bound of omega_bound
to what p
can decay to.
Testing steps
I've tested primarily locally, via running interactive_hosted
with BlenderBot 30B model. The logic is also the same as is running in the BB3 demo:
$ curl -i http://host:6010/completions -d '{"prompt": ["Person 1: Hey, how is it going?\nPerson 2:"], "min_tokens": 32, "max_tokens": 32, "alpha_presence": 0.5, "alpha_frequency": 0.5, "lambda_decay": 0.9}' -H "Content-Type: application/json" -H "Authorization: 'Bearer chatbot'"
.
.
.
{"choices":[{"logprobs":{"finish_reason":"length","text_offset":[0,6,7,10,13,19,24,25,29,32,38,39,41,43,51,58,61,64,73,74,78,84,88,89,96,99,102,109,110,115,121,122],"token_logprobs":[-3.0406370162963867,-0.7952971458435059,-1.7070069313049316,-0.293349027633667,-0.11447523534297943,-1.0152369737625122,-0.3831516206264496,-0.7548114657402039,-2.685699462890625,-1.1253763437271118,-0.06859419494867325,-1.3319696187973022,-1.6205519437789917,-5.452319145202637,-4.7300543785095215,-0.904114305973053,-1.0224359035491943,-1.132882833480835,-0.5936547517776489,-2.4546852111816406,-0.7166130542755127,-0.06965463608503342,-0.013585402630269527,-11.175405502319336,-0.5433281660079956,-1.0400110483169556,-2.77441668510437,-2.9750161170959473,-3.781144618988037,-1.0245062112808228,-1.8328840732574463,-1.66330885887146],"tokens":[" Hello",","," it"," is"," going"," well","."," How"," is"," yours","?"," I","'m"," playing"," around"," on"," my"," computer","."," How"," about"," you","?"," Logged"," in"," to"," reddit",","," like"," usual","!"," H"],"top_logprobs":null},"text":" Hello, it is going well. How is yours? I'm playing around on my computer. How about you? Logged in to reddit, like usual! H"}],"created":1660758120,"id":"ece3b56d-a5fb-4adc-83a9-63945519d713","model":"/checkpoint/kshuster/projects/bb3/bb3_30B/reshard_checkpoint1_mp8/","object":"text_completion"}
Are there other tests that y'all think I should add? Where's the proper testing location for the API?