Skip to content

Commit f999624

Browse files
authored
Merge pull request KhronosGroup#6325 from s-perron/id_overflow12
spirv-opt: Fix crashes in ConvertToHalfPass due to ID overflow
2 parents 148a9fe + eea8df6 commit f999624

File tree

4 files changed

+131
-9
lines changed

4 files changed

+131
-9
lines changed

source/opt/amd_ext_to_khr.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,21 +240,25 @@ bool ReplaceSwizzleInvocations(IRContext* ctx, Instruction* inst,
240240
// This gives the offset in the group of 4 of this invocation.
241241
Instruction* quad_idx = ir_builder.AddBinaryOp(
242242
uint_type_id, spv::Op::OpBitwiseAnd, id->result_id(), quad_mask);
243+
if (quad_idx == nullptr) return false;
243244

244245
// Get the invocation id of the first invocation in the group of 4.
245246
Instruction* quad_ldr =
246247
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseXor,
247248
id->result_id(), quad_idx->result_id());
249+
if (quad_ldr == nullptr) return false;
248250

249251
// Get the offset of the target invocation from the offset vector.
250252
Instruction* my_offset =
251253
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpVectorExtractDynamic,
252254
offset_id, quad_idx->result_id());
255+
if (my_offset == nullptr) return false;
253256

254257
// Determine the index of the invocation to read from.
255258
Instruction* target_inv =
256259
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpIAdd,
257260
quad_ldr->result_id(), my_offset->result_id());
261+
if (target_inv == nullptr) return false;
258262

