Skip to content

Commit 8e1c6f8

Browse files
lhrotkmslhrotk
andauthored
feat: Change Fabric Cog Service Token to Support Billing (#2291)
* change fabric cogservice token to support billing * change mwc token * rename --------- Co-authored-by: cruise <[email protected]>
1 parent 6854b5f commit 8e1c6f8

File tree

9 files changed

+41
-104
lines changed

9 files changed

+41
-104
lines changed

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/CognitiveServiceBase.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ package com.microsoft.azure.synapse.ml.services
66
import com.microsoft.azure.synapse.ml.codegen.Wrappable
77
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
88
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
9-
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, TokenLibrary}
9+
import com.microsoft.azure.synapse.ml.fabric.FabricClient
1010
import com.microsoft.azure.synapse.ml.io.http._
1111
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
1212
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
@@ -330,7 +330,7 @@ trait HasCognitiveServiceInput extends HasURL with HasSubscriptionKey with HasAA
330330
val providedCustomAuthHeader = getValueOpt(row, CustomAuthHeader)
331331
if (providedCustomAuthHeader .isEmpty && PlatformDetails.runningOnFabric()) {
332332
logInfo("Using Default AAD Token On Fabric")
333-
Option(TokenLibrary.getAuthHeader)
333+
Option(FabricClient.getCognitiveMWCTokenAuthHeader)
334334
} else {
335335
providedCustomAuthHeader
336336
}

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAI.scala

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
package com.microsoft.azure.synapse.ml.services.openai
55

66
import com.microsoft.azure.synapse.ml.codegen.GenerationUtils
7-
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting, OpenAITokenLibrary}
7+
import com.microsoft.azure.synapse.ml.fabric.{FabricClient, OpenAIFabricSetting}
88
import com.microsoft.azure.synapse.ml.logging.common.PlatformDetails
99
import com.microsoft.azure.synapse.ml.param.ServiceParam
1010
import com.microsoft.azure.synapse.ml.services._
@@ -277,18 +277,6 @@ trait HasOpenAITextParams extends HasOpenAISharedParams {
277277
}
278278
}
279279

