MLIR 22.0.0git
LLVMToLLVMIRTranslation.cpp
Go to the documentation of this file.
1//===- LLVMToLLVMIRTranslation.cpp - Translate LLVM dialect to LLVM IR ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR LLVM dialect and LLVM IR.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Operation.h"
17#include "mlir/Support/LLVM.h"
19
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/IR/DIBuilder.h"
22#include "llvm/IR/IRBuilder.h"
23#include "llvm/IR/InlineAsm.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/IR/MDBuilder.h"
26#include "llvm/IR/MatrixBuilder.h"
27#include "llvm/IR/MemoryModelRelaxationAnnotations.h"
28#include "llvm/Support/LogicalResult.h"
29
30using namespace mlir;
31using namespace mlir::LLVM;
33
34#include "mlir/Dialect/LLVMIR/LLVMConversionEnumsToLLVM.inc"
35
36static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op) {
37 using llvmFMF = llvm::FastMathFlags;
38 using FuncT = void (llvmFMF::*)(bool);
39 const std::pair<FastmathFlags, FuncT> handlers[] = {
40 // clang-format off
41 {FastmathFlags::nnan, &llvmFMF::setNoNaNs},
42 {FastmathFlags::ninf, &llvmFMF::setNoInfs},
43 {FastmathFlags::nsz, &llvmFMF::setNoSignedZeros},
44 {FastmathFlags::arcp, &llvmFMF::setAllowReciprocal},
45 {FastmathFlags::contract, &llvmFMF::setAllowContract},
46 {FastmathFlags::afn, &llvmFMF::setApproxFunc},
47 {FastmathFlags::reassoc, &llvmFMF::setAllowReassoc},
48 // clang-format on
49 };
50 llvm::FastMathFlags ret;
51 ::mlir::LLVM::FastmathFlags fmfMlir = op.getFastmathAttr().getValue();
52 for (auto it : handlers)
53 if (bitEnumContainsAll(fmfMlir, it.first))
54 (ret.*(it.second))(true);
55 return ret;
56}
57
58/// Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
60 SmallVector<unsigned> position;
61 llvm::append_range(position, indices);
62 return position;
63}
64
65/// Convert an LLVM type to a string for printing in diagnostics.
66static std::string diagStr(const llvm::Type *type) {
67 std::string str;
68 llvm::raw_string_ostream os(str);
69 type->print(os);
70 return str;
71}
72
73/// Get the declaration of an overloaded llvm intrinsic. First we get the
74/// overloaded argument types and/or result type from the CallIntrinsicOp, and
75/// then use those to get the correct declaration of the overloaded intrinsic.
76static FailureOr<llvm::Function *>
77getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id,
78 llvm::Module *module,
79 LLVM::ModuleTranslation &moduleTranslation) {
81 for (Type type : op->getOperandTypes())
82 allArgTys.push_back(moduleTranslation.convertType(type));
83
84 llvm::Type *resTy;
85 if (op.getNumResults() == 0)
86 resTy = llvm::Type::getVoidTy(module->getContext());
87 else
88 resTy = moduleTranslation.convertType(op.getResult(0).getType());
89
90 // ATM we do not support variadic intrinsics.
91 llvm::FunctionType *ft = llvm::FunctionType::get(resTy, allArgTys, false);
92
94 getIntrinsicInfoTableEntries(id, table);
96
97 SmallVector<llvm::Type *, 8> overloadedArgTys;
98 if (llvm::Intrinsic::matchIntrinsicSignature(ft, tableRef,
99 overloadedArgTys) !=
100 llvm::Intrinsic::MatchIntrinsicTypesResult::MatchIntrinsicTypes_Match) {
101 return mlir::emitError(op.getLoc(), "call intrinsic signature ")
102 << diagStr(ft) << " to overloaded intrinsic " << op.getIntrinAttr()
103 << " does not match any of the overloads";
104 }
105
106 ArrayRef<llvm::Type *> overloadedArgTysRef = overloadedArgTys;
107 return llvm::Intrinsic::getOrInsertDeclaration(module, id,
108 overloadedArgTysRef);
109}
110
111static llvm::OperandBundleDef
112convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag,
113 LLVM::ModuleTranslation &moduleTranslation) {
114 std::vector<llvm::Value *> operands;
115 operands.reserve(bundleOperands.size());
116 for (Value bundleArg : bundleOperands)
117 operands.push_back(moduleTranslation.lookupValue(bundleArg));
118 return llvm::OperandBundleDef(bundleTag.str(), std::move(operands));
119}
120
123 LLVM::ModuleTranslation &moduleTranslation) {
125 bundles.reserve(bundleOperands.size());
126
127 for (auto [operands, tagAttr] : llvm::zip_equal(bundleOperands, bundleTags)) {
128 StringRef tag = cast<StringAttr>(tagAttr).getValue();
129 bundles.push_back(convertOperandBundle(operands, tag, moduleTranslation));
130 }
131 return bundles;
132}
133
136 std::optional<ArrayAttr> bundleTags,
137 LLVM::ModuleTranslation &moduleTranslation) {
138 if (!bundleTags)
139 return {};
140 return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
141}
142
143/// Builder for LLVM_CallIntrinsicOp
144static LogicalResult
145convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
146 LLVM::ModuleTranslation &moduleTranslation) {
147 llvm::Module *module = builder.GetInsertBlock()->getModule();
148 llvm::Intrinsic::ID id =
149 llvm::Intrinsic::lookupIntrinsicID(op.getIntrinAttr());
150 if (!id)
151 return mlir::emitError(op.getLoc(), "could not find LLVM intrinsic: ")
152 << op.getIntrinAttr();
153
154 llvm::Function *fn = nullptr;
155 if (llvm::Intrinsic::isOverloaded(id)) {
156 auto fnOrFailure =
157 getOverloadedDeclaration(op, id, module, moduleTranslation);
158 if (failed(fnOrFailure))
159 return failure();
160 fn = *fnOrFailure;
161 } else {
162 fn = llvm::Intrinsic::getOrInsertDeclaration(module, id, {});
163 }
164
165 // Check the result type of the call.
166 const llvm::Type *intrinType =
167 op.getNumResults() == 0
168 ? llvm::Type::getVoidTy(module->getContext())
169 : moduleTranslation.convertType(op.getResultTypes().front());
170 if (intrinType != fn->getReturnType()) {
171 return mlir::emitError(op.getLoc(), "intrinsic call returns ")
172 << diagStr(intrinType) << " but " << op.getIntrinAttr()
173 << " actually returns " << diagStr(fn->getReturnType());
174 }
175
176 // Check the argument types of the call. If the function is variadic, check
177 // the subrange of required arguments.
178 if (!fn->getFunctionType()->isVarArg() &&
179 op.getArgs().size() != fn->arg_size()) {
180 return mlir::emitError(op.getLoc(), "intrinsic call has ")
181 << op.getArgs().size() << " operands but " << op.getIntrinAttr()
182 << " expects " << fn->arg_size();
183 }
184 if (fn->getFunctionType()->isVarArg() &&
185 op.getArgs().size() < fn->arg_size()) {
186 return mlir::emitError(op.getLoc(), "intrinsic call has ")
187 << op.getArgs().size() << " operands but variadic "
188 << op.getIntrinAttr() << " expects at least " << fn->arg_size();
189 }
190 // Check the arguments up to the number the function requires.
191 for (unsigned i = 0, e = fn->arg_size(); i != e; ++i) {
192 const llvm::Type *expected = fn->getArg(i)->getType();
193 const llvm::Type *actual =
194 moduleTranslation.convertType(op.getOperandTypes()[i]);
195 if (actual != expected) {
196 return mlir::emitError(op.getLoc(), "intrinsic call operand #")
197 << i << " has type " << diagStr(actual) << " but "
198 << op.getIntrinAttr() << " expects " << diagStr(expected);
199 }
200 }
201
202 FastmathFlagsInterface itf = op;
203 builder.setFastMathFlags(getFastmathFlags(itf));
204
205 auto *inst = builder.CreateCall(
206 fn, moduleTranslation.lookupValues(op.getArgs()),
207 convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
208 moduleTranslation));
209
210 if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst)))
211 return failure();
212
213 if (op.getNumResults() == 1)
214 moduleTranslation.mapValue(op->getResults().front()) = inst;
215 return success();
216}
217
219 llvm::IRBuilderBase &builder,
220 LLVM::ModuleTranslation &moduleTranslation) {
221 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
222 llvm::LLVMContext &context = llvmModule->getContext();
223 llvm::NamedMDNode *linkerMDNode =
224 llvmModule->getOrInsertNamedMetadata("llvm.linker.options");
226 mdNodes.reserve(options.size());
227 for (auto s : options.getAsRange<StringAttr>()) {
228 auto *mdNode = llvm::MDString::get(context, s.getValue());
229 mdNodes.push_back(mdNode);
230 }
231
232 auto *listMDNode = llvm::MDTuple::get(context, mdNodes);
233 linkerMDNode->addOperand(listMDNode);
234}
235
236static llvm::Metadata *
237convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
238 llvm::IRBuilderBase &builder,
239 LLVM::ModuleTranslation &moduleTranslation) {
240 llvm::LLVMContext &context = builder.getContext();
241 llvm::MDBuilder mdb(context);
243
244 if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
245 for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) {
246 auto getFuncMetadata = [&](FlatSymbolRefAttr sym) -> llvm::Metadata * {
247 if (!sym)
248 return nullptr;
249 if (llvm::Function *fn =
250 moduleTranslation.lookupFunction(sym.getValue()))
251 return llvm::ValueAsMetadata::get(fn);
252 return nullptr;
253 };
254 llvm::Metadata *fromMetadata = getFuncMetadata(entry.getFrom());
255 llvm::Metadata *toMetadata = getFuncMetadata(entry.getTo());
256
257 llvm::Metadata *vals[] = {
258 fromMetadata, toMetadata,
259 mdb.createConstant(llvm::ConstantInt::get(
260 llvm::Type::getInt64Ty(context), entry.getCount()))};
261 nodes.push_back(llvm::MDNode::get(context, vals));
262 }
263 return llvm::MDTuple::getDistinct(context, nodes);
264 }
265 return nullptr;
266}
267
269 StringRef key, ModuleFlagProfileSummaryAttr summaryAttr,
270 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) {
271 llvm::LLVMContext &context = builder.getContext();
272 llvm::MDBuilder mdb(context);
273
274 auto getIntTuple = [&](StringRef key, uint64_t val) -> llvm::MDTuple * {
276 mdb.createString(key), mdb.createConstant(llvm::ConstantInt::get(
277 llvm::Type::getInt64Ty(context), val))};
278 return llvm::MDTuple::get(context, tupleNodes);
279 };
280
282 mdb.createString("ProfileFormat"),
283 mdb.createString(
284 stringifyProfileSummaryFormatKind(summaryAttr.getFormat()))};
285
287 llvm::MDTuple::get(context, fmtNode),
288 getIntTuple("TotalCount", summaryAttr.getTotalCount()),
289 getIntTuple("MaxCount", summaryAttr.getMaxCount()),
290 getIntTuple("MaxInternalCount", summaryAttr.getMaxInternalCount()),
291 getIntTuple("MaxFunctionCount", summaryAttr.getMaxFunctionCount()),
292 getIntTuple("NumCounts", summaryAttr.getNumCounts()),
293 getIntTuple("NumFunctions", summaryAttr.getNumFunctions()),
294 };
295
296 if (summaryAttr.getIsPartialProfile())
297 vals.push_back(
298 getIntTuple("IsPartialProfile", *summaryAttr.getIsPartialProfile()));
299
300 if (summaryAttr.getPartialProfileRatio()) {
302 mdb.createString("PartialProfileRatio"),
303 mdb.createConstant(llvm::ConstantFP::get(
304 llvm::Type::getDoubleTy(context),
305 summaryAttr.getPartialProfileRatio().getValue()))};
306 vals.push_back(llvm::MDTuple::get(context, tupleNodes));
307 }
308
309 SmallVector<llvm::Metadata *> detailedEntries;
310 llvm::Type *llvmInt64Type = llvm::Type::getInt64Ty(context);
311 for (ModuleFlagProfileSummaryDetailedAttr detailedEntry :
312 summaryAttr.getDetailedSummary()) {
314 mdb.createConstant(
315 llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getCutOff())),
316 mdb.createConstant(
317 llvm::ConstantInt::get(llvmInt64Type, detailedEntry.getMinCount())),
318 mdb.createConstant(llvm::ConstantInt::get(
319 llvmInt64Type, detailedEntry.getNumCounts()))};
320 detailedEntries.push_back(llvm::MDTuple::get(context, tupleNodes));
321 }
322 SmallVector<llvm::Metadata *> detailedSummary{
323 mdb.createString("DetailedSummary"),
324 llvm::MDTuple::get(context, detailedEntries)};
325 vals.push_back(llvm::MDTuple::get(context, detailedSummary));
326
327 return llvm::MDNode::get(context, vals);
328}
329
330static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
331 LLVM::ModuleTranslation &moduleTranslation) {
332 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
333 for (auto flagAttr : flags.getAsRange<ModuleFlagAttr>()) {
334 llvm::Metadata *valueMetadata =
336 .Case<StringAttr>([&](auto strAttr) {
337 return llvm::MDString::get(builder.getContext(),
338 strAttr.getValue());
339 })
340 .Case<IntegerAttr>([&](auto intAttr) {
341 return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
342 llvm::Type::getInt32Ty(builder.getContext()),
343 intAttr.getInt()));
344 })
345 .Case<ArrayAttr>([&](auto arrayAttr) {
346 return convertModuleFlagValue(flagAttr.getKey().getValue(),
347 arrayAttr, builder,
348 moduleTranslation);
349 })
350 .Case([&](ModuleFlagProfileSummaryAttr summaryAttr) {
352 flagAttr.getKey().getValue(), summaryAttr, builder,
353 moduleTranslation);
354 })
355 .Default([](auto) { return nullptr; });
356
357 assert(valueMetadata && "expected valid metadata");
358 llvmModule->addModuleFlag(
359 convertModFlagBehaviorToLLVM(flagAttr.getBehavior()),
360 flagAttr.getKey().getValue(), valueMetadata);
361 }
362}
363
364static llvm::DILocalScope *
365getLocalScopeFromLoc(llvm::IRBuilderBase &builder, Location loc,
366 LLVM::ModuleTranslation &moduleTranslation) {
367 if (auto scopeLoc =
369 if (auto *localScope = llvm::dyn_cast<llvm::DILocalScope>(
370 moduleTranslation.translateDebugInfo(scopeLoc.getMetadata())))
371 return localScope;
372 return builder.GetInsertBlock()->getParent()->getSubprogram();
373}
374
375static LogicalResult
376convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
377 LLVM::ModuleTranslation &moduleTranslation) {
378
379 llvm::IRBuilder<>::FastMathFlagGuard fmfGuard(builder);
380 if (auto fmf = dyn_cast<FastmathFlagsInterface>(opInst))
381 builder.setFastMathFlags(getFastmathFlags(fmf));
382
383#include "mlir/Dialect/LLVMIR/LLVMConversions.inc"
384#include "mlir/Dialect/LLVMIR/LLVMIntrinsicConversions.inc"
385
386 // Emit function calls. If the "callee" attribute is present, this is a
387 // direct function call and we also need to look up the remapped function
388 // itself. Otherwise, this is an indirect call and the callee is the first
389 // operand, look it up as a normal value.
390 if (auto callOp = dyn_cast<LLVM::CallOp>(opInst)) {
391 auto operands = moduleTranslation.lookupValues(callOp.getCalleeOperands());
393 convertOperandBundles(callOp.getOpBundleOperands(),
394 callOp.getOpBundleTags(), moduleTranslation);
395 ArrayRef<llvm::Value *> operandsRef(operands);
396 llvm::CallInst *call;
397 if (auto attr = callOp.getCalleeAttr()) {
398 if (llvm::Function *function =
399 moduleTranslation.lookupFunction(attr.getValue())) {
400 call = builder.CreateCall(function, operandsRef, opBundles);
401 } else {
402 Operation *moduleOp = parentLLVMModule(&opInst);
403 Operation *ifuncOp =
404 moduleTranslation.symbolTable().lookupSymbolIn(moduleOp, attr);
405 llvm::GlobalValue *ifunc = moduleTranslation.lookupIFunc(ifuncOp);
406 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
407 moduleTranslation.convertType(callOp.getCalleeFunctionType()));
408 call = builder.CreateCall(calleeType, ifunc, operandsRef, opBundles);
409 }
410 } else {
411 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
412 moduleTranslation.convertType(callOp.getCalleeFunctionType()));
413 call = builder.CreateCall(calleeType, operandsRef.front(),
414 operandsRef.drop_front(), opBundles);
415 }
416 call->setCallingConv(convertCConvToLLVM(callOp.getCConv()));
417 call->setTailCallKind(convertTailCallKindToLLVM(callOp.getTailCallKind()));
418 if (callOp.getConvergentAttr())
419 call->addFnAttr(llvm::Attribute::Convergent);
420 if (callOp.getNoUnwindAttr())
421 call->addFnAttr(llvm::Attribute::NoUnwind);
422 if (callOp.getWillReturnAttr())
423 call->addFnAttr(llvm::Attribute::WillReturn);
424 if (callOp.getNoInlineAttr())
425 call->addFnAttr(llvm::Attribute::NoInline);
426 if (callOp.getAlwaysInlineAttr())
427 call->addFnAttr(llvm::Attribute::AlwaysInline);
428 if (callOp.getInlineHintAttr())
429 call->addFnAttr(llvm::Attribute::InlineHint);
430
431 if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call)))
432 return failure();
433
434 if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
435 llvm::MemoryEffects memEffects =
436 llvm::MemoryEffects(llvm::MemoryEffects::Location::ArgMem,
437 convertModRefInfoToLLVM(memAttr.getArgMem())) |
438 llvm::MemoryEffects(
439 llvm::MemoryEffects::Location::InaccessibleMem,
440 convertModRefInfoToLLVM(memAttr.getInaccessibleMem())) |
441 llvm::MemoryEffects(llvm::MemoryEffects::Location::Other,
442 convertModRefInfoToLLVM(memAttr.getOther())) |
443 llvm::MemoryEffects(llvm::MemoryEffects::Location::ErrnoMem,
444 convertModRefInfoToLLVM(memAttr.getErrnoMem())) |
445 llvm::MemoryEffects(
446 llvm::MemoryEffects::Location::TargetMem0,
447 convertModRefInfoToLLVM(memAttr.getTargetMem0())) |
448 llvm::MemoryEffects(llvm::MemoryEffects::Location::TargetMem1,
449 convertModRefInfoToLLVM(memAttr.getTargetMem1()));
450 call->setMemoryEffects(memEffects);
451 }
452
453 moduleTranslation.setAccessGroupsMetadata(callOp, call);
454 moduleTranslation.setAliasScopeMetadata(callOp, call);
455 moduleTranslation.setTBAAMetadata(callOp, call);
456 // If the called function has a result, remap the corresponding value. Note
457 // that LLVM IR dialect CallOp has either 0 or 1 result.
458 if (opInst.getNumResults() != 0)
459 moduleTranslation.mapValue(opInst.getResult(0), call);
460 // Check that LLVM call returns void for 0-result functions.
461 else if (!call->getType()->isVoidTy())
462 return failure();
463 moduleTranslation.mapCall(callOp, call);
464 return success();
465 }
466
467 if (auto inlineAsmOp = dyn_cast<LLVM::InlineAsmOp>(opInst)) {
468 // TODO: refactor function type creation which usually occurs in std-LLVM
469 // conversion.
470 SmallVector<Type, 8> operandTypes;
471 llvm::append_range(operandTypes, inlineAsmOp.getOperands().getTypes());
472
473 Type resultType;
474 if (inlineAsmOp.getNumResults() == 0) {
475 resultType = LLVM::LLVMVoidType::get(&moduleTranslation.getContext());
476 } else {
477 assert(inlineAsmOp.getNumResults() == 1);
478 resultType = inlineAsmOp.getResultTypes()[0];
479 }
480 auto ft = LLVM::LLVMFunctionType::get(resultType, operandTypes);
481 llvm::InlineAsm *inlineAsmInst =
482 inlineAsmOp.getAsmDialect()
483 ? llvm::InlineAsm::get(
484 static_cast<llvm::FunctionType *>(
485 moduleTranslation.convertType(ft)),
486 inlineAsmOp.getAsmString(), inlineAsmOp.getConstraints(),
487 inlineAsmOp.getHasSideEffects(),
488 inlineAsmOp.getIsAlignStack(),
489 convertAsmDialectToLLVM(*inlineAsmOp.getAsmDialect()))
490 : llvm::InlineAsm::get(static_cast<llvm::FunctionType *>(
491 moduleTranslation.convertType(ft)),
492 inlineAsmOp.getAsmString(),
493 inlineAsmOp.getConstraints(),
494 inlineAsmOp.getHasSideEffects(),
495 inlineAsmOp.getIsAlignStack());
496 llvm::CallInst *inst = builder.CreateCall(
497 inlineAsmInst,
498 moduleTranslation.lookupValues(inlineAsmOp.getOperands()));
499 inst->setTailCallKind(convertTailCallKindToLLVM(
500 inlineAsmOp.getTailCallKindAttr().getTailCallKind()));
501 if (auto maybeOperandAttrs = inlineAsmOp.getOperandAttrs()) {
502 llvm::AttributeList attrList;
503 for (const auto &it : llvm::enumerate(*maybeOperandAttrs)) {
504 Attribute attr = it.value();
505 if (!attr)
506 continue;
507 DictionaryAttr dAttr = cast<DictionaryAttr>(attr);
508 if (dAttr.empty())
509 continue;
510 TypeAttr tAttr =
511 cast<TypeAttr>(dAttr.get(InlineAsmOp::getElementTypeAttrName()));
512 llvm::AttrBuilder b(moduleTranslation.getLLVMContext());
513 llvm::Type *ty = moduleTranslation.convertType(tAttr.getValue());
514 b.addTypeAttr(llvm::Attribute::ElementType, ty);
515 // shift to account for the returned value (this is always 1 aggregate
516 // value in LLVM).
517 int shift = (opInst.getNumResults() > 0) ? 1 : 0;
518 attrList = attrList.addAttributesAtIndex(
519 moduleTranslation.getLLVMContext(), it.index() + shift, b);
520 }
521 inst->setAttributes(attrList);
522 }
523
524 if (opInst.getNumResults() != 0)
525 moduleTranslation.mapValue(opInst.getResult(0), inst);
526 return success();
527 }
528
529 if (auto invOp = dyn_cast<LLVM::InvokeOp>(opInst)) {
530 auto operands = moduleTranslation.lookupValues(invOp.getCalleeOperands());
532 convertOperandBundles(invOp.getOpBundleOperands(),
533 invOp.getOpBundleTags(), moduleTranslation);
534 ArrayRef<llvm::Value *> operandsRef(operands);
535 llvm::InvokeInst *result;
536 if (auto attr = opInst.getAttrOfType<FlatSymbolRefAttr>("callee")) {
537 result = builder.CreateInvoke(
538 moduleTranslation.lookupFunction(attr.getValue()),
539 moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
540 moduleTranslation.lookupBlock(invOp.getSuccessor(1)), operandsRef,
541 opBundles);
542 } else {
543 llvm::FunctionType *calleeType = llvm::cast<llvm::FunctionType>(
544 moduleTranslation.convertType(invOp.getCalleeFunctionType()));
545 result = builder.CreateInvoke(
546 calleeType, operandsRef.front(),
547 moduleTranslation.lookupBlock(invOp.getSuccessor(0)),
548 moduleTranslation.lookupBlock(invOp.getSuccessor(1)),
549 operandsRef.drop_front(), opBundles);
550 }
551 result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
552 if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result)))
553 return failure();
554 moduleTranslation.mapBranch(invOp, result);
555 // InvokeOp can only have 0 or 1 result
556 if (invOp->getNumResults() != 0) {
557 moduleTranslation.mapValue(opInst.getResult(0), result);
558 return success();
559 }
560 return success(result->getType()->isVoidTy());
561 }
562
563 if (auto lpOp = dyn_cast<LLVM::LandingpadOp>(opInst)) {
564 llvm::Type *ty = moduleTranslation.convertType(lpOp.getType());
565 llvm::LandingPadInst *lpi =
566 builder.CreateLandingPad(ty, lpOp.getNumOperands());
567 lpi->setCleanup(lpOp.getCleanup());
568
569 // Add clauses
570 for (llvm::Value *operand :
571 moduleTranslation.lookupValues(lpOp.getOperands())) {
572 // All operands should be constant - checked by verifier
573 if (auto *constOperand = dyn_cast<llvm::Constant>(operand))
574 lpi->addClause(constOperand);
575 }
576 moduleTranslation.mapValue(lpOp.getResult(), lpi);
577 return success();
578 }
579
580 // Emit branches. We need to look up the remapped blocks and ignore the
581 // block arguments that were transformed into PHI nodes.
582 if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
583 llvm::BranchInst *branch =
584 builder.CreateBr(moduleTranslation.lookupBlock(brOp.getSuccessor()));
585 moduleTranslation.mapBranch(&opInst, branch);
586 moduleTranslation.setLoopMetadata(&opInst, branch);
587 return success();
588 }
589 if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
590 llvm::BranchInst *branch = builder.CreateCondBr(
591 moduleTranslation.lookupValue(condbrOp.getOperand(0)),
592 moduleTranslation.lookupBlock(condbrOp.getSuccessor(0)),
593 moduleTranslation.lookupBlock(condbrOp.getSuccessor(1)));
594 moduleTranslation.mapBranch(&opInst, branch);
595 moduleTranslation.setLoopMetadata(&opInst, branch);
596 return success();
597 }
598 if (auto switchOp = dyn_cast<LLVM::SwitchOp>(opInst)) {
599 llvm::SwitchInst *switchInst = builder.CreateSwitch(
600 moduleTranslation.lookupValue(switchOp.getValue()),
601 moduleTranslation.lookupBlock(switchOp.getDefaultDestination()),
602 switchOp.getCaseDestinations().size());
603
604 // Handle switch with zero cases.
605 if (!switchOp.getCaseValues())
606 return success();
607
608 auto *ty = llvm::cast<llvm::IntegerType>(
609 moduleTranslation.convertType(switchOp.getValue().getType()));
610 for (auto i :
611 llvm::zip(llvm::cast<DenseIntElementsAttr>(*switchOp.getCaseValues()),
612 switchOp.getCaseDestinations()))
613 switchInst->addCase(
614 llvm::ConstantInt::get(ty, std::get<0>(i).getLimitedValue()),
615 moduleTranslation.lookupBlock(std::get<1>(i)));
616
617 moduleTranslation.mapBranch(&opInst, switchInst);
618 return success();
619 }
620 if (auto indBrOp = dyn_cast<LLVM::IndirectBrOp>(opInst)) {
621 llvm::IndirectBrInst *indBr = builder.CreateIndirectBr(
622 moduleTranslation.lookupValue(indBrOp.getAddr()),
623 indBrOp->getNumSuccessors());
624 for (auto *succ : indBrOp.getSuccessors())
625 indBr->addDestination(moduleTranslation.lookupBlock(succ));
626 moduleTranslation.mapBranch(&opInst, indBr);
627 return success();
628 }
629
630 // Emit addressof. We need to look up the global value referenced by the
631 // operation and store it in the MLIR-to-LLVM value mapping. This does not
632 // emit any LLVM instruction.
633 if (auto addressOfOp = dyn_cast<LLVM::AddressOfOp>(opInst)) {
634 LLVM::GlobalOp global =
635 addressOfOp.getGlobal(moduleTranslation.symbolTable());
636 LLVM::LLVMFuncOp function =
637 addressOfOp.getFunction(moduleTranslation.symbolTable());
638 LLVM::AliasOp alias = addressOfOp.getAlias(moduleTranslation.symbolTable());
639 LLVM::IFuncOp ifunc = addressOfOp.getIFunc(moduleTranslation.symbolTable());
640
641 // The verifier should not have allowed this.
642 assert((global || function || alias || ifunc) &&
643 "referencing an undefined global, function, alias, or ifunc");
644
645 llvm::Value *llvmValue = nullptr;
646 if (global)
647 llvmValue = moduleTranslation.lookupGlobal(global);
648 else if (alias)
649 llvmValue = moduleTranslation.lookupAlias(alias);
650 else if (function)
651 llvmValue = moduleTranslation.lookupFunction(function.getName());
652 else
653 llvmValue = moduleTranslation.lookupIFunc(ifunc);
654
655 moduleTranslation.mapValue(addressOfOp.getResult(), llvmValue);
656 return success();
657 }
658
659 // Emit dso_local_equivalent. We need to look up the global value referenced
660 // by the operation and store it in the MLIR-to-LLVM value mapping.
661 if (auto dsoLocalEquivalentOp =
662 dyn_cast<LLVM::DSOLocalEquivalentOp>(opInst)) {
663 LLVM::LLVMFuncOp function =
664 dsoLocalEquivalentOp.getFunction(moduleTranslation.symbolTable());
665 LLVM::AliasOp alias =
666 dsoLocalEquivalentOp.getAlias(moduleTranslation.symbolTable());
667
668 // The verifier should not have allowed this.
669 assert((function || alias) &&
670 "referencing an undefined function, or alias");
671
672 llvm::Value *llvmValue = nullptr;
673 if (alias)
674 llvmValue = moduleTranslation.lookupAlias(alias);
675 else
676 llvmValue = moduleTranslation.lookupFunction(function.getName());
677
678 moduleTranslation.mapValue(
679 dsoLocalEquivalentOp.getResult(),
680 llvm::DSOLocalEquivalent::get(cast<llvm::GlobalValue>(llvmValue)));
681 return success();
682 }
683
684 // Emit blockaddress. We first need to find the LLVM block referenced by this
685 // operation and then create a LLVM block address for it.
686 if (auto blockAddressOp = dyn_cast<LLVM::BlockAddressOp>(opInst)) {
687 BlockAddressAttr blockAddressAttr = blockAddressOp.getBlockAddr();
688 llvm::BasicBlock *llvmBlock =
689 moduleTranslation.lookupBlockAddress(blockAddressAttr);
690
691 llvm::Value *llvmValue = nullptr;
692 StringRef fnName = blockAddressAttr.getFunction().getValue();
693 if (llvmBlock) {
694 llvm::Function *llvmFn = moduleTranslation.lookupFunction(fnName);
695 llvmValue = llvm::BlockAddress::get(llvmFn, llvmBlock);
696 } else {
697 // The matching LLVM block is not yet emitted, a placeholder is created
698 // in its place. When the LLVM block is emitted later in translation,
699 // the llvmValue is replaced with the actual llvm::BlockAddress.
700 // A GlobalVariable is chosen as placeholder because in general LLVM
701 // constants are uniqued and are not proper for RAUW, since that could
702 // harm unrelated uses of the constant.
703 llvmValue = new llvm::GlobalVariable(
704 *moduleTranslation.getLLVMModule(),
705 llvm::PointerType::getUnqual(moduleTranslation.getLLVMContext()),
706 /*isConstant=*/true, llvm::GlobalValue::LinkageTypes::ExternalLinkage,
707 /*Initializer=*/nullptr,
708 Twine("__mlir_block_address_")
709 .concat(Twine(fnName))
710 .concat(Twine((uint64_t)blockAddressOp.getOperation())));
711 moduleTranslation.mapUnresolvedBlockAddress(blockAddressOp, llvmValue);
712 }
713
714 moduleTranslation.mapValue(blockAddressOp.getResult(), llvmValue);
715 return success();
716 }
717
718 // Emit block label. If this label is seen before BlockAddressOp is
719 // translated, go ahead and already map it.
720 if (auto blockTagOp = dyn_cast<LLVM::BlockTagOp>(opInst)) {
721 auto funcOp = blockTagOp->getParentOfType<LLVMFuncOp>();
722 BlockAddressAttr blockAddressAttr = BlockAddressAttr::get(
723 &moduleTranslation.getContext(),
724 FlatSymbolRefAttr::get(&moduleTranslation.getContext(),
725 funcOp.getName()),
726 blockTagOp.getTag());
727 moduleTranslation.mapBlockAddress(blockAddressAttr,
728 builder.GetInsertBlock());
729 return success();
730 }
731
732 return failure();
733}
734
735static LogicalResult
737 NamedAttribute attribute,
738 LLVM::ModuleTranslation &moduleTranslation) {
739 StringRef name = attribute.getName();
740 if (name == LLVMDialect::getMmraAttrName()) {
742 if (auto oneTag = dyn_cast<LLVM::MMRATagAttr>(attribute.getValue())) {
743 tags.emplace_back(oneTag.getPrefix(), oneTag.getSuffix());
744 } else if (auto manyTags = dyn_cast<ArrayAttr>(attribute.getValue())) {
745 for (Attribute attr : manyTags) {
746 auto tag = dyn_cast<MMRATagAttr>(attr);
747 if (!tag)
748 return op.emitOpError(
749 "MMRA annotations array contains value that isn't an MMRA tag");
750 tags.emplace_back(tag.getPrefix(), tag.getSuffix());
751 }
752 } else {
753 return op.emitOpError(
754 "llvm.mmra is something other than an MMRA tag or an array of them");
755 }
756 llvm::MDTuple *mmraMd =
757 llvm::MMRAMetadata::getMD(moduleTranslation.getLLVMContext(), tags);
758 if (!mmraMd) {
759 // Empty list, canonicalizes to nothing
760 return success();
761 }
762 for (llvm::Instruction *inst : instructions)
763 inst->setMetadata(llvm::LLVMContext::MD_mmra, mmraMd);
764 return success();
765 }
766 return success();
767}
768
769namespace {
770/// Implementation of the dialect interface that converts operations belonging
771/// to the LLVM dialect to LLVM IR.
772class LLVMDialectLLVMIRTranslationInterface
774public:
776
777 /// Translates the given operation to LLVM IR using the provided IR builder
778 /// and saving the state in `moduleTranslation`.
779 LogicalResult
780 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
781 LLVM::ModuleTranslation &moduleTranslation) const final {
782 return convertOperationImpl(*op, builder, moduleTranslation);
783 }
784
785 /// Handle some metadata that is represented as a discardable attribute.
786 LogicalResult
787 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
788 NamedAttribute attribute,
789 LLVM::ModuleTranslation &moduleTranslation) const final {
790 return amendOperationImpl(*op, instructions, attribute, moduleTranslation);
791 }
792};
793} // namespace
794
796 registry.insert<LLVM::LLVMDialect>();
797 registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
798 dialect->addInterfaces<LLVMDialectLLVMIRTranslationInterface>();
799 });
800}
801
803 DialectRegistry registry;
805 context.appendDialectRegistry(registry);
806}
return success()
static std::string diagStr(const llvm::Type *type)
Convert an LLVM type to a string for printing in diagnostics.
static llvm::Metadata * convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static SmallVector< llvm::OperandBundleDef > convertOperandBundles(OperandRangeRange bundleOperands, ArrayAttr bundleTags, LLVM::ModuleTranslation &moduleTranslation)
static FailureOr< llvm::Function * > getOverloadedDeclaration(CallIntrinsicOp op, llvm::Intrinsic::ID id, llvm::Module *module, LLVM::ModuleTranslation &moduleTranslation)
Get the declaration of an overloaded llvm intrinsic.
static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
Builder for LLVM_CallIntrinsicOp.
static LogicalResult amendOperationImpl(Operation &op, ArrayRef< llvm::Instruction * > instructions, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation)
static llvm::FastMathFlags getFastmathFlags(FastmathFlagsInterface &op)
static llvm::OperandBundleDef convertOperandBundle(OperandRange bundleOperands, StringRef bundleTag, LLVM::ModuleTranslation &moduleTranslation)
static SmallVector< unsigned > extractPosition(ArrayRef< int64_t > indices)
Convert the value of a DenseI64ArrayAttr to a vector of unsigned indices.
static llvm::DILocalScope * getLocalScopeFromLoc(llvm::IRBuilderBase &builder, Location loc, LLVM::ModuleTranslation &moduleTranslation)
static llvm::Metadata * convertModuleFlagProfileSummaryAttr(StringRef key, ModuleFlagProfileSummaryAttr summaryAttr, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
static void convertLinkerOptionsOp(ArrayAttr options, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
Definition Attributes.h:25
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
A symbol reference with a reference path containing a single element.
static FlatSymbolRefAttr get(StringAttr value)
Construct a symbol reference for the given value name.
This class represents a fused location whose metadata is known to be an instance of the given type.
Definition Location.h:149
Base class for dialect interfaces providing translation to LLVM IR.
Implementation class for module translation.
void mapUnresolvedBlockAddress(BlockAddressOp op, llvm::Value *cst)
Maps a blockaddress operation to its corresponding placeholder LLVM value.
void mapCall(Operation *mlir, llvm::CallInst *llvm)
Stores a mapping between an MLIR call operation and a corresponding LLVM call instruction.
llvm::BasicBlock * lookupBlock(Block *block) const
Finds an LLVM IR basic block that corresponds to the given MLIR block.
MLIRContext & getContext()
Returns the MLIR context of the module being translated.
void mapBranch(Operation *mlir, llvm::Instruction *llvm)
Stores the mapping between an MLIR operation with successors and a corresponding LLVM IR instruction.
SmallVector< llvm::Value * > lookupValues(ValueRange values)
Looks up remapped a list of remapped values.
llvm::Value * lookupValue(Value value) const
Finds an LLVM IR value corresponding to the given MLIR value.
LogicalResult convertArgAndResultAttrs(ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call, ArrayRef< unsigned > immArgPositions={})
Converts argument and result attributes from attrsOp to LLVM IR attributes on the call instruction.
SymbolTableCollection & symbolTable()
llvm::Type * convertType(Type type)
Converts the type from MLIR LLVM dialect to LLVM.
void setTBAAMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
Sets LLVM TBAA metadata for memory operations that have TBAA attributes.
llvm::GlobalValue * lookupGlobal(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global value.
llvm::BasicBlock * lookupBlockAddress(BlockAddressAttr attr) const
Finds the LLVM basic block that corresponds to the given BlockAddressAttr.
llvm::GlobalValue * lookupIFunc(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining an IFunc.
llvm::Metadata * translateDebugInfo(LLVM::DINodeAttr attr)
Translates the given LLVM debug info metadata.
llvm::GlobalValue * lookupAlias(Operation *op)
Finds an LLVM IR global value that corresponds to the given MLIR operation defining a global alias va...
llvm::Function * lookupFunction(StringRef name) const
Finds an LLVM IR function by its name.
llvm::Module * getLLVMModule()
Returns the LLVM module in which the IR is being constructed.
void setAliasScopeMetadata(AliasAnalysisOpInterface op, llvm::Instruction *inst)
void setAccessGroupsMetadata(AccessGroupOpInterface op, llvm::Instruction *inst)
void mapValue(Value mlir, llvm::Value *llvm)
Stores the mapping between an MLIR value and its LLVM IR counterpart.
llvm::LLVMContext & getLLVMContext() const
Returns the LLVM context in which the IR is being constructed.
void mapBlockAddress(BlockAddressAttr attr, llvm::BasicBlock *block)
Maps a BlockAddressAttr to its corresponding LLVM basic block.
void setLoopMetadata(Operation *op, llvm::Instruction *inst)
Sets LLVM loop metadata for branch operations that have a loop annotation attribute.
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition Location.h:45
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void appendDialectRegistry(const DialectRegistry &registry)
Append the contents of the given dialect registry to the registry associated with this context.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
AttrClass getAttrOfType(StringAttr name)
Definition Operation.h:550
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
llvm::Constant * getLLVMConstant(llvm::Type *llvmType, Attribute attr, Location loc, const ModuleTranslation &moduleTranslation)
Create an LLVM IR constant of llvmType from the MLIR attribute attr.
Operation * parentLLVMModule(Operation *op)
Lookup parent Module satisfying LLVM conditions on the Module Operation.
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
void registerLLVMDialectTranslation(DialectRegistry &registry)
Register the LLVM dialect and the translation from it to the LLVM IR in the given registry;.