259263
// Do the group operations
260264
uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
@@ -368,13 +372,17 @@ bool ReplaceSwizzleInvocationsMasked(
368372
uint32_t mask_extended = ir_builder.GetUintConstantId(0xFFFFFFE0);
369373
Instruction* and_mask = ir_builder.AddBinaryOp(
370374
uint_type_id, spv::Op::OpBitwiseOr, uint_x, mask_extended);
375+
if (and_mask == nullptr) return false;
371376
Instruction* and_result =
372377
ir_builder.AddBinaryOp(uint_type_id, spv::Op::OpBitwiseAnd,
373378
id->result_id(), and_mask->result_id());
379+
if (and_result == nullptr) return false;
374380
Instruction* or_result = ir_builder.AddBinaryOp(
375381
uint_type_id, spv::Op::OpBitwiseOr, and_result->result_id(), uint_y);
382+
if (or_result == nullptr) return false;
376383
Instruction* target_inv = ir_builder.AddBinaryOp(
377384
uint_type_id, spv::Op::OpBitwiseXor, or_result->result_id(), uint_z);
385+
if (target_inv == nullptr) return false;
378386

379387
// Do the group operations
380388
uint32_t uint_max_id = ir_builder.GetUintConstantId(0xFFFFFFFF);
@@ -449,6 +457,7 @@ bool ReplaceWriteInvocation(IRContext* ctx, Instruction* inst,
449457
Instruction* cmp =
450458
ir_builder.AddBinaryOp(bool_type_id, spv::Op::OpIEqual, t->result_id(),
451459
inst->GetSingleWordInOperand(4));
460+
if (cmp == nullptr) return false;
452461

453462
// Build a select.
454463
inst->SetOpcode(spv::Op::OpSelect);
@@ -524,6 +533,7 @@ bool ReplaceMbcnt(IRContext* context, Instruction* inst,
524533
Instruction* t =
525534
ir_builder.AddBinaryOp(mask_inst->type_id(), spv::Op::OpBitwiseAnd,
526535
bitcast->result_id(), mask_id);
536+
if (t == nullptr) return false;
527537

528538
inst->SetOpcode(spv::Op::OpBitCount);
529539
inst->SetInOperands({{SPV_OPERAND_TYPE_ID, {t->result_id()}}});
@@ -631,10 +641,13 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
631641
// Find which values are negative. Used in later computations.
632642
Instruction* is_z_neg = ir_builder.AddBinaryOp(
633643
bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
644+
if (is_z_neg == nullptr) return false;
634645
Instruction* is_y_neg = ir_builder.AddBinaryOp(
635646
bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
647+
if (is_y_neg == nullptr) return false;
636648
Instruction* is_x_neg = ir_builder.AddBinaryOp(
637649
bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
650+
if (is_x_neg == nullptr) return false;
638651

639652
// Compute cubema
640653
Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
@@ -645,19 +658,23 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
645658
{az->result_id(), amax_x_y->result_id()});
646659
Instruction* cubema = ir_builder.AddBinaryOp(float_type_id, spv::Op::OpFMul,
647660
f2_const_id, amax->result_id());
661+
if (cubema == nullptr) return false;
648662

649663
// Do the comparisons needed for computing cubesc and cubetc.
650664
Instruction* is_z_max =
651665
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
652666
az->result_id(), amax_x_y->result_id());
667+
if (is_z_max == nullptr) return false;
653668
Instruction* not_is_z_max = ir_builder.AddUnaryOp(
654669
bool_id, spv::Op::OpLogicalNot, is_z_max->result_id());
655670
Instruction* y_gr_x =
656671
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
657672
ay->result_id(), ax->result_id());
673+
if (y_gr_x == nullptr) return false;
658674
Instruction* is_y_max =
659675
ir_builder.AddBinaryOp(bool_id, spv::Op::OpLogicalAnd,
660676
not_is_z_max->result_id(), y_gr_x->result_id());
677+
if (is_y_max == nullptr) return false;
661678

662679
// Select the correct value for cubesc.
663680
// TODO(1841): Handle id overflow.
@@ -691,6 +708,7 @@ bool ReplaceCubeFaceCoord(IRContext* ctx, Instruction* inst,
691708
v2_float_type_id, {cubema->result_id(), cubema->result_id()});
692709
Instruction* div = ir_builder.AddBinaryOp(
693710
v2_float_type_id, spv::Op::OpFDiv, cube->result_id(), denom->result_id());
711+
if (div == nullptr) return false;
694712

695713
// Get the final result by adding 0.5 to |div|.
696714
inst->SetOpcode(spv::Op::OpFAdd);
@@ -780,10 +798,13 @@ bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
780798
// Find which values are negative. Used in later computations.
781799
Instruction* is_z_neg = ir_builder.AddBinaryOp(
782800
bool_id, spv::Op::OpFOrdLessThan, z->result_id(), f0_const_id);
801+
if (is_z_neg == nullptr) return false;
783802
Instruction* is_y_neg = ir_builder.AddBinaryOp(
784803
bool_id, spv::Op::OpFOrdLessThan, y->result_id(), f0_const_id);
804+
if (is_y_neg == nullptr) return false;
785805
Instruction* is_x_neg = ir_builder.AddBinaryOp(
786806
bool_id, spv::Op::OpFOrdLessThan, x->result_id(), f0_const_id);
807+
if (is_x_neg == nullptr) return false;
787808

788809
// Find the max value.
789810
Instruction* amax_x_y = ir_builder.AddNaryExtendedInstruction(
@@ -792,9 +813,11 @@ bool ReplaceCubeFaceIndex(IRContext* ctx, Instruction* inst,
792813
Instruction* is_z_max =
793814
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
794815
az->result_id(), amax_x_y->result_id());
816+
if (is_z_max == nullptr) return false;
795817
Instruction* y_gr_x =
796818
ir_builder.AddBinaryOp(bool_id, spv::Op::OpFOrdGreaterThanEqual,
797819
ay->result_id(), ax->result_id());
820+
if (y_gr_x == nullptr) return false;
798821

799822
// Get the value for each case.
800823
// TODO(1841): Handle id overflow.

source/opt/convert_to_half_pass.cpp

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
7575
analysis::Type* ConvertToHalfPass::FloatVectorType(uint32_t v_len,
7676
uint32_t width) {
7777
analysis::Type* reg_float_ty = FloatScalarType(width);
78+
if (reg_float_ty == nullptr) {
79+
return nullptr;
80+
}
7881
analysis::Vector vec_ty(reg_float_ty, v_len);
7982
return context()->get_type_mgr()->GetRegisteredType(&vec_ty);
8083
}
@@ -85,6 +88,9 @@ analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
8588
Instruction* vty_inst = get_def_use_mgr()->GetDef(vty_id);
8689
uint32_t v_len = vty_inst->GetSingleWordInOperand(1);
8790
analysis::Type* reg_vec_ty = FloatVectorType(v_len, width);
91+
if (reg_vec_ty == nullptr) {
92+
return nullptr;
93+
}
8894
analysis::Matrix mat_ty(reg_vec_ty, v_cnt);
8995
return context()->get_type_mgr()->GetRegisteredType(&mat_ty);
9096
}
@@ -99,6 +105,9 @@ uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
99105
reg_equiv_ty = FloatVectorType(ty_inst->GetSingleWordInOperand(1), width);
100106
else // spv::Op::OpTypeFloat
101107
reg_equiv_ty = FloatScalarType(width);
108+
if (reg_equiv_ty == nullptr) {
109+
return 0;
110+
}
102111
return context()->get_type_mgr()->GetTypeInstruction(reg_equiv_ty);
103112
}
104113

@@ -107,6 +116,10 @@ void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
107116
Instruction* val_inst = get_def_use_mgr()->GetDef(*val_idp);
108117
uint32_t ty_id = val_inst->type_id();
109118
uint32_t nty_id = EquivFloatTypeId(ty_id, width);
119+
if (nty_id == 0) {
120+
status_ = Status::Failure;
121+
return;
122+
}
110123
if (nty_id == ty_id) return;
111124
Instruction* cvt_inst;
112125
InstructionBuilder builder(
@@ -116,6 +129,10 @@ void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
116129
cvt_inst = builder.AddNullaryOp(nty_id, spv::Op::OpUndef);
117130
else
118131
cvt_inst = builder.AddUnaryOp(nty_id, spv::Op::OpFConvert, *val_idp);
132+
if (cvt_inst == nullptr) {
133+
status_ = Status::Failure;
134+
return;
135+
}
119136
*val_idp = cvt_inst->result_id();
120137
}
121138

@@ -137,22 +154,43 @@ bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
137154
uint32_t orig_width = (cty_inst->GetSingleWordInOperand(0) == 16) ? 32 : 16;
138155
uint32_t orig_mat_id = inst->GetSingleWordInOperand(0);
139156
uint32_t orig_vty_id = EquivFloatTypeId(vty_id, orig_width);
157+
if (orig_vty_id == 0) {
158+
status_ = Status::Failure;
159+
return false;
160+
}
140161
std::vector<Operand> opnds = {};
141162
for (uint32_t vidx = 0; vidx < v_cnt; ++vidx) {
142163
Instruction* ext_inst = builder.AddIdLiteralOp(
143164
orig_vty_id, spv::Op::OpCompositeExtract, orig_mat_id, vidx);
165+
if (ext_inst == nullptr) {
166+
status_ = Status::Failure;
167+
return false;
168+
}
144169
Instruction* cvt_inst =
145170
builder.AddUnaryOp(vty_id, spv::Op::OpFConvert, ext_inst->result_id());
171+
if (cvt_inst == nullptr) {
172+
status_ = Status::Failure;
173+
return false;
174+
}
146175
opnds.push_back({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id()}});
147176
}
148177
uint32_t mat_id = TakeNextId();
178+
if (mat_id == 0) {
179+
status_ = Status::Failure;
180+
return false;
181+
}
149182
std::unique_ptr<Instruction> mat_inst(new Instruction(
150183
context(), spv::Op::OpCompositeConstruct, mty_id, mat_id, opnds));
151184
(void)builder.AddInstruction(std::move(mat_inst));
152185
context()->ReplaceAllUsesWith(inst->result_id(), mat_id);
153186
// Turn original instruction into copy so it is valid.
187+
uint32_t new_type_id = EquivFloatTypeId(mty_id, orig_width);
188+
if (new_type_id == 0) {
189+
status_ = Status::Failure;
190+
return false;
191+
}
154192
inst->SetOpcode(spv::Op::OpCopyObject);
155-
inst->SetResultType(EquivFloatTypeId(mty_id, orig_width));
193+
inst->SetResultType(new_type_id);
156194
get_def_use_mgr()->AnalyzeInstUse(inst);
157195
return true;
158196
}
@@ -187,13 +225,24 @@ bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
187225
// Convert all float32 based operands to float16 equivalent and change
188226
// instruction type to float16 equivalent.
189227
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
228+
if (status_ == Status::Failure) {
229+
return;
230+
}
190231
Instruction* op_inst = get_def_use_mgr()->GetDef(*idp);
191232
if (!IsFloat(op_inst, 32)) return;
192233
GenConvert(idp, 16, inst);
193234
modified = true;
194235
});
236+
if (status_ == Status::Failure) {
237+
return false;
238+
}
195239
if (IsFloat(inst, 32)) {
196-
inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
240+
uint32_t new_type_id = EquivFloatTypeId(inst->type_id(), 16);
241+
if (new_type_id == 0) {
242+
status_ = Status::Failure;
243+
return false;
244+
}
245+
inst->SetResultType(new_type_id);
197246
converted_ids_.insert(inst->result_id());
198247
modified = true;
199248
}
@@ -211,6 +260,9 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
211260
bool modified = false;
212261
inst->ForEachInId([&ocnt, &prev_idp, &from_width, &to_width, &modified,
213262
this](uint32_t* idp) {
263+
if (status_ == Status::Failure) {
264+
return;
265+
}
214266
if (ocnt % 2 == 0) {
215267
prev_idp = idp;
216268
} else {
@@ -230,8 +282,16 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
230282
}
231283
++ocnt;
232284
});
285+
if (status_ == Status::Failure) {
286+
return false;
287+
}
233288
if (to_width == 16u) {
234-
inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16u));
289+
uint32_t new_type_id = EquivFloatTypeId(inst->type_id(), 16u);
290+
if (new_type_id == 0) {
291+
status_ = Status::Failure;
292+
return false;
293+
}
294+
inst->SetResultType(new_type_id);
235295
converted_ids_.insert(inst->result_id());
236296
modified = true;
237297
}
@@ -242,7 +302,12 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
242302
bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
243303
// If float32 and relaxed, change to float16 convert
244304
if (IsFloat(inst, 32) && IsRelaxed(inst->result_id())) {
245-
inst->SetResultType(EquivFloatTypeId(inst->type_id(), 16));
305+
uint32_t new_type_id = EquivFloatTypeId(inst->type_id(), 16);
306+
if (new_type_id == 0) {
307+
status_ = Status::Failure;
308+
return false;
309+
}
310+
inst->SetResultType(new_type_id);
246311
get_def_use_mgr()->AnalyzeInstUse(inst);
247312
converted_ids_.insert(inst->result_id());
248313
}
@@ -255,7 +320,7 @@ bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
255320
Instruction* val_inst = get_def_use_mgr()->GetDef(val_id);
256321
if (inst->type_id() == val_inst->type_id())
257322
inst->SetOpcode(spv::Op::OpCopyObject);
258-
return true; // modified
323+
return true;
259324
}
260325

