Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def main(
full_batch_size: Optional[int] = None,
prompt_len: int = 32,
ctx_len: int = 128,
comp_ctx_lengths: Optional[List[int]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this in the docstring as well of the function.

generation_len: Optional[int] = None,
mxfp6: bool = False,
mxint8: bool = False,
Expand Down Expand Up @@ -165,6 +166,7 @@ def main(
cache_dir=cache_dir,
hf_token=hf_token,
full_batch_size=full_batch_size,
comp_ctx_lengths=comp_ctx_lengths,
local_model_dir=local_model_dir,
trust_remote_code=trust_remote_code,
)
Expand Down Expand Up @@ -260,6 +262,12 @@ def main(
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
)
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
parser.add_argument(
"--comp_ctx_lengths",
"--comp_ctx_lengths",
type=lambda comp_ctx_lengths: [int(x) for x in comp_ctx_lengths.strip("[]").split(",")],
help="Compute Context length for text generation (comma-separated) e.g. [512,1024,2048] ",
)
parser.add_argument(
"--mxfp6",
"--mxfp6_matmul",
Expand Down
16 changes: 11 additions & 5 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
def CtxGather(
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
# Create a shape tensor based on comp_ctx_len
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)

# Directly use the shape tensor without validation
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
return ops.GatherND(data, ctx_indices, batch_dims=2)

Expand All @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
"""

@staticmethod
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)
18 changes: 12 additions & 6 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,20 @@ def symbolic(

@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxGatherCB(
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
) -> onnxscript.FLOAT:
batch_size = ops.Gather(ops.Shape(batch_index), [0])
num_heads = ops.Gather(ops.Shape(data), [1])
ctx_len = ops.Gather(ops.Shape(data), [2])
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
ctx_len = ops.Reshape(comp_ctx_len, [1])

# Expanded shape to create indices
zero = ops.Constant(value_ints=[0])
one = ops.Constant(value_ints=[1])
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
exp_shape = ops.Concat(
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
)

# Create indices
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
Expand All @@ -119,7 +123,7 @@ def CtxGatherCB(

class CtxGatherFuncCB(torch.autograd.Function):
@staticmethod
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
return data[batch_indices, head_indices, ctx_indices]
Expand All @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
pass

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
def symbolic(
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
) -> torch.Value:
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
Expand Down
73 changes: 72 additions & 1 deletion QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def cloud_ai_100_exec_kv(
prompts_txt_file_path: Optional[str] = None,
device_id: Optional[List[int]] = None,
generation_len: Optional[int] = None,
comp_ctx_lengths: Optional[List[int]] = None,
enable_debug_logs: bool = False,
stream: bool = True,
write_io_dir: Optional[str] = None,
Expand Down Expand Up @@ -368,6 +369,7 @@ def cloud_ai_100_exec_kv(
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
comp_ctx_lengths=comp_ctx_lengths,
enable_debug_logs=enable_debug_logs,
write_io_dir=write_io_dir,
full_batch_size=full_batch_size,
Expand Down Expand Up @@ -407,12 +409,14 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: Optional[int] = None,
) -> None:
self._ctx_len = ctx_len
self.comp_ctx_lengths = comp_ctx_lengths
self._write_io_dir = write_io_dir
self.is_tlm = is_tlm

Expand Down Expand Up @@ -724,6 +728,11 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)

if self.comp_ctx_lengths is not None:
inputs["comp_ctx_lengths"] = np.random.rand(self.comp_ctx_lengths[0])
buffers = {"comp_ctx_len_out": np.zeros(1)}
self._session.set_buffers(buffers)

for i in range(num_chunks):
chunk_inputs = inputs.copy()
chunk_inputs["input_ids"] = inputs["input_ids"][
Expand All @@ -741,6 +750,18 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
generation_len,
)

def initialize_ccl(self, decode_inputs):
max_ccl_id = len(self.comp_ctx_lengths) - 1
max_position_id = np.max(decode_inputs["position_ids"])
ccl_id = 1
for i in range(1, len(self.comp_ctx_lengths)):
if max_position_id < self.comp_ctx_lengths[i]:
ccl_id = i
break
buffers = {"comp_ctx_len_out": np.zeros(1)}

return buffers, ccl_id, max_ccl_id

def run_continuous_batching_decode(self, prompt_queue, generation_len):
"""
Runs continuous batching decode for the given prompt queue and generation length.
Expand Down Expand Up @@ -771,6 +792,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
# Prepare decode inputs inputs.
decode_inputs = self.prepare_decode_inputs()

if self.comp_ctx_lengths is not None:
list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
self._session.set_buffers(buffers)

while prompt_queue or current_decode_ongoing.any():
outputs = self._session.run(decode_inputs)

Expand Down Expand Up @@ -808,6 +835,19 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
batch_id_map[decode_batch_id]
]

if self.comp_ctx_lengths is not None:
###Recalculate ccl_id based on position ids###
# Determine the maximum value of position_ids across all batch elements
max_position_id = np.max(decode_inputs["position_ids"])

# Update ccl_id and comp_ctx_lengths based on the maximum position id
ccl_id = 1
for i in range(1, len(self.comp_ctx_lengths)):
if max_position_id < self.comp_ctx_lengths[i]:
ccl_id = i
break
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]

else:
current_decode_ongoing[decode_batch_id] = False
else:
Expand All @@ -818,6 +858,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
next_token_id[decode_batch_id, -1]
)

