Skip to content

Commit eea8df6

Browse files
committed
spirv-opt: Fix crashes in ConvertToHalfPass due to ID overflow
This pass was crashing when the modules ID bound was reached, causing type creation to fail. This would result in null pointers being passed to the `Vector` and `Matrix` constructors, leading to a segmentation fault. This commit fixes the issue by: - Adding null checks in `FloatVectorType` and `FloatMatrixType` to handle cases where type creation fails. - Using the `status_` member variable to propagate the failure up the call stack, ensuring that the pass fails gracefully.
1 parent 7d5a3d7 commit eea8df6

File tree

4 files changed

+131
-9
lines changed

4 files changed

+131
-9
lines changed

source/opt/amd_ext_to_khr.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,21 +239,25 @@ bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
239239
// This gives the offset in the group of 4 of this invocation.
240240
Instruction* quad_idx = ir_builder.AddBinaryOp(
241241
uint_type_id, spv::Op::OpBitwiseAnd, id->result_id(), quad_mask);
242+
if (quad_idx == nullptr) return false;
242243

243244
// Get the invocation id of the first invocation in the group of 4.
244245
Instruction* quad_ldr =
245246
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseXor,
246247
id->result_id(), quad_idx->result_id());
248+
if (quad_ldr == nullptr) return false;
247249

248250
// Get the offset of the target invocation from the offset vector.
249251
Instruction* my_offset =
250252
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpVectorExtractDynamic,
251253
offset_id, quad_idx->result_id());
254+
if (my_offset == nullptr) return false;
252255

253256
// Determine the index of the invocation to read from.
254257
Instruction* target_inv =
255258
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpIAdd,
256259
quad_ldr->result_id(), my_offset->result_id());
260+
if (target_inv == nullptr) return false;
257261

258262
// Do the group operations
259263
uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
@@ -364,13 +368,17 @@ bool ReplaceSwizzleInvocationsMasked(
364368
uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
365369
Instruction* and_mask = ir_builder.AddBinaryOp(
366370
uint_type_id, spv::Op::OpBitwiseOr, uint_x, mask_extended);
371+
if (and_mask == nullptr) return false;
367372
Instruction* and_result =
368373
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseAnd,
369374
id->result_id(), and_mask->result_id());
375+
if (and_result == nullptr) return false;
370376
Instruction* or_result = ir_builder.AddBinaryOp(
371377
uint_type_id, spv::Op::OpBitwiseOr, and_result->result_id(), uint_y);
378+
if (or_result == nullptr) return false;
372379
Instruction* target_inv = ir_builder.AddBinaryOp(
373380
uint_type_id, spv::Op::OpBitwiseXor, or_result->result_id(), uint_z);
381+
if (target_inv == nullptr) return false;
374382

375383
// Do the group operations
376384
uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
@@ -442,6 +450,7 @@ bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
442450
Instruction* cmp =
443451
ir_builder.AddBinaryOp(bool_type_id, spv::Op::OpIEqual, t->result_id(),
444452
inst->GetSingleWordInOperand(4));
453+
if (cmp == nullptr) return false;
445454

