MLIR 23.0.0git
LLVMDialect.cpp
Go to the documentation of this file.
1//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the LLVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12//===----------------------------------------------------------------------===//
13
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Matchers.h"
26
27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/IR/DataLayout.h"
30#include "llvm/Support/Error.h"
31
32#include "LLVMDialectBytecode.h"
33
34#include <numeric>
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::LLVM;
39using mlir::LLVM::cconv::getMaxEnumValForCConv;
40using mlir::LLVM::linkage::getMaxEnumValForLinkage;
41using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
42
43#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
44
45//===----------------------------------------------------------------------===//
46// Attribute Helpers
47//===----------------------------------------------------------------------===//
48
49static constexpr const char kElemTypeAttrName[] = "elem_type";
50
53 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
54 if (attr.getName() == "fastmathFlags") {
55 auto defAttr =
56 FastmathFlagsAttr::get(attr.getValue().getContext(), {});
57 return defAttr != attr.getValue();
58 }
59 return true;
60 }));
61 return filteredAttrs;
62}
63
64/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
65/// fully defined llvm.func.
66static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
67 Operation *op,
68 SymbolTableCollection &symbolTable) {
69 StringRef name = symbol.getValue();
70 auto func =
71 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
72 if (!func)
73 return op->emitOpError("'")
74 << name << "' does not reference a valid LLVM function";
75 if (func.isExternal())
76 return op->emitOpError("'") << name << "' does not have a definition";
77 return success();
78}
79
80/// Returns a boolean type that has the same shape as `type`. It supports both
81/// fixed size vectors as well as scalable vectors.
82static Type getI1SameShape(Type type) {
83 Type i1Type = IntegerType::get(type.getContext(), 1);
86 return i1Type;
87}
88
89// Parses one of the keywords provided in the list `keywords` and returns the
90// position of the parsed keyword in the list. If none of the keywords from the
91// list is parsed, returns -1.
93 ArrayRef<StringRef> keywords) {
94 for (const auto &en : llvm::enumerate(keywords)) {
95 if (succeeded(parser.parseOptionalKeyword(en.value())))
96 return en.index();
97 }
98 return -1;
99}
100
101namespace {
102template <typename Ty>
103struct EnumTraits {};
104
105#define REGISTER_ENUM_TYPE(Ty) \
106 template <> \
107 struct EnumTraits<Ty> { \
108 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
109 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
110 }
111
112REGISTER_ENUM_TYPE(Linkage);
113REGISTER_ENUM_TYPE(UnnamedAddr);
114REGISTER_ENUM_TYPE(CConv);
115REGISTER_ENUM_TYPE(TailCallKind);
116REGISTER_ENUM_TYPE(Visibility);
117} // namespace
118
119/// Parse an enum from the keyword, or default to the provided default value.
120/// The return type is the enum type by default, unless overridden with the
121/// second template argument.
122template <typename EnumTy, typename RetTy = EnumTy>
124 EnumTy defaultValue) {
126 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
127 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
128
129 int index = parseOptionalKeywordAlternative(parser, names);
130 if (index == -1)
131 return static_cast<RetTy>(defaultValue);
132 return static_cast<RetTy>(index);
133}
134
135static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) {
136 p << stringifyLinkage(val.getLinkage());
137}
138
139static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
140 val = LinkageAttr::get(
141 p.getContext(),
142 parseOptionalLLVMKeyword<LLVM::Linkage>(p, LLVM::Linkage::External));
143 return success();
144}
145
147 bool isExpandLoad,
148 uint64_t alignment = 1) {
149 // From
150 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
151 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
152 //
153 // The pointer alignment defaults to 1.
154 if (alignment == 1) {
155 return nullptr;
156 }
157
158 auto emptyDictAttr = builder.getDictionaryAttr({});
159 auto alignmentAttr = builder.getI64IntegerAttr(alignment);
160 auto namedAttr =
161 builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
162 SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
163 auto alignDictAttr = builder.getDictionaryAttr(attrs);
164 // From
165 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
166 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
167 //
168 // The align parameter attribute can be provided for [expandload]'s first
169 // argument. The align parameter attribute can be provided for
170 // [compressstore]'s second argument.
171 int pos = isExpandLoad ? 0 : 1;
172 return pos == 0 ? builder.getArrayAttr(
173 {alignDictAttr, emptyDictAttr, emptyDictAttr})
174 : builder.getArrayAttr(
175 {emptyDictAttr, alignDictAttr, emptyDictAttr});
176}
177
178//===----------------------------------------------------------------------===//
179// Operand bundle helpers.
180//===----------------------------------------------------------------------===//
181
183 TypeRange operandTypes, StringRef tag) {
184 p.printString(tag);
185 p << "(";
186
187 if (!operands.empty()) {
188 p.printOperands(operands);
189 p << " : ";
190 llvm::interleaveComma(operandTypes, p);
191 }
192
193 p << ")";
194}
195
197 OperandRangeRange opBundleOperands,
198 TypeRangeRange opBundleOperandTypes,
199 std::optional<ArrayAttr> opBundleTags) {
200 if (opBundleOperands.empty())
201 return;
202 assert(opBundleTags && "expect operand bundle tags");
203
204 p << "[";
205 llvm::interleaveComma(
206 llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
207 [&p](auto bundle) {
208 auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
209 printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
210 bundleTag);
211 });
212 p << "]";
213}
214
215static ParseResult parseOneOpBundle(
216 OpAsmParser &p,
218 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
219 SmallVector<Attribute> &opBundleTags) {
220 SMLoc currentParserLoc = p.getCurrentLocation();
222 SmallVector<Type> types;
223 std::string tag;
224
225 if (p.parseString(&tag))
226 return p.emitError(currentParserLoc, "expect operand bundle tag");
227
228 if (p.parseLParen())
229 return failure();
230
231 if (p.parseOptionalRParen()) {
232 if (p.parseOperandList(operands) || p.parseColon() ||
233 p.parseTypeList(types) || p.parseRParen())
234 return failure();
235 }
236
237 opBundleOperands.push_back(std::move(operands));
238 opBundleOperandTypes.push_back(std::move(types));
239 opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
240
241 return success();
242}
243
244static std::optional<ParseResult> parseOpBundles(
245 OpAsmParser &p,
247 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
248 ArrayAttr &opBundleTags) {
249 if (p.parseOptionalLSquare())
250 return std::nullopt;
251
252 if (succeeded(p.parseOptionalRSquare()))
253 return success();
254
255 SmallVector<Attribute> opBundleTagAttrs;
256 auto bundleParser = [&] {
257 return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
258 opBundleTagAttrs);
259 };
260 if (p.parseCommaSeparatedList(bundleParser))
261 return failure();
262
263 if (p.parseRSquare())
264 return failure();
265
266 opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
267
268 return success();
269}
270
271//===----------------------------------------------------------------------===//
272// Printing, parsing, folding and builder for LLVM::CmpOp.
273//===----------------------------------------------------------------------===//
274
275void ICmpOp::print(OpAsmPrinter &p) {
276 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
277 << ", " << getOperand(1);
278 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
279 p << " : " << getLhs().getType();
280}
281
282void FCmpOp::print(OpAsmPrinter &p) {
283 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
284 << ", " << getOperand(1);
285 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
286 p << " : " << getLhs().getType();
287}
288
289// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
290// attribute-dict? `:` type
291// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
292// attribute-dict? `:` type
293template <typename CmpPredicateType>
294static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
295 StringAttr predicateAttr;
297 Type type;
298 SMLoc predicateLoc, trailingTypeLoc;
299 if (parser.getCurrentLocation(&predicateLoc) ||
300 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
301 parser.parseOperand(lhs) || parser.parseComma() ||
302 parser.parseOperand(rhs) ||
303 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
304 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
305 parser.resolveOperand(lhs, type, result.operands) ||
306 parser.resolveOperand(rhs, type, result.operands))
307 return failure();
308
309 // Replace the string attribute `predicate` with an integer attribute.
310 int64_t predicateValue = 0;
311 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
312 std::optional<ICmpPredicate> predicate =
313 symbolizeICmpPredicate(predicateAttr.getValue());
314 if (!predicate)
315 return parser.emitError(predicateLoc)
316 << "'" << predicateAttr.getValue()
317 << "' is an incorrect value of the 'predicate' attribute";
318 predicateValue = static_cast<int64_t>(*predicate);
319 } else {
320 std::optional<FCmpPredicate> predicate =
321 symbolizeFCmpPredicate(predicateAttr.getValue());
322 if (!predicate)
323 return parser.emitError(predicateLoc)
324 << "'" << predicateAttr.getValue()
325 << "' is an incorrect value of the 'predicate' attribute";
326 predicateValue = static_cast<int64_t>(*predicate);
327 }
328
329 result.attributes.set("predicate",
330 parser.getBuilder().getI64IntegerAttr(predicateValue));
331
332 // The result type is either i1 or a vector type <? x i1> if the inputs are
333 // vectors.
334 if (!isCompatibleType(type))
335 return parser.emitError(trailingTypeLoc,
336 "expected LLVM dialect-compatible type");
337 result.addTypes(getI1SameShape(type));
338 return success();
339}
340
341ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
342 return parseCmpOp<ICmpPredicate>(parser, result);
343}
344
345ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
346 return parseCmpOp<FCmpPredicate>(parser, result);
347}
348
349/// Returns a scalar or vector boolean attribute of the given type.
350static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
351 auto boolAttr = BoolAttr::get(ctx, value);
352 ShapedType shapedType = dyn_cast<ShapedType>(type);
353 if (!shapedType)
354 return boolAttr;
355 return DenseElementsAttr::get(shapedType, boolAttr);
356}
357
358OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
359 if (getPredicate() != ICmpPredicate::eq &&
360 getPredicate() != ICmpPredicate::ne)
361 return {};
362
363 // cmpi(eq/ne, x, x) -> true/false
364 if (getLhs() == getRhs())
366 getPredicate() == ICmpPredicate::eq);
367
368 // cmpi(eq/ne, alloca, null) -> false/true
369 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
371 getPredicate() == ICmpPredicate::ne);
372
373 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
374 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
375 Value lhs = getLhs();
376 Value rhs = getRhs();
377 getLhsMutable().assign(rhs);
378 getRhsMutable().assign(lhs);
379 return getResult();
380 }
381
382 return {};
383}
384
385//===----------------------------------------------------------------------===//
386// Printing, parsing and verification for LLVM::AllocaOp.
387//===----------------------------------------------------------------------===//
388
389void AllocaOp::print(OpAsmPrinter &p) {
390 auto funcTy =
391 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
392
393 if (getInalloca())
394 p << " inalloca";
395
396 p << ' ' << getArraySize() << " x " << getElemType();
397 if (getAlignment() && *getAlignment() != 0)
398 p.printOptionalAttrDict((*this)->getAttrs(),
399 {kElemTypeAttrName, getInallocaAttrName()});
400 else
402 (*this)->getAttrs(),
403 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
404 p << " : " << funcTy;
405}
406
407// <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
408// attribute-dict? `:` type `,` type
409ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
411 Type type, elemType;
412 SMLoc trailingTypeLoc;
413
414 if (succeeded(parser.parseOptionalKeyword("inalloca")))
415 result.addAttribute(getInallocaAttrName(result.name),
416 UnitAttr::get(parser.getContext()));
417
418 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
419 parser.parseType(elemType) ||
420 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
421 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
422 return failure();
423
424 std::optional<NamedAttribute> alignmentAttr =
425 result.attributes.getNamed("alignment");
426 if (alignmentAttr.has_value()) {
427 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
428 if (!alignmentInt)
429 return parser.emitError(parser.getNameLoc(),
430 "expected integer alignment");
431 if (alignmentInt.getValue().isZero())
432 result.attributes.erase("alignment");
433 }
434
435 // Extract the result type from the trailing function type.
436 auto funcType = llvm::dyn_cast<FunctionType>(type);
437 if (!funcType || funcType.getNumInputs() != 1 ||
438 funcType.getNumResults() != 1)
439 return parser.emitError(
440 trailingTypeLoc,
441 "expected trailing function type with one argument and one result");
442
443 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
444 return failure();
445
446 Type resultType = funcType.getResult(0);
447 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
448 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
449
450 result.addTypes({funcType.getResult(0)});
451 return success();
452}
453
454LogicalResult AllocaOp::verify() {
455 // Only certain target extension types can be used in 'alloca'.
456 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
457 targetExtType && !targetExtType.supportsMemOps())
458 return emitOpError()
459 << "this target extension type cannot be used in alloca";
460
461 return success();
462}
463
464//===----------------------------------------------------------------------===//
465// LLVM::BrOp
466//===----------------------------------------------------------------------===//
467
468SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
469 assert(index == 0 && "invalid successor index");
470 return SuccessorOperands(getDestOperandsMutable());
471}
472
473//===----------------------------------------------------------------------===//
474// LLVM::CondBrOp
475//===----------------------------------------------------------------------===//
476
477SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
478 assert(index < getNumSuccessors() && "invalid successor index");
479 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
480 : getFalseDestOperandsMutable());
481}
482
483void CondBrOp::build(OpBuilder &builder, OperationState &result,
484 Value condition, Block *trueDest, ValueRange trueOperands,
485 Block *falseDest, ValueRange falseOperands,
486 std::optional<std::pair<uint32_t, uint32_t>> weights) {
487 DenseI32ArrayAttr weightsAttr;
488 if (weights)
489 weightsAttr =
490 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
491 static_cast<int32_t>(weights->second)});
492
493 build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
494 /*loop_annotation=*/{}, trueDest, falseDest);
495}
496
497//===----------------------------------------------------------------------===//
498// LLVM::SwitchOp
499//===----------------------------------------------------------------------===//
500
501void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
502 Block *defaultDestination, ValueRange defaultOperands,
503 DenseIntElementsAttr caseValues,
504 BlockRange caseDestinations,
505 ArrayRef<ValueRange> caseOperands,
506 ArrayRef<int32_t> branchWeights) {
507 DenseI32ArrayAttr weightsAttr;
508 if (!branchWeights.empty())
509 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
510
511 build(builder, result, value, defaultOperands, caseOperands, caseValues,
512 weightsAttr, defaultDestination, caseDestinations);
513}
514
515void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
516 Block *defaultDestination, ValueRange defaultOperands,
517 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
518 ArrayRef<ValueRange> caseOperands,
519 ArrayRef<int32_t> branchWeights) {
520 DenseIntElementsAttr caseValuesAttr;
521 if (!caseValues.empty()) {
522 ShapedType caseValueType = VectorType::get(
523 static_cast<int64_t>(caseValues.size()), value.getType());
524 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
525 }
526
527 build(builder, result, value, defaultDestination, defaultOperands,
528 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
529}
530
531void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
532 Block *defaultDestination, ValueRange defaultOperands,
533 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
534 ArrayRef<ValueRange> caseOperands,
535 ArrayRef<int32_t> branchWeights) {
536 DenseIntElementsAttr caseValuesAttr;
537 if (!caseValues.empty()) {
538 ShapedType caseValueType = VectorType::get(
539 static_cast<int64_t>(caseValues.size()), value.getType());
540 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
541 }
542
543 build(builder, result, value, defaultDestination, defaultOperands,
544 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
545}
546
547/// <cases> ::= `[` (case (`,` case )* )? `]`
548/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
549static ParseResult parseSwitchOpCases(
550 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
551 SmallVectorImpl<Block *> &caseDestinations,
553 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
554 if (failed(parser.parseLSquare()))
555 return failure();
556 if (succeeded(parser.parseOptionalRSquare()))
557 return success();
558 SmallVector<APInt> values;
559 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
560 auto parseCase = [&]() {
561 int64_t value = 0;
562 if (failed(parser.parseInteger(value)))
563 return failure();
564 values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
565
566 Block *destination;
568 SmallVector<Type> operandTypes;
569 if (parser.parseColon() || parser.parseSuccessor(destination))
570 return failure();
571 if (!parser.parseOptionalLParen()) {
573 /*allowResultNumber=*/false) ||
574 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
575 return failure();
576 }
577 caseDestinations.push_back(destination);
578 caseOperands.emplace_back(operands);
579 caseOperandTypes.emplace_back(operandTypes);
580 return success();
581 };
582 if (failed(parser.parseCommaSeparatedList(parseCase)))
583 return failure();
584
585 ShapedType caseValueType =
586 VectorType::get(static_cast<int64_t>(values.size()), flagType);
587 caseValues = DenseIntElementsAttr::get(caseValueType, values);
588 return parser.parseRSquare();
589}
590
591static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
592 DenseIntElementsAttr caseValues,
593 SuccessorRange caseDestinations,
594 OperandRangeRange caseOperands,
595 const TypeRangeRange &caseOperandTypes) {
596 p << '[';
597 p.printNewline();
598 if (!caseValues) {
599 p << ']';
600 return;
601 }
602
603 size_t index = 0;
604 llvm::interleave(
605 llvm::zip(caseValues, caseDestinations),
606 [&](auto i) {
607 p << " ";
608 p << std::get<0>(i);
609 p << ": ";
610 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
611 },
612 [&] {
613 p << ',';
614 p.printNewline();
615 });
616 p.printNewline();
617 p << ']';
618}
619
620LogicalResult SwitchOp::verify() {
621 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
622 (getCaseValues() &&
623 getCaseValues()->size() !=
624 static_cast<int64_t>(getCaseDestinations().size())))
625 return emitOpError("expects number of case values to match number of "
626 "case destinations");
627 if (getCaseValues() &&
628 getValue().getType() != getCaseValues()->getElementType())
629 return emitError("expects case value type to match condition value type");
630 return success();
631}
632
633SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
634 assert(index < getNumSuccessors() && "invalid successor index");
635 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
636 : getCaseOperandsMutable(index - 1));
637}
638
639//===----------------------------------------------------------------------===//
640// Code for LLVM::GEPOp.
641//===----------------------------------------------------------------------===//
642
643GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
644 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
645 getDynamicIndices());
646}
647
648/// Returns the elemental type of any LLVM-compatible vector type or self.
650 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
651 return vectorType.getElementType();
652 return type;
653}
654
655/// Destructures the 'indices' parameter into 'rawConstantIndices' and
656/// 'dynamicIndices', encoding the former in the process. In the process,
657/// dynamic indices which are used to index into a structure type are converted
658/// to constant indices when possible. To do this, the GEPs element type should
659/// be passed as first parameter.
661 SmallVectorImpl<int32_t> &rawConstantIndices,
662 SmallVectorImpl<Value> &dynamicIndices) {
663 for (const GEPArg &iter : indices) {
664 // If the thing we are currently indexing into is a struct we must turn
665 // any integer constants into constant indices. If this is not possible
666 // we don't do anything here. The verifier will catch it and emit a proper
667 // error. All other canonicalization is done in the fold method.
668 bool requiresConst = !rawConstantIndices.empty() &&
669 isa_and_nonnull<LLVMStructType>(currType);
670 if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
671 APInt intC;
672 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
673 intC.isSignedIntN(kGEPConstantBitWidth)) {
674 rawConstantIndices.push_back(intC.getSExtValue());
675 } else {
676 rawConstantIndices.push_back(GEPOp::kDynamicIndex);
677 dynamicIndices.push_back(val);
678 }
679 } else {
680 rawConstantIndices.push_back(cast<GEPConstantIndex>(iter));
681 }
682
683 // Skip for very first iteration of this loop. First index does not index
684 // within the aggregates, but is just a pointer offset.
685 if (rawConstantIndices.size() == 1 || !currType)
686 continue;
687
688 currType = TypeSwitch<Type, Type>(currType)
689 .Case<VectorType, LLVMArrayType>([](auto containerType) {
690 return containerType.getElementType();
691 })
692 .Case([&](LLVMStructType structType) -> Type {
693 int64_t memberIndex = rawConstantIndices.back();
694 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
695 structType.getBody().size())
696 return structType.getBody()[memberIndex];
697 return nullptr;
698 })
699 .Default(nullptr);
700 }
701}
702
703void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
704 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
705 GEPNoWrapFlags noWrapFlags,
706 ArrayRef<NamedAttribute> attributes) {
707 SmallVector<int32_t> rawConstantIndices;
708 SmallVector<Value> dynamicIndices;
709 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
710
711 result.addTypes(resultType);
712 result.addAttributes(attributes);
713 result.getOrAddProperties<Properties>().rawConstantIndices =
714 builder.getDenseI32ArrayAttr(rawConstantIndices);
715 result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
716 result.getOrAddProperties<Properties>().elem_type =
717 TypeAttr::get(elementType);
718 result.addOperands(basePtr);
719 result.addOperands(dynamicIndices);
720}
721
722void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
723 Type elementType, Value basePtr, ValueRange indices,
724 GEPNoWrapFlags noWrapFlags,
725 ArrayRef<NamedAttribute> attributes) {
726 build(builder, result, resultType, elementType, basePtr,
727 SmallVector<GEPArg>(indices), noWrapFlags, attributes);
728}
729
730static ParseResult
733 DenseI32ArrayAttr &rawConstantIndices) {
734 SmallVector<int32_t> constantIndices;
735
736 auto idxParser = [&]() -> ParseResult {
737 int32_t constantIndex;
738 OptionalParseResult parsedInteger =
739 parser.parseOptionalInteger(constantIndex);
740 if (parsedInteger.has_value()) {
741 if (failed(parsedInteger.value()))
742 return failure();
743 constantIndices.push_back(constantIndex);
744 return success();
745 }
746
747 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
748 return parser.parseOperand(indices.emplace_back());
749 };
750 if (parser.parseCommaSeparatedList(idxParser))
751 return failure();
752
753 rawConstantIndices =
754 DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
755 return success();
756}
757
758static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
760 DenseI32ArrayAttr rawConstantIndices) {
761 llvm::interleaveComma(
762 GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
764 if (Value val = llvm::dyn_cast_if_present<Value>(cst))
765 printer.printOperand(val);
766 else
767 printer << cast<IntegerAttr>(cst).getInt();
768 });
769}
770
771/// For the given `indices`, check if they comply with `baseGEPType`,
772/// especially check against LLVMStructTypes nested within.
773static LogicalResult
774verifyStructIndices(Type baseGEPType, unsigned indexPos,
777 if (indexPos >= indices.size())
778 // Stop searching
779 return success();
780
781 return TypeSwitch<Type, LogicalResult>(baseGEPType)
782 .Case([&](LLVMStructType structType) -> LogicalResult {
783 auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
784 if (!attr)
785 return emitOpError() << "expected index " << indexPos
786 << " indexing a struct to be constant";
787
788 int32_t gepIndex = attr.getInt();
789 ArrayRef<Type> elementTypes = structType.getBody();
790 if (gepIndex < 0 ||
791 static_cast<size_t>(gepIndex) >= elementTypes.size())
792 return emitOpError() << "index " << indexPos
793 << " indexing a struct is out of bounds";
794
795 // Instead of recursively going into every children types, we only
796 // dive into the one indexed by gepIndex.
797 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
799 })
800 .Case<VectorType, LLVMArrayType>(
801 [&](auto containerType) -> LogicalResult {
802 return verifyStructIndices(containerType.getElementType(),
803 indexPos + 1, indices, emitOpError);
804 })
805 .Default([&](auto otherType) -> LogicalResult {
806 return emitOpError()
807 << "type " << otherType << " cannot be indexed (index #"
808 << indexPos << ")";
809 });
810}
811
812/// Driver function around `verifyStructIndices`.
813static LogicalResult
818
819LogicalResult LLVM::GEPOp::verify() {
820 if (static_cast<size_t>(
821 llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
822 getDynamicIndices().size())
823 return emitOpError("expected as many dynamic indices as specified in '")
824 << getRawConstantIndicesAttrName().getValue() << "'";
825
826 if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
827 return emitOpError("'inbounds_flag' cannot be used directly.");
828
829 return verifyStructIndices(getElemType(), getIndices(),
830 [&] { return emitOpError(); });
831}
832
833//===----------------------------------------------------------------------===//
834// LoadOp
835//===----------------------------------------------------------------------===//
836
837void LoadOp::getEffects(
839 &effects) {
840 effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
841 // Volatile operations can have target-specific read-write effects on
842 // memory besides the one referred to by the pointer operand.
843 // Similarly, atomic operations that are monotonic or stricter cause
844 // synchronization that from a language point-of-view, are arbitrary
845 // read-writes into memory.
846 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
847 getOrdering() != AtomicOrdering::unordered)) {
848 effects.emplace_back(MemoryEffects::Write::get());
849 effects.emplace_back(MemoryEffects::Read::get());
850 }
851}
852
853/// Returns true if the given type is supported by atomic operations. All
854/// integer, float, and pointer types with a power-of-two bitsize and a minimal
855/// size of 8 bits are supported.
857 const DataLayout &dataLayout) {
858 if (!isa<IntegerType, LLVMPointerType>(type))
860 return false;
861
862 llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type);
863 if (bitWidth.isScalable())
864 return false;
865 // Needs to be at least 8 bits and a power of two.
866 return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
867}
868
869/// Verifies the attributes and the type of atomic memory access operations.
870template <typename OpTy>
871static LogicalResult
872verifyAtomicMemOp(OpTy memOp, Type valueType,
873 ArrayRef<AtomicOrdering> unsupportedOrderings) {
874 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
875 DataLayout dataLayout = DataLayout::closest(memOp);
876 if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
877 return memOp.emitOpError("unsupported type ")
878 << valueType << " for atomic access";
879 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
880 return memOp.emitOpError("unsupported ordering '")
881 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
882 if (!memOp.getAlignment())
883 return memOp.emitOpError("expected alignment for atomic access");
884 return success();
885 }
886 if (memOp.getSyncscope())
887 return memOp.emitOpError(
888 "expected syncscope to be null for non-atomic access");
889 return success();
890}
891
892LogicalResult LoadOp::verify() {
893 Type valueType = getResult().getType();
894 return verifyAtomicMemOp(*this, valueType,
895 {AtomicOrdering::release, AtomicOrdering::acq_rel});
896}
897
898void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
899 Value addr, unsigned alignment, bool isVolatile,
900 bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
901 AtomicOrdering ordering, StringRef syncscope) {
902 build(builder, state, type, addr,
903 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
904 isNonTemporal, isInvariant, isInvariantGroup, ordering,
905 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
906 /*dereferenceable=*/nullptr,
907 /*access_groups=*/nullptr,
908 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
909 /*tbaa=*/nullptr);
910}
911
912//===----------------------------------------------------------------------===//
913// StoreOp
914//===----------------------------------------------------------------------===//
915
916void StoreOp::getEffects(
918 &effects) {
919 effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
920 // Volatile operations can have target-specific read-write effects on
921 // memory besides the one referred to by the pointer operand.
922 // Similarly, atomic operations that are monotonic or stricter cause
923 // synchronization that from a language point-of-view, are arbitrary
924 // read-writes into memory.
925 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
926 getOrdering() != AtomicOrdering::unordered)) {
927 effects.emplace_back(MemoryEffects::Write::get());
928 effects.emplace_back(MemoryEffects::Read::get());
929 }
930}
931
932LogicalResult StoreOp::verify() {
933 Type valueType = getValue().getType();
934 return verifyAtomicMemOp(*this, valueType,
935 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
936}
937
938void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
939 Value addr, unsigned alignment, bool isVolatile,
940 bool isNonTemporal, bool isInvariantGroup,
941 AtomicOrdering ordering, StringRef syncscope) {
942 build(builder, state, value, addr,
943 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
944 isNonTemporal, isInvariantGroup, ordering,
945 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
946 /*access_groups=*/nullptr,
947 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
948}
949
950//===----------------------------------------------------------------------===//
951// CallOp
952//===----------------------------------------------------------------------===//
953
954/// Gets the MLIR Op-like result types of a LLVMFunctionType.
955static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
956 SmallVector<Type, 1> results;
957 Type resultType = calleeType.getReturnType();
958 if (!isa<LLVM::LLVMVoidType>(resultType))
959 results.push_back(resultType);
960 return results;
961}
962
963/// Gets the variadic callee type for a LLVMFunctionType.
964static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
965 return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
966}
967
968/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
969static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
970 ValueRange args) {
971 Type resultType;
972 if (results.empty())
973 resultType = LLVMVoidType::get(context);
974 else
975 resultType = results.front();
976 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
977 /*isVarArg=*/false);
978}
979
980void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
981 StringRef callee, ValueRange args) {
982 build(builder, state, results, builder.getStringAttr(callee), args);
983}
984
985void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
986 StringAttr callee, ValueRange args) {
987 build(builder, state, results, SymbolRefAttr::get(callee), args);
988}
989
990void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
991 FlatSymbolRefAttr callee, ValueRange args) {
992 assert(callee && "expected non-null callee in direct call builder");
993 build(builder, state, results,
994 /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
995 /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
996 /*memory_effects=*/nullptr,
997 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
998 /*noreturn=*/nullptr, /*returns_twice=*/nullptr, /*hot=*/nullptr,
999 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1000 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1001 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1002 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1003 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1004 /*save_reg_params=*/nullptr,
1005 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1006 /*default_func_attrs=*/nullptr,
1007 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1008 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1009 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1010 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1011 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1012 /*inline_hint=*/nullptr);
1013}
1014
1015void CallOp::build(OpBuilder &builder, OperationState &state,
1016 LLVMFunctionType calleeType, StringRef callee,
1017 ValueRange args) {
1018 build(builder, state, calleeType, builder.getStringAttr(callee), args);
1019}
1020
1021void CallOp::build(OpBuilder &builder, OperationState &state,
1022 LLVMFunctionType calleeType, StringAttr callee,
1023 ValueRange args) {
1024 build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
1025}
1026
1027void CallOp::build(OpBuilder &builder, OperationState &state,
1028 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1029 ValueRange args) {
1030 build(builder, state, getCallOpResultTypes(calleeType),
1031 getCallOpVarCalleeType(calleeType), callee, args,
1032 /*fastmathFlags=*/nullptr,
1033 /*CConv=*/nullptr,
1034 /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1035 /*convergent=*/nullptr,
1036 /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1037 /*noreturn=*/nullptr,
1038 /*returns_twice=*/nullptr, /*hot=*/nullptr,
1039 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1040 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1041 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1042 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1043 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1044 /*save_reg_params=*/nullptr,
1045 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1046 /*default_func_attrs=*/nullptr,
1047 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1048 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1049 /*access_groups=*/nullptr,
1050 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1051 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1052 /*inline_hint=*/nullptr);
1053}
1054
1055void CallOp::build(OpBuilder &builder, OperationState &state,
1056 LLVMFunctionType calleeType, ValueRange args) {
1057 build(builder, state, getCallOpResultTypes(calleeType),
1058 getCallOpVarCalleeType(calleeType),
1059 /*callee=*/nullptr, args,
1060 /*fastmathFlags=*/nullptr,
1061 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1062 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1063 /*noreturn=*/nullptr,
1064 /*returns_twice=*/nullptr, /*hot=*/nullptr,
1065 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1066 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1067 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1068 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1069 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1070 /*save_reg_params=*/nullptr,
1071 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1072 /*default_func_attrs=*/nullptr,
1073 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1074 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1075 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1076 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1077 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1078 /*inline_hint=*/nullptr);
1079}
1080
1081void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1082 ValueRange args) {
1083 auto calleeType = func.getFunctionType();
1084 build(builder, state, getCallOpResultTypes(calleeType),
1085 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
1086 /*fastmathFlags=*/nullptr,
1087 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1088 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1089 /*noreturn=*/nullptr,
1090 /*returns_twice=*/nullptr, /*hot=*/nullptr,
1091 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1092 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1093 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1094 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1095 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1096 /*save_reg_params=*/nullptr,
1097 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1098 /*default_func_attrs=*/nullptr,
1099 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1100 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1101 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1102 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1103 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1104 /*inline_hint=*/nullptr);
1105}
1106
1107CallInterfaceCallable CallOp::getCallableForCallee() {
1108 // Direct call.
1109 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1110 return calleeAttr;
1111 // Indirect call, callee Value is the first operand.
1112 return getOperand(0);
1113}
1114
1115void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1116 // Direct call.
1117 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1118 auto symRef = cast<SymbolRefAttr>(callee);
1119 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1120 }
1121 // Indirect call, callee Value is the first operand.
1122 return setOperand(0, cast<Value>(callee));
1123}
1124
1125Operation::operand_range CallOp::getArgOperands() {
1126 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1127}
1128
1129MutableOperandRange CallOp::getArgOperandsMutable() {
1130 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1131 getCalleeOperands().size());
1132}
1133
1134/// Verify that an inlinable callsite of a debug-info-bearing function in a
1135/// debug-info-bearing function has a debug location attached to it. This
1136/// mirrors an LLVM IR verifier.
1137static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1138 if (callee.isExternal())
1139 return success();
1140 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1141 if (!parentFunc)
1142 return success();
1143
1144 auto hasSubprogram = [](Operation *op) {
1145 return op->getLoc()
1146 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1147 nullptr;
1148 };
1149 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1150 return success();
1151 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1152 if (!containsLoc)
1153 return callOp.emitError()
1154 << "inlinable function call in a function with a DISubprogram "
1155 "location must have a debug location";
1156 return success();
1157}
1158
1159/// Verify that the parameter and return types of the variadic callee type match
1160/// the `callOp` argument and result types.
1161template <typename OpTy>
1162static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1163 std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1164 if (!varCalleeType)
1165 return success();
1166
1167 // Verify the variadic callee type is a variadic function type.
1168 if (!varCalleeType->isVarArg())
1169 return callOp.emitOpError(
1170 "expected var_callee_type to be a variadic function type");
1171
1172 // Verify the variadic callee type has at most as many parameters as the call
1173 // has argument operands.
1174 if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1175 return callOp.emitOpError("expected var_callee_type to have at most ")
1176 << callOp.getArgOperands().size() << " parameters";
1177
1178 // Verify the variadic callee type matches the call argument types.
1179 for (auto [paramType, operand] :
1180 llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1181 if (paramType != operand.getType())
1182 return callOp.emitOpError()
1183 << "var_callee_type parameter type mismatch: " << paramType
1184 << " != " << operand.getType();
1185
1186 // Verify the variadic callee type matches the call result type.
1187 if (!callOp.getNumResults()) {
1188 if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1189 return callOp.emitOpError("expected var_callee_type to return void");
1190 } else {
1191 if (callOp.getResult().getType() != varCalleeType->getReturnType())
1192 return callOp.emitOpError("var_callee_type return type mismatch: ")
1193 << varCalleeType->getReturnType()
1194 << " != " << callOp.getResult().getType();
1195 }
1196 return success();
1197}
1198
1199template <typename OpType>
1200static LogicalResult verifyOperandBundles(OpType &op) {
1201 OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1202 std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1203
1204 auto isStringAttr = [](Attribute tagAttr) {
1205 return isa<StringAttr>(tagAttr);
1206 };
1207 if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1208 return op.emitError("operand bundle tag must be a StringAttr");
1209
1210 size_t numOpBundles = opBundleOperands.size();
1211 size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1212 if (numOpBundles != numOpBundleTags)
1213 return op.emitError("expected ")
1214 << numOpBundles << " operand bundle tags, but actually got "
1215 << numOpBundleTags;
1216
1217 return success();
1218}
1219
1220LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
1221
1222LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1224 return failure();
1225
1226 // Type for the callee, we'll get it differently depending if it is a direct
1227 // or indirect call.
1228 Type fnType;
1229
1230 bool isIndirect = false;
1231
1232 // If this is an indirect call, the callee attribute is missing.
1233 FlatSymbolRefAttr calleeName = getCalleeAttr();
1234 if (!calleeName) {
1235 isIndirect = true;
1236 if (!getNumOperands())
1237 return emitOpError(
1238 "must have either a `callee` attribute or at least an operand");
1239 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1240 if (!ptrType)
1241 return emitOpError("indirect call expects a pointer as callee: ")
1242 << getOperand(0).getType();
1243
1244 return success();
1245 } else {
1246 Operation *callee =
1247 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1248 if (!callee)
1249 return emitOpError()
1250 << "'" << calleeName.getValue()
1251 << "' does not reference a symbol in the current scope";
1252 if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
1253 if (failed(verifyCallOpDebugInfo(*this, fn)))
1254 return failure();
1255 fnType = fn.getFunctionType();
1256 } else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
1257 fnType = ifunc.getIFuncType();
1258 } else if (isa<AliasOp>(callee)) {
1259 // Aliases can alias functions, so calling through an alias is valid.
1260 // The function type is determined by the call's operands and result
1261 // types.
1262 fnType = getCalleeFunctionType();
1263 } else {
1264 return emitOpError()
1265 << "'" << calleeName.getValue()
1266 << "' does not reference a valid LLVM function, IFunc, or alias";
1267 }
1268 }
1269
1270 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1271 if (!funcType)
1272 return emitOpError("callee does not have a functional type: ") << fnType;
1273
1274 if (funcType.isVarArg() && !getVarCalleeType())
1275 return emitOpError() << "missing var_callee_type attribute for vararg call";
1276
1277 // Verify that the operand and result types match the callee.
1278
1279 if (!funcType.isVarArg() &&
1280 funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1281 return emitOpError() << "incorrect number of operands ("
1282 << (getCalleeOperands().size() - isIndirect)
1283 << ") for callee (expecting: "
1284 << funcType.getNumParams() << ")";
1285
1286 if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1287 return emitOpError() << "incorrect number of operands ("
1288 << (getCalleeOperands().size() - isIndirect)
1289 << ") for varargs callee (expecting at least: "
1290 << funcType.getNumParams() << ")";
1291
1292 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1293 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1294 return emitOpError() << "operand type mismatch for operand " << i << ": "
1295 << getOperand(i + isIndirect).getType()
1296 << " != " << funcType.getParamType(i);
1297
1298 if (getNumResults() == 0 &&
1299 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1300 return emitOpError() << "expected function call to produce a value";
1301
1302 if (getNumResults() != 0 &&
1303 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1304 return emitOpError()
1305 << "calling function with void result must not produce values";
1306
1307 if (getNumResults() > 1)
1308 return emitOpError()
1309 << "expected LLVM function call to produce 0 or 1 result";
1310
1311 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1312 return emitOpError() << "result type mismatch: " << getResult().getType()
1313 << " != " << funcType.getReturnType();
1314
1315 return success();
1316}
1317
1318void CallOp::print(OpAsmPrinter &p) {
1319 auto callee = getCallee();
1320 bool isDirect = callee.has_value();
1321
1322 p << ' ';
1323
1324 // Print calling convention.
1325 if (getCConv() != LLVM::CConv::C)
1326 p << stringifyCConv(getCConv()) << ' ';
1327
1328 if (getTailCallKind() != LLVM::TailCallKind::None)
1329 p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1330
1331 // Print the direct callee if present as a function attribute, or an indirect
1332 // callee (first operand) otherwise.
1333 if (isDirect)
1334 p.printSymbolName(callee.value());
1335 else
1336 p << getOperand(0);
1337
1338 auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
1339 p << '(' << args << ')';
1340
1341 // Print the variadic callee type if the call is variadic.
1342 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1343 p << " vararg(" << *varCalleeType << ")";
1344
1345 if (!getOpBundleOperands().empty()) {
1346 p << " ";
1347 printOpBundles(p, *this, getOpBundleOperands(),
1348 getOpBundleOperands().getTypes(), getOpBundleTags());
1349 }
1350
1351 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1352 {getCalleeAttrName(), getTailCallKindAttrName(),
1353 getVarCalleeTypeAttrName(), getCConvAttrName(),
1354 getOperandSegmentSizesAttrName(),
1355 getOpBundleSizesAttrName(),
1356 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1357 getResAttrsAttrName()});
1358
1359 p << " : ";
1360 if (!isDirect)
1361 p << getOperand(0).getType() << ", ";
1362
1363 // Reconstruct the MLIR function type from operand and result types.
1365 p, args.getTypes(), getArgAttrsAttr(),
1366 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1367}
1368
1369/// Parses the type of a call operation and resolves the operands if the parsing
1370/// succeeds. Returns failure otherwise.
1372 OpAsmParser &parser, OperationState &result, bool isDirect,
1375 SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1376 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1377 SmallVector<Type> types;
1378 if (parser.parseColon())
1379 return failure();
1380 if (!isDirect) {
1381 types.emplace_back();
1382 if (parser.parseType(types.back()))
1383 return failure();
1384 if (parser.parseOptionalComma())
1385 return parser.emitError(
1386 trailingTypesLoc, "expected indirect call to have 2 trailing types");
1387 }
1388 SmallVector<Type> argTypes;
1389 SmallVector<Type> resTypes;
1390 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1391 resTypes, resultAttrs)) {
1392 if (isDirect)
1393 return parser.emitError(trailingTypesLoc,
1394 "expected direct call to have 1 trailing types");
1395 return parser.emitError(trailingTypesLoc,
1396 "expected trailing function type");
1397 }
1398
1399 if (resTypes.size() > 1)
1400 return parser.emitError(trailingTypesLoc,
1401 "expected function with 0 or 1 result");
1402 if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
1403 return parser.emitError(trailingTypesLoc,
1404 "expected a non-void result type");
1405
1406 // The head element of the types list matches the callee type for
1407 // indirect calls, while the types list is emtpy for direct calls.
1408 // Append the function input types to resolve the call operation
1409 // operands.
1410 llvm::append_range(types, argTypes);
1411 if (parser.resolveOperands(operands, types, parser.getNameLoc(),
1412 result.operands))
1413 return failure();
1414 if (!resTypes.empty())
1415 result.addTypes(resTypes);
1416
1417 return success();
1418}
1419
1420/// Parses an optional function pointer operand before the call argument list
1421/// for indirect calls, or stops parsing at the function identifier otherwise.
1422static ParseResult parseOptionalCallFuncPtr(
1423 OpAsmParser &parser,
1425 OpAsmParser::UnresolvedOperand funcPtrOperand;
1426 OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
1427 if (parseResult.has_value()) {
1428 if (failed(*parseResult))
1429 return *parseResult;
1430 operands.push_back(funcPtrOperand);
1431 }
1432 return success();
1433}
1434
1435static ParseResult resolveOpBundleOperands(
1436 OpAsmParser &parser, SMLoc loc, OperationState &state,
1438 ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1439 StringAttr opBundleSizesAttrName) {
1440 unsigned opBundleIndex = 0;
1441 for (const auto &[operands, types] :
1442 llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
1443 if (operands.size() != types.size())
1444 return parser.emitError(loc, "expected ")
1445 << operands.size()
1446 << " types for operand bundle operands for operand bundle #"
1447 << opBundleIndex << ", but actually got " << types.size();
1448 if (parser.resolveOperands(operands, types, loc, state.operands))
1449 return failure();
1450 }
1451
1452 SmallVector<int32_t> opBundleSizes;
1453 opBundleSizes.reserve(opBundleOperands.size());
1454 for (const auto &operands : opBundleOperands)
1455 opBundleSizes.push_back(operands.size());
1456
1457 state.addAttribute(
1458 opBundleSizesAttrName,
1459 DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
1460
1461 return success();
1462}
1463
1464// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1465// `(` ssa-use-list `)`
1466// ( `vararg(` var-callee-type `)` )?
1467// ( `[` op-bundles-list `]` )?
1468// attribute-dict? `:` (type `,`)? function-type
1469ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1470 SymbolRefAttr funcAttr;
1471 TypeAttr varCalleeType;
1474 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1475 ArrayAttr opBundleTags;
1476
1477 // Default to C Calling Convention if no keyword is provided.
1478 result.addAttribute(
1479 getCConvAttrName(result.name),
1480 CConvAttr::get(parser.getContext(),
1481 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
1482
1483 result.addAttribute(
1484 getTailCallKindAttrName(result.name),
1485 TailCallKindAttr::get(parser.getContext(),
1487 parser, LLVM::TailCallKind::None)));
1488
1489 // Parse a function pointer for indirect calls.
1490 if (parseOptionalCallFuncPtr(parser, operands))
1491 return failure();
1492 bool isDirect = operands.empty();
1493
1494 // Parse a function identifier for direct calls.
1495 if (isDirect)
1496 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1497 return failure();
1498
1499 // Parse the function arguments.
1500 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1501 return failure();
1502
1503 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1504 if (isVarArg) {
1505 StringAttr varCalleeTypeAttrName =
1506 CallOp::getVarCalleeTypeAttrName(result.name);
1507 if (parser.parseLParen().failed() ||
1508 parser
1509 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1510 result.attributes)
1511 .failed() ||
1512 parser.parseRParen().failed())
1513 return failure();
1514 }
1515
1516 SMLoc opBundlesLoc = parser.getCurrentLocation();
1517 if (std::optional<ParseResult> result = parseOpBundles(
1518 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1519 result && failed(*result))
1520 return failure();
1521 if (opBundleTags && !opBundleTags.empty())
1522 result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
1523 opBundleTags);
1524
1525 if (parser.parseOptionalAttrDict(result.attributes))
1526 return failure();
1527
1528 // Parse the trailing type list and resolve the operands.
1530 SmallVector<DictionaryAttr> resultAttrs;
1531 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1532 argAttrs, resultAttrs))
1533 return failure();
1535 parser.getBuilder(), result, argAttrs, resultAttrs,
1536 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1537 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1538 opBundleOperandTypes,
1539 getOpBundleSizesAttrName(result.name)))
1540 return failure();
1541
1542 int32_t numOpBundleOperands = 0;
1543 for (const auto &operands : opBundleOperands)
1544 numOpBundleOperands += operands.size();
1545
1546 result.addAttribute(
1547 CallOp::getOperandSegmentSizeAttr(),
1549 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1550 return success();
1551}
1552
1553LLVMFunctionType CallOp::getCalleeFunctionType() {
1554 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1555 return *varCalleeType;
1556 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1557}
1558
1559///===---------------------------------------------------------------------===//
1560/// LLVM::InvokeOp
1561///===---------------------------------------------------------------------===//
1562
1563void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1564 ValueRange ops, Block *normal, ValueRange normalOps,
1565 Block *unwind, ValueRange unwindOps) {
1566 auto calleeType = func.getFunctionType();
1567 build(builder, state, getCallOpResultTypes(calleeType),
1568 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1569 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1570 nullptr, nullptr, {}, {}, normal, unwind);
1571}
1572
1573void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1574 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1575 ValueRange normalOps, Block *unwind,
1576 ValueRange unwindOps) {
1577 build(builder, state, tys,
1578 /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr,
1579 /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {},
1580 normal, unwind);
1581}
1582
1583void InvokeOp::build(OpBuilder &builder, OperationState &state,
1584 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1585 ValueRange ops, Block *normal, ValueRange normalOps,
1586 Block *unwind, ValueRange unwindOps) {
1587 build(builder, state, getCallOpResultTypes(calleeType),
1588 getCallOpVarCalleeType(calleeType), callee, ops,
1589 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1590 nullptr, nullptr, {}, {}, normal, unwind);
1591}
1592
1593SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1594 assert(index < getNumSuccessors() && "invalid successor index");
1595 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1596 : getUnwindDestOperandsMutable());
1597}
1598
1599CallInterfaceCallable InvokeOp::getCallableForCallee() {
1600 // Direct call.
1601 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1602 return calleeAttr;
1603 // Indirect call, callee Value is the first operand.
1604 return getOperand(0);
1605}
1606
1607void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1608 // Direct call.
1609 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1610 auto symRef = cast<SymbolRefAttr>(callee);
1611 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1612 }
1613 // Indirect call, callee Value is the first operand.
1614 return setOperand(0, cast<Value>(callee));
1615}
1616
1617Operation::operand_range InvokeOp::getArgOperands() {
1618 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1619}
1620
1621MutableOperandRange InvokeOp::getArgOperandsMutable() {
1622 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1623 getCalleeOperands().size());
1624}
1625
1626LogicalResult InvokeOp::verify() {
1628 return failure();
1629
1630 Block *unwindDest = getUnwindDest();
1631 if (unwindDest->empty())
1632 return emitError("must have at least one operation in unwind destination");
1633
1634 // In unwind destination, first operation must be LandingpadOp
1635 if (!isa<LandingpadOp>(unwindDest->front()))
1636 return emitError("first operation in unwind destination should be a "
1637 "llvm.landingpad operation");
1638
1639 if (failed(verifyOperandBundles(*this)))
1640 return failure();
1641
1642 return success();
1643}
1644
1645void InvokeOp::print(OpAsmPrinter &p) {
1646 auto callee = getCallee();
1647 bool isDirect = callee.has_value();
1648
1649 p << ' ';
1650
1651 // Print calling convention.
1652 if (getCConv() != LLVM::CConv::C)
1653 p << stringifyCConv(getCConv()) << ' ';
1654
1655 // Either function name or pointer
1656 if (isDirect)
1657 p.printSymbolName(callee.value());
1658 else
1659 p << getOperand(0);
1660
1661 p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
1662 p << " to ";
1663 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1664 p << " unwind ";
1665 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1666
1667 // Print the variadic callee type if the invoke is variadic.
1668 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1669 p << " vararg(" << *varCalleeType << ")";
1670
1671 if (!getOpBundleOperands().empty()) {
1672 p << " ";
1673 printOpBundles(p, *this, getOpBundleOperands(),
1674 getOpBundleOperands().getTypes(), getOpBundleTags());
1675 }
1676
1677 p.printOptionalAttrDict((*this)->getAttrs(),
1678 {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1679 getCConvAttrName(), getVarCalleeTypeAttrName(),
1680 getOpBundleSizesAttrName(),
1681 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1682 getResAttrsAttrName()});
1683
1684 p << " : ";
1685 if (!isDirect)
1686 p << getOperand(0).getType() << ", ";
1688 p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
1689 getArgAttrsAttr(),
1690 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1691}
1692
1693// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1694// `(` ssa-use-list `)`
1695// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1696// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1697// ( `vararg(` var-callee-type `)` )?
1698// ( `[` op-bundles-list `]` )?
1699// attribute-dict? `:` (type `,`)?
1700// function-type-with-argument-attributes
1701ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1703 SymbolRefAttr funcAttr;
1704 TypeAttr varCalleeType;
1706 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1707 ArrayAttr opBundleTags;
1708 Block *normalDest, *unwindDest;
1709 SmallVector<Value, 4> normalOperands, unwindOperands;
1710 Builder &builder = parser.getBuilder();
1711
1712 // Default to C Calling Convention if no keyword is provided.
1713 result.addAttribute(
1714 getCConvAttrName(result.name),
1715 CConvAttr::get(parser.getContext(),
1716 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
1717
1718 // Parse a function pointer for indirect calls.
1719 if (parseOptionalCallFuncPtr(parser, operands))
1720 return failure();
1721 bool isDirect = operands.empty();
1722
1723 // Parse a function identifier for direct calls.
1724 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1725 return failure();
1726
1727 // Parse the function arguments.
1728 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1729 parser.parseKeyword("to") ||
1730 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1731 parser.parseKeyword("unwind") ||
1732 parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1733 return failure();
1734
1735 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1736 if (isVarArg) {
1737 StringAttr varCalleeTypeAttrName =
1738 InvokeOp::getVarCalleeTypeAttrName(result.name);
1739 if (parser.parseLParen().failed() ||
1740 parser
1741 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1742 result.attributes)
1743 .failed() ||
1744 parser.parseRParen().failed())
1745 return failure();
1746 }
1747
1748 SMLoc opBundlesLoc = parser.getCurrentLocation();
1749 if (std::optional<ParseResult> result = parseOpBundles(
1750 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1751 result && failed(*result))
1752 return failure();
1753 if (opBundleTags && !opBundleTags.empty())
1754 result.addAttribute(
1755 InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
1756 opBundleTags);
1757
1758 if (parser.parseOptionalAttrDict(result.attributes))
1759 return failure();
1760
1761 // Parse the trailing type list and resolve the function operands.
1763 SmallVector<DictionaryAttr> resultAttrs;
1764 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1765 argAttrs, resultAttrs))
1766 return failure();
1768 parser.getBuilder(), result, argAttrs, resultAttrs,
1769 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1770
1771 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1772 opBundleOperandTypes,
1773 getOpBundleSizesAttrName(result.name)))
1774 return failure();
1775
1776 result.addSuccessors({normalDest, unwindDest});
1777 result.addOperands(normalOperands);
1778 result.addOperands(unwindOperands);
1779
1780 int32_t numOpBundleOperands = 0;
1781 for (const auto &operands : opBundleOperands)
1782 numOpBundleOperands += operands.size();
1783
1784 result.addAttribute(
1785 InvokeOp::getOperandSegmentSizeAttr(),
1786 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
1787 static_cast<int32_t>(normalOperands.size()),
1788 static_cast<int32_t>(unwindOperands.size()),
1789 numOpBundleOperands}));
1790 return success();
1791}
1792
1793LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1794 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1795 return *varCalleeType;
1796 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1797}
1798
1799///===----------------------------------------------------------------------===//
1800/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1801///===----------------------------------------------------------------------===//
1802
1803LogicalResult LandingpadOp::verify() {
1804 Value value;
1805 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1806 if (!func.getPersonality())
1807 return emitError(
1808 "llvm.landingpad needs to be in a function with a personality");
1809 }
1810
1811 // Consistency of llvm.landingpad result types is checked in
1812 // LLVMFuncOp::verify().
1813
1814 if (!getCleanup() && getOperands().empty())
1815 return emitError("landingpad instruction expects at least one clause or "
1816 "cleanup attribute");
1817
1818 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1819 value = getOperand(idx);
1820 bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1821 if (isFilter) {
1822 // FIXME: Verify filter clauses when arrays are appropriately handled
1823 } else {
1824 // catch - global addresses only.
1825 // Bitcast ops should have global addresses as their args.
1826 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1827 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1828 continue;
1829 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1830 << "global addresses expected as operand to "
1831 "bitcast used in clauses for landingpad";
1832 }
1833 // ZeroOp and AddressOfOp allowed
1834 if (value.getDefiningOp<ZeroOp>())
1835 continue;
1836 if (value.getDefiningOp<AddressOfOp>())
1837 continue;
1838 return emitError("clause #")
1839 << idx << " is not a known constant - null, addressof, bitcast";
1840 }
1841 }
1842 return success();
1843}
1844
1845void LandingpadOp::print(OpAsmPrinter &p) {
1846 p << (getCleanup() ? " cleanup " : " ");
1847
1848 // Clauses
1849 for (auto value : getOperands()) {
1850 // Similar to llvm - if clause is an array type then it is filter
1851 // clause else catch clause
1852 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1853 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1854 << value.getType() << ") ";
1855 }
1856
1857 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1858
1859 p << ": " << getType();
1860}
1861
1862// <operation> ::= `llvm.landingpad` `cleanup`?
1863// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1864ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1865 // Check for cleanup
1866 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1867 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1868
1869 // Parse clauses with types
1870 while (succeeded(parser.parseOptionalLParen()) &&
1871 (succeeded(parser.parseOptionalKeyword("filter")) ||
1872 succeeded(parser.parseOptionalKeyword("catch")))) {
1874 Type ty;
1875 if (parser.parseOperand(operand) || parser.parseColon() ||
1876 parser.parseType(ty) ||
1877 parser.resolveOperand(operand, ty, result.operands) ||
1878 parser.parseRParen())
1879 return failure();
1880 }
1881
1882 Type type;
1883 if (parser.parseColon() || parser.parseType(type))
1884 return failure();
1885
1886 result.addTypes(type);
1887 return success();
1888}
1889
1890//===----------------------------------------------------------------------===//
1891// ExtractValueOp
1892//===----------------------------------------------------------------------===//
1893
1894/// Extract the type at `position` in the LLVM IR aggregate type
1895/// `containerType`. Each element of `position` is an index into a nested
1896/// aggregate type. Return the resulting type or emit an error.
1898 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1899 ArrayRef<int64_t> position) {
1900 Type llvmType = containerType;
1901 if (!isCompatibleType(containerType)) {
1902 emitError("expected LLVM IR Dialect type, got ") << containerType;
1903 return {};
1904 }
1905
1906 // Infer the element type from the structure type: iteratively step inside the
1907 // type by taking the element type, indexed by the position attribute for
1908 // structures. Check the position index before accessing, it is supposed to
1909 // be in bounds.
1910 for (int64_t idx : position) {
1911 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1912 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1913 emitError("position out of bounds: ") << idx;
1914 return {};
1915 }
1916 llvmType = arrayType.getElementType();
1917 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1918 if (idx < 0 ||
1919 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1920 emitError("position out of bounds: ") << idx;
1921 return {};
1922 }
1923 llvmType = structType.getBody()[idx];
1924 } else {
1925 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1926 return {};
1927 }
1928 }
1929 return llvmType;
1930}
1931
1932/// Extract the type at `position` in the wrapped LLVM IR aggregate type
1933/// `containerType`.
1935 ArrayRef<int64_t> position) {
1936 for (int64_t idx : position) {
1937 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1938 llvmType = structType.getBody()[idx];
1939 else
1940 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1941 }
1942 return llvmType;
1943}
1944
1945/// Extracts the element at the given index from an attribute. For
1946/// `ElementsAttr`, returns the element at the specified index, or `nullptr` if
1947/// the shaped type does not have rank 1. For `ArrayAttr`, returns the element
1948/// at the specified index. For `ZeroAttr`, `UndefAttr`, and `PoisonAttr`,
1949/// returns the attribute itself unchanged. Returns `nullptr` if the attribute
1950/// is not one of these types or if the index is out of bounds.
1952 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
1953 ShapedType shapedType = elementsAttr.getShapedType();
1954 if (!shapedType.hasRank() || shapedType.getRank() != 1)
1955 return nullptr;
1956 if (index < static_cast<size_t>(elementsAttr.getNumElements()))
1957 return elementsAttr.getValues<Attribute>()[index];
1958 return nullptr;
1959 }
1960 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1961 if (index < arrayAttr.getValue().size())
1962 return arrayAttr[index];
1963 return nullptr;
1964 }
1965 if (isa<ZeroAttr, UndefAttr, PoisonAttr>(attr))
1966 return attr;
1967 return nullptr;
1968}
1969
1970OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1971 if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
1972 SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
1973 newPos.append(getPosition().begin(), getPosition().end());
1974 setPosition(newPos);
1975 getContainerMutable().set(extractValueOp.getContainer());
1976 return getResult();
1977 }
1978
1979 Attribute containerAttr;
1980 if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
1981 for (int64_t pos : getPosition()) {
1982 containerAttr = extractElementAt(containerAttr, pos);
1983 if (!containerAttr)
1984 return nullptr;
1985 }
1986 return containerAttr;
1987 }
1988
1989 Value container = getContainer();
1990 ArrayRef<int64_t> extractPos = getPosition();
1991 while (auto insertValueOp = container.getDefiningOp<InsertValueOp>()) {
1992 ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1993 auto extractPosSize = extractPos.size();
1994 auto insertPosSize = insertPos.size();
1995
1996 // Case 1: Exact match of positions.
1997 if (extractPos == insertPos)
1998 return insertValueOp.getValue();
1999
2000 // Case 2: Insert position is a prefix of extract position. Continue
2001 // traversal with the inserted value. Example:
2002 // ```
2003 // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
2004 // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
2005 // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
2006 // %3 = llvm.insertvalue %2, %foo[0]
2007 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
2008 // %4 = llvm.extractvalue %3[0, 0]
2009 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
2010 // ```
2011 // In the above example, %4 is folded to %arg1.
2012 if (extractPosSize > insertPosSize &&
2013 extractPos.take_front(insertPosSize) == insertPos) {
2014 container = insertValueOp.getValue();
2015 extractPos = extractPos.drop_front(insertPosSize);
2016 continue;
2017 }
2018
2019 // Case 3: Try to continue the traversal with the container value.
2020
2021 // If extract position is a prefix of insert position, stop propagating back
2022 // as it will miss dependencies. For instance, %3 should not fold to %f0 in
2023 // the following example:
2024 // ```
2025 // %1 = llvm.insertvalue %f0, %0[0, 0] :
2026 // !llvm.array<4 x !llvm.array<4 x f32>>
2027 // %2 = llvm.insertvalue %arr, %1[0] :
2028 // !llvm.array<4 x !llvm.array<4 x f32>>
2029 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
2030 // ```
2031 if (insertPosSize > extractPosSize &&
2032 extractPos == insertPos.take_front(extractPosSize))
2033 break;
2034 // If neither a prefix, nor the exact position, we can extract out of the
2035 // value being inserted into. Moreover, we can try again if that operand
2036 // is itself an insertvalue expression.
2037 container = insertValueOp.getContainer();
2038 }
2039
2040 // We failed to resolve past this container either because it is not an
2041 // InsertValueOp, or it is an InsertValueOp that partially overlaps with the
2042 // value being extracted. Update to read from this container instead.
2043 if (container == getContainer())
2044 return {};
2045 setPosition(extractPos);
2046 getContainerMutable().assign(container);
2047 return getResult();
2048}
2049
2050LogicalResult ExtractValueOp::verify() {
2051 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
2053 emitError, getContainer().getType(), getPosition());
2054 if (!valueType)
2055 return failure();
2056
2057 if (getRes().getType() != valueType)
2058 return emitOpError() << "Type mismatch: extracting from "
2059 << getContainer().getType() << " should produce "
2060 << valueType << " but this op returns "
2061 << getRes().getType();
2062 return success();
2063}
2064
2065void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
2066 Value container, ArrayRef<int64_t> position) {
2067 build(builder, state,
2068 getInsertExtractValueElementType(container.getType(), position),
2069 container, builder.getAttr<DenseI64ArrayAttr>(position));
2070}
2071
2072//===----------------------------------------------------------------------===//
2073// InsertValueOp
2074//===----------------------------------------------------------------------===//
2075
2076namespace {
2077/// Update any ExtractValueOps using a given InsertValueOp to instead read from
2078/// the closest InsertValueOp in the chain leading up to the current op that
2079/// writes to the same member. This traversal could be done entirely in
2080/// ExtractValueOp::fold, but doing it here significantly speeds things up
2081/// because we can handle several ExtractValueOps with a single traversal.
2082/// For instance, in this example:
2083/// %i0 = llvm.insertvalue %v0, %undef[0]
2084/// %i1 = llvm.insertvalue %v1, %0[1]
2085/// ...
2086/// %i999 = llvm.insertvalue %v999, %998[999]
2087/// %e0 = llvm.extractvalue %i999[0]
2088/// %e1 = llvm.extractvalue %i999[1]
2089/// ...
2090/// %e999 = llvm.extractvalue %i999[999]
2091/// Individually running the folder on each extractvalue would require
2092/// traversing the insertvalue chain 1000 times, but running this pattern on the
2093/// InsertValueOp would allow us to achieve the same result with a single
2094/// traversal. The resulting IR after this pattern will then be:
2095/// %i0 = llvm.insertvalue %v0, %undef[0]
2096/// %i1 = llvm.insertvalue %v1, %0[1]
2097/// ...
2098/// %i999 = llvm.insertvalue %v999, %998[999]
2099/// %e0 = llvm.extractvalue %i0[0]
2100/// %e1 = llvm.extractvalue %i1[1]
2101/// ...
2102/// %e999 = llvm.extractvalue %i999[999]
2103struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
2105
2106 LogicalResult matchAndRewrite(InsertValueOp insertOp,
2107 PatternRewriter &rewriter) const override {
2108 bool changed = false;
2109 // Map each position in the top-level struct to the ExtractOps that read
2110 // from it. For the example in the doc-comment above this map will be empty
2111 // when we visit ops %i0 - %i998. For %i999, it will contain:
2112 // 0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 }
2114 auto insertBaseIdx = insertOp.getPosition()[0];
2115 for (auto &use : insertOp->getUses()) {
2116 if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
2117 auto baseIdx = extractOp.getPosition()[0];
2118 // We can skip reads of the member that insertOp writes to since they
2119 // will not be updated.
2120 if (baseIdx == insertBaseIdx)
2121 continue;
2122 posToExtractOps[baseIdx].push_back(extractOp);
2123 }
2124 }
2125 // Walk up the chain of insertions and try to resolve the remaining
2126 // extractions that access the same member.
2127 Value nextContainer = insertOp.getContainer();
2128 while (!posToExtractOps.empty()) {
2129 auto curInsert =
2130 dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
2131 if (!curInsert)
2132 break;
2133 nextContainer = curInsert.getContainer();
2134
2135 // Check if any extractions read the member written by this insertion.
2136 auto curInsertBaseIdx = curInsert.getPosition()[0];
2137 auto it = posToExtractOps.find(curInsertBaseIdx);
2138 if (it == posToExtractOps.end())
2139 continue;
2140
2141 // Update the ExtractOps to read from the current insertion.
2142 for (auto &extractOp : it->second) {
2143 rewriter.modifyOpInPlace(extractOp, [&] {
2144 extractOp.getContainerMutable().assign(curInsert);
2145 });
2146 }
2147 // The entry should never be empty if it exists, so if we are at this
2148 // point, set changed to true.
2149 assert(!it->second.empty());
2150 changed |= true;
2151 posToExtractOps.erase(it);
2152 }
2153 // There was no insertion along the chain that wrote the member accessed by
2154 // these extracts. So we can update them to use the top of the chain.
2155 for (auto &[baseIdx, extracts] : posToExtractOps) {
2156 for (auto &extractOp : extracts) {
2157 rewriter.modifyOpInPlace(extractOp, [&] {
2158 extractOp.getContainerMutable().assign(nextContainer);
2159 });
2160 }
2161 assert(!extracts.empty() && "Empty list in map");
2162 changed = true;
2163 }
2164 return success(changed);
2165 }
2166};
2167} // namespace
2168
2169void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2170 MLIRContext *context) {
2171 patterns.add<ResolveExtractValueSource>(context);
2172}
2173
2174/// Infer the value type from the container type and position.
2175static ParseResult
2177 Type containerType,
2178 DenseI64ArrayAttr position) {
2180 [&](StringRef msg) {
2181 return parser.emitError(parser.getCurrentLocation(), msg);
2182 },
2183 containerType, position.asArrayRef());
2184 return success(!!valueType);
2185}
2186
2187/// Nothing to print for an inferred type.
2189 Operation *op, Type valueType,
2190 Type containerType,
2191 DenseI64ArrayAttr position) {}
2192
2193LogicalResult InsertValueOp::verify() {
2194 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
2196 emitError, getContainer().getType(), getPosition());
2197 if (!valueType)
2198 return failure();
2199
2200 if (getValue().getType() != valueType)
2201 return emitOpError() << "Type mismatch: cannot insert "
2202 << getValue().getType() << " into "
2203 << getContainer().getType();
2204
2205 return success();
2206}
2207
2208//===----------------------------------------------------------------------===//
2209// ReturnOp
2210//===----------------------------------------------------------------------===//
2211
2212LogicalResult ReturnOp::verify() {
2213 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
2214 if (!parent)
2215 return success();
2216
2217 Type expectedType = parent.getFunctionType().getReturnType();
2218 if (llvm::isa<LLVMVoidType>(expectedType)) {
2219 if (!getArg())
2220 return success();
2221 InFlightDiagnostic diag = emitOpError("expected no operands");
2222 diag.attachNote(parent->getLoc()) << "when returning from function";
2223 return diag;
2224 }
2225 if (!getArg()) {
2226 if (llvm::isa<LLVMVoidType>(expectedType))
2227 return success();
2228 InFlightDiagnostic diag = emitOpError("expected 1 operand");
2229 diag.attachNote(parent->getLoc()) << "when returning from function";
2230 return diag;
2231 }
2232 if (expectedType != getArg().getType()) {
2233 InFlightDiagnostic diag = emitOpError("mismatching result types");
2234 diag.attachNote(parent->getLoc()) << "when returning from function";
2235 return diag;
2236 }
2237 return success();
2238}
2239
2240//===----------------------------------------------------------------------===//
2241// LLVM::AddressOfOp.
2242//===----------------------------------------------------------------------===//
2243
2244GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2245 return dyn_cast_or_null<GlobalOp>(
2246 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2247}
2248
2249LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2250 return dyn_cast_or_null<LLVMFuncOp>(
2251 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2252}
2253
2254AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
2255 return dyn_cast_or_null<AliasOp>(
2256 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2257}
2258
2259IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) {
2260 return dyn_cast_or_null<IFuncOp>(
2261 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2262}
2263
2264LogicalResult
2265AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2266 Operation *symbol =
2267 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
2268
2269 auto global = dyn_cast_or_null<GlobalOp>(symbol);
2270 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2271 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2272 auto ifunc = dyn_cast_or_null<IFuncOp>(symbol);
2273
2274 if (!global && !function && !alias && !ifunc)
2275 return emitOpError("must reference a global defined by 'llvm.mlir.global', "
2276 "'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'");
2277
2278 LLVMPointerType type = getType();
2279 if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
2280 (alias && alias.getAddrSpace() != type.getAddressSpace()))
2281 return emitOpError("pointer address space must match address space of the "
2282 "referenced global or alias");
2283
2284 return success();
2285}
2286
2287// AddressOfOp constant-folds to the global symbol name.
2288OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2289 return getGlobalNameAttr();
2290}
2291
2292//===----------------------------------------------------------------------===//
2293// LLVM::DSOLocalEquivalentOp
2294//===----------------------------------------------------------------------===//
2295
2296LLVMFuncOp
2297DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) {
2298 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
2299 parentLLVMModule(*this), getFunctionNameAttr()));
2300}
2301
2302AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) {
2303 return dyn_cast_or_null<AliasOp>(symbolTable.lookupSymbolIn(
2304 parentLLVMModule(*this), getFunctionNameAttr()));
2305}
2306
2307LogicalResult
2308DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2309 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
2310 getFunctionNameAttr());
2311 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2312 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2313
2314 if (!function && !alias)
2315 return emitOpError(
2316 "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'");
2317
2318 if (alias) {
2319 if (alias.getInitializer()
2320 .walk([&](AddressOfOp addrOp) {
2321 if (addrOp.getGlobal(symbolTable))
2322 return WalkResult::interrupt();
2323 return WalkResult::advance();
2324 })
2325 .wasInterrupted())
2326 return emitOpError("must reference an alias to a function");
2327 }
2328
2329 if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) ||
2330 (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak))
2331 return emitOpError(
2332 "target function with 'extern_weak' linkage not allowed");
2333
2334 return success();
2335}
2336
2337/// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent
2338/// attribute.
2339OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) {
2340 return DSOLocalEquivalentAttr::get(getContext(), getFunctionNameAttr());
2341}
2342
2343//===----------------------------------------------------------------------===//
2344// Verifier for LLVM::ComdatOp.
2345//===----------------------------------------------------------------------===//
2346
2347void ComdatOp::build(OpBuilder &builder, OperationState &result,
2348 StringRef symName) {
2349 result.addAttribute(getSymNameAttrName(result.name),
2350 builder.getStringAttr(symName));
2351 Region *body = result.addRegion();
2352 body->emplaceBlock();
2353}
2354
2355LogicalResult ComdatOp::verifyRegions() {
2356 Region &body = getBody();
2357 for (Operation &op : body.getOps())
2358 if (!isa<ComdatSelectorOp>(op))
2359 return op.emitError(
2360 "only comdat selector symbols can appear in a comdat region");
2361
2362 return success();
2363}
2364
2365//===----------------------------------------------------------------------===//
2366// Builder, printer and verifier for LLVM::GlobalOp.
2367//===----------------------------------------------------------------------===//
2368
2369void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2370 bool isConstant, Linkage linkage, StringRef name,
2371 Attribute value, uint64_t alignment, unsigned addrSpace,
2372 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2374 ArrayRef<Attribute> dbgExprs) {
2375 result.addAttribute(getSymNameAttrName(result.name),
2376 builder.getStringAttr(name));
2377 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
2378 if (isConstant)
2379 result.addAttribute(getConstantAttrName(result.name),
2380 builder.getUnitAttr());
2381 if (value)
2382 result.addAttribute(getValueAttrName(result.name), value);
2383 if (dsoLocal)
2384 result.addAttribute(getDsoLocalAttrName(result.name),
2385 builder.getUnitAttr());
2386 if (threadLocal)
2387 result.addAttribute(getThreadLocal_AttrName(result.name),
2388 builder.getUnitAttr());
2389 if (comdat)
2390 result.addAttribute(getComdatAttrName(result.name), comdat);
2391
2392 // Only add an alignment attribute if the "alignment" input
2393 // is different from 0. The value must also be a power of two, but
2394 // this is tested in GlobalOp::verify, not here.
2395 if (alignment != 0)
2396 result.addAttribute(getAlignmentAttrName(result.name),
2397 builder.getI64IntegerAttr(alignment));
2398
2399 result.addAttribute(getLinkageAttrName(result.name),
2400 LinkageAttr::get(builder.getContext(), linkage));
2401 if (addrSpace != 0)
2402 result.addAttribute(getAddrSpaceAttrName(result.name),
2403 builder.getI32IntegerAttr(addrSpace));
2404 result.attributes.append(attrs.begin(), attrs.end());
2405
2406 if (!dbgExprs.empty())
2407 result.addAttribute(getDbgExprsAttrName(result.name),
2408 ArrayAttr::get(builder.getContext(), dbgExprs));
2409
2410 result.addRegion();
2411}
2412
2413template <typename OpType>
2414static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op) {
2415 p << ' ' << stringifyLinkage(op.getLinkage()) << ' ';
2416 StringRef visibility = stringifyVisibility(op.getVisibility_());
2417 if (!visibility.empty())
2418 p << visibility << ' ';
2419 if (op.getThreadLocal_())
2420 p << "thread_local ";
2421 if (auto unnamedAddr = op.getUnnamedAddr()) {
2422 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2423 if (!str.empty())
2424 p << str << ' ';
2425 }
2426}
2427
2428void GlobalOp::print(OpAsmPrinter &p) {
2430 if (getConstant())
2431 p << "constant ";
2432 p.printSymbolName(getSymName());
2433 p << '(';
2434 if (auto value = getValueOrNull())
2435 p.printAttribute(value);
2436 p << ')';
2437 if (auto comdat = getComdat())
2438 p << " comdat(" << *comdat << ')';
2439
2440 // Note that the alignment attribute is printed using the
2441 // default syntax here, even though it is an inherent attribute
2442 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2443 p.printOptionalAttrDict((*this)->getAttrs(),
2444 {SymbolTable::getSymbolAttrName(),
2445 getGlobalTypeAttrName(), getConstantAttrName(),
2446 getValueAttrName(), getLinkageAttrName(),
2447 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2448 getVisibility_AttrName(), getComdatAttrName()});
2449
2450 // Print the trailing type unless it's a string global.
2451 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
2452 return;
2453 p << " : " << getType();
2454
2455 Region &initializer = getInitializerRegion();
2456 if (!initializer.empty()) {
2457 p << ' ';
2458 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
2459 }
2460}
2461
2462static LogicalResult verifyComdat(Operation *op,
2463 std::optional<SymbolRefAttr> attr) {
2464 if (!attr)
2465 return success();
2466
2467 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
2468 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
2469 return op->emitError() << "expected comdat symbol";
2470
2471 return success();
2472}
2473
2474static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2476 // Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
2477 // block to be removed by canonicalizer's region simplify pass, which needs to
2478 // be dialect aware to allow extra constraints to be described.
2479 WalkResult res = funcOp.walk([&](BlockTagOp blockTagOp) {
2480 if (blockTags.contains(blockTagOp.getTag())) {
2481 blockTagOp.emitError()
2482 << "duplicate block tag '" << blockTagOp.getTag().getId()
2483 << "' in the same function: ";
2484 return WalkResult::interrupt();
2485 }
2486 blockTags.insert(blockTagOp.getTag());
2487 return WalkResult::advance();
2488 });
2489
2490 return failure(res.wasInterrupted());
2491}
2492
2493/// Parse common attributes that might show up in the same order in both
2494/// GlobalOp and AliasOp.
2495template <typename OpType>
2496static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser,
2498 MLIRContext *ctx = parser.getContext();
2499 // Parse optional linkage, default to External.
2500 result.addAttribute(
2501 OpType::getLinkageAttrName(result.name),
2502 LLVM::LinkageAttr::get(ctx, parseOptionalLLVMKeyword<Linkage>(
2503 parser, LLVM::Linkage::External)));
2504
2505 // Parse optional visibility, default to Default.
2506 result.addAttribute(OpType::getVisibility_AttrName(result.name),
2509 parser, LLVM::Visibility::Default)));
2510
2511 if (succeeded(parser.parseOptionalKeyword("thread_local")))
2512 result.addAttribute(OpType::getThreadLocal_AttrName(result.name),
2513 parser.getBuilder().getUnitAttr());
2514
2515 // Parse optional UnnamedAddr, default to None.
2516 result.addAttribute(OpType::getUnnamedAddrAttrName(result.name),
2519 parser, LLVM::UnnamedAddr::None)));
2520
2521 return success();
2522}
2523
2524// operation ::= `llvm.mlir.global` linkage? visibility?
2525// (`unnamed_addr` | `local_unnamed_addr`)?
2526// `thread_local`? `constant`? `@` identifier
2527// `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2528// attribute-list? (`:` type)? region?
2529//
2530// The type can be omitted for string attributes, in which case it will be
2531// inferred from the value of the string as [strlen(value) x i8].
2532ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2533 // Call into common parsing between GlobalOp and AliasOp.
2535 return failure();
2536
2537 if (succeeded(parser.parseOptionalKeyword("constant")))
2538 result.addAttribute(getConstantAttrName(result.name),
2539 parser.getBuilder().getUnitAttr());
2540
2541 StringAttr name;
2542 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2543 result.attributes) ||
2544 parser.parseLParen())
2545 return failure();
2546
2547 Attribute value;
2548 if (parser.parseOptionalRParen()) {
2549 if (parser.parseAttribute(value, getValueAttrName(result.name),
2550 result.attributes) ||
2551 parser.parseRParen())
2552 return failure();
2553 }
2554
2555 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2556 SymbolRefAttr comdat;
2557 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2558 parser.parseRParen())
2559 return failure();
2560
2561 result.addAttribute(getComdatAttrName(result.name), comdat);
2562 }
2563
2565 if (parser.parseOptionalAttrDict(result.attributes) ||
2566 parser.parseOptionalColonTypeList(types))
2567 return failure();
2568
2569 if (types.size() > 1)
2570 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2571
2572 Region &initRegion = *result.addRegion();
2573 if (types.empty()) {
2574 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2575 MLIRContext *context = parser.getContext();
2576 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2577 strAttr.getValue().size());
2578 types.push_back(arrayType);
2579 } else {
2580 return parser.emitError(parser.getNameLoc(),
2581 "type can only be omitted for string globals");
2582 }
2583 } else {
2584 OptionalParseResult parseResult =
2585 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2586 /*argTypes=*/{});
2587 if (parseResult.has_value() && failed(*parseResult))
2588 return failure();
2589 }
2590
2591 result.addAttribute(getGlobalTypeAttrName(result.name),
2592 TypeAttr::get(types[0]));
2593 return success();
2594}
2595
2596static bool isZeroAttribute(Attribute value) {
2597 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2598 return intValue.getValue().isZero();
2599 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2600 return fpValue.getValue().isZero();
2601 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
2602 return isZeroAttribute(splatValue.getSplatValue<Attribute>());
2603 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2604 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2605 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2606 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2607 return false;
2608}
2609
2610LogicalResult GlobalOp::verify() {
2611 bool validType = isCompatibleOuterType(getType())
2612 ? !llvm::isa<LLVMVoidType, TokenType, LLVMMetadataType,
2613 LLVMLabelType>(getType())
2614 : llvm::isa<PointerElementTypeInterface>(getType());
2615 if (!validType)
2616 return emitOpError(
2617 "expects type to be a valid element type for an LLVM global");
2618 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2619 return emitOpError("must appear at the module level");
2620
2621 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2622 auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2623 IntegerType elementType =
2624 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2625 if (!elementType || elementType.getWidth() != 8 ||
2626 type.getNumElements() != strAttr.getValue().size())
2627 return emitOpError(
2628 "requires an i8 array type of the length equal to that of the string "
2629 "attribute");
2630 }
2631
2632 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2633 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2634 return emitOpError()
2635 << "this target extension type cannot be used in a global";
2636
2637 if (Attribute value = getValueOrNull())
2638 return emitOpError() << "global with target extension type can only be "
2639 "initialized with zero-initializer";
2640 }
2641
2642 if (getLinkage() == Linkage::Common) {
2643 if (Attribute value = getValueOrNull()) {
2644 if (!isZeroAttribute(value)) {
2645 return emitOpError()
2646 << "expected zero value for '"
2647 << stringifyLinkage(Linkage::Common) << "' linkage";
2648 }
2649 }
2650 }
2651
2652 if (getLinkage() == Linkage::Appending) {
2653 if (!llvm::isa<LLVMArrayType>(getType())) {
2654 return emitOpError() << "expected array type for '"
2655 << stringifyLinkage(Linkage::Appending)
2656 << "' linkage";
2657 }
2658 }
2659
2660 if (failed(verifyComdat(*this, getComdat())))
2661 return failure();
2662
2663 std::optional<uint64_t> alignAttr = getAlignment();
2664 if (alignAttr.has_value()) {
2665 uint64_t value = alignAttr.value();
2666 if (!llvm::isPowerOf2_64(value))
2667 return emitError() << "alignment attribute is not a power of 2";
2668 }
2669
2670 return success();
2671}
2672
2673LogicalResult GlobalOp::verifyRegions() {
2674 if (Block *b = getInitializerBlock()) {
2675 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2676 if (ret.operand_type_begin() == ret.operand_type_end())
2677 return emitOpError("initializer region cannot return void");
2678 if (*ret.operand_type_begin() != getType())
2679 return emitOpError("initializer region type ")
2680 << *ret.operand_type_begin() << " does not match global type "
2681 << getType();
2682
2683 for (Operation &op : *b) {
2684 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2685 if (!iface || !iface.hasNoEffect())
2686 return op.emitError()
2687 << "ops with side effects not allowed in global initializers";
2688 }
2689
2690 if (getValueOrNull())
2691 return emitOpError("cannot have both initializer value and region");
2692 }
2693
2694 return success();
2695}
2696
2697//===----------------------------------------------------------------------===//
2698// LLVM::GlobalCtorsOp
2699//===----------------------------------------------------------------------===//
2700
2701static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
2702 if (data.empty())
2703 return success();
2704
2705 if (llvm::all_of(data.getAsRange<Attribute>(), [](Attribute v) {
2706 return isa<FlatSymbolRefAttr, ZeroAttr>(v);
2707 }))
2708 return success();
2709 return op->emitError("data element must be symbol or #llvm.zero");
2710}
2711
2712LogicalResult
2713GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2714 for (Attribute ctor : getCtors()) {
2715 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2716 symbolTable)))
2717 return failure();
2718 }
2719 return success();
2720}
2721
2722LogicalResult GlobalCtorsOp::verify() {
2723 if (checkGlobalXtorData(*this, getData()).failed())
2724 return failure();
2725
2726 if (getCtors().size() == getPriorities().size() &&
2727 getCtors().size() == getData().size())
2728 return success();
2729 return emitError(
2730 "ctors, priorities, and data must have the same number of elements");
2731}
2732
2733//===----------------------------------------------------------------------===//
2734// LLVM::GlobalDtorsOp
2735//===----------------------------------------------------------------------===//
2736
2737LogicalResult
2738GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2739 for (Attribute dtor : getDtors()) {
2740 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2741 symbolTable)))
2742 return failure();
2743 }
2744 return success();
2745}
2746
2747LogicalResult GlobalDtorsOp::verify() {
2748 if (checkGlobalXtorData(*this, getData()).failed())
2749 return failure();
2750
2751 if (getDtors().size() == getPriorities().size() &&
2752 getDtors().size() == getData().size())
2753 return success();
2754 return emitError(
2755 "dtors, priorities, and data must have the same number of elements");
2756}
2757
2758//===----------------------------------------------------------------------===//
2759// Builder, printer and verifier for LLVM::AliasOp.
2760//===----------------------------------------------------------------------===//
2761
2762void AliasOp::build(OpBuilder &builder, OperationState &result, Type type,
2763 Linkage linkage, StringRef name, bool dsoLocal,
2764 bool threadLocal, ArrayRef<NamedAttribute> attrs) {
2765 result.addAttribute(getSymNameAttrName(result.name),
2766 builder.getStringAttr(name));
2767 result.addAttribute(getAliasTypeAttrName(result.name), TypeAttr::get(type));
2768 if (dsoLocal)
2769 result.addAttribute(getDsoLocalAttrName(result.name),
2770 builder.getUnitAttr());
2771 if (threadLocal)
2772 result.addAttribute(getThreadLocal_AttrName(result.name),
2773 builder.getUnitAttr());
2774
2775 result.addAttribute(getLinkageAttrName(result.name),
2776 LinkageAttr::get(builder.getContext(), linkage));
2777 result.attributes.append(attrs.begin(), attrs.end());
2778
2779 result.addRegion();
2780}
2781
2782void AliasOp::print(OpAsmPrinter &p) {
2784
2785 p.printSymbolName(getSymName());
2786 p.printOptionalAttrDict((*this)->getAttrs(),
2787 {SymbolTable::getSymbolAttrName(),
2788 getAliasTypeAttrName(), getLinkageAttrName(),
2789 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2790 getVisibility_AttrName()});
2791
2792 // Print the trailing type.
2793 p << " : " << getType() << ' ';
2794 // Print the initializer region.
2795 p.printRegion(getInitializerRegion(), /*printEntryBlockArgs=*/false);
2796}
2797
2798// operation ::= `llvm.mlir.alias` linkage? visibility?
2799// (`unnamed_addr` | `local_unnamed_addr`)?
2800// `thread_local`? `@` identifier
2801// `(` attribute? `)`
2802// attribute-list? `:` type region
2803//
2804ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
2805 // Call into common parsing between GlobalOp and AliasOp.
2807 return failure();
2808
2809 StringAttr name;
2810 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2811 result.attributes))
2812 return failure();
2813
2815 if (parser.parseOptionalAttrDict(result.attributes) ||
2816 parser.parseOptionalColonTypeList(types))
2817 return failure();
2818
2819 if (types.size() > 1)
2820 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2821
2822 Region &initRegion = *result.addRegion();
2823 if (parser.parseRegion(initRegion).failed())
2824 return failure();
2825
2826 result.addAttribute(getAliasTypeAttrName(result.name),
2827 TypeAttr::get(types[0]));
2828 return success();
2829}
2830
2831LogicalResult AliasOp::verify() {
2832 bool validType = isCompatibleOuterType(getType())
2833 ? !llvm::isa<LLVMVoidType, TokenType, LLVMMetadataType,
2834 LLVMLabelType>(getType())
2835 : llvm::isa<PointerElementTypeInterface>(getType());
2836 if (!validType)
2837 return emitOpError(
2838 "expects type to be a valid element type for an LLVM global alias");
2839
2840 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2841 switch (getLinkage()) {
2842 case Linkage::External:
2843 case Linkage::Internal:
2844 case Linkage::Private:
2845 case Linkage::Weak:
2846 case Linkage::WeakODR:
2847 case Linkage::Linkonce:
2848 case Linkage::LinkonceODR:
2849 case Linkage::AvailableExternally:
2850 break;
2851 default:
2852 return emitOpError()
2853 << "'" << stringifyLinkage(getLinkage())
2854 << "' linkage not supported in aliases, available options: private, "
2855 "internal, linkonce, weak, linkonce_odr, weak_odr, external or "
2856 "available_externally";
2857 }
2858
2859 return success();
2860}
2861
2862LogicalResult AliasOp::verifyRegions() {
2863 Block &b = getInitializerBlock();
2864 auto ret = cast<ReturnOp>(b.getTerminator());
2865 if (ret.getNumOperands() == 0 ||
2866 !isa<LLVM::LLVMPointerType>(ret.getOperand(0).getType()))
2867 return emitOpError("initializer region must always return a pointer");
2868
2869 for (Operation &op : b) {
2870 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2871 if (!iface || !iface.hasNoEffect())
2872 return op.emitError()
2873 << "ops with side effects are not allowed in alias initializers";
2874 }
2875
2876 return success();
2877}
2878
2879unsigned AliasOp::getAddrSpace() {
2880 Block &initializer = getInitializerBlock();
2881 auto ret = cast<ReturnOp>(initializer.getTerminator());
2882 auto ptrTy = cast<LLVMPointerType>(ret.getOperand(0).getType());
2883 return ptrTy.getAddressSpace();
2884}
2885
2886//===----------------------------------------------------------------------===//
2887// IFuncOp
2888//===----------------------------------------------------------------------===//
2889
2890void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2891 Type iFuncType, StringRef resolverName, Type resolverType,
2892 Linkage linkage, LLVM::Visibility visibility) {
2893 return build(builder, result, name, iFuncType, resolverName, resolverType,
2894 linkage, /*dso_local=*/false, /*address_space=*/0,
2895 UnnamedAddr::None, visibility);
2896}
2897
2898LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2899 Operation *symbol =
2900 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr());
2901 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2902 auto resolver = dyn_cast<LLVMFuncOp>(symbol);
2903 auto alias = dyn_cast<AliasOp>(symbol);
2904 while (alias) {
2905 Block &initBlock = alias.getInitializerBlock();
2906 auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
2907 auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
2908 // FIXME: This is a best effort solution. The AliasOp body might be more
2909 // complex and in that case we bail out with success. To completely match
2910 // the LLVM IR logic it would be necessary to implement proper alias and
2911 // cast stripping.
2912 if (!addrOp)
2913 return success();
2914 resolver = addrOp.getFunction(symbolTable);
2915 alias = addrOp.getAlias(symbolTable);
2916 }
2917 if (!resolver)
2918 return emitOpError("must have a function resolver");
2919 Linkage linkage = resolver.getLinkage();
2920 if (resolver.isExternal() || linkage == Linkage::AvailableExternally)
2921 return emitOpError("resolver must be a definition");
2922 if (!isa<LLVMPointerType>(resolver.getFunctionType().getReturnType()))
2923 return emitOpError("resolver must return a pointer");
2924 auto resolverPtr = dyn_cast<LLVMPointerType>(getResolverType());
2925 if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace())
2926 return emitOpError("resolver has incorrect type");
2927 return success();
2928}
2929
2930LogicalResult IFuncOp::verify() {
2931 switch (getLinkage()) {
2932 case Linkage::External:
2933 case Linkage::Internal:
2934 case Linkage::Private:
2935 case Linkage::Weak:
2936 case Linkage::WeakODR:
2937 case Linkage::Linkonce:
2938 case Linkage::LinkonceODR:
2939 break;
2940 default:
2941 return emitOpError() << "'" << stringifyLinkage(getLinkage())
2942 << "' linkage not supported in ifuncs, available "
2943 "options: private, internal, linkonce, weak, "
2944 "linkonce_odr, weak_odr, or external linkage";
2945 }
2946 return success();
2947}
2948
2949//===----------------------------------------------------------------------===//
2950// ShuffleVectorOp
2951//===----------------------------------------------------------------------===//
2952
2953void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2954 Value v2, DenseI32ArrayAttr mask,
2956 auto containerType = v1.getType();
2957 auto vType = LLVM::getVectorType(
2958 cast<VectorType>(containerType).getElementType(), mask.size(),
2959 LLVM::isScalableVectorType(containerType));
2960 build(builder, state, vType, v1, v2, mask);
2961 state.addAttributes(attrs);
2962}
2963
2964void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2965 Value v2, ArrayRef<int32_t> mask) {
2966 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2967}
2968
2969/// Build the result type of a shuffle vector operation.
2970static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2971 Type &resType, DenseI32ArrayAttr mask) {
2972 if (!LLVM::isCompatibleVectorType(v1Type))
2973 return parser.emitError(parser.getCurrentLocation(),
2974 "expected an LLVM compatible vector type");
2975 resType =
2976 LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
2977 mask.size(), LLVM::isScalableVectorType(v1Type));
2978 return success();
2979}
2980
2981/// Nothing to do when the result type is inferred.
2982static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2983 Type resType, DenseI32ArrayAttr mask) {}
2984
2985LogicalResult ShuffleVectorOp::verify() {
2986 if (LLVM::isScalableVectorType(getV1().getType()) &&
2987 llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2988 return emitOpError("expected a splat operation for scalable vectors");
2989 return success();
2990}
2991
2992// Folding for shufflevector op when v1 is single element 1D vector
2993// and the mask is a single zero. OpFoldResult will be v1 in this case.
2994OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
2995 // Check if operand 0 is a single element vector.
2996 auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
2997 if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
2998 return {};
2999 // Check if the mask is a single zero.
3000 // Note: The mask is guaranteed to be non-empty.
3001 if (getMask().size() != 1 || getMask()[0] != 0)
3002 return {};
3003 return getV1();
3004}
3005
3006//===----------------------------------------------------------------------===//
3007// Implementations for LLVM::LLVMFuncOp.
3008//===----------------------------------------------------------------------===//
3009
3010// Add the entry block to the function.
3011Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
3012 assert(empty() && "function already has an entry block");
3013 OpBuilder::InsertionGuard g(builder);
3014 Block *entry = builder.createBlock(&getBody());
3015
3016 // FIXME: Allow passing in proper locations for the entry arguments.
3017 LLVMFunctionType type = getFunctionType();
3018 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
3019 entry->addArgument(type.getParamType(i), getLoc());
3020 return entry;
3021}
3022
3023void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
3024 StringRef name, Type type, LLVM::Linkage linkage,
3025 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
3027 ArrayRef<DictionaryAttr> argAttrs,
3028 std::optional<uint64_t> functionEntryCount) {
3029 result.addRegion();
3031 builder.getStringAttr(name));
3032 result.addAttribute(getFunctionTypeAttrName(result.name),
3033 TypeAttr::get(type));
3034 result.addAttribute(getLinkageAttrName(result.name),
3035 LinkageAttr::get(builder.getContext(), linkage));
3036 result.addAttribute(getCConvAttrName(result.name),
3037 CConvAttr::get(builder.getContext(), cconv));
3038 result.attributes.append(attrs.begin(), attrs.end());
3039 if (dsoLocal)
3040 result.addAttribute(getDsoLocalAttrName(result.name),
3041 builder.getUnitAttr());
3042 if (comdat)
3043 result.addAttribute(getComdatAttrName(result.name), comdat);
3044 if (functionEntryCount)
3045 result.addAttribute(getFunctionEntryCountAttrName(result.name),
3046 builder.getI64IntegerAttr(functionEntryCount.value()));
3047#ifndef NDEBUG
3048 std::optional<NamedAttribute> duplicate = result.attributes.findDuplicate();
3049 if (duplicate.has_value()) {
3050 llvm::report_fatal_error(
3051 Twine("LLVMFuncOp propagated an attribute that is meant "
3052 "to be constructed by the builder: ") +
3053 duplicate->getName().str());
3054 }
3055#endif
3056 if (argAttrs.empty())
3057 return;
3058
3059 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
3060 "expected as many argument attribute lists as arguments");
3062 builder, result, argAttrs, /*resultAttrs=*/{},
3063 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3064}
3065
3066// Builds an LLVM function type from the given lists of input and output types.
3067// Returns a null type if any of the types provided are non-LLVM types, or if
3068// there is more than one output type.
3069static Type
3071 ArrayRef<Type> outputs,
3073 Builder &b = parser.getBuilder();
3074 if (outputs.size() > 1) {
3075 parser.emitError(loc, "failed to construct function type: expected zero or "
3076 "one function result");
3077 return {};
3078 }
3079
3080 // Convert inputs to LLVM types, exit early on error.
3081 SmallVector<Type, 4> llvmInputs;
3082 for (auto t : inputs) {
3083 if (!isCompatibleType(t)) {
3084 parser.emitError(loc, "failed to construct function type: expected LLVM "
3085 "type for function arguments");
3086 return {};
3087 }
3088 llvmInputs.push_back(t);
3089 }
3090
3091 // No output is denoted as "void" in LLVM type system.
3092 Type llvmOutput =
3093 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
3094 if (!isCompatibleType(llvmOutput)) {
3095 parser.emitError(loc, "failed to construct function type: expected LLVM "
3096 "type for function results")
3097 << llvmOutput;
3098 return {};
3099 }
3100 return LLVMFunctionType::get(llvmOutput, llvmInputs,
3101 variadicFlag.isVariadic());
3102}
3103
3104// Parses an LLVM function.
3105//
3106// operation ::= `llvm.func` linkage? cconv? function-signature
3107// (`comdat(` symbol-ref-id `)`)?
3108// function-attributes?
3109// function-body
3110//
3111ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
3112 // Default to external linkage if no keyword is provided.
3113 result.addAttribute(getLinkageAttrName(result.name),
3114 LinkageAttr::get(parser.getContext(),
3116 parser, LLVM::Linkage::External)));
3117
3118 // Parse optional visibility, default to Default.
3119 result.addAttribute(getVisibility_AttrName(result.name),
3122 parser, LLVM::Visibility::Default)));
3123
3124 // Parse optional UnnamedAddr, default to None.
3125 result.addAttribute(getUnnamedAddrAttrName(result.name),
3128 parser, LLVM::UnnamedAddr::None)));
3129
3130 // Default to C Calling Convention if no keyword is provided.
3131 result.addAttribute(
3132 getCConvAttrName(result.name),
3133 CConvAttr::get(parser.getContext(),
3134 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
3135
3136 StringAttr nameAttr;
3138 SmallVector<DictionaryAttr> resultAttrs;
3139 SmallVector<Type> resultTypes;
3140 bool isVariadic;
3141
3142 auto signatureLocation = parser.getCurrentLocation();
3143 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3144 result.attributes) ||
3146 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
3147 resultAttrs))
3148 return failure();
3149
3150 SmallVector<Type> argTypes;
3151 for (auto &arg : entryArgs)
3152 argTypes.push_back(arg.type);
3153 auto type =
3154 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
3156 if (!type)
3157 return failure();
3158 result.addAttribute(getFunctionTypeAttrName(result.name),
3159 TypeAttr::get(type));
3160
3161 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
3162 int64_t minRange, maxRange;
3163 if (parser.parseLParen() || parser.parseInteger(minRange) ||
3164 parser.parseComma() || parser.parseInteger(maxRange) ||
3165 parser.parseRParen())
3166 return failure();
3167 auto intTy = IntegerType::get(parser.getContext(), 32);
3168 result.addAttribute(
3169 getVscaleRangeAttrName(result.name),
3170 LLVM::VScaleRangeAttr::get(parser.getContext(),
3171 IntegerAttr::get(intTy, minRange),
3172 IntegerAttr::get(intTy, maxRange)));
3173 }
3174 // Parse the optional comdat selector.
3175 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
3176 SymbolRefAttr comdat;
3177 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
3178 parser.parseRParen())
3179 return failure();
3180
3181 result.addAttribute(getComdatAttrName(result.name), comdat);
3182 }
3183
3184 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
3185 return failure();
3187 parser.getBuilder(), result, entryArgs, resultAttrs,
3188 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3189
3190 auto *body = result.addRegion();
3191 OptionalParseResult parseResult =
3192 parser.parseOptionalRegion(*body, entryArgs);
3193 return failure(parseResult.has_value() && failed(*parseResult));
3194}
3195
3196// Print the LLVMFuncOp. Collects argument and result types and passes them to
3197// helper functions. Drops "void" result since it cannot be parsed back. Skips
3198// the external linkage since it is the default value.
3199void LLVMFuncOp::print(OpAsmPrinter &p) {
3200 p << ' ';
3201 if (getLinkage() != LLVM::Linkage::External)
3202 p << stringifyLinkage(getLinkage()) << ' ';
3203 StringRef visibility = stringifyVisibility(getVisibility_());
3204 if (!visibility.empty())
3205 p << visibility << ' ';
3206 if (auto unnamedAddr = getUnnamedAddr()) {
3207 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
3208 if (!str.empty())
3209 p << str << ' ';
3210 }
3211 if (getCConv() != LLVM::CConv::C)
3212 p << stringifyCConv(getCConv()) << ' ';
3213
3214 p.printSymbolName(getName());
3215
3216 LLVMFunctionType fnType = getFunctionType();
3217 SmallVector<Type, 8> argTypes;
3218 SmallVector<Type, 1> resTypes;
3219 argTypes.reserve(fnType.getNumParams());
3220 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
3221 argTypes.push_back(fnType.getParamType(i));
3222
3223 Type returnType = fnType.getReturnType();
3224 if (!llvm::isa<LLVMVoidType>(returnType))
3225 resTypes.push_back(returnType);
3226
3228 isVarArg(), resTypes);
3229
3230 // Print vscale range if present
3231 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
3232 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
3233 << vscale->getMaxRange().getInt() << ')';
3234
3235 // Print the optional comdat selector.
3236 if (auto comdat = getComdat())
3237 p << " comdat(" << *comdat << ')';
3238
3240 p, *this,
3241 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
3242 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
3243 getComdatAttrName(), getUnnamedAddrAttrName(),
3244 getVscaleRangeAttrName()});
3245
3246 // Print the body if this is not an external function.
3247 Region &body = getBody();
3248 if (!body.empty()) {
3249 p << ' ';
3250 p.printRegion(body, /*printEntryBlockArgs=*/false,
3251 /*printBlockTerminators=*/true);
3252 }
3253}
3254
3255// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3256// - functions don't have 'common' linkage
3257// - external functions have 'external' or 'extern_weak' linkage;
3258// - vararg is (currently) only supported for external functions;
3259LogicalResult LLVMFuncOp::verify() {
3260 if (getLinkage() == LLVM::Linkage::Common)
3261 return emitOpError() << "functions cannot have '"
3262 << stringifyLinkage(LLVM::Linkage::Common)
3263 << "' linkage";
3264
3265 if (failed(verifyComdat(*this, getComdat())))
3266 return failure();
3267
3268 if (isExternal()) {
3269 if (getLinkage() != LLVM::Linkage::External &&
3270 getLinkage() != LLVM::Linkage::ExternWeak)
3271 return emitOpError() << "external functions must have '"
3272 << stringifyLinkage(LLVM::Linkage::External)
3273 << "' or '"
3274 << stringifyLinkage(LLVM::Linkage::ExternWeak)
3275 << "' linkage";
3276 return success();
3277 }
3278
3279 // In LLVM IR, these attributes are composed by convention, not by design.
3280 if (isNoInline() && isAlwaysInline())
3281 return emitError("no_inline and always_inline attributes are incompatible");
3282
3283 if (isOptimizeNone() && !isNoInline())
3284 return emitOpError("with optimize_none must also be no_inline");
3285
3286 Type landingpadResultTy;
3287 StringRef diagnosticMessage;
3288 bool isLandingpadTypeConsistent =
3289 !walk([&](Operation *op) {
3290 const auto checkType = [&](Type type, StringRef errorMessage) {
3291 if (!landingpadResultTy) {
3292 landingpadResultTy = type;
3293 return WalkResult::advance();
3294 }
3295 if (landingpadResultTy != type) {
3296 diagnosticMessage = errorMessage;
3297 return WalkResult::interrupt();
3298 }
3299 return WalkResult::advance();
3300 };
3302 .Case([&](LandingpadOp landingpad) {
3303 constexpr StringLiteral errorMessage =
3304 "'llvm.landingpad' should have a consistent result type "
3305 "inside a function";
3306 return checkType(landingpad.getType(), errorMessage);
3307 })
3308 .Case([&](ResumeOp resume) {
3309 constexpr StringLiteral errorMessage =
3310 "'llvm.resume' should have a consistent input type inside a "
3311 "function";
3312 return checkType(resume.getValue().getType(), errorMessage);
3313 })
3314 .Default([](auto) { return WalkResult::skip(); });
3315 }).wasInterrupted();
3316 if (!isLandingpadTypeConsistent) {
3317 assert(!diagnosticMessage.empty() &&
3318 "Expecting a non-empty diagnostic message");
3319 return emitError(diagnosticMessage);
3320 }
3321
3322 if (failed(verifyBlockTags(*this)))
3323 return failure();
3324
3325 return success();
3326}
3327
3328/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3329/// - entry block arguments are of LLVM types.
3330LogicalResult LLVMFuncOp::verifyRegions() {
3331 if (isExternal())
3332 return success();
3333
3334 unsigned numArguments = getFunctionType().getNumParams();
3335 Block &entryBlock = front();
3336 for (unsigned i = 0; i < numArguments; ++i) {
3337 Type argType = entryBlock.getArgument(i).getType();
3338 if (!isCompatibleType(argType))
3339 return emitOpError("entry block argument #")
3340 << i << " is not of LLVM type";
3341 }
3342
3343 return success();
3344}
3345
3346Region *LLVMFuncOp::getCallableRegion() {
3347 if (isExternal())
3348 return nullptr;
3349 return &getBody();
3350}
3351
3352//===----------------------------------------------------------------------===//
3353// UndefOp.
3354//===----------------------------------------------------------------------===//
3355
3356/// Fold an undef operation to a dedicated undef attribute.
3357OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
3358 return LLVM::UndefAttr::get(getContext());
3359}
3360
3361//===----------------------------------------------------------------------===//
3362// PoisonOp.
3363//===----------------------------------------------------------------------===//
3364
3365/// Fold a poison operation to a dedicated poison attribute.
3366OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
3367 return LLVM::PoisonAttr::get(getContext());
3368}
3369
3370//===----------------------------------------------------------------------===//
3371// MetadataAsValueOp.
3372//===----------------------------------------------------------------------===//
3373
3374/// Fold a metadata-as-value operation to its wrapped metadata attribute.
3375OpFoldResult LLVM::MetadataAsValueOp::fold(FoldAdaptor) {
3376 return getMetadataAttr();
3377}
3378
3379//===----------------------------------------------------------------------===//
3380// ZeroOp.
3381//===----------------------------------------------------------------------===//
3382
3383LogicalResult LLVM::ZeroOp::verify() {
3384 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3385 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
3386 return emitOpError()
3387 << "target extension type does not support zero-initializer";
3388
3389 return success();
3390}
3391
3392/// Fold a zero operation to a builtin zero attribute when possible and fall
3393/// back to a dedicated zero attribute.
3394OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
3396 if (result)
3397 return result;
3398 return LLVM::ZeroAttr::get(getContext());
3399}
3400
3401//===----------------------------------------------------------------------===//
3402// ConstantOp.
3403//===----------------------------------------------------------------------===//
3404
3405/// Compute the total number of elements in the given type, also taking into
3406/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
3407/// Everything else is treated as a scalar.
3409 if (auto vecType = dyn_cast<VectorType>(t)) {
3410 assert(!vecType.isScalable() &&
3411 "number of elements of a scalable vector type is unknown");
3412 return vecType.getNumElements() * getNumElements(vecType.getElementType());
3413 }
3414 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3415 return arrayType.getNumElements() *
3416 getNumElements(arrayType.getElementType());
3417 return 1;
3418}
3419
3420/// Determine the element type of `type`. Supported types are `VectorType`,
3421/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3423 while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3424 type = arrayType.getElementType();
3425 if (auto vecType = dyn_cast<VectorType>(type))
3426 return vecType.getElementType();
3427 if (auto tenType = dyn_cast<TensorType>(type))
3428 return tenType.getElementType();
3429 return type;
3430}
3431
3432/// Check if the given type is a scalable vector type or a vector/array type
3433/// that contains a nested scalable vector type.
3435 if (auto vecType = dyn_cast<VectorType>(t)) {
3436 if (vecType.isScalable())
3437 return true;
3438 return hasScalableVectorType(vecType.getElementType());
3439 }
3440 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3441 return hasScalableVectorType(arrayType.getElementType());
3442 return false;
3443}
3444
3445/// Verifies the constant array represented by `arrayAttr` matches the provided
3446/// `arrayType`.
3447static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3448 LLVM::LLVMArrayType arrayType,
3449 ArrayAttr arrayAttr, int dim) {
3450 if (arrayType.getNumElements() != arrayAttr.size())
3451 return op.emitOpError()
3452 << "array attribute size does not match array type size in "
3453 "dimension "
3454 << dim << ": " << arrayAttr.size() << " vs. "
3455 << arrayType.getNumElements();
3456
3457 llvm::DenseSet<Attribute> elementsVerified;
3458
3459 // Recursively verify sub-dimensions for multidimensional arrays.
3460 if (auto subArrayType =
3461 dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
3462 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
3463 if (elementsVerified.insert(elementAttr).second) {
3464 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3465 continue;
3466 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3467 if (!subArrayAttr)
3468 return op.emitOpError()
3469 << "nested attribute for sub-array in dimension " << dim
3470 << " at index " << idx
3471 << " must be a zero, or undef, or array attribute";
3472 if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
3473 dim + 1)))
3474 return failure();
3475 }
3476 return success();
3477 }
3478
3479 // Forbid usages of ArrayAttr for simple array types that should use
3480 // DenseElementsAttr instead. Note that there would be a use case for such
3481 // array types when one element value is obtained via a ptr-to-int conversion
3482 // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3483 // user needs this so far, and it seems better to avoid people misusing the
3484 // ArrayAttr for simple types.
3485 Type elementType = arrayType.getElementType();
3486 if (isa<LLVM::LLVMPointerType>(elementType)) {
3487 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3488 if (isa<FlatSymbolRefAttr, LLVM::ZeroAttr, LLVM::UndefAttr,
3489 LLVM::PoisonAttr>(elementAttr))
3490 continue;
3491 return op.emitOpError()
3492 << "pointer array element at index " << idx
3493 << " must be a flat symbol reference, zero, undef, or poison";
3494 }
3495 return success();
3496 }
3497 auto structType = dyn_cast<LLVM::LLVMStructType>(elementType);
3498 if (!structType)
3499 return op.emitOpError() << "for array with an array attribute must have a "
3500 "struct element type";
3501
3502 // Shallow verification that leaf attributes are appropriate as struct initial
3503 // value.
3504 size_t numStructElements = structType.getBody().size();
3505 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3506 if (elementsVerified.insert(elementAttr).second) {
3507 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3508 continue;
3509 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3510 if (!subArrayAttr)
3511 return op.emitOpError()
3512 << "nested attribute for struct element at index " << idx
3513 << " must be a zero, or undef, or array attribute";
3514 if (subArrayAttr.size() != numStructElements)
3515 return op.emitOpError()
3516 << "nested array attribute size for struct element at index "
3517 << idx << " must match struct size: " << subArrayAttr.size()
3518 << " vs. " << numStructElements;
3519 }
3520 }
3521
3522 return success();
3523}
3524
3525LogicalResult LLVM::ConstantOp::verify() {
3526 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
3527 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
3528 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
3529 !arrayType.getElementType().isInteger(8)) {
3530 return emitOpError() << "expected array type of "
3531 << sAttr.getValue().size()
3532 << " i8 elements for the string constant";
3533 }
3534 return success();
3535 }
3536 if (auto structType = dyn_cast<LLVMStructType>(getType())) {
3537 auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
3538 if (!arrayAttr)
3539 return emitOpError() << "expected array attribute for struct type";
3540
3541 ArrayRef<Type> elementTypes = structType.getBody();
3542 if (arrayAttr.size() != elementTypes.size()) {
3543 return emitOpError() << "expected array attribute of size "
3544 << elementTypes.size();
3545 }
3546 for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
3547 if (!type.isSignlessIntOrIndexOrFloat()) {
3548 return emitOpError() << "expected struct element types to be floating "
3549 "point type or integer type";
3550 }
3551 if (!isa<FloatAttr, IntegerAttr>(attr)) {
3552 return emitOpError() << "expected element of array attribute to be "
3553 "floating point or integer";
3554 }
3555 if (cast<TypedAttr>(attr).getType() != type)
3556 return emitOpError()
3557 << "struct element at index " << i << " is of wrong type";
3558 }
3559
3560 return success();
3561 }
3562 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3563 return emitOpError() << "does not support target extension type.";
3564
3565 // Check that an attribute whose element type has floating point semantics
3566 // `attributeFloatSemantics` is compatible with a type whose element type
3567 // is `constantElementType`.
3568 //
3569 // Requirement is that either
3570 // 1) They have identical floating point types.
3571 // 2) `constantElementType` is an integer type of the same width as the float
3572 // attribute. This is to support builtin MLIR float types without LLVM
3573 // equivalents, see comments in getLLVMConstant for more details.
3574 auto verifyFloatSemantics =
3575 [this](const llvm::fltSemantics &attributeFloatSemantics,
3576 Type constantElementType) -> LogicalResult {
3577 if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3578 if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
3579 return emitOpError()
3580 << "attribute and type have different float semantics";
3581 }
3582 return success();
3583 }
3584 unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
3585 if (isa<IntegerType>(constantElementType)) {
3586 if (!constantElementType.isInteger(floatWidth))
3587 return emitOpError() << "expected integer type of width " << floatWidth;
3588
3589 return success();
3590 }
3591 return success();
3592 };
3593
3594 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3595 if (isa<IntegerAttr>(getValue())) {
3596 if (!llvm::isa<IntegerType>(getType()))
3597 return emitOpError() << "expected integer type";
3598 } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3599 return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
3600 } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3602 // The exact number of elements of a scalable vector is unknown, so we
3603 // allow only splat attributes.
3604 auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
3605 if (!splatElementsAttr)
3606 return emitOpError()
3607 << "scalable vector type requires a splat attribute";
3608 return success();
3609 }
3610 if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
3611 return emitOpError() << "expected vector or array type";
3612
3613 // The number of elements of the attribute and the type must match.
3614 int64_t attrNumElements = elementsAttr.getNumElements();
3615 if (getNumElements(getType()) != attrNumElements) {
3616 return emitOpError()
3617 << "type and attribute have a different number of elements: "
3618 << getNumElements(getType()) << " vs. " << attrNumElements;
3619 }
3620
3621 Type attrElmType = getElementType(elementsAttr.getType());
3622 Type resultElmType = getElementType(getType());
3623 if (auto floatType = dyn_cast<FloatType>(attrElmType))
3624 return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
3625
3626 if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3627 return emitOpError(
3628 "expected integer element type for integer elements attribute");
3629 }
3630 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3631
3632 // The case where the constant is LLVMStructType has already been handled.
3633 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
3634 if (!arrayType)
3635 return emitOpError()
3636 << "expected array or struct type for array attribute";
3637
3638 // When the attribute is an ArrayAttr, check that its nesting matches the
3639 // corresponding ArrayType or VectorType nesting.
3640 return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
3641 } else {
3642 return emitOpError()
3643 << "only supports integer, float, string or elements attributes";
3644 }
3645
3646 return success();
3647}
3648
3649bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
3650 // The value's type must be the same as the provided type.
3651 auto typedAttr = dyn_cast<TypedAttr>(value);
3652 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
3653 return false;
3654 // The value's type must be an LLVM compatible type.
3655 if (!isCompatibleType(type))
3656 return false;
3657 // TODO: Add support for additional attributes kinds once needed.
3658 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
3659}
3660
3661ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
3662 Type type, Location loc) {
3663 if (isBuildableWith(value, type))
3664 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
3665 return nullptr;
3666}
3667
3668// Constant op constant-folds to its value.
3669OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
3670
3671//===----------------------------------------------------------------------===//
3672// AtomicRMWOp
3673//===----------------------------------------------------------------------===//
3674
3675void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3676 AtomicBinOp binOp, Value ptr, Value val,
3677 AtomicOrdering ordering, StringRef syncscope,
3678 unsigned alignment, bool isVolatile) {
3679 build(builder, state, val.getType(), binOp, ptr, val, ordering,
3680 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3681 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
3682 /*access_groups=*/nullptr,
3683 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3684}
3685
3686LogicalResult AtomicRMWOp::verify() {
3687 auto valType = getVal().getType();
3688 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3689 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax ||
3690 getBinOp() == AtomicBinOp::fminimum ||
3691 getBinOp() == AtomicBinOp::fmaximum ||
3692 getBinOp() == AtomicBinOp::fminimumnum ||
3693 getBinOp() == AtomicBinOp::fmaximumnum) {
3694 if (isCompatibleVectorType(valType)) {
3695 if (isScalableVectorType(valType))
3696 return emitOpError("expected LLVM IR fixed vector type");
3697 Type elemType = llvm::cast<VectorType>(valType).getElementType();
3698 if (!isCompatibleFloatingPointType(elemType))
3699 return emitOpError(
3700 "expected LLVM IR floating point type for vector element");
3701 } else if (!isCompatibleFloatingPointType(valType)) {
3702 return emitOpError("expected LLVM IR floating point type");
3703 }
3704 } else if (getBinOp() == AtomicBinOp::xchg) {
3705 DataLayout dataLayout = DataLayout::closest(*this);
3706 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3707 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
3708 } else {
3709 auto intType = llvm::dyn_cast<IntegerType>(valType);
3710 unsigned intBitWidth = intType ? intType.getWidth() : 0;
3711 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3712 intBitWidth != 64)
3713 return emitOpError("expected LLVM IR integer type");
3714 }
3715
3716 if (static_cast<unsigned>(getOrdering()) <
3717 static_cast<unsigned>(AtomicOrdering::monotonic))
3718 return emitOpError() << "expected at least '"
3719 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3720 << "' ordering";
3721
3722 return success();
3723}
3724
3725//===----------------------------------------------------------------------===//
3726// AtomicCmpXchgOp
3727//===----------------------------------------------------------------------===//
3728
3729/// Returns an LLVM struct type that contains a value type and a boolean type.
3730static LLVMStructType getValAndBoolStructType(Type valType) {
3731 auto boolType = IntegerType::get(valType.getContext(), 1);
3732 return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
3733}
3734
3735void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3736 Value ptr, Value cmp, Value val,
3737 AtomicOrdering successOrdering,
3738 AtomicOrdering failureOrdering, StringRef syncscope,
3739 unsigned alignment, bool isWeak, bool isVolatile) {
3740 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
3741 successOrdering, failureOrdering,
3742 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3743 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
3744 isVolatile, /*access_groups=*/nullptr,
3745 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3746}
3747
3748LogicalResult AtomicCmpXchgOp::verify() {
3749 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
3750 if (!ptrType)
3751 return emitOpError("expected LLVM IR pointer type for operand #0");
3752 auto valType = getVal().getType();
3753 DataLayout dataLayout = DataLayout::closest(*this);
3754 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3755 return emitOpError("unexpected LLVM IR type");
3756 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3757 getFailureOrdering() < AtomicOrdering::monotonic)
3758 return emitOpError("ordering must be at least 'monotonic'");
3759 if (getFailureOrdering() == AtomicOrdering::release ||
3760 getFailureOrdering() == AtomicOrdering::acq_rel)
3761 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
3762 return success();
3763}
3764
3765//===----------------------------------------------------------------------===//
3766// FenceOp
3767//===----------------------------------------------------------------------===//
3768
3769void FenceOp::build(OpBuilder &builder, OperationState &state,
3770 AtomicOrdering ordering, StringRef syncscope) {
3771 build(builder, state, ordering,
3772 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
3773}
3774
3775LogicalResult FenceOp::verify() {
3776 if (getOrdering() == AtomicOrdering::not_atomic ||
3777 getOrdering() == AtomicOrdering::unordered ||
3778 getOrdering() == AtomicOrdering::monotonic)
3779 return emitOpError("can be given only acquire, release, acq_rel, "
3780 "and seq_cst orderings");
3781 return success();
3782}
3783
3784//===----------------------------------------------------------------------===//
3785// Verifier for extension ops
3786//===----------------------------------------------------------------------===//
3787
3788/// Verifies that the given extension operation operates on consistent scalars
3789/// or vectors, and that the target width is larger than the input width.
3790template <class ExtOp>
3791static LogicalResult verifyExtOp(ExtOp op) {
3792 IntegerType inputType, outputType;
3793 if (isCompatibleVectorType(op.getArg().getType())) {
3794 if (!isCompatibleVectorType(op.getResult().getType()))
3795 return op.emitError(
3796 "input type is a vector but output type is an integer");
3797 if (getVectorNumElements(op.getArg().getType()) !=
3798 getVectorNumElements(op.getResult().getType()))
3799 return op.emitError("input and output vectors are of incompatible shape");
3800 // Because this is a CastOp, the element of vectors is guaranteed to be an
3801 // integer.
3802 inputType = cast<IntegerType>(
3803 cast<VectorType>(op.getArg().getType()).getElementType());
3804 outputType = cast<IntegerType>(
3805 cast<VectorType>(op.getResult().getType()).getElementType());
3806 } else {
3807 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3808 // an integer.
3809 inputType = cast<IntegerType>(op.getArg().getType());
3810 outputType = dyn_cast<IntegerType>(op.getResult().getType());
3811 if (!outputType)
3812 return op.emitError(
3813 "input type is an integer but output type is a vector");
3814 }
3815
3816 if (outputType.getWidth() <= inputType.getWidth())
3817 return op.emitError("integer width of the output type is smaller or "
3818 "equal to the integer width of the input type");
3819 return success();
3820}
3821
3822//===----------------------------------------------------------------------===//
3823// ZExtOp
3824//===----------------------------------------------------------------------===//
3825
3826LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
3827
3828OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3829 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3830 if (!arg)
3831 return {};
3832
3833 size_t targetSize = cast<IntegerType>(getType()).getWidth();
3834 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
3835}
3836
3837//===----------------------------------------------------------------------===//
3838// SExtOp
3839//===----------------------------------------------------------------------===//
3840
3841LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
3842
3843//===----------------------------------------------------------------------===//
3844// Folder and verifier for LLVM::BitcastOp
3845//===----------------------------------------------------------------------===//
3846
3847/// Folds a cast op that can be chained.
3848template <typename T>
3850 typename T::FoldAdaptor adaptor) {
3851 // cast(x : T0, T0) -> x
3852 if (castOp.getArg().getType() == castOp.getType())
3853 return castOp.getArg();
3854 if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3855 // cast(cast(x : T0, T1), T0) -> x
3856 if (prev.getArg().getType() == castOp.getType())
3857 return prev.getArg();
3858 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3859 castOp.getArgMutable().set(prev.getArg());
3860 return Value{castOp};
3861 }
3862 return {};
3863}
3864
3865OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3866 return foldChainableCast(*this, adaptor);
3867}
3868
3869LogicalResult LLVM::BitcastOp::verify() {
3870 auto resultType = llvm::dyn_cast<LLVMPointerType>(
3871 extractVectorElementType(getResult().getType()));
3872 auto sourceType = llvm::dyn_cast<LLVMPointerType>(
3873 extractVectorElementType(getArg().getType()));
3874
3875 // If one of the types is a pointer (or vector of pointers), then
3876 // both source and result type have to be pointers.
3877 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3878 return emitOpError("can only cast pointers from and to pointers");
3879
3880 if (!resultType)
3881 return success();
3882
3883 auto isVector = llvm::IsaPred<VectorType>;
3884
3885 // Due to bitcast requiring both operands to be of the same size, it is not
3886 // possible for only one of the two to be a pointer of vectors.
3887 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3888 return emitOpError("cannot cast pointer to vector of pointers");
3889
3890 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3891 return emitOpError("cannot cast vector of pointers to pointer");
3892
3893 // Bitcast cannot cast between pointers of different address spaces.
3894 // 'llvm.addrspacecast' must be used for this purpose instead.
3895 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3896 return emitOpError("cannot cast pointers of different address spaces, "
3897 "use 'llvm.addrspacecast' instead");
3898
3899 return success();
3900}
3901
3902LogicalResult LLVM::PtrToAddrOp::verify() {
3903 auto pointerType =
3904 cast<LLVM::LLVMPointerType>(extractVectorElementType(getArg().getType()));
3905 auto integerType = cast<IntegerType>(extractVectorElementType(getType()));
3906
3907 auto dataLayout = DataLayout::closest(*this);
3908 std::optional<unsigned> width = dataLayout.getTypeIndexBitwidth(pointerType);
3909 assert(width && "pointers always return an index bitwidth");
3910 if (width != integerType.getWidth())
3911 return emitOpError("bit-width of integer result type ")
3912 << integerType << " must match the pointer bitwidth (" << *width
3913 << ") specified in the datalayout";
3914
3915 return success();
3916}
3917
3918//===----------------------------------------------------------------------===//
3919// Folder for LLVM::AddrSpaceCastOp
3920//===----------------------------------------------------------------------===//
3921
3922OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3923 return foldChainableCast(*this, adaptor);
3924}
3925
3926Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3927
3928//===----------------------------------------------------------------------===//
3929// Folder for LLVM::GEPOp
3930//===----------------------------------------------------------------------===//
3931
3932OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3933 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3934 adaptor.getDynamicIndices());
3935
3936 // gep %x:T, 0 -> %x
3937 if (getBase().getType() == getType() && indices.size() == 1)
3938 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
3939 if (integer.getValue().isZero())
3940 return getBase();
3941
3942 // Canonicalize any dynamic indices of constant value to constant indices.
3943 bool changed = false;
3944 SmallVector<GEPArg> gepArgs;
3945 for (auto iter : llvm::enumerate(indices)) {
3946 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
3947 // Constant indices can only be int32_t, so if integer does not fit we
3948 // are forced to keep it dynamic, despite being a constant.
3949 if (!indices.isDynamicIndex(iter.index()) || !integer ||
3950 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
3951
3952 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3953 if (Value val = llvm::dyn_cast_if_present<Value>(existing))
3954 gepArgs.emplace_back(val);
3955 else
3956 gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt());
3957
3958 continue;
3959 }
3960
3961 changed = true;
3962 gepArgs.emplace_back(integer.getInt());
3963 }
3964 if (changed) {
3965 SmallVector<int32_t> rawConstantIndices;
3966 SmallVector<Value> dynamicIndices;
3967 destructureIndices(getElemType(), gepArgs, rawConstantIndices,
3968 dynamicIndices);
3969
3970 getDynamicIndicesMutable().assign(dynamicIndices);
3971 setRawConstantIndices(rawConstantIndices);
3972 return Value{*this};
3973 }
3974
3975 return {};
3976}
3977
3978Value LLVM::GEPOp::getViewSource() { return getBase(); }
3979
3980//===----------------------------------------------------------------------===//
3981// ShlOp
3982//===----------------------------------------------------------------------===//
3983
3984OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3985 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3986 if (!rhs)
3987 return {};
3988
3989 if (rhs.getValue().getZExtValue() >=
3990 getLhs().getType().getIntOrFloatBitWidth())
3991 return {}; // TODO: Fold into poison.
3992
3993 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3994 if (!lhs)
3995 return {};
3996
3997 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
3998}
3999
4000//===----------------------------------------------------------------------===//
4001// OrOp
4002//===----------------------------------------------------------------------===//
4003
4004OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
4005 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
4006 if (!lhs)
4007 return {};
4008
4009 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
4010 if (!rhs)
4011 return {};
4012
4013 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
4014}
4015
4016//===----------------------------------------------------------------------===//
4017// CallIntrinsicOp
4018//===----------------------------------------------------------------------===//
4019
4020LogicalResult CallIntrinsicOp::verify() {
4021 if (!getIntrin().starts_with("llvm."))
4022 return emitOpError() << "intrinsic name must start with 'llvm.'";
4023 if (failed(verifyOperandBundles(*this)))
4024 return failure();
4025 return success();
4026}
4027
4028void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4029 mlir::StringAttr intrin, mlir::ValueRange args) {
4030 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
4031 FastmathFlagsAttr{},
4032 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4033 /*res_attrs=*/{});
4034}
4035
4036void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4037 mlir::StringAttr intrin, mlir::ValueRange args,
4038 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
4039 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
4040 fastMathFlags,
4041 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4042 /*res_attrs=*/{});
4043}
4044
4045void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4046 mlir::Type resultType, mlir::StringAttr intrin,
4047 mlir::ValueRange args) {
4048 build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
4049 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4050 /*res_attrs=*/{});
4051}
4052
4053void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4054 mlir::TypeRange resultTypes,
4055 mlir::StringAttr intrin, mlir::ValueRange args,
4056 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
4057 build(builder, state, resultTypes, intrin, args, fastMathFlags,
4058 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4059 /*res_attrs=*/{});
4060}
4061
4062ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
4064 StringAttr intrinAttr;
4067 SmallVector<SmallVector<Type>> opBundleOperandTypes;
4068 ArrayAttr opBundleTags;
4069
4070 // Parse intrinsic name.
4072 intrinAttr, parser.getBuilder().getType<NoneType>()))
4073 return failure();
4074 result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
4075 intrinAttr);
4076
4077 if (parser.parseLParen())
4078 return failure();
4079
4080 // Parse the function arguments.
4081 if (parser.parseOperandList(operands))
4082 return mlir::failure();
4083
4084 if (parser.parseRParen())
4085 return mlir::failure();
4086
4087 // Handle bundles.
4088 SMLoc opBundlesLoc = parser.getCurrentLocation();
4089 if (std::optional<ParseResult> result = parseOpBundles(
4090 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
4091 result && failed(*result))
4092 return failure();
4093 if (opBundleTags && !opBundleTags.empty())
4094 result.addAttribute(
4095 CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
4096 opBundleTags);
4097
4098 if (parser.parseOptionalAttrDict(result.attributes))
4099 return mlir::failure();
4100
4102 SmallVector<DictionaryAttr> resultAttrs;
4103 if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
4104 operands, argAttrs, resultAttrs))
4105 return failure();
4107 parser.getBuilder(), result, argAttrs, resultAttrs,
4108 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
4109
4110 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
4111 opBundleOperandTypes,
4112 getOpBundleSizesAttrName(result.name)))
4113 return failure();
4114
4115 int32_t numOpBundleOperands = 0;
4116 for (const auto &operands : opBundleOperands)
4117 numOpBundleOperands += operands.size();
4118
4119 result.addAttribute(
4120 CallIntrinsicOp::getOperandSegmentSizeAttr(),
4122 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
4123
4124 return mlir::success();
4125}
4126
4127void CallIntrinsicOp::print(OpAsmPrinter &p) {
4128 p << ' ';
4129 p.printAttributeWithoutType(getIntrinAttr());
4130
4131 OperandRange args = getArgs();
4132 p << "(" << args << ")";
4133
4134 // Operand bundles.
4135 if (!getOpBundleOperands().empty()) {
4136 p << ' ';
4137 printOpBundles(p, *this, getOpBundleOperands(),
4138 getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
4139 }
4140
4141 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
4142 {getOperandSegmentSizesAttrName(),
4143 getOpBundleSizesAttrName(), getIntrinAttrName(),
4144 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
4145 getResAttrsAttrName()});
4146
4147 p << " : ";
4148
4149 // Reconstruct the MLIR function type from operand and result types.
4151 p, args.getTypes(), getArgAttrsAttr(),
4152 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
4153}
4154
4155//===----------------------------------------------------------------------===//
4156// LinkerOptionsOp
4157//===----------------------------------------------------------------------===//
4158
4159LogicalResult LinkerOptionsOp::verify() {
4160 if (mlir::Operation *parentOp = (*this)->getParentOp();
4161 parentOp && !satisfiesLLVMModule(parentOp))
4162 return emitOpError("must appear at the module level");
4163 return success();
4164}
4165
4166//===----------------------------------------------------------------------===//
4167// ModuleFlagsOp
4168//===----------------------------------------------------------------------===//
4169
4170LogicalResult ModuleFlagsOp::verify() {
4171 if (Operation *parentOp = (*this)->getParentOp();
4172 parentOp && !satisfiesLLVMModule(parentOp))
4173 return emitOpError("must appear at the module level");
4174 for (Attribute flag : getFlags())
4175 if (!isa<ModuleFlagAttr>(flag))
4176 return emitOpError("expected a module flag attribute");
4177 return success();
4178}
4179
4180//===----------------------------------------------------------------------===//
4181// InlineAsmOp
4182//===----------------------------------------------------------------------===//
4183
4184void InlineAsmOp::getEffects(
4186 &effects) {
4187 if (getHasSideEffects()) {
4188 effects.emplace_back(MemoryEffects::Write::get());
4189 effects.emplace_back(MemoryEffects::Read::get());
4190 }
4191}
4192
4193//===----------------------------------------------------------------------===//
4194// BlockAddressOp
4195//===----------------------------------------------------------------------===//
4196
4197LogicalResult
4198BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4199 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
4200 getBlockAddr().getFunction());
4201 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
4202
4203 if (!function)
4204 return emitOpError("must reference a function defined by 'llvm.func'");
4205
4206 return success();
4207}
4208
4209LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
4210 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
4211 parentLLVMModule(*this), getBlockAddr().getFunction()));
4212}
4213
4214BlockTagOp BlockAddressOp::getBlockTagOp() {
4216 parentLLVMModule(*this), getBlockAddr().getFunction());
4217 if (!sym)
4218 return nullptr;
4219 auto funcOp = dyn_cast<LLVMFuncOp>(sym);
4220 if (!funcOp)
4221 return nullptr;
4222 BlockTagOp blockTagOp = nullptr;
4223 funcOp.walk([&](LLVM::BlockTagOp labelOp) {
4224 if (labelOp.getTag() == getBlockAddr().getTag()) {
4225 blockTagOp = labelOp;
4226 return WalkResult::interrupt();
4227 }
4228 return WalkResult::advance();
4229 });
4230 return blockTagOp;
4231}
4232
4233LogicalResult BlockAddressOp::verify() {
4234 if (!getBlockTagOp())
4235 return emitOpError(
4236 "expects an existing block label target in the referenced function");
4237
4238 return success();
4239}
4240
4241/// Fold a blockaddress operation to a dedicated blockaddress
4242/// attribute.
4243OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
4244
4245//===----------------------------------------------------------------------===//
4246// LLVM::IndirectBrOp
4247//===----------------------------------------------------------------------===//
4248
4249SuccessorOperands IndirectBrOp::getSuccessorOperands(unsigned index) {
4250 assert(index < getNumSuccessors() && "invalid successor index");
4251 return SuccessorOperands(getSuccOperandsMutable()[index]);
4252}
4253
4254void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4255 Value addr, ArrayRef<ValueRange> succOperands,
4256 BlockRange successors) {
4257 odsState.addOperands(addr);
4258 for (ValueRange range : succOperands)
4259 odsState.addOperands(range);
4260 SmallVector<int32_t> rangeSegments;
4261 for (ValueRange range : succOperands)
4262 rangeSegments.push_back(range.size());
4263 odsState.getOrAddProperties<Properties>().indbr_operand_segments =
4264 odsBuilder.getDenseI32ArrayAttr(rangeSegments);
4265 odsState.addSuccessors(successors);
4266}
4267
4269 OpAsmParser &parser, Type &flagType,
4270 SmallVectorImpl<Block *> &succOperandBlocks,
4272 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
4273 if (failed(parser.parseCommaSeparatedList(
4275 [&]() {
4276 Block *destination = nullptr;
4277 SmallVector<OpAsmParser::UnresolvedOperand> operands;
4278 SmallVector<Type> operandTypes;
4279
4280 if (parser.parseSuccessor(destination).failed())
4281 return failure();
4282
4283 if (succeeded(parser.parseOptionalLParen())) {
4284 if (failed(parser.parseOperandList(
4285 operands, OpAsmParser::Delimiter::None)) ||
4286 failed(parser.parseColonTypeList(operandTypes)) ||
4287 failed(parser.parseRParen()))
4288 return failure();
4289 }
4290 succOperandBlocks.push_back(destination);
4291 succOperands.emplace_back(operands);
4292 succOperandsTypes.emplace_back(operandTypes);
4293 return success();
4294 },
4295 "successor blocks")))
4296 return failure();
4297 return success();
4298}
4299
4300static void
4301printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
4302 SuccessorRange succs, OperandRangeRange succOperands,
4303 const TypeRangeRange &succOperandsTypes) {
4304 p << "[";
4305 llvm::interleave(
4306 llvm::zip(succs, succOperands),
4307 [&](auto i) {
4308 p.printNewline();
4309 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
4310 },
4311 [&] { p << ','; });
4312 if (!succOperands.empty())
4313 p.printNewline();
4314 p << "]";
4315}
4316
4317//===----------------------------------------------------------------------===//
4318// SincosOp (intrinsic)
4319//===----------------------------------------------------------------------===//
4320
4321LogicalResult LLVM::SincosOp::verify() {
4322 auto operandType = getOperand().getType();
4323 auto resultType = getResult().getType();
4324 auto resultStructType =
4325 mlir::dyn_cast<mlir::LLVM::LLVMStructType>(resultType);
4326 if (!resultStructType || resultStructType.getBody().size() != 2 ||
4327 resultStructType.getBody()[0] != operandType ||
4328 resultStructType.getBody()[1] != operandType) {
4329 return emitOpError("expected result type to be an homogeneous struct with "
4330 "two elements matching the operand type, but got ")
4331 << resultType;
4332 }
4333 return success();
4334}
4335
4336//===----------------------------------------------------------------------===//
4337// AssumeOp (intrinsic)
4338//===----------------------------------------------------------------------===//
4339
4340void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4341 mlir::Value cond) {
4342 return build(builder, state, cond, /*op_bundle_operands=*/{},
4343 /*op_bundle_tags=*/ArrayAttr{});
4344}
4345
4346void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4347 Value cond, llvm::StringRef tag, ValueRange args) {
4348 return build(builder, state, cond, ArrayRef<ValueRange>(args),
4349 builder.getStrArrayAttr(tag));
4350}
4351
4352void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4353 Value cond, AssumeAlignTag, Value ptr, Value align) {
4354 return build(builder, state, cond, "align", ValueRange{ptr, align});
4355}
4356
4357void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4359 Value ptr2) {
4360 return build(builder, state, cond, "separate_storage",
4361 ValueRange{ptr1, ptr2});
4362}
4363
4364LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
4365
4366//===----------------------------------------------------------------------===//
4367// masked_gather (intrinsic)
4368//===----------------------------------------------------------------------===//
4369
4370LogicalResult LLVM::masked_gather::verify() {
4371 auto ptrsVectorType = getPtrs().getType();
4372 Type expectedPtrsVectorType =
4375 // Vector of pointers type should match result vector type, other than the
4376 // element type.
4377 if (ptrsVectorType != expectedPtrsVectorType)
4378 return emitOpError("expected operand #1 type to be ")
4379 << expectedPtrsVectorType;
4380 return success();
4381}
4382
4383//===----------------------------------------------------------------------===//
4384// masked_scatter (intrinsic)
4385//===----------------------------------------------------------------------===//
4386
4387LogicalResult LLVM::masked_scatter::verify() {
4388 auto ptrsVectorType = getPtrs().getType();
4389 Type expectedPtrsVectorType =
4391 LLVM::getVectorNumElements(getValue().getType()));
4392 // Vector of pointers type should match value vector type, other than the
4393 // element type.
4394 if (ptrsVectorType != expectedPtrsVectorType)
4395 return emitOpError("expected operand #2 type to be ")
4396 << expectedPtrsVectorType;
4397 return success();
4398}
4399
4400//===----------------------------------------------------------------------===//
4401// masked_expandload (intrinsic)
4402//===----------------------------------------------------------------------===//
4403
4404void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
4405 mlir::TypeRange resTys, Value ptr,
4406 Value mask, Value passthru,
4407 uint64_t align) {
4408 ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
4409 build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
4410 /*res_attrs=*/nullptr);
4411}
4412
4413//===----------------------------------------------------------------------===//
4414// masked_compressstore (intrinsic)
4415//===----------------------------------------------------------------------===//
4416
4417void LLVM::masked_compressstore::build(OpBuilder &builder,
4418 OperationState &state, Value value,
4419 Value ptr, Value mask, uint64_t align) {
4420 ArrayAttr argAttrs =
4421 getLLVMAlignParamForCompressExpand(builder, false, align);
4422 build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
4423 /*res_attrs=*/nullptr);
4424}
4425
4426//===----------------------------------------------------------------------===//
4427// InlineAsmOp
4428//===----------------------------------------------------------------------===//
4429
4430LogicalResult InlineAsmOp::verify() {
4431 if (!getTailCallKindAttr())
4432 return success();
4433
4434 if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4435 return emitOpError(
4436 "tail call kind 'musttail' is not supported by this operation");
4437
4438 return success();
4439}
4440
4441//===----------------------------------------------------------------------===//
4442// UDivOp
4443//===----------------------------------------------------------------------===//
4444Speculation::Speculatability UDivOp::getSpeculatability() {
4445 // X / 0 => UB
4446 Value divisor = getRhs();
4447 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
4449
4451}
4452
4453//===----------------------------------------------------------------------===//
4454// SDivOp
4455//===----------------------------------------------------------------------===//
4456Speculation::Speculatability SDivOp::getSpeculatability() {
4457 // This function conservatively assumes that all signed division by -1 are
4458 // not speculatable.
4459 // X / 0 => UB
4460 // INT_MIN / -1 => UB
4461 Value divisor = getRhs();
4462 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
4465
4467}
4468
4469//===----------------------------------------------------------------------===//
4470// LLVMDialect initialization, type parsing, and registration.
4471//===----------------------------------------------------------------------===//
4472
4473void LLVMDialect::initialize() {
4474 registerAttributes();
4475
4476 // clang-format off
4477 addTypes<LLVMVoidType,
4478 LLVMLabelType,
4479 LLVMMetadataType>();
4480 // clang-format on
4481 registerTypes();
4482
4483 addOperations<
4484#define GET_OP_LIST
4485#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4486
4487 ,
4488#define GET_OP_LIST
4489#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4490
4491 >();
4492
4493 // Support unknown operations because not all LLVM operations are registered.
4494 allowUnknownOperations();
4495 declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
4497}
4498
4499#define GET_OP_CLASSES
4500#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4501
4502#define GET_OP_CLASSES
4503#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4504
4505LogicalResult LLVMDialect::verifyDataLayoutString(
4506 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
4507 llvm::Expected<llvm::DataLayout> maybeDataLayout =
4508 llvm::DataLayout::parse(descr);
4509 if (maybeDataLayout)
4510 return success();
4511
4512 std::string message;
4513 llvm::raw_string_ostream messageStream(message);
4514 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
4515 reportError("invalid data layout descriptor: " + message);
4516 return failure();
4517}
4518
4519/// Verify LLVM dialect attributes.
4520LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
4521 NamedAttribute attr) {
4522 // If the data layout attribute is present, it must use the LLVM data layout
4523 // syntax. Try parsing it and report errors in case of failure. Users of this
4524 // attribute may assume it is well-formed and can pass it to the (asserting)
4525 // llvm::DataLayout constructor.
4526 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
4527 return success();
4528 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
4529 return verifyDataLayoutString(
4530 stringAttr.getValue(),
4531 [op](const Twine &message) { op->emitOpError() << message.str(); });
4532
4533 return op->emitOpError() << "expected '"
4534 << LLVM::LLVMDialect::getDataLayoutAttrName()
4535 << "' to be a string attributes";
4536}
4537
4538LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
4539 Type paramType,
4540 NamedAttribute paramAttr) {
4541 // LLVM attribute may be attached to a result of operation that has not been
4542 // converted to LLVM dialect yet, so the result may have a type with unknown
4543 // representation in LLVM dialect type space. In this case we cannot verify
4544 // whether the attribute may be
4545 bool verifyValueType = isCompatibleType(paramType);
4546 StringAttr name = paramAttr.getName();
4547
4548 auto checkUnitAttrType = [&]() -> LogicalResult {
4549 if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
4550 return op->emitError() << name << " should be a unit attribute";
4551 return success();
4552 };
4553 auto checkTypeAttrType = [&]() -> LogicalResult {
4554 if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
4555 return op->emitError() << name << " should be a type attribute";
4556 return success();
4557 };
4558 auto checkIntegerAttrType = [&]() -> LogicalResult {
4559 if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
4560 return op->emitError() << name << " should be an integer attribute";
4561 return success();
4562 };
4563 auto checkPointerType = [&]() -> LogicalResult {
4564 if (!llvm::isa<LLVMPointerType>(paramType))
4565 return op->emitError()
4566 << name << " attribute attached to non-pointer LLVM type";
4567 return success();
4568 };
4569 auto checkIntegerType = [&]() -> LogicalResult {
4570 if (!llvm::isa<IntegerType>(paramType))
4571 return op->emitError()
4572 << name << " attribute attached to non-integer LLVM type";
4573 return success();
4574 };
4575 auto checkPointerTypeMatches = [&]() -> LogicalResult {
4576 if (failed(checkPointerType()))
4577 return failure();
4578
4579 return success();
4580 };
4581
4582 // Check a unit attribute that is attached to a pointer value.
4583 if (name == LLVMDialect::getNoAliasAttrName() ||
4584 name == LLVMDialect::getReadonlyAttrName() ||
4585 name == LLVMDialect::getReadnoneAttrName() ||
4586 name == LLVMDialect::getWriteOnlyAttrName() ||
4587 name == LLVMDialect::getNestAttrName() ||
4588 name == LLVMDialect::getNoCaptureAttrName() ||
4589 name == LLVMDialect::getNoFreeAttrName() ||
4590 name == LLVMDialect::getNonNullAttrName()) {
4591 if (failed(checkUnitAttrType()))
4592 return failure();
4593 if (verifyValueType && failed(checkPointerType()))
4594 return failure();
4595 return success();
4596 }
4597
4598 // Check a type attribute that is attached to a pointer value.
4599 if (name == LLVMDialect::getStructRetAttrName() ||
4600 name == LLVMDialect::getByValAttrName() ||
4601 name == LLVMDialect::getByRefAttrName() ||
4602 name == LLVMDialect::getElementTypeAttrName() ||
4603 name == LLVMDialect::getInAllocaAttrName() ||
4604 name == LLVMDialect::getPreallocatedAttrName()) {
4605 if (failed(checkTypeAttrType()))
4606 return failure();
4607 if (verifyValueType && failed(checkPointerTypeMatches()))
4608 return failure();
4609 return success();
4610 }
4611
4612 // Check a unit attribute that is attached to an integer value.
4613 if (name == LLVMDialect::getSExtAttrName() ||
4614 name == LLVMDialect::getZExtAttrName()) {
4615 if (failed(checkUnitAttrType()))
4616 return failure();
4617 if (verifyValueType && failed(checkIntegerType()))
4618 return failure();
4619 return success();
4620 }
4621
4622 // Check an integer attribute that is attached to a pointer value.
4623 if (name == LLVMDialect::getAlignAttrName() ||
4624 name == LLVMDialect::getDereferenceableAttrName() ||
4625 name == LLVMDialect::getDereferenceableOrNullAttrName()) {
4626 if (failed(checkIntegerAttrType()))
4627 return failure();
4628 if (verifyValueType && failed(checkPointerType()))
4629 return failure();
4630 return success();
4631 }
4632
4633 // Check an integer attribute that is attached to a pointer value.
4634 if (name == LLVMDialect::getStackAlignmentAttrName()) {
4635 if (failed(checkIntegerAttrType()))
4636 return failure();
4637 return success();
4638 }
4639
4640 // Check a unit attribute that can be attached to arbitrary types.
4641 if (name == LLVMDialect::getNoUndefAttrName() ||
4642 name == LLVMDialect::getInRegAttrName() ||
4643 name == LLVMDialect::getReturnedAttrName())
4644 return checkUnitAttrType();
4645
4646 return success();
4647}
4648
4649/// Verify LLVMIR function argument attributes.
4650LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
4651 unsigned regionIdx,
4652 unsigned argIdx,
4653 NamedAttribute argAttr) {
4654 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4655 if (!funcOp)
4656 return success();
4657 Type argType = funcOp.getArgumentTypes()[argIdx];
4658
4659 return verifyParameterAttribute(op, argType, argAttr);
4660}
4661
4662LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
4663 unsigned regionIdx,
4664 unsigned resIdx,
4665 NamedAttribute resAttr) {
4666 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4667 if (!funcOp)
4668 return success();
4669 Type resType = funcOp.getResultTypes()[resIdx];
4670
4671 // Check to see if this function has a void return with a result attribute
4672 // to it. It isn't clear what semantics we would assign to that.
4673 if (llvm::isa<LLVMVoidType>(resType))
4674 return op->emitError() << "cannot attach result attributes to functions "
4675 "with a void return";
4676
4677 // Check to see if this attribute is allowed as a result attribute. Only
4678 // explicitly forbidden LLVM attributes will cause an error.
4679 auto name = resAttr.getName();
4680 if (name == LLVMDialect::getAllocAlignAttrName() ||
4681 name == LLVMDialect::getAllocatedPointerAttrName() ||
4682 name == LLVMDialect::getByValAttrName() ||
4683 name == LLVMDialect::getByRefAttrName() ||
4684 name == LLVMDialect::getInAllocaAttrName() ||
4685 name == LLVMDialect::getNestAttrName() ||
4686 name == LLVMDialect::getNoCaptureAttrName() ||
4687 name == LLVMDialect::getNoFreeAttrName() ||
4688 name == LLVMDialect::getPreallocatedAttrName() ||
4689 name == LLVMDialect::getReadnoneAttrName() ||
4690 name == LLVMDialect::getReadonlyAttrName() ||
4691 name == LLVMDialect::getReturnedAttrName() ||
4692 name == LLVMDialect::getStackAlignmentAttrName() ||
4693 name == LLVMDialect::getStructRetAttrName() ||
4694 name == LLVMDialect::getWriteOnlyAttrName())
4695 return op->emitError() << name << " is not a valid result attribute";
4696 return verifyParameterAttribute(op, resType, resAttr);
4697}
4698
4699Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
4700 Type type, Location loc) {
4701 // If this was folded from an operation other than llvm.mlir.constant, it
4702 // should be materialized as such. Note that an llvm.mlir.zero may fold into
4703 // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
4704 if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
4705 if (isa<LLVM::LLVMPointerType>(type))
4706 return LLVM::AddressOfOp::create(builder, loc, type, symbol);
4707 if (isa<LLVM::UndefAttr>(value))
4708 return LLVM::UndefOp::create(builder, loc, type);
4709 if (isa<LLVM::PoisonAttr>(value))
4710 return LLVM::PoisonOp::create(builder, loc, type);
4711 if (isa<LLVM::ZeroAttr>(value))
4712 return LLVM::ZeroOp::create(builder, loc, type);
4713 if (isa<LLVM::MDStringAttr, LLVM::MDConstantAttr, LLVM::MDFuncAttr,
4714 LLVM::MDNodeAttr>(value))
4715 if (isa<LLVM::LLVMMetadataType>(type))
4716 return LLVM::MetadataAsValueOp::create(builder, loc, type, value);
4717 // Otherwise try materializing it as a regular llvm.mlir.constant op.
4718 return LLVM::ConstantOp::materialize(builder, value, type, loc);
4719}
4720
4721//===----------------------------------------------------------------------===//
4722// Utility functions.
4723//===----------------------------------------------------------------------===//
4724
4726 StringRef name, StringRef value,
4727 LLVM::Linkage linkage) {
4728 assert(builder.getInsertionBlock() &&
4729 builder.getInsertionBlock()->getParentOp() &&
4730 "expected builder to point to a block constrained in an op");
4731 auto module =
4732 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
4733 assert(module && "builder points to an op outside of a module");
4734
4735 // Create the global at the entry of the module.
4736 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
4737 MLIRContext *ctx = builder.getContext();
4738 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
4739 auto global = LLVM::GlobalOp::create(
4740 moduleBuilder, loc, type, /*isConstant=*/true, linkage, name,
4741 builder.getStringAttr(value), /*alignment=*/0);
4742
4743 LLVMPointerType ptrType = LLVMPointerType::get(ctx);
4744 // Get the pointer to the first character in the global string.
4745 Value globalPtr =
4746 LLVM::AddressOfOp::create(builder, loc, ptrType, global.getSymNameAttr());
4747 return LLVM::GEPOp::create(builder, loc, ptrType, type, globalPtr,
4748 ArrayRef<GEPArg>{0, 0});
4749}
4750
4755
4757 Operation *module = op->getParentOp();
4758 while (module && !satisfiesLLVMModule(module))
4759 module = module->getParentOp();
4760 assert(module && "unexpected operation outside of a module");
4761 return module;
4762}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static int parseOptionalKeywordAlternative(OpAsmParser &parser, ArrayRef< StringRef > keywords)
static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, bool isExpandLoad, uint64_t alignment=1)
static LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data)
static ParseResult parseGEPIndices(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &indices, DenseI32ArrayAttr &rawConstantIndices)
static LogicalResult verifyOperandBundles(OpType &op)
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result)
static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, TypeRange operandTypes, StringRef tag)
static LogicalResult verifyComdat(Operation *op, std::optional< SymbolRefAttr > attr)
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, ValueRange args)
Constructs a LLVMFunctionType from MLIR results and args.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallOpVarCalleeType(OpTy callOp)
Verify that the parameter and return types of the variadic callee type match the callOp argument and ...
static ParseResult parseOptionalCallFuncPtr(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands)
Parses an optional function pointer operand before the call argument list for indirect calls,...
static bool isZeroAttribute(Attribute value)
static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, OperandRange indices, DenseI32ArrayAttr rawConstantIndices)
static std::optional< ParseResult > parseOpBundles(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand > > &opBundleOperands, SmallVector< SmallVector< Type > > &opBundleOperandTypes, ArrayAttr &opBundleTags)
static LLVMStructType getValAndBoolStructType(Type valType)
Returns an LLVM struct type that contains a value type and a boolean type.
static void printOpBundles(OpAsmPrinter &p, Operation *op, OperandRangeRange opBundleOperands, TypeRangeRange opBundleOperandTypes, std::optional< ArrayAttr > opBundleTags)
static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, Type resType, DenseI32ArrayAttr mask)
Nothing to do when the result type is inferred.
static LogicalResult verifyBlockTags(LLVMFuncOp funcOp)
static Type buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef< Type > inputs, ArrayRef< Type > outputs, function_interface_impl::VariadicFlag variadicFlag)
static auto processFMFAttr(ArrayRef< NamedAttribute > attrs)
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType)
Gets the variadic callee type for a LLVMFunctionType.
static Type getInsertExtractValueElementType(function_ref< InFlightDiagnostic(StringRef)> emitError, Type containerType, ArrayRef< int64_t > position)
Extract the type at position in the LLVM IR aggregate type containerType.
static ParseResult parseOneOpBundle(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand > > &opBundleOperands, SmallVector< SmallVector< Type > > &opBundleOperandTypes, SmallVector< Attribute > &opBundleTags)
static Type getElementType(Type type)
Determine the element type of type.
static void printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static ParseResult resolveOpBundleOperands(OpAsmParser &parser, SMLoc loc, OperationState &state, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand > > opBundleOperands, ArrayRef< SmallVector< Type > > opBundleOperandTypes, StringAttr opBundleSizesAttrName)
static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val)
static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op, LLVM::LLVMArrayType arrayType, ArrayAttr arrayAttr, int dim)
Verifies the constant array represented by arrayAttr matches the provided arrayType.
static ParseResult parseCallTypeAndResolveOperands(OpAsmParser &parser, OperationState &result, bool isDirect, ArrayRef< OpAsmParser::UnresolvedOperand > operands, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses the type of a call operation and resolves the operands if the parsing succeeds.
static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, Operation *op, SymbolTableCollection &symbolTable)
Verifies symbol's use in op to ensure the symbol is a valid and fully defined llvm....
static Type extractVectorElementType(Type type)
Returns the elemental type of any LLVM-compatible vector type or self.
static bool hasScalableVectorType(Type t)
Check if the given type is a scalable vector type or a vector/array type that contains a nested scala...
static SmallVector< Type, 1 > getCallOpResultTypes(LLVMFunctionType calleeType)
Gets the MLIR Op-like result types of a LLVMFunctionType.
static OpFoldResult foldChainableCast(T castOp, typename T::FoldAdaptor adaptor)
Folds a cast op that can be chained.
static void destructureIndices(Type currType, ArrayRef< GEPArg > indices, SmallVectorImpl< int32_t > &rawConstantIndices, SmallVectorImpl< Value > &dynamicIndices)
Destructures the 'indices' parameter into 'rawConstantIndices' and 'dynamicIndices',...
static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser, OperationState &result)
Parse common attributes that might show up in the same order in both GlobalOp and AliasOp.
static Type getI1SameShape(Type type)
Returns a boolean type that has the same shape as type.
static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op)
static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
Returns a scalar or vector boolean attribute of the given type.
static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee)
Verify that an inlinable callsite of a debug-info-bearing function in a debug-info-bearing function h...
static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, Type &resType, DenseI32ArrayAttr mask)
Build the result type of a shuffle vector operation.
static LogicalResult verifyExtOp(ExtOp op)
Verifies that the given extension operation operates on consistent scalars or vectors,...
static constexpr const char kElemTypeAttrName[]
static ParseResult parseInsertExtractValueElementType(AsmParser &parser, Type &valueType, Type containerType, DenseI64ArrayAttr position)
Infer the value type from the container type and position.
static LogicalResult verifyStructIndices(Type baseGEPType, unsigned indexPos, GEPIndicesAdaptor< ValueRange > indices, function_ref< InFlightDiagnostic()> emitOpError)
For the given indices, check if they comply with baseGEPType, especially check against LLVMStructType...
static Attribute extractElementAt(Attribute attr, size_t index)
Extracts the element at the given index from an attribute.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void printInsertExtractValueElementType(AsmPrinter &printer, Operation *op, Type valueType, Type containerType, DenseI64ArrayAttr position)
Nothing to print for an inferred type.
static ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
#define REGISTER_ENUM_TYPE(Ty)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
@ Square
Square brackets surrounding zero or more operands.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalColonTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional colon followed by a type list, which if present must have at least one type.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseCustomAttributeWithFallback(Attribute &result, Type type, function_ref< ParseResult(Attribute &result, Type type)> parseAttribute)=0
Parse a custom attribute with the provided callback, unless the next token is #, in which case the ge...
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printString(StringRef string)
Print the given string as a quoted string, escaping any special or non-printable characters in it.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:311
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class represents a fused location whose metadata is known to be an instance of the given type.
Definition Location.h:149
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Class used for building a 'llvm.getelementptr'.
Definition LLVMDialect.h:71
Class used for convenient access and iteration over GEP indices.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
This class represents a single result from folding an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:85
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
OperandRange operand_range
Definition Operation.h:396
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents a specific instance of an effect.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction for a range of TypeRange.
Definition TypeRange.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
A named class for passing around the variadic flag.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:227
void addBytecodeInterface(LLVMDialect *dialect)
Add the interfaces necessary for encoding the LLVM dialect components in bytecode.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
Operation * parentLLVMModule(Operation *op)
Lookup parent Module satisfying LLVM conditions on the Module Operation.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition LLVMDialect.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout)
Returns true if the given type is supported by atomic operations.
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, TypeRange resultTypes, ArrayAttr resultAttrs, Region *body=nullptr, bool printEmptyResult=true)
Print a function signature for a call or callable operation.
ParseResult parseFunctionSignature(OpAsmParser &parser, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs, bool mustParseEmptyResult=true)
Parses a function signature using parser.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.