if self.comp_ctx_lengths is not None:
# Update ccl_id and comp_ctx_lengths based on the maximum position id
if decode_inputs["position_ids"][decode_batch_id, -1] >= self.comp_ctx_lengths[ccl_id] - 1:
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]

generated_id_current_index[decode_batch_id] += 1

return decode_pause_time
Expand All @@ -842,7 +888,21 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
self._session.set_buffers({"logits": logits_out_placeholder})
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
num_token = 0

if self.comp_ctx_lengths is not None:
list_of_comp_ctx_lengths = [np.zeros(length) for length in self.comp_ctx_lengths]
buffers, ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]
self._session.set_buffers(buffers)

cache_index = np.max(decode_inputs["position_ids"])
for num_token in range(1, generation_len):
if self.comp_ctx_lengths is not None:
if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
# if cache_index >= self.comp_ctx_lengths[ccl_id] - 1:
ccl_id = min(ccl_id + 1, max_ccl_id)
decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths[ccl_id]

if streamer:
streamer.put(decode_inputs["input_ids"][0])
outputs = self._session.run(decode_inputs)
Expand All @@ -854,6 +914,7 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
# Prepare inputs for next iteration
decode_inputs["input_ids"] = outputs["logits"].argmax(2)
decode_inputs["position_ids"][:, -1] += 1
cache_index += 1
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id

Expand Down Expand Up @@ -901,17 +962,27 @@ def __init__(
qpc_path: str,
full_batch_size: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths: Optional[List[int]] = None,
device_id: Optional[List[int]] = None,
enable_debug_logs: bool = False,
write_io_dir: Optional[str] = None,
is_tlm: bool = False,
) -> None:
self._qaic_model = QEffTextGenerationBase(
tokenizer, qpc_path, full_batch_size, ctx_len, device_id, enable_debug_logs, write_io_dir, is_tlm
tokenizer,
qpc_path,
full_batch_size,
ctx_len,
comp_ctx_lengths,
device_id,
enable_debug_logs,
write_io_dir,
is_tlm,
)
self._full_batch_size = self._qaic_model.full_batch_size
self._tokenizer = self._qaic_model.tokenizer
self._ctx_len = ctx_len
self.comp_ctx_lengths = comp_ctx_lengths
self._perf_metrics = None
self._prompt_queue = None
self._text_streamer = None
Expand Down
36 changes: 23 additions & 13 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def read_only(self, layer_idx, cache_kwargs):
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None)
comp_ctx_len = cache_kwargs.get("CCL")

ctx_len = k_out.shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
Expand All @@ -101,15 +103,19 @@ def read_only(self, layer_idx, cache_kwargs):
else:
invalid_idx_value = 0

ctx_indices = ctx_indices[:, :, :comp_ctx_len]
invalid_mask = ctx_indices > gather_limit

invalid_mask = invalid_mask[:, :, :comp_ctx_len]

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)

k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

Expand Down Expand Up @@ -144,6 +150,7 @@ def update(
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
comp_ctx_len = cache_kwargs.get("CCL")

# Scatter
if batch_index is not None:
Expand All @@ -163,26 +170,29 @@ def update(
self.value_cache[layer_idx], position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Gather
ctx_len = k_out.shape[2]
ctx_len = self.key_cache[layer_idx].shape[2]
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit

if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0

ctx_indices = ctx_indices[:, :, :comp_ctx_len]
invalid_mask = ctx_indices > gather_limit

invalid_mask = invalid_mask[:, :, :comp_ctx_len]

ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices)
k_out = CtxGatherFuncCB.apply(self.key_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
v_out = CtxGatherFuncCB.apply(self.value_cache[layer_idx], batch_index, ctx_indices, comp_ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices)
v_out = CtxGatherFunc.apply(v_out, ctx_indices)
k_out = CtxGatherFunc.apply(self.key_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = CtxGatherFunc.apply(self.value_cache[layer_idx], ctx_indices, comp_ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out
Expand Down
Loading
Loading