Morgan Funtowicz commited on
Commit
38fa9fc
·
1 Parent(s): 69894ec

feat(embeddings): expose some more to Python and return corresponding embedding (with copy for now)

Browse files
Files changed (1) hide show
  1. handler.py +31 -12
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  import platform
 
2
 
3
  import torch
4
  from loguru import logger
@@ -7,6 +8,23 @@ from hfendpoints.openai.embeddings import Embedding, EmbeddingEndpoint, Embeddin
7
  from sentence_transformers import SentenceTransformer
8
 
9
  from hfendpoints import EndpointConfig, Handler, __version__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class SentenceTransformerHandler(Handler):
@@ -32,21 +50,22 @@ class SentenceTransformerHandler(Handler):
32
  else:
33
  self._model = torch.compile(self._model)
34
 
35
- @torch.compile
36
- def forward(self, documents: str):
37
- # TODO: Ask Tom how to do this better without tokenizing twice?
38
- tokens = self._model.tokenize(documents)
39
- vectors = self._model.encode(documents, output_value="sentence_embedding", normalize_embeddings=True)
40
-
41
- return tokens, vectors
42
-
43
  async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
44
- with torch.backends.mkldnn.verbose(torch.backends.mkldnn.VERBOSE_ON_CREATION):
45
  with torch.inference_mode(), torch.amp.autocast("cpu", dtype=torch.float32):
 
46
  vectors = self._model.encode(request.input)
47
- embedding = Embedding(index=0, embedding=vectors.tolist())
48
- usage = Usage(prompt_tokens=len(request.input), total_tokens=len(request.input))
49
- return EmbeddingResponse(model=self._model_name, embeddings=[embedding], usage=usage)
 
 
 
 
 
 
 
 
50
 
51
 
52
  def entrypoint():
 
1
  import platform
2
+ from typing import Union, Sequence, Sized
3
 
4
  import torch
5
  from loguru import logger
 
8
  from sentence_transformers import SentenceTransformer
9
 
10
  from hfendpoints import EndpointConfig, Handler, __version__
11
+ from torch.backends.mkldnn import VERBOSE_ON_CREATION, VERBOSE_OFF
12
+
13
+
14
+
15
+ def get_usage(tokens: Union[Sized, Sequence[Sized]], is_batched: bool) -> Usage:
16
+ """
17
+ Compute the number of processed tokens and return as Usage object matching OpenAI
18
+ :param tokens: List or nested List of tokens
19
+ :param is_batched: Flag indicating if the original request contained batched inputs
20
+ :return: Usage object matching OpenAI specifications
21
+ """
22
+ if is_batched:
23
+ num_tokens = sum(len(document) for document in tokens)
24
+ else:
25
+ num_tokens = len(tokens)
26
+
27
+ return Usage(prompt_tokens=num_tokens, total_tokens=num_tokens)
28
 
29
 
30
  class SentenceTransformerHandler(Handler):
 
50
  else:
51
  self._model = torch.compile(self._model)
52
 
 
 
 
 
 
 
 
 
53
  async def __call__(self, request: EmbeddingRequest, ctx: Context) -> EmbeddingResponse:
54
+ with torch.backends.mkldnn.verbose(VERBOSE_ON_CREATION if self._config.is_debug else VERBOSE_OFF):
55
  with torch.inference_mode(), torch.amp.autocast("cpu", dtype=torch.float32):
56
+ tokens = self._model.tokenize(request.input)
57
  vectors = self._model.encode(request.input)
58
+
59
+ embeddings = [[None] * len(request)]
60
+ if not request.is_batched:
61
+ embeddings[0] = Embedding(index=0, embedding=vectors.tolist())
62
+ else:
63
+ for (index, embedding) in enumerate(vectors.tolist()):
64
+ embedding = Embedding(index=index, embedding=embedding)
65
+ embeddings[index] = embedding
66
+
67
+ usage = get_usage(tokens, request.is_batched)
68
+ return EmbeddingResponse(model=self._model_name, embeddings=embeddings, usage=usage)
69
 
70
 
71
  def entrypoint():