@@ -75,6 +75,9 @@ analysis::Type* ConvertToHalfPass::FloatScalarType(uint32_t width) {
7575analysis::Type* ConvertToHalfPass::FloatVectorType (uint32_t v_len,
7676 uint32_t width) {
7777 analysis::Type* reg_float_ty = FloatScalarType (width);
78+ if (reg_float_ty == nullptr ) {
79+ return nullptr ;
80+ }
7881 analysis::Vector vec_ty (reg_float_ty, v_len);
7982 return context ()->get_type_mgr ()->GetRegisteredType (&vec_ty);
8083}
@@ -85,6 +88,9 @@ analysis::Type* ConvertToHalfPass::FloatMatrixType(uint32_t v_cnt,
8588 Instruction* vty_inst = get_def_use_mgr ()->GetDef (vty_id);
8689 uint32_t v_len = vty_inst->GetSingleWordInOperand (1 );
8790 analysis::Type* reg_vec_ty = FloatVectorType (v_len, width);
91+ if (reg_vec_ty == nullptr ) {
92+ return nullptr ;
93+ }
8894 analysis::Matrix mat_ty (reg_vec_ty, v_cnt);
8995 return context ()->get_type_mgr ()->GetRegisteredType (&mat_ty);
9096}
@@ -99,6 +105,9 @@ uint32_t ConvertToHalfPass::EquivFloatTypeId(uint32_t ty_id, uint32_t width) {
99105 reg_equiv_ty = FloatVectorType (ty_inst->GetSingleWordInOperand (1 ), width);
100106 else // spv::Op::OpTypeFloat
101107 reg_equiv_ty = FloatScalarType (width);
108+ if (reg_equiv_ty == nullptr ) {
109+ return 0 ;
110+ }
102111 return context ()->get_type_mgr ()->GetTypeInstruction (reg_equiv_ty);
103112}
104113
@@ -107,6 +116,10 @@ void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
107116 Instruction* val_inst = get_def_use_mgr ()->GetDef (*val_idp);
108117 uint32_t ty_id = val_inst->type_id ();
109118 uint32_t nty_id = EquivFloatTypeId (ty_id, width);
119+ if (nty_id == 0 ) {
120+ status_ = Status::Failure;
121+ return ;
122+ }
110123 if (nty_id == ty_id) return ;
111124 Instruction* cvt_inst;
112125 InstructionBuilder builder (
@@ -116,6 +129,10 @@ void ConvertToHalfPass::GenConvert(uint32_t* val_idp, uint32_t width,
116129 cvt_inst = builder.AddNullaryOp (nty_id, spv::Op::OpUndef);
117130 else
118131 cvt_inst = builder.AddUnaryOp (nty_id, spv::Op::OpFConvert, *val_idp);
132+ if (cvt_inst == nullptr ) {
133+ status_ = Status::Failure;
134+ return ;
135+ }
119136 *val_idp = cvt_inst->result_id ();
120137}
121138
@@ -137,22 +154,43 @@ bool ConvertToHalfPass::MatConvertCleanup(Instruction* inst) {
137154 uint32_t orig_width = (cty_inst->GetSingleWordInOperand (0 ) == 16 ) ? 32 : 16 ;
138155 uint32_t orig_mat_id = inst->GetSingleWordInOperand (0 );
139156 uint32_t orig_vty_id = EquivFloatTypeId (vty_id, orig_width);
157+ if (orig_vty_id == 0 ) {
158+ status_ = Status::Failure;
159+ return false ;
160+ }
140161 std::vector<Operand> opnds = {};
141162 for (uint32_t vidx = 0 ; vidx < v_cnt; ++vidx) {
142163 Instruction* ext_inst = builder.AddIdLiteralOp (
143164 orig_vty_id, spv::Op::OpCompositeExtract, orig_mat_id, vidx);
165+ if (ext_inst == nullptr ) {
166+ status_ = Status::Failure;
167+ return false ;
168+ }
144169 Instruction* cvt_inst =
145170 builder.AddUnaryOp (vty_id, spv::Op::OpFConvert, ext_inst->result_id ());
171+ if (cvt_inst == nullptr ) {
172+ status_ = Status::Failure;
173+ return false ;
174+ }
146175 opnds.push_back ({SPV_OPERAND_TYPE_ID, {cvt_inst->result_id ()}});
147176 }
148177 uint32_t mat_id = TakeNextId ();
178+ if (mat_id == 0 ) {
179+ status_ = Status::Failure;
180+ return false ;
181+ }
149182 std::unique_ptr<Instruction> mat_inst (new Instruction (
150183 context (), spv::Op::OpCompositeConstruct, mty_id, mat_id, opnds));
151184 (void )builder.AddInstruction (std::move (mat_inst));
152185 context ()->ReplaceAllUsesWith (inst->result_id (), mat_id);
153186 // Turn original instruction into copy so it is valid.
187+ uint32_t new_type_id = EquivFloatTypeId (mty_id, orig_width);
188+ if (new_type_id == 0 ) {
189+ status_ = Status::Failure;
190+ return false ;
191+ }
154192 inst->SetOpcode (spv::Op::OpCopyObject);
155- inst->SetResultType (EquivFloatTypeId (mty_id, orig_width) );
193+ inst->SetResultType (new_type_id );
156194 get_def_use_mgr ()->AnalyzeInstUse (inst);
157195 return true ;
158196}
@@ -187,13 +225,24 @@ bool ConvertToHalfPass::GenHalfArith(Instruction* inst) {
187225 // Convert all float32 based operands to float16 equivalent and change
188226 // instruction type to float16 equivalent.
189227 inst->ForEachInId ([&inst, &modified, this ](uint32_t * idp) {
228+ if (status_ == Status::Failure) {
229+ return ;
230+ }
190231 Instruction* op_inst = get_def_use_mgr ()->GetDef (*idp);
191232 if (!IsFloat (op_inst, 32 )) return ;
192233 GenConvert (idp, 16 , inst);
193234 modified = true ;
194235 });
236+ if (status_ == Status::Failure) {
237+ return false ;
238+ }
195239 if (IsFloat (inst, 32 )) {
196- inst->SetResultType (EquivFloatTypeId (inst->type_id (), 16 ));
240+ uint32_t new_type_id = EquivFloatTypeId (inst->type_id (), 16 );
241+ if (new_type_id == 0 ) {
242+ status_ = Status::Failure;
243+ return false ;
244+ }
245+ inst->SetResultType (new_type_id);
197246 converted_ids_.insert (inst->result_id ());
198247 modified = true ;
199248 }
@@ -211,6 +260,9 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
211260 bool modified = false ;
212261 inst->ForEachInId ([&ocnt, &prev_idp, &from_width, &to_width, &modified,
213262 this ](uint32_t * idp) {
263+ if (status_ == Status::Failure) {
264+ return ;
265+ }
214266 if (ocnt % 2 == 0 ) {
215267 prev_idp = idp;
216268 } else {
@@ -230,8 +282,16 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
230282 }
231283 ++ocnt;
232284 });
285+ if (status_ == Status::Failure) {
286+ return false ;
287+ }
233288 if (to_width == 16u ) {
234- inst->SetResultType (EquivFloatTypeId (inst->type_id (), 16u ));
289+ uint32_t new_type_id = EquivFloatTypeId (inst->type_id (), 16u );
290+ if (new_type_id == 0 ) {
291+ status_ = Status::Failure;
292+ return false ;
293+ }
294+ inst->SetResultType (new_type_id);
235295 converted_ids_.insert (inst->result_id ());
236296 modified = true ;
237297 }
@@ -242,7 +302,12 @@ bool ConvertToHalfPass::ProcessPhi(Instruction* inst, uint32_t from_width,
242302bool ConvertToHalfPass::ProcessConvert (Instruction* inst) {
243303 // If float32 and relaxed, change to float16 convert
244304 if (IsFloat (inst, 32 ) && IsRelaxed (inst->result_id ())) {
245- inst->SetResultType (EquivFloatTypeId (inst->type_id (), 16 ));
305+ uint32_t new_type_id = EquivFloatTypeId (inst->type_id (), 16 );
306+ if (new_type_id == 0 ) {
307+ status_ = Status::Failure;
308+ return false ;
309+ }
310+ inst->SetResultType (new_type_id);
246311 get_def_use_mgr ()->AnalyzeInstUse (inst);
247312 converted_ids_.insert (inst->result_id ());
248313 }
@@ -255,7 +320,7 @@ bool ConvertToHalfPass::ProcessConvert(Instruction* inst) {
255320 Instruction* val_inst = get_def_use_mgr ()->GetDef (val_id);
256321 if (inst->type_id () == val_inst->type_id ())
257322 inst->SetOpcode (spv::Op::OpCopyObject);
258- return true ; // modified
323+ return true ;
259324}
260325
261326bool ConvertToHalfPass::ProcessImageRef (Instruction* inst) {
@@ -265,6 +330,9 @@ bool ConvertToHalfPass::ProcessImageRef(Instruction* inst) {
265330 uint32_t dref_id = inst->GetSingleWordInOperand (kImageSampleDrefIdInIdx );
266331 if (converted_ids_.count (dref_id) > 0 ) {
267332 GenConvert (&dref_id, 32 , inst);
333+ if (status_ == Status::Failure) {
334+ return false ;
335+ }
268336 inst->SetInOperand (kImageSampleDrefIdInIdx , {dref_id});
269337 get_def_use_mgr ()->AnalyzeInstUse (inst);
270338 modified = true ;
@@ -279,11 +347,17 @@ bool ConvertToHalfPass::ProcessDefault(Instruction* inst) {
279347 if (inst->opcode () == spv::Op::OpPhi) return ProcessPhi (inst, 16u , 32u );
280348 bool modified = false ;
281349 inst->ForEachInId ([&inst, &modified, this ](uint32_t * idp) {
350+ if (status_ == Status::Failure) {
351+ return ;
352+ }
282353 if (converted_ids_.count (*idp) == 0 ) return ;
283354 uint32_t old_id = *idp;
284355 GenConvert (idp, 32 , inst);
285356 if (*idp != old_id) modified = true ;
286357 });
358+ if (status_ == Status::Failure) {
359+ return false ;
360+ }
287361 if (modified) get_def_use_mgr ()->AnalyzeInstUse (inst);
288362 return modified;
289363}
@@ -370,19 +444,38 @@ bool ConvertToHalfPass::ProcessFunction(Function* func) {
370444 });
371445 // Replace invalid converts of matrix into equivalent vector extracts,
372446 // converts and finally a composite construct
447+ bool ok = true ;
373448 cfg ()->ForEachBlockInReversePostOrder (
374- func->entry ().get (), [&modified, this ](BasicBlock* bb) {
375- for (auto ii = bb->begin (); ii != bb->end (); ++ii)
376- modified |= MatConvertCleanup (&*ii);
449+ func->entry ().get (), [&modified, &ok, this ](BasicBlock* bb) {
450+ if (!ok) {
451+ return ;
452+ }
453+ for (auto ii = bb->begin (); ii != bb->end (); ++ii) {
454+ bool Mmodified = MatConvertCleanup (&*ii);
455+ if (status_ == Status::Failure) {
456+ ok = false ;
457+ break ;
458+ }
459+ modified |= Mmodified;
460+ }
377461 });
462+
463+ if (!ok) {
464+ return false ;
465+ }
378466 return modified;
379467}
380468
381469Pass::Status ConvertToHalfPass::ProcessImpl () {
470+ status_ = Status::SuccessWithoutChange;
382471 Pass::ProcessFunction pfn = [this ](Function* fp) {
383472 return ProcessFunction (fp);
384473 };
385474 bool modified = context ()->ProcessReachableCallTree (pfn);
475+ if (status_ == Status::Failure) {
476+ return status_;
477+ }
478+
386479 // If modified, make sure module has Float16 capability
387480 if (modified) context ()->AddCapability (spv::Capability::Float16);
388481 // Remove all RelaxedPrecision decorations from instructions and globals
@@ -514,4 +607,4 @@ void ConvertToHalfPass::Initialize() {
514607}
515608
516609} // namespace opt
517- } // namespace spvtools
610+ } // namespace spvtools
0 commit comments