446455
// Build a select.
447456
inst->SetOpcode(spv::Op::OpSelect);
@@ -517,6 +526,7 @@ bool ReplaceMbcnt(IRContext* context, Instruction* inst,
517526
Instruction* t =
518527
ir_builder.AddBinaryOp(mask_inst->type_id(), spv::Op::OpBitwiseAnd,
519528
bitcast->result_id(), mask_id);
529+
if (t == nullptr) return false;
520530

521531
inst->SetOpcode(spv::Op::OpBitCount);
522532
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
@@ -621,10 +631,13 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
621631
// Find which values are negative. Used in later computations.
622632
Instruction* is_z_neg = ir_builder.AddBinaryOp(
623633
bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
634+
if (is_z_neg == nullptr) return false;
624635
Instruction* is_y_neg = ir_builder.AddBinaryOp(
625636
bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
637+
if (is_y_neg == nullptr) return false;
626638
Instruction* is_x_neg = ir_builder.AddBinaryOp(
627639
bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
640+
if (is_x_neg == nullptr) return false;
628641

629642
// Compute cubema
630643
Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
@@ -635,19 +648,23 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
635648
{az->result_id(), amax_x_y->result_id()});
636649
Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, spv::Op::OpFMul,
637650
f2_const_id, amax->result_id());
651+
if (cubema == nullptr) return false;
638652

639653
// Do the comparisons needed for computing cubesc and cubetc.
640654
Instruction* is_z_max =
641655
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
642656
az->result_id(), amax_x_y->result_id());
657+
if (is_z_max == nullptr) return false;
643658
Instruction* not_is_z_max = ir_builder.AddUnaryOp(
644659
bool_id, spv::Op::OpLogicalNot, is_z_max->result_id());
645660
Instruction* y_gr_x =
646661
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
647662
ay->result_id(), ax->result_id());
663+
if (y_gr_x == nullptr) return false;
648664
Instruction* is_y_max =
649665
ir_builder.AddBinaryOp(bool_id, spv::Op::OpLogicalAnd,
650666
not_is_z_max->result_id(), y_gr_x->result_id());
667+
if (is_y_max == nullptr) return false;
651668

652669
// Select the correct value for cubesc.
653670
Instruction* cubesc_case_1 = ir_builder.AddSelect(
@@ -675,6 +692,7 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
675692
v2_float_type_id, {cubema->result_id(), cubema->result_id()});
676693
Instruction* div = ir_builder.AddBinaryOp(
677694
v2_float_type_id, spv::Op::OpFDiv, cube->result_id(), denom->result_id());
695+
if (div == nullptr) return false;
678696

679697
// Get the final result by adding 0.5 to |div|.
680698
inst->SetOpcode(spv::Op::OpFAdd);
@@ -761,10 +779,13 @@ bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
761779
// Find which values are negative. Used in later computations.
762780
Instruction* is_z_neg = ir_builder.AddBinaryOp(
763781
bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
782+
if (is_z_neg == nullptr) return false;
764783
Instruction* is_y_neg = ir_builder.AddBinaryOp(
765784
bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
785+
if (is_y_neg == nullptr) return false;
766786
Instruction* is_x_neg = ir_builder.AddBinaryOp(
767787
bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
788+
if (is_x_neg == nullptr) return false;
768789

769790
// Find the max value.
770791
Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
@@ -773,9 +794,11 @@ bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
773794
Instruction* is_z_max =
774795
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
775796
az->result_id(), amax_x_y->result_id());
797+
if (is_z_max == nullptr) return false;
776798
Instruction* y_gr_x =
777799
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
778800
ay->result_id(), ax->result_id());
801+
if (y_gr_x == nullptr) return false;
779802

780803
// Get the value for each case.
781804
Instruction* case_z = ir_builder.AddSelect(

source/opt/convert_to_half_pass.cpp

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
7575
analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
7676
uint32_t width) {
7777
analysis::Type* reg_float_ty = FloatScalarType(width);
78+
if (reg_float_ty == nullptr) {
79+
return nullptr;
80+
}
7881
analysis::Vector vec_ty(reg_float_ty, v_len);
7982
return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
8083
}
@@ -85,6 +88,9 @@ analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
8588
Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
8689
uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
8790
analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
91+
if (reg_vec_ty == nullptr) {
92+
return nullptr;
93+
}
8894
analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
8995
return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
9096
}
@@ -99,6 +105,9 @@ uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
99105
reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
100106
else // spv::Op::OpTypeFloat
101107
reg_equiv_ty = FloatScalarType(width);
108+
if (reg_equiv_ty == nullptr) {
109+
return 0;
110+
}
102111
return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
103112
}
104113

