Data Load and Procedure
Daterloader Loading Data
Refer to the usage tutorial for data preparation process
Here is a detailed introduction to the process of packing data in the Dataloader.
When the use_packed_dataset is set to True in the config.py , the packed data is constructed through the build_pack function. The construction logic is exemplified as follows:
Assuming micro_bsz is set to 2 and SEQ_LEN is set to 8, the input data format is as follows:
[2323, 442, 252, 341]
[233, 3442, 322, 31, 2514, 49731, 51]
[4326, 427, 465, 22, 314, 9725, 346, 1343]
[24, 2562, 5, 25, 356]
The value of packed_length is micro_bsz * SEQ_LEN, which is 16. For the aforementioned input data, each input is packed into a data segment of length 16. If a single sentence exceeds the packed_length when concatenated, the excess part is truncated, and the remaining part is treated as the beginning of the next input. After all the text is packed, any part that is less than packed_length is padded with 0 . The format of the packed data is as follows:
[2323, 442, 252, 341, 233, 3442, 322, 31, 2514, 49731, 51, 4326, 427, 465, 22, 314]
[9725, 346, 1343, 24, 2562, 5, 25, 356, 0, 0, 0, 0, 0, 0, 0, 0]
The value of label is taken from the second value to the last value of each input in the data, and -100 is padded at the last position. After the aforementioned data is packed, the corresponding label values are as follows:
[442, 252, 341, -100, 3442, 322, 31, 2514, 49731, 51, -100, 427, 465, 22, 314, 9725]
[346, 1343, -100, 2562, 5, 25, 356, -100, -100, -100, -100, -100, -100, -100, -100, -100]
When the use_packed_dataset is set to False in the config.py , the packed data is constructed through the build_unpack function. The construction logic is exemplified as follows:
Assuming micro_bsz is set to 2 and SEQ_LEN is set to 8, the input data format is as follows:
[2323, 442, 252, 341]
[233, 3442, 322, 31, 2514, 49731, 51]
[4326, 427, 465, 22, 314, 9725, 346, 1343]
[24, 2562, 5, 25, 356, 3145, 246, 25, 1451, 67, 73, 541, 265]
[4524, 2465, 562, 67, 26, 265, 21, 256, 145, 1345]
[34, 14]
Here, the packing process follows three conditions:
The number of sub-sentences in a pack must not exceed
micro_bsz. Even if the total length after packing is less than the product ofmicro_bsz * SEQ_LEN, the content of the subsequent sub-sentence should not be included in the pack. Any shortfall in length is to be padded with 0.The length of a single sub-sentence must not exceed
SEQ_LEN. Any part that exceeds is to be directly truncated and discarded, and then it is packed with the next sub-sentence.If the length of the sub-sentence after packing exceeds
micro_bsz * SEQ_LEN, the excess is truncated and discarded.
Following the aforementioned rules, the format of the data after packing is as follows:
[2323, 442, 252, 341, 233, 3442, 322, 31, 2514, 49731, 51, 0, 0, 0, 0, 0]
[4326, 427, 465, 22, 314, 9725, 346, 1343, 24, 2562, 5, 25, 356, 3145, 246, 25]
[4524, 2465, 562, 67, 26, 265, 21, 256, 34, 14, 0, 0, 0, 0, 0, 0]
the value of label is as follows:
[442, 252, 341, -100, 3442, 322, 31, 2514, 49731, 51, -100, -100, -100, -100, -100, -100]
[427, 465, 22, 314, 9725, 346, 1343, -100, 2562, 5, 25, 356, 3145, 246, 25, -100]
[2465, 562, 67, 26, 265, 21, 256, -100, 14, -100, -100, -100, -100, -100, -100, -100]
Note: If use_packed_dataset is not set, it defaults to True . Generally, training is performed with use_packed_dataset set to True to enhance training efficiency and accuracy.
Achieve Data From Dataloader
After constructing the data in the Dataloader using the methods described above, during each forward pass, data will be fetched sequentially from the dataloader. Below is a detailed introduction to the process of data acquisition and handling.
Data Fetching
batch_data, actual_batch_size = engine.load_batch(data_iter)
Here, the batch_data is of type list , which contains two elements. The first element is a dict type of data named data , and the second element is a torch.Tensor type of label named label .
In the batch_data list, the first element data is a dictionary that contains three fields: input_ids , cu_seqlens , and indexes . The types and shapes of these fields are as follows:
batch_data[0]['input_ids'] -> torch.Size([micro_num, micro_bsz * SEQ_LEN]), 保存输入语句经过tokenize之后的id值
batch_data[0]['cu_seqlens'] -> list类型, 大小为micro_num, 其中每个元素类型为torch.Tensor, 保存pack到micro_bsz * SEQ_LEN长度的每个拼接字句的索引
batch_data[0]['indexes'] -> torch.Size([micro_num, micro_bsz * SEQ_LEN]), 保存每个input_ids的索引值, 从0开始递增
The shape of the second element label is:
batch_data[1] -> torch.Size([micro_num, micro_bsz * SEQ_LEN])
The micro_num is configured in the config.py file and represents the size of the gradient accumulation, that is, after micro_num consecutive forward and backward passes, an update to the gradients is carried out. The pack data length, which is micro_bsz * SEQ_LEN, is achieved by combining multiple inputs into a single input of length micro_bsz * SEQ_LEN , thereby enhancing the training efficiency.
For example, assuming micro_num is set to 2, micro_bsz is 2, and SEQ_LEN is 8
batch_data[0]['input_ids']:
tensor([[ 2323, 442, 252, 341, 233, 3442, 322, 31, 2514, 49731, 51, 4326, 427, 465, 22, 314],
[ 9725, 346, 1343, 24, 2562, 5, 25, 356, 0, 0, 0, 0, 0, 0, 0, 0]])
In the example, the first batch is composed of sub-sentences with lengths of 4, 7, and 5, respectively, and the second batch is composed of sub-sentences with lengths of 3 and 5, respectively. Then:
batch_data[0]['cu_seqlens']:
tensor([[ 0, 4, 11, 16],
[ 0, 3, 8, 16]])
In this case, the difference between each pair of adjacent numbers represents the length of the current sub-sentence.
batch_data[0]['indexes']:
tensor([[ 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4],
[ 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 5, 6, 7]])
In this context, each number represents the position of a token within the current sub-sentence. If there is padding at the end of the last sentence, the indexes still increment from 0 up to and including the end of the padding.
batch_data[1]:
tensor([[ 442, 252, 341, -100, 3442, 322, 31, 2514, 49731, 51, -100, 427, 465, 22, 314, 9725],
[ 346, 1343, -100, 2562, 5, 25, 356, -100, -100, -100, -100, -100, -100, -100, -100, -100]])
Here are the corresponding values for the label.
Data Processing
_data, _label = self._load_accum_batch(data, label)
Firstly, the _load_micro_batch function is used to transform the first dimension of data and label , which is micro_num , into 1. By updating the value of offset, data for each micro-batch is retrieved sequentially.
Secondly, further processing of the data is carried out by registering a data_process_func .
When use_packed_dataset is set to True in the config.py, the process within the data_process_func is as follows:
Using the packed_data_normalizer function, dimensionality reduction is performed on data[‘indexes’] and data[‘cu_seqlens’], removing the first dimension with a size of 1. Additionally, the maximum length of individual sub-sentences is calculated using the values in data[‘cu_seqlens’], and this value is recorded in data[‘max_seqlen’].
Following the example provided, assuming the first batch of data is loaded, the data and label after processing by _load_accum_batch would be as follows:
data['input_ids']:
tensor([[ 2323, 442, 252, 341, 233, 3442, 322, 31, 2514, 49731, 51, 4326, 427, 465, 22, 314]])
data['cu_seqlens']:
tensor([ 0, 4, 11, 16])
data['indexes']:
tensor([ 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4])
data['max_seqlen']:
7
label:
tensor([[ 442, 252, 341, -100, 3442, 322, 31, 2514, 49731, 51, -100, 427, 465, 22, 314, 9725]])
If the tp parallel mode is set to “isp”, and the tp size (i.e., the sequence parallel size) is greater than 1, then the split_data_sequence_parallel function will be registered within the data_process_func to split the data along the sequence dimension.
Assuming the tp size is 2, the result of splitting the aforementioned data data[‘input_ids’], data[‘indexes’], and label using the split_data_sequence_parallel function would be as follows:
Data in tp rank0
data['input_ids']:
tensor([[ 2323, 442, 252, 341, 233, 3442, 322, 31]])
data['indexes']:
tensor([ 0, 1, 2, 3, 0, 1, 2, 3])
label:
tensor([[ 442, 252, 341, -100, 3442, 322, 31, 2514]])
Data in tp rank1
data['input_ids']:
tensor([[ 2514, 49731, 51, 4326, 427, 465, 22, 314]])
data['indexes']:
tensor([ 4, 5, 6, 0, 1, 2, 3, 4])
label:
tensor([[ 49731, 51, -100, 427, 465, 22, 314, 9725]])
When use_packed_dataset is set to False in the config.py, the process within the data_process_func is as follows:
Using the unpack_data function, the data is processed to unpack it, restoring the format of data[“input_ids”] and label to the unpacked format, and removing the “cu_seqlens” and “indexes” fields from the data.
After unpacking, the shapes of data[“input_ids”] and label are torch.Size([micro_bsz, SEQ_LEN]).
Following the example of the data provided:
Assuming micro_bsz is set to 2 and SEQ_LEN is set to 8, the input data format is as follows:
[2323, 442, 252, 341]
[233, 3442, 322, 31, 2514, 49731, 51]
The packed data format is as follows:
[2323, 442, 252, 341, 233, 3442, 322, 31, 2514, 49731, 51, 0, 0, 0, 0, 0]
the value of label is as follows:
[442, 252, 341, -100, 3442, 322, 31, 2514, 49731, 51, -100, -100, -100, -100, -100, -100]
After processing with unpack_data, data[“input_ids”] and label are as follows:
data["input_ids"]:
tensor([[2323, 442, 252, 341, 0, 0, 0, 0],
[233, 3442, 322, 31, 2514, 49731, 51, 0]])
label:
tensor([[442, 252, 341, -100, -100, -100, -100, -100],
[3442, 322, 31, 2514, 49731, 51, -100, -100]])
If the tp parallel mode is set to “isp”, and the tp size (i.e., the sequence parallel size) is greater than 1, then the split_data_sequence_parallel function will be registered within the data_process_func to split the data along the sequence dimension.
Assuming the tp size is 2, the result of splitting the aforementioned data data[‘input_ids’] and label would be as follows:
Data in tp rank0
data["input_ids"]:
tensor([[2323, 442, 252, 341],
[233, 3442, 322, 31]])
label:
tensor([[442, 252, 341, -100],
[3442, 322, 31, 2514]])
Data in tp rank1
data["input_ids"]:
tensor([[0, 0, 0, 0],
[2514, 49731, 51, 0]])
label:
tensor([[-100, -100, -100, -100],
[49731, 51, -100, -100]])
During the Forward process, the data format is:
Using the internlm2 model as an example, I will detail the weight situation of the entire model under different parallelism modes and the flow of data during operation.
Firstly, the process of weight partitioning in the model under different parallel modes is introduced.
weight partitioning in ISP parallel mode
For the specific principles of ISP parallelism, please refer to:<https://internevo.readthedocs.io/en/latest/parallel.html#internlm-tensor-parallel>`_
In the internlm2 model, the parameters that involve weight partitioning are: “wqkv”, “wo”, “w1”, “w2”, “w3”, “output”. These are partitioned using the new_linear function.
Assuming the configuration file sets the weight parallelism size to wp_size, the model structure and weights after initialization are as follows:
InternLM2(
(tok_embeddings): Embedding1D()
(layers): ModuleList(
(0): InternLM2Decoder(
(attention): GQA(
(rotary_emb): RotaryEmbedding()
(wqkv): ColumnParallelLinear(in_features=hidden_size, out_features=(hidden_size + 2 * hidden_size // num_attention_heads * num_kv_attention_heads) // wp_size, bias=True)
(inner_attn): DistributedAttention(
(local_attn): SelfAttention(
(dropout): Dropout(p=0.0, inplace=False)
)
)
(inner_cross_attn): DistributedAttention(
(local_attn): CrossAttention(
(dropout): Dropout(p=0.0, inplace=False)
)
)
(wo): ColumnParallelLinear(in_features=hidden_size, out_features=hidden_size // wp_size, bias=True)
)
(dropout1): Dropout(p=0.0, inplace=False)
(dropout2): Dropout(p=0.0, inplace=False)
(attention_norm): _RMSNorm(torch.Size([hidden_size]), eps=1e-05, )
(ffn_norm): _RMSNorm(torch.Size([hidden_size]), eps=1e-05, )
(feed_forward): FeedForward(
(w1): ColumnParallelLinear(in_features=hidden_size, out_features=(multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)) // wp_size, bias=False)
(w2): ColumnParallelLinear(in_features=multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of), out_features=hidden_size // wp_size, bias=False)
(w3): ColumnParallelLinear(in_features=hidden_size, out_features=(multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)) // wp_size, bias=False)
)
)
)
(norm): _RMSNorm(torch.Size([hidden_size]), eps=1e-05, )
(output): ScaleColumnParallelLinear(in_features=hidden_size, out_features=vocab_size // wp_size, bias=False)
)
weight partitioning in MTP/MSP/FSP parallel mode
For the specific principles of MTP/MSP/FSP parallelism, please refer to:<https://internevo.readthedocs.io/en/latest/parallel.html#internlm-tensor-parallel>`_
Compared to the ISP parallel mode, the MSP parallel mode partitions the weights of the same parameters, but the partitioning method is different. In the ISP parallel mode, all partitioned parameters use column-wise partitioning, while in the MSP parallel mode, the “wo” and “w2” parameters are partitioned using row-wise partitioning.
Assuming the configuration file sets the tensor parallelism size to tp_size, the initialized model structure and weights are essentially consistent with the weight results listed in ISP. The wp_size in ISP mode corresponds to tp_size in MSP mode. The parameters “wo” and “w2” that differ are as follows:
(Pdb) InternLM2(
(tok_embeddings): Embedding1D()
(layers): ModuleList(
(0): InternLM2Decoder(
(attention): GQA(
......
(wo): RowParallelLinear(in_features=hidden_size // tp_size, out_features=hidden_size, bias=False)
)
......
(feed_forward): FeedForward(
......
(w2): RowParallelLinear(in_features=(multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)) // tp_size, out_features=hidden_size, bias=False)
......
)
)
)
......
)
Forward Procedure
In the internlm2 model, the overall forward process is shown in the figure below:
Below is an introduction to the changes in data during the forward process in the diagram above, under different parallel modes.
Data Procedure in ISP parallelism
Assuming the configuration file sets the tensor parallelism size to sp_size (in ISP mode, the tensor parallelism size is the same as the sequence parallelism size).
Expanding on the introduction, here is a description of the changes in data dimensions at each step of the computation process:
tok_embeddings Calculation Procedure
During the embedding computation process, the seq_len dimension of the data is partitioned.
input parameters and weight:
input_ids:
torch.Size([1, (micro_bsz * seq_len) // sp_size])
self.tok_embeddings.weight:
torch.Size([vocab_size, hidden_size // wp_size])
output result:
hidden_states:
torch.Size([1, (micro_bsz * seq_len) // sp_size, hidden_size])
attention Calculation Procedure
qkv preparation
qkv = self.wqkv(x)
In this computation process, a weight_hook is used to perform an All-Gather operation on the weights that were previously partitioned by weight parallelism. The final output result qkv will have its last dimension as the out_features dimension from self.wqkv after the All-Gather operation.
Note: All weights that have been partitioned by the new_linear function and used in subsequent computations will undergo an All-Gather operation via weight_hook during the forward process.
qkv:
torch.Size([1, (micro_bsz * seq_len) // sp_size, hidden_size + 2 * hidden_size // num_attention_heads * num_kv_attention_heads])
Subsequently, the qkv is split into dimensions of [batch_size, seq_len, num_head, group_size, head_dim], and the values for q (query), k (key), and v (value) are calculated respectively:
qkv:
torch.Size([1, (micro_bsz * seq_len) // sp_size, num_kv_attention_heads, num_attention_heads // num_kv_attention_heads + 2, hidden_size // num_attention_heads])
q:
torch.Size([1, (micro_bsz * seq_len) // sp_size, num_attention_heads, hidden_size // num_attention_heads]) # 取qkv中第四个维度前num_attention_heads // num_kv_attention_heads位对应的数值,并将第三维度与第四维度的值组合在一起
k:
torch.Size([1, (micro_bsz * seq_len) // sp_size, num_kv_attention_heads, hidden_size // num_attention_heads]) # 取qkv中第四个维度倒数第二位对应的数值
v:
torch.Size([1, (micro_bsz * seq_len) // sp_size, num_kv_attention_heads, hidden_size // num_attention_heads]) # 取qkv中第四个维度倒数第一位对应的数值
Afterward, the values of k and v are combined for subsequent attention computation.
kv:
torch.Size([1, (micro_bsz * seq_len) // sp_size, 2, num_kv_attention_heads, hidden_size // num_attention_heads])
calculate attention
The process of attention calculation is as follows:
context = self.inner_attn(q, kv)
Here, the dispatch mechanism is used to determine whether the q, k, and v are separate or combined, and then the corresponding forward function is called to perform the attention calculation.
Before calculating the attention, an AllToAll communication is used to scatter the num_head dimension of q and kv, and to gather the seq_len dimension.
q:
torch.Size([1, micro_bsz * seq_len, num_attention_heads // sp_size, hidden_size // num_attention_heads])
kv:
torch.Size([1, micro_bsz * seq_len, 2, num_kv_attention_heads // sp_size, hidden_size // num_attention_heads])
The function context = self.local_attn(q, kv) is called to perform the attention computation. The resulting dimension of the computation is:
context:
torch.Size([1, micro_bsz * seq_len, num_attention_heads // sp_size, hidden_size // num_attention_heads])
After the attention computation, an AllToAll communication is used again to gather the num_head dimension of q and kv, and to scatter the seq_len dimension.
context:
torch.Size([1, (micro_bsz * seq_len) // sp_size, num_attention_heads, hidden_size // num_attention_heads])
Output transformation
The output of the attention computation is transformed by calling “wo”, and the dimensions of the output result are as follows:
torch.Size([1, (micro_bsz * seq_len) // sp_size, hidden_size])
feed_forward Calculation Procedure
In the feed-forward network layer, linear transformations are applied to the output results using “w1”, “w2”, and “w3”. The results after the transformation are as follows:
w1_o = self.w1(x)
w3_o = self.w3(x)
out = self.w2(Silu(w1_o, w3_o))
w1_o:
torch.Size([1, (micro_bsz * seq_len) // sp_size, multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)])
w3_o:
torch.Size([1, (micro_bsz * seq_len) // sp_size, multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)])
out:
torch.Size([1, (micro_bsz * seq_len) // sp_size, hidden_size])
norm Calculation Procedure
The result dimensions remain unchanged after the computation through the normalization layer, and are as follows:
hidden_states:
torch.Size([1, (micro_bsz * seq_len) // sp_size, hidden_size])
output Calculation Procedure
Finally, the last layer of the model is transformed into a format suitable for the final task through the output layer, with the results as follows:
hidden_states:
torch.Size([1, (micro_bsz * seq_len) // sp_size, vocab_size])
Data Procedure in MTP/MSP/FSP parallelism
In the MTP parallel mode, only tensor parallelism is used to partition the model weights, without involving the splitting of the data’s seq_len dimension. In contrast, both MSP and FSP parallel modes involve the serialization parallelism of data, and the size of the serialization parallelism is the same as that of tensor parallelism; both share the same communication group.
tok_embeddings Calculation Procedure
During the computation process of the embedding, the embedding weights will be partitioned:
self.tok_embeddings.weight:
torch.Size([vocab_size, hidden_size // tp_size])
The input and output results of the MTP mode are as follows:
input_ids:
torch.Size([1, micro_bsz * seq_len])
hidden_states:
torch.Size([1, micro_bsz * seq_len, hidden_size])
The input and output results for MSP and FSP are as follows:
input_ids:
torch.Size([1, micro_bsz * seq_len])
hidden_states:
torch.Size([1, (micro_bsz * seq_len) // tp_size, hidden_size])
attention Calculation Procedure
Before entering the attention computation, if it is the MSP/FSP parallel mode, an All-Gather communication will be used to gather the data that has been split by serialized parallelism. Therefore, during the entire attention computation process, the parameter dimensions of the MTP/MSP/FSP three parallel modes are consistent.
After the attention computation is completed, in the wo layer for linear transformation, if it is the MSP/FSP parallel mode, a Reduce-Scatter communication will be used to integrate the results of the row-wise linear transformation, while also performing serialized parallel operations.
qkv preparation
qkv = self.wqkv(x)
The dimensions of the computed qkv are as follows:
qkv:
torch.Size([1, micro_bsz * seq_len, (hidden_size + 2 * hidden_size // num_attention_heads * num_kv_attention_heads) // tp_size])
Subsequently, the qkv is split into dimensions of [batch_size, seq_len, num_head, group_size, head_dim] and calculates the values for q, k, and v respectively, where tensor parallelism partitioning is performed on the num_head dimension.
qkv:
torch.Size([1, micro_bsz * seq_len, num_kv_attention_heads // tp_size, num_attention_heads // num_kv_attention_heads + 2, hidden_size // num_attention_heads])
q:
torch.Size([1, micro_bsz * seq_len, num_attention_heads // tp_size, hidden_size // num_attention_heads]) # 取qkv中第四个维度前num_attention_heads // num_kv_attention_heads位对应的数值,并将第三维度与第四维度的值组合在一起
k:
torch.Size([1, micro_bsz * seq_len, num_kv_attention_heads // tp_size, hidden_size // num_attention_heads]) # 取qkv中第四个维度倒数第二位对应的数值
v:
torch.Size([1, micro_bsz * seq_len, num_kv_attention_heads // tp_size, hidden_size // num_attention_heads]) # 取qkv中第四个维度倒数第一位对应的数值
Afterward, the values of k and v are combined for subsequent attention computation.
kv:
torch.Size([1, micro_bsz * seq_len, 2, num_kv_attention_heads // tp_size, hidden_size // num_attention_heads])
calculate attention
The process of attention calculation is as follows:
context = self.inner_attn(q, kv)
Here, the attention calculation is performed directly, without the need for AllToAll communication as required in the “ISP” mode.
The dimensions of the computation result are:
context:
torch.Size([1, micro_bsz * seq_len, num_attention_heads // tp_size, hidden_size // num_attention_heads])
Output transformation
The “wo” is called to transform the output results of the attention computation.
In the MTP parallel mode, the dimensions of the output results are as follows:
torch.Size([1, micro_bsz * seq_len, hidden_size])
In the MSP/FSP parallel modes, the dimensions of the output results are as follows:
torch.Size([1, (micro_bsz * seq_len) // tp_size, hidden_size])
feed_forward Calculation Procedure
In the feed-forward network layer, linear transformations are applied to the output results using “w1”, “w2”, and “w3”.
In the MSP/FSP parallel modes, an All-Gather communication is required before the linear transformation layers of w1 and w3. Therefore, the output dimensions are the same across the MTP/MSP/FSP parallel modes.
w1_o:
torch.Size([1, micro_bsz * seq_len, (multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)) // tp_size])
w3_o:
torch.Size([1, micro_bsz * seq_len, (multiple_of * ((int(hidden_size * mlp_ratio) + multiple_of - 1) // multiple_of)) // tp_size])
After the linear transformation through the “w2” layer, in the MSP/FSP parallel modes, a Reduce-Scatter communication is necessary.
In the MTP parallel mode, the dimensions of the output results are as follows:
out = self.w2(Silu(w1_o, w3_o))
out:
torch.Size([1, micro_bsz * seq_len, hidden_size])
In the MSP/FSP parallel modes, the dimensions of the output results are as follows:
out = self.w2(Silu(w1_o, w3_o))
out:
torch.Size([1, (micro_bsz * seq_len) // tp_size, hidden_size])
norm Calculation Procedure
The result dimensions remain unchanged after the computation through the norm layer.
In the MTP parallel mode, the dimensions of the output results are as follows:
hidden_states:
torch.Size([1, micro_bsz * seq_len, hidden_size])
In the MSP/FSP parallel modes, the dimensions of the output results are as follows:
hidden_states:
torch.Size([1, (micro_bsz * seq_len) // tp_size, hidden_size])
output Calculation Procedure
Finally, the last layer of the model is transformed into a format suitable for the final task through the output layer, with the results as follows:
hidden_states:
torch.Size([1, micro_bsz * seq_len, vocab_size])