280-
trait HasOpenAICognitiveServiceInput extends HasCognitiveServiceInput {
281-
override protected def getCustomAuthHeader(row: Row): Option[String] = {
282-
val providedCustomHeader = getValueOpt(row, CustomAuthHeader)
283-
if (providedCustomHeader.isEmpty && PlatformDetails.runningOnFabric()) {
284-
logInfo("Using Default OpenAI Token On Fabric")
285-
Option(OpenAITokenLibrary.getAuthHeader)
286-
} else {
287-
providedCustomHeader
288-
}
289-
}
290-
}
291-
292280
abstract class OpenAIServicesBase(override val uid: String) extends CognitiveServicesBase(uid: String)
293281
with HasOpenAISharedParams with OpenAIFabricSetting {
294282
setDefault(timeout -> 360.0)

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIChatCompletion.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@ package com.microsoft.azure.synapse.ml.services.openai
55

66
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
77
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
8-
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
8+
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
99
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
1010
import org.apache.spark.ml.ComplexParamsReadable
11-
import org.apache.spark.ml.param.Param
1211
import org.apache.spark.ml.util._
1312
import org.apache.spark.sql.Row
1413
import org.apache.spark.sql.types._
@@ -20,7 +19,7 @@ import scala.language.existentials
2019
object OpenAIChatCompletion extends ComplexParamsReadable[OpenAIChatCompletion]
2120

2221
class OpenAIChatCompletion(override val uid: String) extends OpenAIServicesBase(uid)
23-
with HasOpenAITextParams with HasMessagesInput with HasOpenAICognitiveServiceInput
22+
with HasOpenAITextParams with HasMessagesInput with HasCognitiveServiceInput
2423
with HasInternalJsonOutputParser with SynapseMLLogging {
2524
logClass(FeatureNames.AiServices.OpenAI)
2625

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAICompletion.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package com.microsoft.azure.synapse.ml.services.openai
55

66
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
77
import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
8-
import com.microsoft.azure.synapse.ml.services.HasInternalJsonOutputParser
8+
import com.microsoft.azure.synapse.ml.services.{HasCognitiveServiceInput, HasInternalJsonOutputParser}
99
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
1010
import org.apache.spark.ml.ComplexParamsReadable
1111
import org.apache.spark.ml.util._
@@ -19,7 +19,7 @@ import scala.language.existentials
1919
object OpenAICompletion extends ComplexParamsReadable[OpenAICompletion]
2020

2121
class OpenAICompletion(override val uid: String) extends OpenAIServicesBase(uid)
22-
with HasOpenAITextParams with HasPromptInputs with HasOpenAICognitiveServiceInput
22+
with HasOpenAITextParams with HasPromptInputs with HasCognitiveServiceInput
2323
with HasInternalJsonOutputParser with SynapseMLLogging {
2424
logClass(FeatureNames.AiServices.OpenAI)
2525

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIEmbedding.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import com.microsoft.azure.synapse.ml.param.AnyJsonFormat.anyFormat
77
import com.microsoft.azure.synapse.ml.io.http.JSONOutputParser
88
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
99
import com.microsoft.azure.synapse.ml.param.ServiceParam
10+
import com.microsoft.azure.synapse.ml.services.HasCognitiveServiceInput
1011
import org.apache.http.entity.{AbstractHttpEntity, ContentType, StringEntity}
1112
import org.apache.spark.ml.ComplexParamsReadable
1213
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
@@ -22,7 +23,7 @@ import scala.language.existentials
2223
object OpenAIEmbedding extends ComplexParamsReadable[OpenAIEmbedding]
2324

2425
class OpenAIEmbedding (override val uid: String) extends OpenAIServicesBase(uid)
25-
with HasOpenAIEmbeddingParams with HasOpenAICognitiveServiceInput with SynapseMLLogging {
26+
with HasOpenAIEmbeddingParams with HasCognitiveServiceInput with SynapseMLLogging {
2627
logClass(FeatureNames.AiServices.OpenAI)
2728

2829
def this() = this(Identifiable.randomUID("OpenAIEmbedding"))

cognitive/src/main/scala/com/microsoft/azure/synapse/ml/services/openai/OpenAIPrompt.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@
33

44
package com.microsoft.azure.synapse.ml.services.openai
55

6-
import com.microsoft.azure.synapse.ml.services._
76
import com.microsoft.azure.synapse.ml.core.contracts.HasOutputCol
87
import com.microsoft.azure.synapse.ml.core.spark.Functions
98
import com.microsoft.azure.synapse.ml.io.http.{ConcurrencyParams, HasErrorCol, HasURL}
109
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
1110
import com.microsoft.azure.synapse.ml.param.StringStringMapParam
11+
import com.microsoft.azure.synapse.ml.services._
1212
import org.apache.http.entity.AbstractHttpEntity
1313
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
1414
import org.apache.spark.ml.util.Identifiable
1515
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
16-
import org.apache.spark.sql.Row.unapplySeq
1716
import org.apache.spark.sql.catalyst.encoders.RowEncoder
1817
import org.apache.spark.sql.functions.udf
1918
import org.apache.spark.sql.types.{DataType, StructType}
@@ -28,7 +27,7 @@ class OpenAIPrompt(override val uid: String) extends Transformer
2827
with HasErrorCol with HasOutputCol
2928
with HasURL with HasCustomCogServiceDomain with ConcurrencyParams
3029
with HasSubscriptionKey with HasAADToken with HasCustomAuthHeader
31-
with HasOpenAICognitiveServiceInput
30+
with HasCognitiveServiceInput
3231
with ComplexParamsWritable with SynapseMLLogging {
3332

3433
logClass(FeatureNames.AiServices.OpenAI)

core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/FabricClient.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ object FabricClient extends RESTUtils {
122122

123123
private def getHeaders: Map[String, String] = {
124124
Map(
125-
"Authorization" -> s"Bearer ${TokenLibrary.getAccessToken}",
125+
"Authorization" -> s"${getMLWorkloadAADAuthHeader}",
126126
"RequestId" -> UUID.randomUUID().toString,
127127
"Content-Type" -> "application/json",
128128
"x-ms-workload-resource-moniker" -> UUID.randomUUID().toString
@@ -143,4 +143,10 @@ object FabricClient extends RESTUtils {
143143
def usagePost(url: String, body: String): JsValue = {
144144
usagePost(url, body, getHeaders);
145145
}
146+
147+
def getMLWorkloadAADAuthHeader: String = TokenLibrary.getMLWorkloadAADAuthHeader
148+
149+
def getCognitiveMWCTokenAuthHeader: String = {
150+
TokenLibrary.getCognitiveMwcTokenAuthHeader(WorkspaceID.getOrElse(""), ArtifactID.getOrElse(""))
151+
}
146152
}

core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/OpenAITokenLibrary.scala

Lines changed: 0 additions & 72 deletions
This file was deleted.

core/src/main/scala/com/microsoft/azure/synapse/ml/fabric/TokenLibrary.scala

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@ package com.microsoft.azure.synapse.ml.fabric
66
import scala.reflect.runtime.currentMirror
77
import scala.reflect.runtime.universe._
88

9-
trait AuthHeaderProvider {
10-
def getAuthHeader: String
11-
}
12-
13-
object TokenLibrary extends AuthHeaderProvider {
9+
object TokenLibrary {
1410
def getAccessToken: String = {
1511
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
1612
val mirror = currentMirror
@@ -27,9 +23,29 @@ object TokenLibrary extends AuthHeaderProvider {
2723
}
2824
}.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type $argType not found"))
2925
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
30-
methodMirror("pbi").asInstanceOf[String]
26+
methodMirror("ml").asInstanceOf[String]
3127
}
3228

29+
def getSparkMwcToken(workspaceId: String, artifactId: String): String = {
30+
val objectName = "com.microsoft.azure.trident.tokenlibrary.TokenLibrary"
31+
val mirror = currentMirror
32+
val module = mirror.staticModule(objectName)
33+
val obj = mirror.reflectModule(module).instance
34+
val objType = mirror.reflect(obj).symbol.toType
35+
val methodName = "getMwcToken"
36+
val methodSymbols = objType.decl(TermName(methodName)).asTerm.alternatives
37+
val argTypes = List(typeOf[String], typeOf[String], typeOf[Integer], typeOf[String])
38+
val selectedMethodSymbol = methodSymbols.find { m =>
39+
m.asMethod.paramLists.flatten.map(_.typeSignature).zip(argTypes).forall { case (a, b) => a =:= b }
40+
}.getOrElse(throw new NoSuchMethodException(s"Method $methodName with argument type not found"))
41+
val methodMirror = mirror.reflect(obj).reflectMethod(selectedMethodSymbol.asMethod)
42+
methodMirror(workspaceId, artifactId, 2, "SparkCore")
43+
.asInstanceOf[String]
44+
}
45+
46+
47+
def getMLWorkloadAADAuthHeader: String = "Bearer " + getAccessToken
3348

34-
def getAuthHeader: String = "Bearer " + getAccessToken
49+
def getCognitiveMwcTokenAuthHeader(workspaceId: String, artifactId: String): String = "MwcToken " +
50+
getSparkMwcToken(workspaceId, artifactId)
3551
}

0 commit comments

Comments
 (0)