@@ -107,6 +116,10 @@ void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
107116
Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
108117
uint32_t ty_id = val_inst->type_id();
109118
uint32_t nty_id = EquivFloatTypeId(ty_id, width);
119+
if (nty_id == 0) {
120+
status_ = Status::Failure;
121+
return;
122+
}
110123
if (nty_id == ty_id) return;
111124
Instruction* cvt_inst;
112125
InstructionBuilder builder(
@@ -116,6 +129,10 @@ void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
116129
cvt_inst = builder.AddNullaryOp(nty_id, spv::Op::OpUndef);
117130
else
118131
cvt_inst = builder.AddUnaryOp(nty_id, spv::Op::OpFConvert, *val_idp);
132+
if (cvt_inst == nullptr) {
133+
status_ = Status::Failure;
134+
return;
135+
}
119136
*val_idp = cvt_inst->result_id();
120137
}
121138

@@ -137,22 +154,43 @@ bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
137154
uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
138155
uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
139156
uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
157+
if (orig_vty_id == 0) {
158+
status_ = Status::Failure;
159+
return false;
160+
}
140161
std::vector<Operand> opnds = {};
141162
for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
142163
Instruction* ext_inst = builder.AddIdLiteralOp(
143164
orig_vty_id, spv::Op::OpCompositeExtract, orig_mat_id, vidx);
165+
if (ext_inst == nullptr) {
166+
status_ = Status::Failure;
167+
return false;
168+
}
144169
Instruction* cvt_inst =
145170
builder.AddUnaryOp(vty_id, spv::Op::OpFConvert, ext_inst->result_id());
171+
if (cvt_inst == nullptr) {
172+
status_ = Status::Failure;
173+
return false;
174+
}
146175
opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
147176
}
148177
uint32_t mat_id = TakeNextId();
178+
if (mat_id == 0) {
179+
status_ = Status::Failure;
180+
return false;
181+
}
149182
std::unique_ptr<Instruction> mat_inst(new Instruction(
150183
context(), spv::Op::OpCompositeConstruct, mty_id, mat_id, opnds));
151184
(void)builder.AddInstruction(std::move(mat_inst));
152185
context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
153186
// Turn original instruction into copy so it is valid.
187+
uint32_t new_type_id = EquivFloatTypeId(mty_id, orig_width);
188+
if (new_type_id == 0) {
189+
status_ = Status::Failure;
190+
return false;
191+
}
154192
inst->SetOpcode(spv::Op::OpCopyObject);
155-
inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
193+
inst->SetResultType(new_type_id);
156194
get_def_use_mgr()->AnalyzeInstUse(inst);
157195
return true;
158196
}
@@ -187,13 +225,24 @@ bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
187225
// Convert all float32 based operands to float16 equivalent and change
188226
// instruction type to float16 equivalent.
189227
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
228+
if (status_ == Status::Failure) {
229+
return;
230+
}
190231
Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
191232
if (!IsFloat(op_inst, 32)) return;
192233
GenConvert(idp, 16, inst);
193234
modified = true;
194235
});
236+
if (status_ == Status::Failure) {
237+
return false;
238+
}
195239
if (IsFloat(inst, 32)) {
196-
inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
240+
uint32_t new_type_id = EquivFloatTypeId(inst->type_id(), 16);
241+
if (new_type_id == 0) {
242+
status_ = Status::Failure;
243+
return false;
244+
}
245+
inst->SetResultType(new_type_id);
197246
converted_ids_.insert(inst->result_id());
198247
modified = true;
199248
}
@@ -211,6 +260,9 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
211260
bool modified = false;
212261
inst->ForEachInId([&ocnt, &prev_idp, &from_width, &to_width, &modified,
213262
this](uint32_t* idp) {
263+
if (status_ == Status::Failure) {
264+
return;
265+
}
214266
if (ocnt % 2 == 0) {
215267
prev_idp = idp;
216268
} else {
@@ -230,8 +282,16 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
230282
}
231283
++ocnt;
232284
});
285+
if (status_ == Status::Failure) {
286+
return false;
287+
}
233288
if (to_width == 16u) {
234-
inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16u));
289+
uint32_t new_type_id = EquivFloatTypeId(inst->type_id(), 16u);
290+
if (new_type_id == 0) {
291+
status_ = Status::Failure;
292+
return false;
293+
}
294+
inst->SetResultType(new_type_id);
235295
converted_ids_.insert(inst->result_id());
236296
modified = true;
237297
}
@@ -242,7 +302,12 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
242302
bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
243303
// If float32 and relaxed, change to float16 convert
244304
if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
245-
inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
305+
uint32_t new_type_id = EquivFloatTypeId(inst->type_id(), 16);
306+
if (new_type_id == 0) {
307+
status_ = Status::Failure;
308+
return false;
309+
}
310+
inst->SetResultType(new_type_id);
246311
get_def_use_mgr()->AnalyzeInstUse(inst);
247312
converted_ids_.insert(inst->result_id());
248313
}
@@ -255,7 +320,7 @@ bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
255320
Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
256321
if (inst->type_id() == val_inst->type_id())
257322
inst->SetOpcode(spv::Op::OpCopyObject);
258-
return true; // modified
323+
return true;
259324
}
260325