261326
bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
@@ -265,6 +330,9 @@ bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
265330
uint32_t dref_id = inst->GetSingleWordInOperand(kImageSampleDrefIdInIdx);
266331
if (converted_ids_.count(dref_id) > 0) {
267332
GenConvert(&dref_id, 32, inst);
333+
if (status_ == Status::Failure) {
334+
return false;
335+
}
268336
inst->SetInOperand(kImageSampleDrefIdInIdx, {dref_id});
269337
get_def_use_mgr()->AnalyzeInstUse(inst);
270338
modified = true;
@@ -279,11 +347,17 @@ bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
279347
if (inst->opcode() == spv::Op::OpPhi) return ProcessPhi(inst, 16u, 32u);
280348
bool modified = false;
281349
inst->ForEachInId([&inst, &modified, this](uint32_t* idp) {
350+
if (status_ == Status::Failure) {
351+
return;
352+
}
282353
if (converted_ids_.count(*idp) == 0) return;
283354
uint32_t old_id = *idp;
284355
GenConvert(idp, 32, inst);
285356
if (*idp != old_id) modified = true;
286357
});
358+
if (status_ == Status::Failure) {
359+
return false;
360+
}
287361
if (modified) get_def_use_mgr()->AnalyzeInstUse(inst);
288362
return modified;
289363
}
@@ -370,19 +444,38 @@ bool ConvertToHalfPass::ProcessFunction(Function* func) {
370444
});
371445
// Replace invalid converts of matrix into equivalent vector extracts,
372446
// converts and finally a composite construct
447+
bool ok = true;
373448
cfg()->ForEachBlockInReversePostOrder(
374-
func->entry().get(), [&modified, this](BasicBlock* bb) {
375-
for (auto ii = bb->begin(); ii != bb->end(); ++ii)
376-
modified |= MatConvertCleanup(&*ii);
449+
func->entry().get(), [&modified, &ok, this](BasicBlock* bb) {
450+
if (!ok) {
451+
return;
452+
}
453+
for (auto ii = bb->begin(); ii != bb->end(); ++ii) {
454+
bool Mmodified = MatConvertCleanup(&*ii);
455+
if (status_ == Status::Failure) {
456+
ok = false;
457+
break;
458+
}
459+
modified |= Mmodified;
460+
}
377461
});
462+
463+
if (!ok) {
464+
return false;
465+
}
378466
return modified;
379467
}
380468

381469
Pass::Status ConvertToHalfPass::ProcessImpl() {
470+
status_ = Status::SuccessWithoutChange;
382471
Pass::ProcessFunction pfn = [this](Function* fp) {
383472
return ProcessFunction(fp);
384473
};
385474
bool modified = context()->ProcessReachableCallTree(pfn);
475+
if (status_ == Status::Failure) {
476+
return status_;
477+
}
478+
386479
// If modified, make sure module has Float16 capability
387480
if (modified) context()->AddCapability(spv::Capability::Float16);
388481
// Remove all RelaxedPrecision decorations from instructions and globals
@@ -514,4 +607,4 @@ void ConvertToHalfPass::Initialize() {
514607
}
515608

516609
} // namespace opt
517-
} // namespace spvtools
610+
} // namespace spvtools

source/opt/convert_to_half_pass.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ class ConvertToHalfPass : public Pass {
130130
}
131131
};
132132

133+
// The status of the pass.
134+
Pass::Status status_;
135+
133136
// Set of core operations to be processed
134137
std::unordered_set<spv::Op, hasher> target_ops_core_;
135138

0 commit comments

Comments
 (0)