Skip to content

Commit 3a9cae6

Browse files
committed
added byte fallback
Signed-off-by: Prabod Rathnayaka <[email protected]>
1 parent 25580a9 commit 3a9cae6

File tree

1 file changed

+17
-18
lines changed

1 file changed

+17
-18
lines changed

src/main/scala/com/johnsnowlabs/nlp/annotators/tokenizer/bpe/Phi3VisionTokenizer.scala

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import com.johnsnowlabs.nlp.annotators.common.IndexedToken
2121
import java.nio.charset.Charset
2222
import scala.collection.mutable.ListBuffer
2323
import scala.util.matching.Regex
24+
import scala.collection.mutable
2425

2526
class Phi3VisionTokenizer(
2627
merges: Map[(String, String), Int],
@@ -89,24 +90,22 @@ class Phi3VisionTokenizer(
8990
}
9091

9192
def decodeTokens(tokens: Array[Int]): String = {
92-
var text = tokens
93-
.map(token => decoderVocab(token))
94-
.filter(x => !specialTokens.contains(x))
95-
.mkString("")
96-
97-
text = text.replaceAll("", " ").trim()
98-
99-
text =
100-
try {
101-
val bytes =
102-
text.map(x => unicodeToByteMapping(x.toString)).map(x => x.toByte).toArray
103-
new String(bytes, Charset.forName("UTF-8"))
104-
} catch {
105-
case e: Exception =>
106-
{}
107-
// Do nothing, just return the text
108-
text
93+
val decoded = new mutable.StringBuilder()
94+
tokens.foreach { token =>
95+
{
96+
val decodedToken = decoderVocab(token)
97+
if (!specialTokens.contains(decodedToken)) {
98+
if (decodedToken.startsWith("<0x") && decodedToken.endsWith(">")) {
99+
val strippedHex = decodedToken.replaceAll("<0x|>", "")
100+
val byteValue = Integer.parseInt(strippedHex, 16)
101+
decoded.append(byteValue.toChar)
102+
} else {
103+
decoded.append(decodedToken)
104+
}
105+
}
109106
}
110-
text
107+
108+
}
109+
decoded.toString().replaceAll(decoderVocab(29871), " ").trim()
111110
}
112111
}

0 commit comments

Comments
 (0)