261326
bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
@@ -265,6 +330,9 @@ bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
265330
uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
266331
if (converted_ids_.count(dref_id) > 0) {
267332
GenConvert(&dref_id, 32, inst);
333+
if (status_ == Status::Failure) {
334+
return false;
335+
}
268336
inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
269337
get_def_use_mgr()->AnalyzeInstUse(inst);
270338
modified = true;
@@ -279,11 +347,17 @@ bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
279347
if (inst->opcode() == spv::Op::OpPhi) return ProcessPhi(inst, 16u, 32u);
280348
bool modified = false;
281349
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
350+
if (status_ == Status::Failure) {
351+
return;
352+
}
282353
if (converted_ids_.count(*idp) == 0) return;
283354
uint32_t old_id = *idp;
284355
GenConvert(idp, 32, inst);
285356
if (*idp != old_id) modified = true;
286357
});
358+
if (status_ == Status::Failure) {
359+
return false;
360+
}
287361
if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
288362
return modified;
289363
}
@@ -370,19 +444,38 @@ bool ConvertToHalfPass::ProcessFunction(Function* func) {
370444
});
371445
// Replace invalid converts of matrix into equivalent vector extracts,
372446
// converts and finally a composite construct
447+
bool ok = true;
373448
cfg()->ForEachBlockInReversePostOrder(
374-
func->entry().get(), [&modified, this](BasicBlock* bb) {
375-
for (auto ii = bb->begin(); ii != bb->end(); ++ii)
376-
modified |= MatConvertCleanup(&*ii);
449+
func->entry().get(), [&modified, &ok, this](BasicBlock* bb) {
450+
if (!ok) {
451+
return;
452+
}
453+
for (auto ii = bb->begin(); ii != bb->end(); ++ii) {
454+
bool Mmodified = MatConvertCleanup(&*ii);
455+
if (status_ == Status::Failure) {
456+
ok = false;
457+
break;
458+
}
459+
modified |= Mmodified;
460+
}
377461
});
462+
463+
if (!ok) {
464+
return false;
465+
}
378466
return modified;
379467
}
380468

381469
Pass::Status ConvertToHalfPass::ProcessImpl() {
470+
status_ = Status::SuccessWithoutChange;
382471
Pass::ProcessFunction pfn = [this](Function* fp) {
383472
return ProcessFunction(fp);
384473
};
385474
bool modified = context()->ProcessReachableCallTree(pfn);
475+
if (status_ == Status::Failure) {
476+
return status_;
477+
}
478+
386479
// If modified, make sure module has Float16 capability
387480
if (modified) context()->AddCapability(spv::Capability::Float16);
388481
// Remove all RelaxedPrecision decorations from instructions and globals
@@ -514,4 +607,4 @@ void ConvertToHalfPass::Initialize() {
514607
}
515608

516609
} // namespace opt
517-
} // namespace spvtools
610+
} // namespace spvtools

source/opt/convert_to_half_pass.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ class ConvertToHalfPass : public Pass {
130130
}
131131
};
132132

133+
// The status of the pass.
134+
Pass::Status status_;
135+
133136
// Set of core operations to be processed
134137
std::unordered_set<spv::Op, hasher> target_ops_core_;
135138

0 commit comments

Comments
 (0)