MLIR 22.0.0git
LLVMDialect.cpp
Go to the documentation of this file.
1//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the LLVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12//===----------------------------------------------------------------------===//
13
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Matchers.h"
26
27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/IR/DataLayout.h"
30#include "llvm/Support/Error.h"
31
32#include "LLVMDialectBytecode.h"
33
34#include <numeric>
35#include <optional>
36
37using namespace mlir;
38using namespace mlir::LLVM;
39using mlir::LLVM::cconv::getMaxEnumValForCConv;
40using mlir::LLVM::linkage::getMaxEnumValForLinkage;
41using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
42
43#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
44
45//===----------------------------------------------------------------------===//
46// Attribute Helpers
47//===----------------------------------------------------------------------===//
48
49static constexpr const char kElemTypeAttrName[] = "elem_type";
50
53 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
54 if (attr.getName() == "fastmathFlags") {
55 auto defAttr =
56 FastmathFlagsAttr::get(attr.getValue().getContext(), {});
57 return defAttr != attr.getValue();
58 }
59 return true;
60 }));
61 return filteredAttrs;
62}
63
64/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
65/// fully defined llvm.func.
66static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
67 Operation *op,
68 SymbolTableCollection &symbolTable) {
69 StringRef name = symbol.getValue();
70 auto func =
71 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
72 if (!func)
73 return op->emitOpError("'")
74 << name << "' does not reference a valid LLVM function";
75 if (func.isExternal())
76 return op->emitOpError("'") << name << "' does not have a definition";
77 return success();
78}
79
80/// Returns a boolean type that has the same shape as `type`. It supports both
81/// fixed size vectors as well as scalable vectors.
82static Type getI1SameShape(Type type) {
83 Type i1Type = IntegerType::get(type.getContext(), 1);
86 return i1Type;
87}
88
89// Parses one of the keywords provided in the list `keywords` and returns the
90// position of the parsed keyword in the list. If none of the keywords from the
91// list is parsed, returns -1.
93 ArrayRef<StringRef> keywords) {
94 for (const auto &en : llvm::enumerate(keywords)) {
95 if (succeeded(parser.parseOptionalKeyword(en.value())))
96 return en.index();
97 }
98 return -1;
99}
100
101namespace {
102template <typename Ty>
103struct EnumTraits {};
104
105#define REGISTER_ENUM_TYPE(Ty) \
106 template <> \
107 struct EnumTraits<Ty> { \
108 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
109 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
110 }
111
112REGISTER_ENUM_TYPE(Linkage);
113REGISTER_ENUM_TYPE(UnnamedAddr);
114REGISTER_ENUM_TYPE(CConv);
115REGISTER_ENUM_TYPE(TailCallKind);
116REGISTER_ENUM_TYPE(Visibility);
117} // namespace
118
119/// Parse an enum from the keyword, or default to the provided default value.
120/// The return type is the enum type by default, unless overridden with the
121/// second template argument.
122template <typename EnumTy, typename RetTy = EnumTy>
124 EnumTy defaultValue) {
126 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
127 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
128
129 int index = parseOptionalKeywordAlternative(parser, names);
130 if (index == -1)
131 return static_cast<RetTy>(defaultValue);
132 return static_cast<RetTy>(index);
133}
134
135static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) {
136 p << stringifyLinkage(val.getLinkage());
137}
138
139static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
140 val = LinkageAttr::get(
141 p.getContext(),
142 parseOptionalLLVMKeyword<LLVM::Linkage>(p, LLVM::Linkage::External));
143 return success();
144}
145
147 bool isExpandLoad,
148 uint64_t alignment = 1) {
149 // From
150 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
151 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
152 //
153 // The pointer alignment defaults to 1.
154 if (alignment == 1) {
155 return nullptr;
156 }
157
158 auto emptyDictAttr = builder.getDictionaryAttr({});
159 auto alignmentAttr = builder.getI64IntegerAttr(alignment);
160 auto namedAttr =
161 builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
162 SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
163 auto alignDictAttr = builder.getDictionaryAttr(attrs);
164 // From
165 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
166 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
167 //
168 // The align parameter attribute can be provided for [expandload]'s first
169 // argument. The align parameter attribute can be provided for
170 // [compressstore]'s second argument.
171 int pos = isExpandLoad ? 0 : 1;
172 return pos == 0 ? builder.getArrayAttr(
173 {alignDictAttr, emptyDictAttr, emptyDictAttr})
174 : builder.getArrayAttr(
175 {emptyDictAttr, alignDictAttr, emptyDictAttr});
176}
177
178//===----------------------------------------------------------------------===//
179// Operand bundle helpers.
180//===----------------------------------------------------------------------===//
181
183 TypeRange operandTypes, StringRef tag) {
184 p.printString(tag);
185 p << "(";
186
187 if (!operands.empty()) {
188 p.printOperands(operands);
189 p << " : ";
190 llvm::interleaveComma(operandTypes, p);
191 }
192
193 p << ")";
194}
195
197 OperandRangeRange opBundleOperands,
198 TypeRangeRange opBundleOperandTypes,
199 std::optional<ArrayAttr> opBundleTags) {
200 if (opBundleOperands.empty())
201 return;
202 assert(opBundleTags && "expect operand bundle tags");
203
204 p << "[";
205 llvm::interleaveComma(
206 llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
207 [&p](auto bundle) {
208 auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
209 printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
210 bundleTag);
211 });
212 p << "]";
213}
214
215static ParseResult parseOneOpBundle(
216 OpAsmParser &p,
218 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
219 SmallVector<Attribute> &opBundleTags) {
220 SMLoc currentParserLoc = p.getCurrentLocation();
222 SmallVector<Type> types;
223 std::string tag;
224
225 if (p.parseString(&tag))
226 return p.emitError(currentParserLoc, "expect operand bundle tag");
227
228 if (p.parseLParen())
229 return failure();
230
231 if (p.parseOptionalRParen()) {
232 if (p.parseOperandList(operands) || p.parseColon() ||
233 p.parseTypeList(types) || p.parseRParen())
234 return failure();
235 }
236
237 opBundleOperands.push_back(std::move(operands));
238 opBundleOperandTypes.push_back(std::move(types));
239 opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
240
241 return success();
242}
243
244static std::optional<ParseResult> parseOpBundles(
245 OpAsmParser &p,
247 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
248 ArrayAttr &opBundleTags) {
249 if (p.parseOptionalLSquare())
250 return std::nullopt;
251
252 if (succeeded(p.parseOptionalRSquare()))
253 return success();
254
255 SmallVector<Attribute> opBundleTagAttrs;
256 auto bundleParser = [&] {
257 return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
258 opBundleTagAttrs);
259 };
260 if (p.parseCommaSeparatedList(bundleParser))
261 return failure();
262
263 if (p.parseRSquare())
264 return failure();
265
266 opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
267
268 return success();
269}
270
271//===----------------------------------------------------------------------===//
272// Printing, parsing, folding and builder for LLVM::CmpOp.
273//===----------------------------------------------------------------------===//
274
275void ICmpOp::print(OpAsmPrinter &p) {
276 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
277 << ", " << getOperand(1);
278 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
279 p << " : " << getLhs().getType();
280}
281
282void FCmpOp::print(OpAsmPrinter &p) {
283 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
284 << ", " << getOperand(1);
285 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
286 p << " : " << getLhs().getType();
287}
288
289// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
290// attribute-dict? `:` type
291// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
292// attribute-dict? `:` type
293template <typename CmpPredicateType>
294static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
295 StringAttr predicateAttr;
297 Type type;
298 SMLoc predicateLoc, trailingTypeLoc;
299 if (parser.getCurrentLocation(&predicateLoc) ||
300 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
301 parser.parseOperand(lhs) || parser.parseComma() ||
302 parser.parseOperand(rhs) ||
303 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
304 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
305 parser.resolveOperand(lhs, type, result.operands) ||
306 parser.resolveOperand(rhs, type, result.operands))
307 return failure();
308
309 // Replace the string attribute `predicate` with an integer attribute.
310 int64_t predicateValue = 0;
311 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
312 std::optional<ICmpPredicate> predicate =
313 symbolizeICmpPredicate(predicateAttr.getValue());
314 if (!predicate)
315 return parser.emitError(predicateLoc)
316 << "'" << predicateAttr.getValue()
317 << "' is an incorrect value of the 'predicate' attribute";
318 predicateValue = static_cast<int64_t>(*predicate);
319 } else {
320 std::optional<FCmpPredicate> predicate =
321 symbolizeFCmpPredicate(predicateAttr.getValue());
322 if (!predicate)
323 return parser.emitError(predicateLoc)
324 << "'" << predicateAttr.getValue()
325 << "' is an incorrect value of the 'predicate' attribute";
326 predicateValue = static_cast<int64_t>(*predicate);
327 }
328
329 result.attributes.set("predicate",
330 parser.getBuilder().getI64IntegerAttr(predicateValue));
331
332 // The result type is either i1 or a vector type <? x i1> if the inputs are
333 // vectors.
334 if (!isCompatibleType(type))
335 return parser.emitError(trailingTypeLoc,
336 "expected LLVM dialect-compatible type");
337 result.addTypes(getI1SameShape(type));
338 return success();
339}
340
341ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
342 return parseCmpOp<ICmpPredicate>(parser, result);
343}
344
345ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
346 return parseCmpOp<FCmpPredicate>(parser, result);
347}
348
349/// Returns a scalar or vector boolean attribute of the given type.
350static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
351 auto boolAttr = BoolAttr::get(ctx, value);
352 ShapedType shapedType = dyn_cast<ShapedType>(type);
353 if (!shapedType)
354 return boolAttr;
355 return DenseElementsAttr::get(shapedType, boolAttr);
356}
357
358OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
359 if (getPredicate() != ICmpPredicate::eq &&
360 getPredicate() != ICmpPredicate::ne)
361 return {};
362
363 // cmpi(eq/ne, x, x) -> true/false
364 if (getLhs() == getRhs())
366 getPredicate() == ICmpPredicate::eq);
367
368 // cmpi(eq/ne, alloca, null) -> false/true
369 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
371 getPredicate() == ICmpPredicate::ne);
372
373 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
374 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
375 Value lhs = getLhs();
376 Value rhs = getRhs();
377 getLhsMutable().assign(rhs);
378 getRhsMutable().assign(lhs);
379 return getResult();
380 }
381
382 return {};
383}
384
385//===----------------------------------------------------------------------===//
386// Printing, parsing and verification for LLVM::AllocaOp.
387//===----------------------------------------------------------------------===//
388
389void AllocaOp::print(OpAsmPrinter &p) {
390 auto funcTy =
391 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
392
393 if (getInalloca())
394 p << " inalloca";
395
396 p << ' ' << getArraySize() << " x " << getElemType();
397 if (getAlignment() && *getAlignment() != 0)
398 p.printOptionalAttrDict((*this)->getAttrs(),
399 {kElemTypeAttrName, getInallocaAttrName()});
400 else
402 (*this)->getAttrs(),
403 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
404 p << " : " << funcTy;
405}
406
407// <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
408// attribute-dict? `:` type `,` type
409ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
411 Type type, elemType;
412 SMLoc trailingTypeLoc;
413
414 if (succeeded(parser.parseOptionalKeyword("inalloca")))
415 result.addAttribute(getInallocaAttrName(result.name),
416 UnitAttr::get(parser.getContext()));
417
418 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
419 parser.parseType(elemType) ||
420 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
421 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
422 return failure();
423
424 std::optional<NamedAttribute> alignmentAttr =
425 result.attributes.getNamed("alignment");
426 if (alignmentAttr.has_value()) {
427 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
428 if (!alignmentInt)
429 return parser.emitError(parser.getNameLoc(),
430 "expected integer alignment");
431 if (alignmentInt.getValue().isZero())
432 result.attributes.erase("alignment");
433 }
434
435 // Extract the result type from the trailing function type.
436 auto funcType = llvm::dyn_cast<FunctionType>(type);
437 if (!funcType || funcType.getNumInputs() != 1 ||
438 funcType.getNumResults() != 1)
439 return parser.emitError(
440 trailingTypeLoc,
441 "expected trailing function type with one argument and one result");
442
443 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
444 return failure();
445
446 Type resultType = funcType.getResult(0);
447 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
448 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
449
450 result.addTypes({funcType.getResult(0)});
451 return success();
452}
453
454LogicalResult AllocaOp::verify() {
455 // Only certain target extension types can be used in 'alloca'.
456 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
457 targetExtType && !targetExtType.supportsMemOps())
458 return emitOpError()
459 << "this target extension type cannot be used in alloca";
460
461 return success();
462}
463
464//===----------------------------------------------------------------------===//
465// LLVM::BrOp
466//===----------------------------------------------------------------------===//
467
468SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
469 assert(index == 0 && "invalid successor index");
470 return SuccessorOperands(getDestOperandsMutable());
471}
472
473//===----------------------------------------------------------------------===//
474// LLVM::CondBrOp
475//===----------------------------------------------------------------------===//
476
477SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
478 assert(index < getNumSuccessors() && "invalid successor index");
479 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
480 : getFalseDestOperandsMutable());
481}
482
483void CondBrOp::build(OpBuilder &builder, OperationState &result,
484 Value condition, Block *trueDest, ValueRange trueOperands,
485 Block *falseDest, ValueRange falseOperands,
486 std::optional<std::pair<uint32_t, uint32_t>> weights) {
487 DenseI32ArrayAttr weightsAttr;
488 if (weights)
489 weightsAttr =
490 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
491 static_cast<int32_t>(weights->second)});
492
493 build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
494 /*loop_annotation=*/{}, trueDest, falseDest);
495}
496
497//===----------------------------------------------------------------------===//
498// LLVM::SwitchOp
499//===----------------------------------------------------------------------===//
500
501void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
502 Block *defaultDestination, ValueRange defaultOperands,
503 DenseIntElementsAttr caseValues,
504 BlockRange caseDestinations,
505 ArrayRef<ValueRange> caseOperands,
506 ArrayRef<int32_t> branchWeights) {
507 DenseI32ArrayAttr weightsAttr;
508 if (!branchWeights.empty())
509 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
510
511 build(builder, result, value, defaultOperands, caseOperands, caseValues,
512 weightsAttr, defaultDestination, caseDestinations);
513}
514
515void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
516 Block *defaultDestination, ValueRange defaultOperands,
517 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
518 ArrayRef<ValueRange> caseOperands,
519 ArrayRef<int32_t> branchWeights) {
520 DenseIntElementsAttr caseValuesAttr;
521 if (!caseValues.empty()) {
522 ShapedType caseValueType = VectorType::get(
523 static_cast<int64_t>(caseValues.size()), value.getType());
524 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
525 }
526
527 build(builder, result, value, defaultDestination, defaultOperands,
528 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
529}
530
531void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
532 Block *defaultDestination, ValueRange defaultOperands,
533 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
534 ArrayRef<ValueRange> caseOperands,
535 ArrayRef<int32_t> branchWeights) {
536 DenseIntElementsAttr caseValuesAttr;
537 if (!caseValues.empty()) {
538 ShapedType caseValueType = VectorType::get(
539 static_cast<int64_t>(caseValues.size()), value.getType());
540 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
541 }
542
543 build(builder, result, value, defaultDestination, defaultOperands,
544 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
545}
546
547/// <cases> ::= `[` (case (`,` case )* )? `]`
548/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
549static ParseResult parseSwitchOpCases(
550 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
551 SmallVectorImpl<Block *> &caseDestinations,
553 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
554 if (failed(parser.parseLSquare()))
555 return failure();
556 if (succeeded(parser.parseOptionalRSquare()))
557 return success();
558 SmallVector<APInt> values;
559 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
560 auto parseCase = [&]() {
561 int64_t value = 0;
562 if (failed(parser.parseInteger(value)))
563 return failure();
564 values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
565
566 Block *destination;
568 SmallVector<Type> operandTypes;
569 if (parser.parseColon() || parser.parseSuccessor(destination))
570 return failure();
571 if (!parser.parseOptionalLParen()) {
573 /*allowResultNumber=*/false) ||
574 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
575 return failure();
576 }
577 caseDestinations.push_back(destination);
578 caseOperands.emplace_back(operands);
579 caseOperandTypes.emplace_back(operandTypes);
580 return success();
581 };
582 if (failed(parser.parseCommaSeparatedList(parseCase)))
583 return failure();
584
585 ShapedType caseValueType =
586 VectorType::get(static_cast<int64_t>(values.size()), flagType);
587 caseValues = DenseIntElementsAttr::get(caseValueType, values);
588 return parser.parseRSquare();
589}
590
591static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
592 DenseIntElementsAttr caseValues,
593 SuccessorRange caseDestinations,
594 OperandRangeRange caseOperands,
595 const TypeRangeRange &caseOperandTypes) {
596 p << '[';
597 p.printNewline();
598 if (!caseValues) {
599 p << ']';
600 return;
601 }
602
603 size_t index = 0;
604 llvm::interleave(
605 llvm::zip(caseValues, caseDestinations),
606 [&](auto i) {
607 p << " ";
608 p << std::get<0>(i);
609 p << ": ";
610 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
611 },
612 [&] {
613 p << ',';
614 p.printNewline();
615 });
616 p.printNewline();
617 p << ']';
618}
619
620LogicalResult SwitchOp::verify() {
621 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
622 (getCaseValues() &&
623 getCaseValues()->size() !=
624 static_cast<int64_t>(getCaseDestinations().size())))
625 return emitOpError("expects number of case values to match number of "
626 "case destinations");
627 if (getCaseValues() &&
628 getValue().getType() != getCaseValues()->getElementType())
629 return emitError("expects case value type to match condition value type");
630 return success();
631}
632
633SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
634 assert(index < getNumSuccessors() && "invalid successor index");
635 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
636 : getCaseOperandsMutable(index - 1));
637}
638
639//===----------------------------------------------------------------------===//
640// Code for LLVM::GEPOp.
641//===----------------------------------------------------------------------===//
642
643GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
644 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
645 getDynamicIndices());
646}
647
648/// Returns the elemental type of any LLVM-compatible vector type or self.
650 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
651 return vectorType.getElementType();
652 return type;
653}
654
655/// Destructures the 'indices' parameter into 'rawConstantIndices' and
656/// 'dynamicIndices', encoding the former in the process. In the process,
657/// dynamic indices which are used to index into a structure type are converted
658/// to constant indices when possible. To do this, the GEPs element type should
659/// be passed as first parameter.
661 SmallVectorImpl<int32_t> &rawConstantIndices,
662 SmallVectorImpl<Value> &dynamicIndices) {
663 for (const GEPArg &iter : indices) {
664 // If the thing we are currently indexing into is a struct we must turn
665 // any integer constants into constant indices. If this is not possible
666 // we don't do anything here. The verifier will catch it and emit a proper
667 // error. All other canonicalization is done in the fold method.
668 bool requiresConst = !rawConstantIndices.empty() &&
669 isa_and_nonnull<LLVMStructType>(currType);
670 if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
671 APInt intC;
672 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
673 intC.isSignedIntN(kGEPConstantBitWidth)) {
674 rawConstantIndices.push_back(intC.getSExtValue());
675 } else {
676 rawConstantIndices.push_back(GEPOp::kDynamicIndex);
677 dynamicIndices.push_back(val);
678 }
679 } else {
680 rawConstantIndices.push_back(cast<GEPConstantIndex>(iter));
681 }
682
683 // Skip for very first iteration of this loop. First index does not index
684 // within the aggregates, but is just a pointer offset.
685 if (rawConstantIndices.size() == 1 || !currType)
686 continue;
687
688 currType = TypeSwitch<Type, Type>(currType)
689 .Case<VectorType, LLVMArrayType>([](auto containerType) {
690 return containerType.getElementType();
691 })
692 .Case([&](LLVMStructType structType) -> Type {
693 int64_t memberIndex = rawConstantIndices.back();
694 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
695 structType.getBody().size())
696 return structType.getBody()[memberIndex];
697 return nullptr;
698 })
699 .Default(nullptr);
700 }
701}
702
703void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
704 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
705 GEPNoWrapFlags noWrapFlags,
706 ArrayRef<NamedAttribute> attributes) {
707 SmallVector<int32_t> rawConstantIndices;
708 SmallVector<Value> dynamicIndices;
709 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
710
711 result.addTypes(resultType);
712 result.addAttributes(attributes);
713 result.getOrAddProperties<Properties>().rawConstantIndices =
714 builder.getDenseI32ArrayAttr(rawConstantIndices);
715 result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
716 result.getOrAddProperties<Properties>().elem_type =
717 TypeAttr::get(elementType);
718 result.addOperands(basePtr);
719 result.addOperands(dynamicIndices);
720}
721
722void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
723 Type elementType, Value basePtr, ValueRange indices,
724 GEPNoWrapFlags noWrapFlags,
725 ArrayRef<NamedAttribute> attributes) {
726 build(builder, result, resultType, elementType, basePtr,
727 SmallVector<GEPArg>(indices), noWrapFlags, attributes);
728}
729
730static ParseResult
733 DenseI32ArrayAttr &rawConstantIndices) {
734 SmallVector<int32_t> constantIndices;
735
736 auto idxParser = [&]() -> ParseResult {
737 int32_t constantIndex;
738 OptionalParseResult parsedInteger =
739 parser.parseOptionalInteger(constantIndex);
740 if (parsedInteger.has_value()) {
741 if (failed(parsedInteger.value()))
742 return failure();
743 constantIndices.push_back(constantIndex);
744 return success();
745 }
746
747 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
748 return parser.parseOperand(indices.emplace_back());
749 };
750 if (parser.parseCommaSeparatedList(idxParser))
751 return failure();
752
753 rawConstantIndices =
754 DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
755 return success();
756}
757
758static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
760 DenseI32ArrayAttr rawConstantIndices) {
761 llvm::interleaveComma(
762 GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
764 if (Value val = llvm::dyn_cast_if_present<Value>(cst))
765 printer.printOperand(val);
766 else
767 printer << cast<IntegerAttr>(cst).getInt();
768 });
769}
770
771/// For the given `indices`, check if they comply with `baseGEPType`,
772/// especially check against LLVMStructTypes nested within.
773static LogicalResult
774verifyStructIndices(Type baseGEPType, unsigned indexPos,
777 if (indexPos >= indices.size())
778 // Stop searching
779 return success();
780
781 return TypeSwitch<Type, LogicalResult>(baseGEPType)
782 .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
783 auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
784 if (!attr)
785 return emitOpError() << "expected index " << indexPos
786 << " indexing a struct to be constant";
787
788 int32_t gepIndex = attr.getInt();
789 ArrayRef<Type> elementTypes = structType.getBody();
790 if (gepIndex < 0 ||
791 static_cast<size_t>(gepIndex) >= elementTypes.size())
792 return emitOpError() << "index " << indexPos
793 << " indexing a struct is out of bounds";
794
795 // Instead of recursively going into every children types, we only
796 // dive into the one indexed by gepIndex.
797 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
799 })
800 .Case<VectorType, LLVMArrayType>(
801 [&](auto containerType) -> LogicalResult {
802 return verifyStructIndices(containerType.getElementType(),
803 indexPos + 1, indices, emitOpError);
804 })
805 .Default([&](auto otherType) -> LogicalResult {
806 return emitOpError()
807 << "type " << otherType << " cannot be indexed (index #"
808 << indexPos << ")";
809 });
810}
811
812/// Driver function around `verifyStructIndices`.
813static LogicalResult
818
819LogicalResult LLVM::GEPOp::verify() {
820 if (static_cast<size_t>(
821 llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
822 getDynamicIndices().size())
823 return emitOpError("expected as many dynamic indices as specified in '")
824 << getRawConstantIndicesAttrName().getValue() << "'";
825
826 if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
827 return emitOpError("'inbounds_flag' cannot be used directly.");
828
829 return verifyStructIndices(getElemType(), getIndices(),
830 [&] { return emitOpError(); });
831}
832
833//===----------------------------------------------------------------------===//
834// LoadOp
835//===----------------------------------------------------------------------===//
836
837void LoadOp::getEffects(
839 &effects) {
840 effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
841 // Volatile operations can have target-specific read-write effects on
842 // memory besides the one referred to by the pointer operand.
843 // Similarly, atomic operations that are monotonic or stricter cause
844 // synchronization that from a language point-of-view, are arbitrary
845 // read-writes into memory.
846 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
847 getOrdering() != AtomicOrdering::unordered)) {
848 effects.emplace_back(MemoryEffects::Write::get());
849 effects.emplace_back(MemoryEffects::Read::get());
850 }
851}
852
853/// Returns true if the given type is supported by atomic operations. All
854/// integer, float, and pointer types with a power-of-two bitsize and a minimal
855/// size of 8 bits are supported.
857 const DataLayout &dataLayout) {
858 if (!isa<IntegerType, LLVMPointerType>(type))
860 return false;
861
862 llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type);
863 if (bitWidth.isScalable())
864 return false;
865 // Needs to be at least 8 bits and a power of two.
866 return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
867}
868
869/// Verifies the attributes and the type of atomic memory access operations.
870template <typename OpTy>
871static LogicalResult
872verifyAtomicMemOp(OpTy memOp, Type valueType,
873 ArrayRef<AtomicOrdering> unsupportedOrderings) {
874 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
875 DataLayout dataLayout = DataLayout::closest(memOp);
876 if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
877 return memOp.emitOpError("unsupported type ")
878 << valueType << " for atomic access";
879 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
880 return memOp.emitOpError("unsupported ordering '")
881 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
882 if (!memOp.getAlignment())
883 return memOp.emitOpError("expected alignment for atomic access");
884 return success();
885 }
886 if (memOp.getSyncscope())
887 return memOp.emitOpError(
888 "expected syncscope to be null for non-atomic access");
889 return success();
890}
891
892LogicalResult LoadOp::verify() {
893 Type valueType = getResult().getType();
894 return verifyAtomicMemOp(*this, valueType,
895 {AtomicOrdering::release, AtomicOrdering::acq_rel});
896}
897
898void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
899 Value addr, unsigned alignment, bool isVolatile,
900 bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
901 AtomicOrdering ordering, StringRef syncscope) {
902 build(builder, state, type, addr,
903 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
904 isNonTemporal, isInvariant, isInvariantGroup, ordering,
905 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
906 /*dereferenceable=*/nullptr,
907 /*access_groups=*/nullptr,
908 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
909 /*tbaa=*/nullptr);
910}
911
912//===----------------------------------------------------------------------===//
913// StoreOp
914//===----------------------------------------------------------------------===//
915
916void StoreOp::getEffects(
918 &effects) {
919 effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
920 // Volatile operations can have target-specific read-write effects on
921 // memory besides the one referred to by the pointer operand.
922 // Similarly, atomic operations that are monotonic or stricter cause
923 // synchronization that from a language point-of-view, are arbitrary
924 // read-writes into memory.
925 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
926 getOrdering() != AtomicOrdering::unordered)) {
927 effects.emplace_back(MemoryEffects::Write::get());
928 effects.emplace_back(MemoryEffects::Read::get());
929 }
930}
931
932LogicalResult StoreOp::verify() {
933 Type valueType = getValue().getType();
934 return verifyAtomicMemOp(*this, valueType,
935 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
936}
937
938void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
939 Value addr, unsigned alignment, bool isVolatile,
940 bool isNonTemporal, bool isInvariantGroup,
941 AtomicOrdering ordering, StringRef syncscope) {
942 build(builder, state, value, addr,
943 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
944 isNonTemporal, isInvariantGroup, ordering,
945 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
946 /*access_groups=*/nullptr,
947 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
948}
949
950//===----------------------------------------------------------------------===//
951// CallOp
952//===----------------------------------------------------------------------===//
953
954/// Gets the MLIR Op-like result types of a LLVMFunctionType.
955static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
956 SmallVector<Type, 1> results;
957 Type resultType = calleeType.getReturnType();
958 if (!isa<LLVM::LLVMVoidType>(resultType))
959 results.push_back(resultType);
960 return results;
961}
962
963/// Gets the variadic callee type for a LLVMFunctionType.
964static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
965 return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
966}
967
968/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
969static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
970 ValueRange args) {
971 Type resultType;
972 if (results.empty())
973 resultType = LLVMVoidType::get(context);
974 else
975 resultType = results.front();
976 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
977 /*isVarArg=*/false);
978}
979
980void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
981 StringRef callee, ValueRange args) {
982 build(builder, state, results, builder.getStringAttr(callee), args);
983}
984
985void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
986 StringAttr callee, ValueRange args) {
987 build(builder, state, results, SymbolRefAttr::get(callee), args);
988}
989
990void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
991 FlatSymbolRefAttr callee, ValueRange args) {
992 assert(callee && "expected non-null callee in direct call builder");
993 build(builder, state, results,
994 /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
995 /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
996 /*memory_effects=*/nullptr,
997 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
998 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
999 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1000 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1001 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1002 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1003 /*inline_hint=*/nullptr);
1004}
1005
1006void CallOp::build(OpBuilder &builder, OperationState &state,
1007 LLVMFunctionType calleeType, StringRef callee,
1008 ValueRange args) {
1009 build(builder, state, calleeType, builder.getStringAttr(callee), args);
1010}
1011
1012void CallOp::build(OpBuilder &builder, OperationState &state,
1013 LLVMFunctionType calleeType, StringAttr callee,
1014 ValueRange args) {
1015 build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
1016}
1017
1018void CallOp::build(OpBuilder &builder, OperationState &state,
1019 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1020 ValueRange args) {
1021 build(builder, state, getCallOpResultTypes(calleeType),
1022 getCallOpVarCalleeType(calleeType), callee, args,
1023 /*fastmathFlags=*/nullptr,
1024 /*CConv=*/nullptr,
1025 /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1026 /*convergent=*/nullptr,
1027 /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1028 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1029 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1030 /*access_groups=*/nullptr,
1031 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1032 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1033 /*inline_hint=*/nullptr);
1034}
1035
1036void CallOp::build(OpBuilder &builder, OperationState &state,
1037 LLVMFunctionType calleeType, ValueRange args) {
1038 build(builder, state, getCallOpResultTypes(calleeType),
1039 getCallOpVarCalleeType(calleeType),
1040 /*callee=*/nullptr, args,
1041 /*fastmathFlags=*/nullptr,
1042 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1043 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1044 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1045 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1046 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1047 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1048 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1049 /*inline_hint=*/nullptr);
1050}
1051
1052void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1053 ValueRange args) {
1054 auto calleeType = func.getFunctionType();
1055 build(builder, state, getCallOpResultTypes(calleeType),
1056 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
1057 /*fastmathFlags=*/nullptr,
1058 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1059 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1060 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1061 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1062 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1063 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1064 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1065 /*inline_hint=*/nullptr);
1066}
1067
1068CallInterfaceCallable CallOp::getCallableForCallee() {
1069 // Direct call.
1070 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1071 return calleeAttr;
1072 // Indirect call, callee Value is the first operand.
1073 return getOperand(0);
1074}
1075
1076void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1077 // Direct call.
1078 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1079 auto symRef = cast<SymbolRefAttr>(callee);
1080 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1081 }
1082 // Indirect call, callee Value is the first operand.
1083 return setOperand(0, cast<Value>(callee));
1084}
1085
1086Operation::operand_range CallOp::getArgOperands() {
1087 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1088}
1089
1090MutableOperandRange CallOp::getArgOperandsMutable() {
1091 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1092 getCalleeOperands().size());
1093}
1094
1095/// Verify that an inlinable callsite of a debug-info-bearing function in a
1096/// debug-info-bearing function has a debug location attached to it. This
1097/// mirrors an LLVM IR verifier.
1098static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1099 if (callee.isExternal())
1100 return success();
1101 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1102 if (!parentFunc)
1103 return success();
1104
1105 auto hasSubprogram = [](Operation *op) {
1106 return op->getLoc()
1107 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1108 nullptr;
1109 };
1110 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1111 return success();
1112 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1113 if (!containsLoc)
1114 return callOp.emitError()
1115 << "inlinable function call in a function with a DISubprogram "
1116 "location must have a debug location";
1117 return success();
1118}
1119
1120/// Verify that the parameter and return types of the variadic callee type match
1121/// the `callOp` argument and result types.
1122template <typename OpTy>
1123static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1124 std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1125 if (!varCalleeType)
1126 return success();
1127
1128 // Verify the variadic callee type is a variadic function type.
1129 if (!varCalleeType->isVarArg())
1130 return callOp.emitOpError(
1131 "expected var_callee_type to be a variadic function type");
1132
1133 // Verify the variadic callee type has at most as many parameters as the call
1134 // has argument operands.
1135 if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1136 return callOp.emitOpError("expected var_callee_type to have at most ")
1137 << callOp.getArgOperands().size() << " parameters";
1138
1139 // Verify the variadic callee type matches the call argument types.
1140 for (auto [paramType, operand] :
1141 llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1142 if (paramType != operand.getType())
1143 return callOp.emitOpError()
1144 << "var_callee_type parameter type mismatch: " << paramType
1145 << " != " << operand.getType();
1146
1147 // Verify the variadic callee type matches the call result type.
1148 if (!callOp.getNumResults()) {
1149 if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1150 return callOp.emitOpError("expected var_callee_type to return void");
1151 } else {
1152 if (callOp.getResult().getType() != varCalleeType->getReturnType())
1153 return callOp.emitOpError("var_callee_type return type mismatch: ")
1154 << varCalleeType->getReturnType()
1155 << " != " << callOp.getResult().getType();
1156 }
1157 return success();
1158}
1159
1160template <typename OpType>
1161static LogicalResult verifyOperandBundles(OpType &op) {
1162 OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1163 std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1164
1165 auto isStringAttr = [](Attribute tagAttr) {
1166 return isa<StringAttr>(tagAttr);
1167 };
1168 if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1169 return op.emitError("operand bundle tag must be a StringAttr");
1170
1171 size_t numOpBundles = opBundleOperands.size();
1172 size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1173 if (numOpBundles != numOpBundleTags)
1174 return op.emitError("expected ")
1175 << numOpBundles << " operand bundle tags, but actually got "
1176 << numOpBundleTags;
1177
1178 return success();
1179}
1180
1181LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
1182
1183LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1185 return failure();
1186
1187 // Type for the callee, we'll get it differently depending if it is a direct
1188 // or indirect call.
1189 Type fnType;
1190
1191 bool isIndirect = false;
1192
1193 // If this is an indirect call, the callee attribute is missing.
1194 FlatSymbolRefAttr calleeName = getCalleeAttr();
1195 if (!calleeName) {
1196 isIndirect = true;
1197 if (!getNumOperands())
1198 return emitOpError(
1199 "must have either a `callee` attribute or at least an operand");
1200 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1201 if (!ptrType)
1202 return emitOpError("indirect call expects a pointer as callee: ")
1203 << getOperand(0).getType();
1204
1205 return success();
1206 } else {
1207 Operation *callee =
1208 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1209 if (!callee)
1210 return emitOpError()
1211 << "'" << calleeName.getValue()
1212 << "' does not reference a symbol in the current scope";
1213 if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
1214 if (failed(verifyCallOpDebugInfo(*this, fn)))
1215 return failure();
1216 fnType = fn.getFunctionType();
1217 } else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
1218 fnType = ifunc.getIFuncType();
1219 } else {
1220 return emitOpError()
1221 << "'" << calleeName.getValue()
1222 << "' does not reference a valid LLVM function or IFunc";
1223 }
1224 }
1225
1226 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1227 if (!funcType)
1228 return emitOpError("callee does not have a functional type: ") << fnType;
1229
1230 if (funcType.isVarArg() && !getVarCalleeType())
1231 return emitOpError() << "missing var_callee_type attribute for vararg call";
1232
1233 // Verify that the operand and result types match the callee.
1234
1235 if (!funcType.isVarArg() &&
1236 funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1237 return emitOpError() << "incorrect number of operands ("
1238 << (getCalleeOperands().size() - isIndirect)
1239 << ") for callee (expecting: "
1240 << funcType.getNumParams() << ")";
1241
1242 if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1243 return emitOpError() << "incorrect number of operands ("
1244 << (getCalleeOperands().size() - isIndirect)
1245 << ") for varargs callee (expecting at least: "
1246 << funcType.getNumParams() << ")";
1247
1248 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1249 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1250 return emitOpError() << "operand type mismatch for operand " << i << ": "
1251 << getOperand(i + isIndirect).getType()
1252 << " != " << funcType.getParamType(i);
1253
1254 if (getNumResults() == 0 &&
1255 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1256 return emitOpError() << "expected function call to produce a value";
1257
1258 if (getNumResults() != 0 &&
1259 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1260 return emitOpError()
1261 << "calling function with void result must not produce values";
1262
1263 if (getNumResults() > 1)
1264 return emitOpError()
1265 << "expected LLVM function call to produce 0 or 1 result";
1266
1267 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1268 return emitOpError() << "result type mismatch: " << getResult().getType()
1269 << " != " << funcType.getReturnType();
1270
1271 return success();
1272}
1273
1274void CallOp::print(OpAsmPrinter &p) {
1275 auto callee = getCallee();
1276 bool isDirect = callee.has_value();
1277
1278 p << ' ';
1279
1280 // Print calling convention.
1281 if (getCConv() != LLVM::CConv::C)
1282 p << stringifyCConv(getCConv()) << ' ';
1283
1284 if (getTailCallKind() != LLVM::TailCallKind::None)
1285 p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1286
1287 // Print the direct callee if present as a function attribute, or an indirect
1288 // callee (first operand) otherwise.
1289 if (isDirect)
1290 p.printSymbolName(callee.value());
1291 else
1292 p << getOperand(0);
1293
1294 auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
1295 p << '(' << args << ')';
1296
1297 // Print the variadic callee type if the call is variadic.
1298 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1299 p << " vararg(" << *varCalleeType << ")";
1300
1301 if (!getOpBundleOperands().empty()) {
1302 p << " ";
1303 printOpBundles(p, *this, getOpBundleOperands(),
1304 getOpBundleOperands().getTypes(), getOpBundleTags());
1305 }
1306
1307 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1308 {getCalleeAttrName(), getTailCallKindAttrName(),
1309 getVarCalleeTypeAttrName(), getCConvAttrName(),
1310 getOperandSegmentSizesAttrName(),
1311 getOpBundleSizesAttrName(),
1312 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1313 getResAttrsAttrName()});
1314
1315 p << " : ";
1316 if (!isDirect)
1317 p << getOperand(0).getType() << ", ";
1318
1319 // Reconstruct the MLIR function type from operand and result types.
1321 p, args.getTypes(), getArgAttrsAttr(),
1322 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1323}
1324
1325/// Parses the type of a call operation and resolves the operands if the parsing
1326/// succeeds. Returns failure otherwise.
1328 OpAsmParser &parser, OperationState &result, bool isDirect,
1331 SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1332 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1333 SmallVector<Type> types;
1334 if (parser.parseColon())
1335 return failure();
1336 if (!isDirect) {
1337 types.emplace_back();
1338 if (parser.parseType(types.back()))
1339 return failure();
1340 if (parser.parseOptionalComma())
1341 return parser.emitError(
1342 trailingTypesLoc, "expected indirect call to have 2 trailing types");
1343 }
1344 SmallVector<Type> argTypes;
1345 SmallVector<Type> resTypes;
1346 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1347 resTypes, resultAttrs)) {
1348 if (isDirect)
1349 return parser.emitError(trailingTypesLoc,
1350 "expected direct call to have 1 trailing types");
1351 return parser.emitError(trailingTypesLoc,
1352 "expected trailing function type");
1353 }
1354
1355 if (resTypes.size() > 1)
1356 return parser.emitError(trailingTypesLoc,
1357 "expected function with 0 or 1 result");
1358 if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
1359 return parser.emitError(trailingTypesLoc,
1360 "expected a non-void result type");
1361
1362 // The head element of the types list matches the callee type for
1363 // indirect calls, while the types list is emtpy for direct calls.
1364 // Append the function input types to resolve the call operation
1365 // operands.
1366 llvm::append_range(types, argTypes);
1367 if (parser.resolveOperands(operands, types, parser.getNameLoc(),
1368 result.operands))
1369 return failure();
1370 if (!resTypes.empty())
1371 result.addTypes(resTypes);
1372
1373 return success();
1374}
1375
1376/// Parses an optional function pointer operand before the call argument list
1377/// for indirect calls, or stops parsing at the function identifier otherwise.
1378static ParseResult parseOptionalCallFuncPtr(
1379 OpAsmParser &parser,
1381 OpAsmParser::UnresolvedOperand funcPtrOperand;
1382 OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
1383 if (parseResult.has_value()) {
1384 if (failed(*parseResult))
1385 return *parseResult;
1386 operands.push_back(funcPtrOperand);
1387 }
1388 return success();
1389}
1390
1391static ParseResult resolveOpBundleOperands(
1392 OpAsmParser &parser, SMLoc loc, OperationState &state,
1394 ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1395 StringAttr opBundleSizesAttrName) {
1396 unsigned opBundleIndex = 0;
1397 for (const auto &[operands, types] :
1398 llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
1399 if (operands.size() != types.size())
1400 return parser.emitError(loc, "expected ")
1401 << operands.size()
1402 << " types for operand bundle operands for operand bundle #"
1403 << opBundleIndex << ", but actually got " << types.size();
1404 if (parser.resolveOperands(operands, types, loc, state.operands))
1405 return failure();
1406 }
1407
1408 SmallVector<int32_t> opBundleSizes;
1409 opBundleSizes.reserve(opBundleOperands.size());
1410 for (const auto &operands : opBundleOperands)
1411 opBundleSizes.push_back(operands.size());
1412
1413 state.addAttribute(
1414 opBundleSizesAttrName,
1415 DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
1416
1417 return success();
1418}
1419
1420// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1421// `(` ssa-use-list `)`
1422// ( `vararg(` var-callee-type `)` )?
1423// ( `[` op-bundles-list `]` )?
1424// attribute-dict? `:` (type `,`)? function-type
1425ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1426 SymbolRefAttr funcAttr;
1427 TypeAttr varCalleeType;
1430 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1431 ArrayAttr opBundleTags;
1432
1433 // Default to C Calling Convention if no keyword is provided.
1434 result.addAttribute(
1435 getCConvAttrName(result.name),
1436 CConvAttr::get(parser.getContext(),
1437 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
1438
1439 result.addAttribute(
1440 getTailCallKindAttrName(result.name),
1441 TailCallKindAttr::get(parser.getContext(),
1443 parser, LLVM::TailCallKind::None)));
1444
1445 // Parse a function pointer for indirect calls.
1446 if (parseOptionalCallFuncPtr(parser, operands))
1447 return failure();
1448 bool isDirect = operands.empty();
1449
1450 // Parse a function identifier for direct calls.
1451 if (isDirect)
1452 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1453 return failure();
1454
1455 // Parse the function arguments.
1456 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1457 return failure();
1458
1459 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1460 if (isVarArg) {
1461 StringAttr varCalleeTypeAttrName =
1462 CallOp::getVarCalleeTypeAttrName(result.name);
1463 if (parser.parseLParen().failed() ||
1464 parser
1465 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1466 result.attributes)
1467 .failed() ||
1468 parser.parseRParen().failed())
1469 return failure();
1470 }
1471
1472 SMLoc opBundlesLoc = parser.getCurrentLocation();
1473 if (std::optional<ParseResult> result = parseOpBundles(
1474 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1475 result && failed(*result))
1476 return failure();
1477 if (opBundleTags && !opBundleTags.empty())
1478 result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
1479 opBundleTags);
1480
1481 if (parser.parseOptionalAttrDict(result.attributes))
1482 return failure();
1483
1484 // Parse the trailing type list and resolve the operands.
1486 SmallVector<DictionaryAttr> resultAttrs;
1487 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1488 argAttrs, resultAttrs))
1489 return failure();
1491 parser.getBuilder(), result, argAttrs, resultAttrs,
1492 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1493 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1494 opBundleOperandTypes,
1495 getOpBundleSizesAttrName(result.name)))
1496 return failure();
1497
1498 int32_t numOpBundleOperands = 0;
1499 for (const auto &operands : opBundleOperands)
1500 numOpBundleOperands += operands.size();
1501
1502 result.addAttribute(
1503 CallOp::getOperandSegmentSizeAttr(),
1505 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1506 return success();
1507}
1508
1509LLVMFunctionType CallOp::getCalleeFunctionType() {
1510 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1511 return *varCalleeType;
1512 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1513}
1514
1515///===---------------------------------------------------------------------===//
1516/// LLVM::InvokeOp
1517///===---------------------------------------------------------------------===//
1518
1519void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1520 ValueRange ops, Block *normal, ValueRange normalOps,
1521 Block *unwind, ValueRange unwindOps) {
1522 auto calleeType = func.getFunctionType();
1523 build(builder, state, getCallOpResultTypes(calleeType),
1524 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1525 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1526 nullptr, nullptr, {}, {}, normal, unwind);
1527}
1528
1529void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1530 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1531 ValueRange normalOps, Block *unwind,
1532 ValueRange unwindOps) {
1533 build(builder, state, tys,
1534 /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr,
1535 /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {},
1536 normal, unwind);
1537}
1538
1539void InvokeOp::build(OpBuilder &builder, OperationState &state,
1540 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1541 ValueRange ops, Block *normal, ValueRange normalOps,
1542 Block *unwind, ValueRange unwindOps) {
1543 build(builder, state, getCallOpResultTypes(calleeType),
1544 getCallOpVarCalleeType(calleeType), callee, ops,
1545 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1546 nullptr, nullptr, {}, {}, normal, unwind);
1547}
1548
1549SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1550 assert(index < getNumSuccessors() && "invalid successor index");
1551 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1552 : getUnwindDestOperandsMutable());
1553}
1554
1555CallInterfaceCallable InvokeOp::getCallableForCallee() {
1556 // Direct call.
1557 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1558 return calleeAttr;
1559 // Indirect call, callee Value is the first operand.
1560 return getOperand(0);
1561}
1562
1563void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1564 // Direct call.
1565 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1566 auto symRef = cast<SymbolRefAttr>(callee);
1567 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1568 }
1569 // Indirect call, callee Value is the first operand.
1570 return setOperand(0, cast<Value>(callee));
1571}
1572
1573Operation::operand_range InvokeOp::getArgOperands() {
1574 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1575}
1576
1577MutableOperandRange InvokeOp::getArgOperandsMutable() {
1578 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1579 getCalleeOperands().size());
1580}
1581
1582LogicalResult InvokeOp::verify() {
1584 return failure();
1585
1586 Block *unwindDest = getUnwindDest();
1587 if (unwindDest->empty())
1588 return emitError("must have at least one operation in unwind destination");
1589
1590 // In unwind destination, first operation must be LandingpadOp
1591 if (!isa<LandingpadOp>(unwindDest->front()))
1592 return emitError("first operation in unwind destination should be a "
1593 "llvm.landingpad operation");
1594
1595 if (failed(verifyOperandBundles(*this)))
1596 return failure();
1597
1598 return success();
1599}
1600
1601void InvokeOp::print(OpAsmPrinter &p) {
1602 auto callee = getCallee();
1603 bool isDirect = callee.has_value();
1604
1605 p << ' ';
1606
1607 // Print calling convention.
1608 if (getCConv() != LLVM::CConv::C)
1609 p << stringifyCConv(getCConv()) << ' ';
1610
1611 // Either function name or pointer
1612 if (isDirect)
1613 p.printSymbolName(callee.value());
1614 else
1615 p << getOperand(0);
1616
1617 p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
1618 p << " to ";
1619 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1620 p << " unwind ";
1621 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1622
1623 // Print the variadic callee type if the invoke is variadic.
1624 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1625 p << " vararg(" << *varCalleeType << ")";
1626
1627 if (!getOpBundleOperands().empty()) {
1628 p << " ";
1629 printOpBundles(p, *this, getOpBundleOperands(),
1630 getOpBundleOperands().getTypes(), getOpBundleTags());
1631 }
1632
1633 p.printOptionalAttrDict((*this)->getAttrs(),
1634 {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1635 getCConvAttrName(), getVarCalleeTypeAttrName(),
1636 getOpBundleSizesAttrName(),
1637 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1638 getResAttrsAttrName()});
1639
1640 p << " : ";
1641 if (!isDirect)
1642 p << getOperand(0).getType() << ", ";
1644 p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
1645 getArgAttrsAttr(),
1646 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1647}
1648
1649// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1650// `(` ssa-use-list `)`
1651// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1652// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1653// ( `vararg(` var-callee-type `)` )?
1654// ( `[` op-bundles-list `]` )?
1655// attribute-dict? `:` (type `,`)?
1656// function-type-with-argument-attributes
1657ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1659 SymbolRefAttr funcAttr;
1660 TypeAttr varCalleeType;
1662 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1663 ArrayAttr opBundleTags;
1664 Block *normalDest, *unwindDest;
1665 SmallVector<Value, 4> normalOperands, unwindOperands;
1666 Builder &builder = parser.getBuilder();
1667
1668 // Default to C Calling Convention if no keyword is provided.
1669 result.addAttribute(
1670 getCConvAttrName(result.name),
1671 CConvAttr::get(parser.getContext(),
1672 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
1673
1674 // Parse a function pointer for indirect calls.
1675 if (parseOptionalCallFuncPtr(parser, operands))
1676 return failure();
1677 bool isDirect = operands.empty();
1678
1679 // Parse a function identifier for direct calls.
1680 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1681 return failure();
1682
1683 // Parse the function arguments.
1684 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1685 parser.parseKeyword("to") ||
1686 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1687 parser.parseKeyword("unwind") ||
1688 parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1689 return failure();
1690
1691 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1692 if (isVarArg) {
1693 StringAttr varCalleeTypeAttrName =
1694 InvokeOp::getVarCalleeTypeAttrName(result.name);
1695 if (parser.parseLParen().failed() ||
1696 parser
1697 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1698 result.attributes)
1699 .failed() ||
1700 parser.parseRParen().failed())
1701 return failure();
1702 }
1703
1704 SMLoc opBundlesLoc = parser.getCurrentLocation();
1705 if (std::optional<ParseResult> result = parseOpBundles(
1706 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1707 result && failed(*result))
1708 return failure();
1709 if (opBundleTags && !opBundleTags.empty())
1710 result.addAttribute(
1711 InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
1712 opBundleTags);
1713
1714 if (parser.parseOptionalAttrDict(result.attributes))
1715 return failure();
1716
1717 // Parse the trailing type list and resolve the function operands.
1719 SmallVector<DictionaryAttr> resultAttrs;
1720 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1721 argAttrs, resultAttrs))
1722 return failure();
1724 parser.getBuilder(), result, argAttrs, resultAttrs,
1725 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1726
1727 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1728 opBundleOperandTypes,
1729 getOpBundleSizesAttrName(result.name)))
1730 return failure();
1731
1732 result.addSuccessors({normalDest, unwindDest});
1733 result.addOperands(normalOperands);
1734 result.addOperands(unwindOperands);
1735
1736 int32_t numOpBundleOperands = 0;
1737 for (const auto &operands : opBundleOperands)
1738 numOpBundleOperands += operands.size();
1739
1740 result.addAttribute(
1741 InvokeOp::getOperandSegmentSizeAttr(),
1742 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
1743 static_cast<int32_t>(normalOperands.size()),
1744 static_cast<int32_t>(unwindOperands.size()),
1745 numOpBundleOperands}));
1746 return success();
1747}
1748
1749LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1750 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1751 return *varCalleeType;
1752 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1753}
1754
1755///===----------------------------------------------------------------------===//
1756/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1757///===----------------------------------------------------------------------===//
1758
1759LogicalResult LandingpadOp::verify() {
1760 Value value;
1761 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1762 if (!func.getPersonality())
1763 return emitError(
1764 "llvm.landingpad needs to be in a function with a personality");
1765 }
1766
1767 // Consistency of llvm.landingpad result types is checked in
1768 // LLVMFuncOp::verify().
1769
1770 if (!getCleanup() && getOperands().empty())
1771 return emitError("landingpad instruction expects at least one clause or "
1772 "cleanup attribute");
1773
1774 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1775 value = getOperand(idx);
1776 bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1777 if (isFilter) {
1778 // FIXME: Verify filter clauses when arrays are appropriately handled
1779 } else {
1780 // catch - global addresses only.
1781 // Bitcast ops should have global addresses as their args.
1782 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1783 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1784 continue;
1785 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1786 << "global addresses expected as operand to "
1787 "bitcast used in clauses for landingpad";
1788 }
1789 // ZeroOp and AddressOfOp allowed
1790 if (value.getDefiningOp<ZeroOp>())
1791 continue;
1792 if (value.getDefiningOp<AddressOfOp>())
1793 continue;
1794 return emitError("clause #")
1795 << idx << " is not a known constant - null, addressof, bitcast";
1796 }
1797 }
1798 return success();
1799}
1800
1801void LandingpadOp::print(OpAsmPrinter &p) {
1802 p << (getCleanup() ? " cleanup " : " ");
1803
1804 // Clauses
1805 for (auto value : getOperands()) {
1806 // Similar to llvm - if clause is an array type then it is filter
1807 // clause else catch clause
1808 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1809 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1810 << value.getType() << ") ";
1811 }
1812
1813 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1814
1815 p << ": " << getType();
1816}
1817
1818// <operation> ::= `llvm.landingpad` `cleanup`?
1819// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1820ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1821 // Check for cleanup
1822 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1823 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1824
1825 // Parse clauses with types
1826 while (succeeded(parser.parseOptionalLParen()) &&
1827 (succeeded(parser.parseOptionalKeyword("filter")) ||
1828 succeeded(parser.parseOptionalKeyword("catch")))) {
1830 Type ty;
1831 if (parser.parseOperand(operand) || parser.parseColon() ||
1832 parser.parseType(ty) ||
1833 parser.resolveOperand(operand, ty, result.operands) ||
1834 parser.parseRParen())
1835 return failure();
1836 }
1837
1838 Type type;
1839 if (parser.parseColon() || parser.parseType(type))
1840 return failure();
1841
1842 result.addTypes(type);
1843 return success();
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// ExtractValueOp
1848//===----------------------------------------------------------------------===//
1849
1850/// Extract the type at `position` in the LLVM IR aggregate type
1851/// `containerType`. Each element of `position` is an index into a nested
1852/// aggregate type. Return the resulting type or emit an error.
1854 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1855 ArrayRef<int64_t> position) {
1856 Type llvmType = containerType;
1857 if (!isCompatibleType(containerType)) {
1858 emitError("expected LLVM IR Dialect type, got ") << containerType;
1859 return {};
1860 }
1861
1862 // Infer the element type from the structure type: iteratively step inside the
1863 // type by taking the element type, indexed by the position attribute for
1864 // structures. Check the position index before accessing, it is supposed to
1865 // be in bounds.
1866 for (int64_t idx : position) {
1867 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1868 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1869 emitError("position out of bounds: ") << idx;
1870 return {};
1871 }
1872 llvmType = arrayType.getElementType();
1873 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1874 if (idx < 0 ||
1875 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1876 emitError("position out of bounds: ") << idx;
1877 return {};
1878 }
1879 llvmType = structType.getBody()[idx];
1880 } else {
1881 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1882 return {};
1883 }
1884 }
1885 return llvmType;
1886}
1887
1888/// Extract the type at `position` in the wrapped LLVM IR aggregate type
1889/// `containerType`.
1891 ArrayRef<int64_t> position) {
1892 for (int64_t idx : position) {
1893 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1894 llvmType = structType.getBody()[idx];
1895 else
1896 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1897 }
1898 return llvmType;
1899}
1900
1901/// Extracts the element at the given index from an attribute. For
1902/// `ElementsAttr` and `ArrayAttr`, returns the element at the specified index.
1903/// For `ZeroAttr`, `UndefAttr`, and `PoisonAttr`, returns the attribute itself
1904/// unchanged. Returns `nullptr` if the attribute is not one of these types or
1905/// if the index is out of bounds.
1907 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
1908 if (index < static_cast<size_t>(elementsAttr.getNumElements()))
1909 return elementsAttr.getValues<Attribute>()[index];
1910 return nullptr;
1911 }
1912 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1913 if (index < arrayAttr.getValue().size())
1914 return arrayAttr[index];
1915 return nullptr;
1916 }
1917 if (isa<ZeroAttr, UndefAttr, PoisonAttr>(attr))
1918 return attr;
1919 return nullptr;
1920}
1921
1922OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1923 if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
1924 SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
1925 newPos.append(getPosition().begin(), getPosition().end());
1926 setPosition(newPos);
1927 getContainerMutable().set(extractValueOp.getContainer());
1928 return getResult();
1929 }
1930
1931 Operation *container = getContainer().getDefiningOp();
1932 OpFoldResult result = {};
1933 ArrayRef<int64_t> extractPos = getPosition();
1934 bool switchedToInsertedValue = false;
1935 while (auto insertValueOp = dyn_cast_if_present<InsertValueOp>(container)) {
1936 ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1937 auto extractPosSize = extractPos.size();
1938 auto insertPosSize = insertPos.size();
1939
1940 // Case 1: Exact match of positions.
1941 if (extractPos == insertPos)
1942 return insertValueOp.getValue();
1943
1944 // Case 2: Insert position is a prefix of extract position. Continue
1945 // traversal with the inserted value. Example:
1946 // ```
1947 // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1948 // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1949 // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1950 // %3 = llvm.insertvalue %2, %foo[0]
1951 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1952 // %4 = llvm.extractvalue %3[0, 0]
1953 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1954 // ```
1955 // In the above example, %4 is folded to %arg1.
1956 if (extractPosSize > insertPosSize &&
1957 extractPos.take_front(insertPosSize) == insertPos) {
1958 container = insertValueOp.getValue().getDefiningOp();
1959 extractPos = extractPos.drop_front(insertPosSize);
1960 switchedToInsertedValue = true;
1961 continue;
1962 }
1963
1964 // Case 3: Try to continue the traversal with the container value.
1965 unsigned min = std::min(extractPosSize, insertPosSize);
1966
1967 // If one is fully prefix of the other, stop propagating back as it will
1968 // miss dependencies. For instance, %3 should not fold to %f0 in the
1969 // following example:
1970 // ```
1971 // %1 = llvm.insertvalue %f0, %0[0, 0] :
1972 // !llvm.array<4 x !llvm.array<4 x f32>>
1973 // %2 = llvm.insertvalue %arr, %1[0] :
1974 // !llvm.array<4 x !llvm.array<4 x f32>>
1975 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1976 // ```
1977 if (extractPos.take_front(min) == insertPos.take_front(min))
1978 return result;
1979 // If neither a prefix, nor the exact position, we can extract out of the
1980 // value being inserted into. Moreover, we can try again if that operand
1981 // is itself an insertvalue expression.
1982 if (!switchedToInsertedValue) {
1983 // Do not swap out the container operand if we decided earlier to
1984 // continue the traversal with the inserted value (Case 2).
1985 getContainerMutable().assign(insertValueOp.getContainer());
1986 result = getResult();
1987 }
1988 container = insertValueOp.getContainer().getDefiningOp();
1989 }
1990 if (!container)
1991 return result;
1992
1993 Attribute containerAttr;
1994 if (!matchPattern(container, m_Constant(&containerAttr)))
1995 return nullptr;
1996 for (int64_t pos : extractPos) {
1997 containerAttr = extractElementAt(containerAttr, pos);
1998 if (!containerAttr)
1999 return nullptr;
2000 }
2001 return containerAttr;
2002}
2003
2004LogicalResult ExtractValueOp::verify() {
2005 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
2007 emitError, getContainer().getType(), getPosition());
2008 if (!valueType)
2009 return failure();
2010
2011 if (getRes().getType() != valueType)
2012 return emitOpError() << "Type mismatch: extracting from "
2013 << getContainer().getType() << " should produce "
2014 << valueType << " but this op returns "
2015 << getRes().getType();
2016 return success();
2017}
2018
2019void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
2020 Value container, ArrayRef<int64_t> position) {
2021 build(builder, state,
2022 getInsertExtractValueElementType(container.getType(), position),
2023 container, builder.getAttr<DenseI64ArrayAttr>(position));
2024}
2025
2026//===----------------------------------------------------------------------===//
2027// InsertValueOp
2028//===----------------------------------------------------------------------===//
2029
2030/// Infer the value type from the container type and position.
2031static ParseResult
2033 Type containerType,
2034 DenseI64ArrayAttr position) {
2036 [&](StringRef msg) {
2037 return parser.emitError(parser.getCurrentLocation(), msg);
2038 },
2039 containerType, position.asArrayRef());
2040 return success(!!valueType);
2041}
2042
2043/// Nothing to print for an inferred type.
2045 Operation *op, Type valueType,
2046 Type containerType,
2047 DenseI64ArrayAttr position) {}
2048
2049LogicalResult InsertValueOp::verify() {
2050 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
2052 emitError, getContainer().getType(), getPosition());
2053 if (!valueType)
2054 return failure();
2055
2056 if (getValue().getType() != valueType)
2057 return emitOpError() << "Type mismatch: cannot insert "
2058 << getValue().getType() << " into "
2059 << getContainer().getType();
2060
2061 return success();
2062}
2063
2064//===----------------------------------------------------------------------===//
2065// ReturnOp
2066//===----------------------------------------------------------------------===//
2067
2068LogicalResult ReturnOp::verify() {
2069 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
2070 if (!parent)
2071 return success();
2072
2073 Type expectedType = parent.getFunctionType().getReturnType();
2074 if (llvm::isa<LLVMVoidType>(expectedType)) {
2075 if (!getArg())
2076 return success();
2077 InFlightDiagnostic diag = emitOpError("expected no operands");
2078 diag.attachNote(parent->getLoc()) << "when returning from function";
2079 return diag;
2080 }
2081 if (!getArg()) {
2082 if (llvm::isa<LLVMVoidType>(expectedType))
2083 return success();
2084 InFlightDiagnostic diag = emitOpError("expected 1 operand");
2085 diag.attachNote(parent->getLoc()) << "when returning from function";
2086 return diag;
2087 }
2088 if (expectedType != getArg().getType()) {
2089 InFlightDiagnostic diag = emitOpError("mismatching result types");
2090 diag.attachNote(parent->getLoc()) << "when returning from function";
2091 return diag;
2092 }
2093 return success();
2094}
2095
2096//===----------------------------------------------------------------------===//
2097// LLVM::AddressOfOp.
2098//===----------------------------------------------------------------------===//
2099
2100GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2101 return dyn_cast_or_null<GlobalOp>(
2102 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2103}
2104
2105LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2106 return dyn_cast_or_null<LLVMFuncOp>(
2107 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2108}
2109
2110AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
2111 return dyn_cast_or_null<AliasOp>(
2112 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2113}
2114
2115IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) {
2116 return dyn_cast_or_null<IFuncOp>(
2117 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2118}
2119
2120LogicalResult
2121AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2122 Operation *symbol =
2123 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
2124
2125 auto global = dyn_cast_or_null<GlobalOp>(symbol);
2126 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2127 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2128 auto ifunc = dyn_cast_or_null<IFuncOp>(symbol);
2129
2130 if (!global && !function && !alias && !ifunc)
2131 return emitOpError("must reference a global defined by 'llvm.mlir.global', "
2132 "'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'");
2133
2134 LLVMPointerType type = getType();
2135 if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
2136 (alias && alias.getAddrSpace() != type.getAddressSpace()))
2137 return emitOpError("pointer address space must match address space of the "
2138 "referenced global or alias");
2139
2140 return success();
2141}
2142
2143// AddressOfOp constant-folds to the global symbol name.
2144OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2145 return getGlobalNameAttr();
2146}
2147
2148//===----------------------------------------------------------------------===//
2149// LLVM::DSOLocalEquivalentOp
2150//===----------------------------------------------------------------------===//
2151
2152LLVMFuncOp
2153DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) {
2154 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
2155 parentLLVMModule(*this), getFunctionNameAttr()));
2156}
2157
2158AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) {
2159 return dyn_cast_or_null<AliasOp>(symbolTable.lookupSymbolIn(
2160 parentLLVMModule(*this), getFunctionNameAttr()));
2161}
2162
2163LogicalResult
2164DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2165 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
2166 getFunctionNameAttr());
2167 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2168 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2169
2170 if (!function && !alias)
2171 return emitOpError(
2172 "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'");
2173
2174 if (alias) {
2175 if (alias.getInitializer()
2176 .walk([&](AddressOfOp addrOp) {
2177 if (addrOp.getGlobal(symbolTable))
2178 return WalkResult::interrupt();
2179 return WalkResult::advance();
2180 })
2181 .wasInterrupted())
2182 return emitOpError("must reference an alias to a function");
2183 }
2184
2185 if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) ||
2186 (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak))
2187 return emitOpError(
2188 "target function with 'extern_weak' linkage not allowed");
2189
2190 return success();
2191}
2192
2193/// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent
2194/// attribute.
2195OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) {
2196 return DSOLocalEquivalentAttr::get(getContext(), getFunctionNameAttr());
2197}
2198
2199//===----------------------------------------------------------------------===//
2200// Verifier for LLVM::ComdatOp.
2201//===----------------------------------------------------------------------===//
2202
2203void ComdatOp::build(OpBuilder &builder, OperationState &result,
2204 StringRef symName) {
2205 result.addAttribute(getSymNameAttrName(result.name),
2206 builder.getStringAttr(symName));
2207 Region *body = result.addRegion();
2208 body->emplaceBlock();
2209}
2210
2211LogicalResult ComdatOp::verifyRegions() {
2212 Region &body = getBody();
2213 for (Operation &op : body.getOps())
2214 if (!isa<ComdatSelectorOp>(op))
2215 return op.emitError(
2216 "only comdat selector symbols can appear in a comdat region");
2217
2218 return success();
2219}
2220
2221//===----------------------------------------------------------------------===//
2222// Builder, printer and verifier for LLVM::GlobalOp.
2223//===----------------------------------------------------------------------===//
2224
2225void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2226 bool isConstant, Linkage linkage, StringRef name,
2227 Attribute value, uint64_t alignment, unsigned addrSpace,
2228 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2230 ArrayRef<Attribute> dbgExprs) {
2231 result.addAttribute(getSymNameAttrName(result.name),
2232 builder.getStringAttr(name));
2233 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
2234 if (isConstant)
2235 result.addAttribute(getConstantAttrName(result.name),
2236 builder.getUnitAttr());
2237 if (value)
2238 result.addAttribute(getValueAttrName(result.name), value);
2239 if (dsoLocal)
2240 result.addAttribute(getDsoLocalAttrName(result.name),
2241 builder.getUnitAttr());
2242 if (threadLocal)
2243 result.addAttribute(getThreadLocal_AttrName(result.name),
2244 builder.getUnitAttr());
2245 if (comdat)
2246 result.addAttribute(getComdatAttrName(result.name), comdat);
2247
2248 // Only add an alignment attribute if the "alignment" input
2249 // is different from 0. The value must also be a power of two, but
2250 // this is tested in GlobalOp::verify, not here.
2251 if (alignment != 0)
2252 result.addAttribute(getAlignmentAttrName(result.name),
2253 builder.getI64IntegerAttr(alignment));
2254
2255 result.addAttribute(getLinkageAttrName(result.name),
2256 LinkageAttr::get(builder.getContext(), linkage));
2257 if (addrSpace != 0)
2258 result.addAttribute(getAddrSpaceAttrName(result.name),
2259 builder.getI32IntegerAttr(addrSpace));
2260 result.attributes.append(attrs.begin(), attrs.end());
2261
2262 if (!dbgExprs.empty())
2263 result.addAttribute(getDbgExprsAttrName(result.name),
2264 ArrayAttr::get(builder.getContext(), dbgExprs));
2265
2266 result.addRegion();
2267}
2268
2269template <typename OpType>
2270static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op) {
2271 p << ' ' << stringifyLinkage(op.getLinkage()) << ' ';
2272 StringRef visibility = stringifyVisibility(op.getVisibility_());
2273 if (!visibility.empty())
2274 p << visibility << ' ';
2275 if (op.getThreadLocal_())
2276 p << "thread_local ";
2277 if (auto unnamedAddr = op.getUnnamedAddr()) {
2278 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2279 if (!str.empty())
2280 p << str << ' ';
2281 }
2282}
2283
2284void GlobalOp::print(OpAsmPrinter &p) {
2286 if (getConstant())
2287 p << "constant ";
2288 p.printSymbolName(getSymName());
2289 p << '(';
2290 if (auto value = getValueOrNull())
2291 p.printAttribute(value);
2292 p << ')';
2293 if (auto comdat = getComdat())
2294 p << " comdat(" << *comdat << ')';
2295
2296 // Note that the alignment attribute is printed using the
2297 // default syntax here, even though it is an inherent attribute
2298 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2299 p.printOptionalAttrDict((*this)->getAttrs(),
2300 {SymbolTable::getSymbolAttrName(),
2301 getGlobalTypeAttrName(), getConstantAttrName(),
2302 getValueAttrName(), getLinkageAttrName(),
2303 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2304 getVisibility_AttrName(), getComdatAttrName()});
2305
2306 // Print the trailing type unless it's a string global.
2307 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
2308 return;
2309 p << " : " << getType();
2310
2311 Region &initializer = getInitializerRegion();
2312 if (!initializer.empty()) {
2313 p << ' ';
2314 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
2315 }
2316}
2317
2318static LogicalResult verifyComdat(Operation *op,
2319 std::optional<SymbolRefAttr> attr) {
2320 if (!attr)
2321 return success();
2322
2323 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
2324 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
2325 return op->emitError() << "expected comdat symbol";
2326
2327 return success();
2328}
2329
2330static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2332 // Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
2333 // block to be removed by canonicalizer's region simplify pass, which needs to
2334 // be dialect aware to allow extra constraints to be described.
2335 WalkResult res = funcOp.walk([&](BlockTagOp blockTagOp) {
2336 if (blockTags.contains(blockTagOp.getTag())) {
2337 blockTagOp.emitError()
2338 << "duplicate block tag '" << blockTagOp.getTag().getId()
2339 << "' in the same function: ";
2340 return WalkResult::interrupt();
2341 }
2342 blockTags.insert(blockTagOp.getTag());
2343 return WalkResult::advance();
2344 });
2345
2346 return failure(res.wasInterrupted());
2347}
2348
2349/// Parse common attributes that might show up in the same order in both
2350/// GlobalOp and AliasOp.
2351template <typename OpType>
2352static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser,
2354 MLIRContext *ctx = parser.getContext();
2355 // Parse optional linkage, default to External.
2356 result.addAttribute(
2357 OpType::getLinkageAttrName(result.name),
2358 LLVM::LinkageAttr::get(ctx, parseOptionalLLVMKeyword<Linkage>(
2359 parser, LLVM::Linkage::External)));
2360
2361 // Parse optional visibility, default to Default.
2362 result.addAttribute(OpType::getVisibility_AttrName(result.name),
2365 parser, LLVM::Visibility::Default)));
2366
2367 if (succeeded(parser.parseOptionalKeyword("thread_local")))
2368 result.addAttribute(OpType::getThreadLocal_AttrName(result.name),
2369 parser.getBuilder().getUnitAttr());
2370
2371 // Parse optional UnnamedAddr, default to None.
2372 result.addAttribute(OpType::getUnnamedAddrAttrName(result.name),
2375 parser, LLVM::UnnamedAddr::None)));
2376
2377 return success();
2378}
2379
2380// operation ::= `llvm.mlir.global` linkage? visibility?
2381// (`unnamed_addr` | `local_unnamed_addr`)?
2382// `thread_local`? `constant`? `@` identifier
2383// `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2384// attribute-list? (`:` type)? region?
2385//
2386// The type can be omitted for string attributes, in which case it will be
2387// inferred from the value of the string as [strlen(value) x i8].
2388ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2389 // Call into common parsing between GlobalOp and AliasOp.
2391 return failure();
2392
2393 if (succeeded(parser.parseOptionalKeyword("constant")))
2394 result.addAttribute(getConstantAttrName(result.name),
2395 parser.getBuilder().getUnitAttr());
2396
2397 StringAttr name;
2398 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2399 result.attributes) ||
2400 parser.parseLParen())
2401 return failure();
2402
2403 Attribute value;
2404 if (parser.parseOptionalRParen()) {
2405 if (parser.parseAttribute(value, getValueAttrName(result.name),
2406 result.attributes) ||
2407 parser.parseRParen())
2408 return failure();
2409 }
2410
2411 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2412 SymbolRefAttr comdat;
2413 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2414 parser.parseRParen())
2415 return failure();
2416
2417 result.addAttribute(getComdatAttrName(result.name), comdat);
2418 }
2419
2421 if (parser.parseOptionalAttrDict(result.attributes) ||
2422 parser.parseOptionalColonTypeList(types))
2423 return failure();
2424
2425 if (types.size() > 1)
2426 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2427
2428 Region &initRegion = *result.addRegion();
2429 if (types.empty()) {
2430 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2431 MLIRContext *context = parser.getContext();
2432 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2433 strAttr.getValue().size());
2434 types.push_back(arrayType);
2435 } else {
2436 return parser.emitError(parser.getNameLoc(),
2437 "type can only be omitted for string globals");
2438 }
2439 } else {
2440 OptionalParseResult parseResult =
2441 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2442 /*argTypes=*/{});
2443 if (parseResult.has_value() && failed(*parseResult))
2444 return failure();
2445 }
2446
2447 result.addAttribute(getGlobalTypeAttrName(result.name),
2448 TypeAttr::get(types[0]));
2449 return success();
2450}
2451
2452static bool isZeroAttribute(Attribute value) {
2453 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2454 return intValue.getValue().isZero();
2455 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2456 return fpValue.getValue().isZero();
2457 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
2458 return isZeroAttribute(splatValue.getSplatValue<Attribute>());
2459 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2460 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2461 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2462 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2463 return false;
2464}
2465
2466LogicalResult GlobalOp::verify() {
2467 bool validType = isCompatibleOuterType(getType())
2468 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2469 LLVMMetadataType, LLVMLabelType>(getType())
2470 : llvm::isa<PointerElementTypeInterface>(getType());
2471 if (!validType)
2472 return emitOpError(
2473 "expects type to be a valid element type for an LLVM global");
2474 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2475 return emitOpError("must appear at the module level");
2476
2477 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2478 auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2479 IntegerType elementType =
2480 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2481 if (!elementType || elementType.getWidth() != 8 ||
2482 type.getNumElements() != strAttr.getValue().size())
2483 return emitOpError(
2484 "requires an i8 array type of the length equal to that of the string "
2485 "attribute");
2486 }
2487
2488 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2489 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2490 return emitOpError()
2491 << "this target extension type cannot be used in a global";
2492
2493 if (Attribute value = getValueOrNull())
2494 return emitOpError() << "global with target extension type can only be "
2495 "initialized with zero-initializer";
2496 }
2497
2498 if (getLinkage() == Linkage::Common) {
2499 if (Attribute value = getValueOrNull()) {
2500 if (!isZeroAttribute(value)) {
2501 return emitOpError()
2502 << "expected zero value for '"
2503 << stringifyLinkage(Linkage::Common) << "' linkage";
2504 }
2505 }
2506 }
2507
2508 if (getLinkage() == Linkage::Appending) {
2509 if (!llvm::isa<LLVMArrayType>(getType())) {
2510 return emitOpError() << "expected array type for '"
2511 << stringifyLinkage(Linkage::Appending)
2512 << "' linkage";
2513 }
2514 }
2515
2516 if (failed(verifyComdat(*this, getComdat())))
2517 return failure();
2518
2519 std::optional<uint64_t> alignAttr = getAlignment();
2520 if (alignAttr.has_value()) {
2521 uint64_t value = alignAttr.value();
2522 if (!llvm::isPowerOf2_64(value))
2523 return emitError() << "alignment attribute is not a power of 2";
2524 }
2525
2526 return success();
2527}
2528
2529LogicalResult GlobalOp::verifyRegions() {
2530 if (Block *b = getInitializerBlock()) {
2531 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2532 if (ret.operand_type_begin() == ret.operand_type_end())
2533 return emitOpError("initializer region cannot return void");
2534 if (*ret.operand_type_begin() != getType())
2535 return emitOpError("initializer region type ")
2536 << *ret.operand_type_begin() << " does not match global type "
2537 << getType();
2538
2539 for (Operation &op : *b) {
2540 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2541 if (!iface || !iface.hasNoEffect())
2542 return op.emitError()
2543 << "ops with side effects not allowed in global initializers";
2544 }
2545
2546 if (getValueOrNull())
2547 return emitOpError("cannot have both initializer value and region");
2548 }
2549
2550 return success();
2551}
2552
2553//===----------------------------------------------------------------------===//
2554// LLVM::GlobalCtorsOp
2555//===----------------------------------------------------------------------===//
2556
2557static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
2558 if (data.empty())
2559 return success();
2560
2561 if (llvm::all_of(data.getAsRange<Attribute>(), [](Attribute v) {
2562 return isa<FlatSymbolRefAttr, ZeroAttr>(v);
2563 }))
2564 return success();
2565 return op->emitError("data element must be symbol or #llvm.zero");
2566}
2567
2568LogicalResult
2569GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2570 for (Attribute ctor : getCtors()) {
2571 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2572 symbolTable)))
2573 return failure();
2574 }
2575 return success();
2576}
2577
2578LogicalResult GlobalCtorsOp::verify() {
2579 if (checkGlobalXtorData(*this, getData()).failed())
2580 return failure();
2581
2582 if (getCtors().size() == getPriorities().size() &&
2583 getCtors().size() == getData().size())
2584 return success();
2585 return emitError(
2586 "ctors, priorities, and data must have the same number of elements");
2587}
2588
2589//===----------------------------------------------------------------------===//
2590// LLVM::GlobalDtorsOp
2591//===----------------------------------------------------------------------===//
2592
2593LogicalResult
2594GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2595 for (Attribute dtor : getDtors()) {
2596 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2597 symbolTable)))
2598 return failure();
2599 }
2600 return success();
2601}
2602
2603LogicalResult GlobalDtorsOp::verify() {
2604 if (checkGlobalXtorData(*this, getData()).failed())
2605 return failure();
2606
2607 if (getDtors().size() == getPriorities().size() &&
2608 getDtors().size() == getData().size())
2609 return success();
2610 return emitError(
2611 "dtors, priorities, and data must have the same number of elements");
2612}
2613
2614//===----------------------------------------------------------------------===//
2615// Builder, printer and verifier for LLVM::AliasOp.
2616//===----------------------------------------------------------------------===//
2617
2618void AliasOp::build(OpBuilder &builder, OperationState &result, Type type,
2619 Linkage linkage, StringRef name, bool dsoLocal,
2620 bool threadLocal, ArrayRef<NamedAttribute> attrs) {
2621 result.addAttribute(getSymNameAttrName(result.name),
2622 builder.getStringAttr(name));
2623 result.addAttribute(getAliasTypeAttrName(result.name), TypeAttr::get(type));
2624 if (dsoLocal)
2625 result.addAttribute(getDsoLocalAttrName(result.name),
2626 builder.getUnitAttr());
2627 if (threadLocal)
2628 result.addAttribute(getThreadLocal_AttrName(result.name),
2629 builder.getUnitAttr());
2630
2631 result.addAttribute(getLinkageAttrName(result.name),
2632 LinkageAttr::get(builder.getContext(), linkage));
2633 result.attributes.append(attrs.begin(), attrs.end());
2634
2635 result.addRegion();
2636}
2637
2638void AliasOp::print(OpAsmPrinter &p) {
2640
2641 p.printSymbolName(getSymName());
2642 p.printOptionalAttrDict((*this)->getAttrs(),
2643 {SymbolTable::getSymbolAttrName(),
2644 getAliasTypeAttrName(), getLinkageAttrName(),
2645 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2646 getVisibility_AttrName()});
2647
2648 // Print the trailing type.
2649 p << " : " << getType() << ' ';
2650 // Print the initializer region.
2651 p.printRegion(getInitializerRegion(), /*printEntryBlockArgs=*/false);
2652}
2653
2654// operation ::= `llvm.mlir.alias` linkage? visibility?
2655// (`unnamed_addr` | `local_unnamed_addr`)?
2656// `thread_local`? `@` identifier
2657// `(` attribute? `)`
2658// attribute-list? `:` type region
2659//
2660ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
2661 // Call into common parsing between GlobalOp and AliasOp.
2663 return failure();
2664
2665 StringAttr name;
2666 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2667 result.attributes))
2668 return failure();
2669
2671 if (parser.parseOptionalAttrDict(result.attributes) ||
2672 parser.parseOptionalColonTypeList(types))
2673 return failure();
2674
2675 if (types.size() > 1)
2676 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2677
2678 Region &initRegion = *result.addRegion();
2679 if (parser.parseRegion(initRegion).failed())
2680 return failure();
2681
2682 result.addAttribute(getAliasTypeAttrName(result.name),
2683 TypeAttr::get(types[0]));
2684 return success();
2685}
2686
2687LogicalResult AliasOp::verify() {
2688 bool validType = isCompatibleOuterType(getType())
2689 ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2690 LLVMMetadataType, LLVMLabelType>(getType())
2691 : llvm::isa<PointerElementTypeInterface>(getType());
2692 if (!validType)
2693 return emitOpError(
2694 "expects type to be a valid element type for an LLVM global alias");
2695
2696 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2697 switch (getLinkage()) {
2698 case Linkage::External:
2699 case Linkage::Internal:
2700 case Linkage::Private:
2701 case Linkage::Weak:
2702 case Linkage::WeakODR:
2703 case Linkage::Linkonce:
2704 case Linkage::LinkonceODR:
2705 case Linkage::AvailableExternally:
2706 break;
2707 default:
2708 return emitOpError()
2709 << "'" << stringifyLinkage(getLinkage())
2710 << "' linkage not supported in aliases, available options: private, "
2711 "internal, linkonce, weak, linkonce_odr, weak_odr, external or "
2712 "available_externally";
2713 }
2714
2715 return success();
2716}
2717
2718LogicalResult AliasOp::verifyRegions() {
2719 Block &b = getInitializerBlock();
2720 auto ret = cast<ReturnOp>(b.getTerminator());
2721 if (ret.getNumOperands() == 0 ||
2722 !isa<LLVM::LLVMPointerType>(ret.getOperand(0).getType()))
2723 return emitOpError("initializer region must always return a pointer");
2724
2725 for (Operation &op : b) {
2726 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2727 if (!iface || !iface.hasNoEffect())
2728 return op.emitError()
2729 << "ops with side effects are not allowed in alias initializers";
2730 }
2731
2732 return success();
2733}
2734
2735unsigned AliasOp::getAddrSpace() {
2736 Block &initializer = getInitializerBlock();
2737 auto ret = cast<ReturnOp>(initializer.getTerminator());
2738 auto ptrTy = cast<LLVMPointerType>(ret.getOperand(0).getType());
2739 return ptrTy.getAddressSpace();
2740}
2741
2742//===----------------------------------------------------------------------===//
2743// IFuncOp
2744//===----------------------------------------------------------------------===//
2745
2746void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2747 Type iFuncType, StringRef resolverName, Type resolverType,
2748 Linkage linkage, LLVM::Visibility visibility) {
2749 return build(builder, result, name, iFuncType, resolverName, resolverType,
2750 linkage, /*dso_local=*/false, /*address_space=*/0,
2751 UnnamedAddr::None, visibility);
2752}
2753
2754LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2755 Operation *symbol =
2756 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr());
2757 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2758 auto resolver = dyn_cast<LLVMFuncOp>(symbol);
2759 auto alias = dyn_cast<AliasOp>(symbol);
2760 while (alias) {
2761 Block &initBlock = alias.getInitializerBlock();
2762 auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
2763 auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
2764 // FIXME: This is a best effort solution. The AliasOp body might be more
2765 // complex and in that case we bail out with success. To completely match
2766 // the LLVM IR logic it would be necessary to implement proper alias and
2767 // cast stripping.
2768 if (!addrOp)
2769 return success();
2770 resolver = addrOp.getFunction(symbolTable);
2771 alias = addrOp.getAlias(symbolTable);
2772 }
2773 if (!resolver)
2774 return emitOpError("must have a function resolver");
2775 Linkage linkage = resolver.getLinkage();
2776 if (resolver.isExternal() || linkage == Linkage::AvailableExternally)
2777 return emitOpError("resolver must be a definition");
2778 if (!isa<LLVMPointerType>(resolver.getFunctionType().getReturnType()))
2779 return emitOpError("resolver must return a pointer");
2780 auto resolverPtr = dyn_cast<LLVMPointerType>(getResolverType());
2781 if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace())
2782 return emitOpError("resolver has incorrect type");
2783 return success();
2784}
2785
2786LogicalResult IFuncOp::verify() {
2787 switch (getLinkage()) {
2788 case Linkage::External:
2789 case Linkage::Internal:
2790 case Linkage::Private:
2791 case Linkage::Weak:
2792 case Linkage::WeakODR:
2793 case Linkage::Linkonce:
2794 case Linkage::LinkonceODR:
2795 break;
2796 default:
2797 return emitOpError() << "'" << stringifyLinkage(getLinkage())
2798 << "' linkage not supported in ifuncs, available "
2799 "options: private, internal, linkonce, weak, "
2800 "linkonce_odr, weak_odr, or external linkage";
2801 }
2802 return success();
2803}
2804
2805//===----------------------------------------------------------------------===//
2806// ShuffleVectorOp
2807//===----------------------------------------------------------------------===//
2808
2809void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2810 Value v2, DenseI32ArrayAttr mask,
2812 auto containerType = v1.getType();
2813 auto vType = LLVM::getVectorType(
2814 cast<VectorType>(containerType).getElementType(), mask.size(),
2815 LLVM::isScalableVectorType(containerType));
2816 build(builder, state, vType, v1, v2, mask);
2817 state.addAttributes(attrs);
2818}
2819
2820void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2821 Value v2, ArrayRef<int32_t> mask) {
2822 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2823}
2824
2825/// Build the result type of a shuffle vector operation.
2826static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2827 Type &resType, DenseI32ArrayAttr mask) {
2828 if (!LLVM::isCompatibleVectorType(v1Type))
2829 return parser.emitError(parser.getCurrentLocation(),
2830 "expected an LLVM compatible vector type");
2831 resType =
2832 LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
2833 mask.size(), LLVM::isScalableVectorType(v1Type));
2834 return success();
2835}
2836
2837/// Nothing to do when the result type is inferred.
2838static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2839 Type resType, DenseI32ArrayAttr mask) {}
2840
2841LogicalResult ShuffleVectorOp::verify() {
2842 if (LLVM::isScalableVectorType(getV1().getType()) &&
2843 llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2844 return emitOpError("expected a splat operation for scalable vectors");
2845 return success();
2846}
2847
2848// Folding for shufflevector op when v1 is single element 1D vector
2849// and the mask is a single zero. OpFoldResult will be v1 in this case.
2850OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
2851 // Check if operand 0 is a single element vector.
2852 auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
2853 if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
2854 return {};
2855 // Check if the mask is a single zero.
2856 // Note: The mask is guaranteed to be non-empty.
2857 if (getMask().size() != 1 || getMask()[0] != 0)
2858 return {};
2859 return getV1();
2860}
2861
2862//===----------------------------------------------------------------------===//
2863// Implementations for LLVM::LLVMFuncOp.
2864//===----------------------------------------------------------------------===//
2865
2866// Add the entry block to the function.
2867Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
2868 assert(empty() && "function already has an entry block");
2869 OpBuilder::InsertionGuard g(builder);
2870 Block *entry = builder.createBlock(&getBody());
2871
2872 // FIXME: Allow passing in proper locations for the entry arguments.
2873 LLVMFunctionType type = getFunctionType();
2874 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2875 entry->addArgument(type.getParamType(i), getLoc());
2876 return entry;
2877}
2878
2879void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2880 StringRef name, Type type, LLVM::Linkage linkage,
2881 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2883 ArrayRef<DictionaryAttr> argAttrs,
2884 std::optional<uint64_t> functionEntryCount) {
2885 result.addRegion();
2887 builder.getStringAttr(name));
2888 result.addAttribute(getFunctionTypeAttrName(result.name),
2889 TypeAttr::get(type));
2890 result.addAttribute(getLinkageAttrName(result.name),
2891 LinkageAttr::get(builder.getContext(), linkage));
2892 result.addAttribute(getCConvAttrName(result.name),
2893 CConvAttr::get(builder.getContext(), cconv));
2894 result.attributes.append(attrs.begin(), attrs.end());
2895 if (dsoLocal)
2896 result.addAttribute(getDsoLocalAttrName(result.name),
2897 builder.getUnitAttr());
2898 if (comdat)
2899 result.addAttribute(getComdatAttrName(result.name), comdat);
2900 if (functionEntryCount)
2901 result.addAttribute(getFunctionEntryCountAttrName(result.name),
2902 builder.getI64IntegerAttr(functionEntryCount.value()));
2903 if (argAttrs.empty())
2904 return;
2905
2906 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2907 "expected as many argument attribute lists as arguments");
2909 builder, result, argAttrs, /*resultAttrs=*/{},
2910 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2911}
2912
2913// Builds an LLVM function type from the given lists of input and output types.
2914// Returns a null type if any of the types provided are non-LLVM types, or if
2915// there is more than one output type.
2916static Type
2918 ArrayRef<Type> outputs,
2920 Builder &b = parser.getBuilder();
2921 if (outputs.size() > 1) {
2922 parser.emitError(loc, "failed to construct function type: expected zero or "
2923 "one function result");
2924 return {};
2925 }
2926
2927 // Convert inputs to LLVM types, exit early on error.
2928 SmallVector<Type, 4> llvmInputs;
2929 for (auto t : inputs) {
2930 if (!isCompatibleType(t)) {
2931 parser.emitError(loc, "failed to construct function type: expected LLVM "
2932 "type for function arguments");
2933 return {};
2934 }
2935 llvmInputs.push_back(t);
2936 }
2937
2938 // No output is denoted as "void" in LLVM type system.
2939 Type llvmOutput =
2940 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2941 if (!isCompatibleType(llvmOutput)) {
2942 parser.emitError(loc, "failed to construct function type: expected LLVM "
2943 "type for function results")
2944 << llvmOutput;
2945 return {};
2946 }
2947 return LLVMFunctionType::get(llvmOutput, llvmInputs,
2948 variadicFlag.isVariadic());
2949}
2950
2951// Parses an LLVM function.
2952//
2953// operation ::= `llvm.func` linkage? cconv? function-signature
2954// (`comdat(` symbol-ref-id `)`)?
2955// function-attributes?
2956// function-body
2957//
2958ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2959 // Default to external linkage if no keyword is provided.
2960 result.addAttribute(getLinkageAttrName(result.name),
2961 LinkageAttr::get(parser.getContext(),
2963 parser, LLVM::Linkage::External)));
2964
2965 // Parse optional visibility, default to Default.
2966 result.addAttribute(getVisibility_AttrName(result.name),
2969 parser, LLVM::Visibility::Default)));
2970
2971 // Parse optional UnnamedAddr, default to None.
2972 result.addAttribute(getUnnamedAddrAttrName(result.name),
2975 parser, LLVM::UnnamedAddr::None)));
2976
2977 // Default to C Calling Convention if no keyword is provided.
2978 result.addAttribute(
2979 getCConvAttrName(result.name),
2980 CConvAttr::get(parser.getContext(),
2981 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
2982
2983 StringAttr nameAttr;
2985 SmallVector<DictionaryAttr> resultAttrs;
2986 SmallVector<Type> resultTypes;
2987 bool isVariadic;
2988
2989 auto signatureLocation = parser.getCurrentLocation();
2990 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2991 result.attributes) ||
2993 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2994 resultAttrs))
2995 return failure();
2996
2997 SmallVector<Type> argTypes;
2998 for (auto &arg : entryArgs)
2999 argTypes.push_back(arg.type);
3000 auto type =
3001 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
3003 if (!type)
3004 return failure();
3005 result.addAttribute(getFunctionTypeAttrName(result.name),
3006 TypeAttr::get(type));
3007
3008 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
3009 int64_t minRange, maxRange;
3010 if (parser.parseLParen() || parser.parseInteger(minRange) ||
3011 parser.parseComma() || parser.parseInteger(maxRange) ||
3012 parser.parseRParen())
3013 return failure();
3014 auto intTy = IntegerType::get(parser.getContext(), 32);
3015 result.addAttribute(
3016 getVscaleRangeAttrName(result.name),
3017 LLVM::VScaleRangeAttr::get(parser.getContext(),
3018 IntegerAttr::get(intTy, minRange),
3019 IntegerAttr::get(intTy, maxRange)));
3020 }
3021 // Parse the optional comdat selector.
3022 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
3023 SymbolRefAttr comdat;
3024 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
3025 parser.parseRParen())
3026 return failure();
3027
3028 result.addAttribute(getComdatAttrName(result.name), comdat);
3029 }
3030
3031 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
3032 return failure();
3034 parser.getBuilder(), result, entryArgs, resultAttrs,
3035 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3036
3037 auto *body = result.addRegion();
3038 OptionalParseResult parseResult =
3039 parser.parseOptionalRegion(*body, entryArgs);
3040 return failure(parseResult.has_value() && failed(*parseResult));
3041}
3042
3043// Print the LLVMFuncOp. Collects argument and result types and passes them to
3044// helper functions. Drops "void" result since it cannot be parsed back. Skips
3045// the external linkage since it is the default value.
3046void LLVMFuncOp::print(OpAsmPrinter &p) {
3047 p << ' ';
3048 if (getLinkage() != LLVM::Linkage::External)
3049 p << stringifyLinkage(getLinkage()) << ' ';
3050 StringRef visibility = stringifyVisibility(getVisibility_());
3051 if (!visibility.empty())
3052 p << visibility << ' ';
3053 if (auto unnamedAddr = getUnnamedAddr()) {
3054 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
3055 if (!str.empty())
3056 p << str << ' ';
3057 }
3058 if (getCConv() != LLVM::CConv::C)
3059 p << stringifyCConv(getCConv()) << ' ';
3060
3061 p.printSymbolName(getName());
3062
3063 LLVMFunctionType fnType = getFunctionType();
3064 SmallVector<Type, 8> argTypes;
3065 SmallVector<Type, 1> resTypes;
3066 argTypes.reserve(fnType.getNumParams());
3067 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
3068 argTypes.push_back(fnType.getParamType(i));
3069
3070 Type returnType = fnType.getReturnType();
3071 if (!llvm::isa<LLVMVoidType>(returnType))
3072 resTypes.push_back(returnType);
3073
3075 isVarArg(), resTypes);
3076
3077 // Print vscale range if present
3078 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
3079 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
3080 << vscale->getMaxRange().getInt() << ')';
3081
3082 // Print the optional comdat selector.
3083 if (auto comdat = getComdat())
3084 p << " comdat(" << *comdat << ')';
3085
3087 p, *this,
3088 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
3089 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
3090 getComdatAttrName(), getUnnamedAddrAttrName(),
3091 getVscaleRangeAttrName()});
3092
3093 // Print the body if this is not an external function.
3094 Region &body = getBody();
3095 if (!body.empty()) {
3096 p << ' ';
3097 p.printRegion(body, /*printEntryBlockArgs=*/false,
3098 /*printBlockTerminators=*/true);
3099 }
3100}
3101
3102// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3103// - functions don't have 'common' linkage
3104// - external functions have 'external' or 'extern_weak' linkage;
3105// - vararg is (currently) only supported for external functions;
3106LogicalResult LLVMFuncOp::verify() {
3107 if (getLinkage() == LLVM::Linkage::Common)
3108 return emitOpError() << "functions cannot have '"
3109 << stringifyLinkage(LLVM::Linkage::Common)
3110 << "' linkage";
3111
3112 if (failed(verifyComdat(*this, getComdat())))
3113 return failure();
3114
3115 if (isExternal()) {
3116 if (getLinkage() != LLVM::Linkage::External &&
3117 getLinkage() != LLVM::Linkage::ExternWeak)
3118 return emitOpError() << "external functions must have '"
3119 << stringifyLinkage(LLVM::Linkage::External)
3120 << "' or '"
3121 << stringifyLinkage(LLVM::Linkage::ExternWeak)
3122 << "' linkage";
3123 return success();
3124 }
3125
3126 // In LLVM IR, these attributes are composed by convention, not by design.
3127 if (isNoInline() && isAlwaysInline())
3128 return emitError("no_inline and always_inline attributes are incompatible");
3129
3130 if (isOptimizeNone() && !isNoInline())
3131 return emitOpError("with optimize_none must also be no_inline");
3132
3133 Type landingpadResultTy;
3134 StringRef diagnosticMessage;
3135 bool isLandingpadTypeConsistent =
3136 !walk([&](Operation *op) {
3137 const auto checkType = [&](Type type, StringRef errorMessage) {
3138 if (!landingpadResultTy) {
3139 landingpadResultTy = type;
3140 return WalkResult::advance();
3141 }
3142 if (landingpadResultTy != type) {
3143 diagnosticMessage = errorMessage;
3144 return WalkResult::interrupt();
3145 }
3146 return WalkResult::advance();
3147 };
3149 .Case<LandingpadOp>([&](auto landingpad) {
3150 constexpr StringLiteral errorMessage =
3151 "'llvm.landingpad' should have a consistent result type "
3152 "inside a function";
3153 return checkType(landingpad.getType(), errorMessage);
3154 })
3155 .Case<ResumeOp>([&](auto resume) {
3156 constexpr StringLiteral errorMessage =
3157 "'llvm.resume' should have a consistent input type inside a "
3158 "function";
3159 return checkType(resume.getValue().getType(), errorMessage);
3160 })
3161 .Default([](auto) { return WalkResult::skip(); });
3162 }).wasInterrupted();
3163 if (!isLandingpadTypeConsistent) {
3164 assert(!diagnosticMessage.empty() &&
3165 "Expecting a non-empty diagnostic message");
3166 return emitError(diagnosticMessage);
3167 }
3168
3169 if (failed(verifyBlockTags(*this)))
3170 return failure();
3171
3172 return success();
3173}
3174
3175/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3176/// - entry block arguments are of LLVM types.
3177LogicalResult LLVMFuncOp::verifyRegions() {
3178 if (isExternal())
3179 return success();
3180
3181 unsigned numArguments = getFunctionType().getNumParams();
3182 Block &entryBlock = front();
3183 for (unsigned i = 0; i < numArguments; ++i) {
3184 Type argType = entryBlock.getArgument(i).getType();
3185 if (!isCompatibleType(argType))
3186 return emitOpError("entry block argument #")
3187 << i << " is not of LLVM type";
3188 }
3189
3190 return success();
3191}
3192
3193Region *LLVMFuncOp::getCallableRegion() {
3194 if (isExternal())
3195 return nullptr;
3196 return &getBody();
3197}
3198
3199//===----------------------------------------------------------------------===//
3200// UndefOp.
3201//===----------------------------------------------------------------------===//
3202
3203/// Fold an undef operation to a dedicated undef attribute.
3204OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
3205 return LLVM::UndefAttr::get(getContext());
3206}
3207
3208//===----------------------------------------------------------------------===//
3209// PoisonOp.
3210//===----------------------------------------------------------------------===//
3211
3212/// Fold a poison operation to a dedicated poison attribute.
3213OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
3214 return LLVM::PoisonAttr::get(getContext());
3215}
3216
3217//===----------------------------------------------------------------------===//
3218// ZeroOp.
3219//===----------------------------------------------------------------------===//
3220
3221LogicalResult LLVM::ZeroOp::verify() {
3222 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3223 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
3224 return emitOpError()
3225 << "target extension type does not support zero-initializer";
3226
3227 return success();
3228}
3229
3230/// Fold a zero operation to a builtin zero attribute when possible and fall
3231/// back to a dedicated zero attribute.
3232OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
3234 if (result)
3235 return result;
3236 return LLVM::ZeroAttr::get(getContext());
3237}
3238
3239//===----------------------------------------------------------------------===//
3240// ConstantOp.
3241//===----------------------------------------------------------------------===//
3242
3243/// Compute the total number of elements in the given type, also taking into
3244/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
3245/// Everything else is treated as a scalar.
3247 if (auto vecType = dyn_cast<VectorType>(t)) {
3248 assert(!vecType.isScalable() &&
3249 "number of elements of a scalable vector type is unknown");
3250 return vecType.getNumElements() * getNumElements(vecType.getElementType());
3251 }
3252 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3253 return arrayType.getNumElements() *
3254 getNumElements(arrayType.getElementType());
3255 return 1;
3256}
3257
3258/// Determine the element type of `type`. Supported types are `VectorType`,
3259/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3261 while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3262 type = arrayType.getElementType();
3263 if (auto vecType = dyn_cast<VectorType>(type))
3264 return vecType.getElementType();
3265 if (auto tenType = dyn_cast<TensorType>(type))
3266 return tenType.getElementType();
3267 return type;
3268}
3269
3270/// Check if the given type is a scalable vector type or a vector/array type
3271/// that contains a nested scalable vector type.
3273 if (auto vecType = dyn_cast<VectorType>(t)) {
3274 if (vecType.isScalable())
3275 return true;
3276 return hasScalableVectorType(vecType.getElementType());
3277 }
3278 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3279 return hasScalableVectorType(arrayType.getElementType());
3280 return false;
3281}
3282
3283/// Verifies the constant array represented by `arrayAttr` matches the provided
3284/// `arrayType`.
3285static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3286 LLVM::LLVMArrayType arrayType,
3287 ArrayAttr arrayAttr, int dim) {
3288 if (arrayType.getNumElements() != arrayAttr.size())
3289 return op.emitOpError()
3290 << "array attribute size does not match array type size in "
3291 "dimension "
3292 << dim << ": " << arrayAttr.size() << " vs. "
3293 << arrayType.getNumElements();
3294
3295 llvm::DenseSet<Attribute> elementsVerified;
3296
3297 // Recursively verify sub-dimensions for multidimensional arrays.
3298 if (auto subArrayType =
3299 dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
3300 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
3301 if (elementsVerified.insert(elementAttr).second) {
3302 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3303 continue;
3304 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3305 if (!subArrayAttr)
3306 return op.emitOpError()
3307 << "nested attribute for sub-array in dimension " << dim
3308 << " at index " << idx
3309 << " must be a zero, or undef, or array attribute";
3310 if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
3311 dim + 1)))
3312 return failure();
3313 }
3314 return success();
3315 }
3316
3317 // Forbid usages of ArrayAttr for simple array types that should use
3318 // DenseElementsAttr instead. Note that there would be a use case for such
3319 // array types when one element value is obtained via a ptr-to-int conversion
3320 // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3321 // user needs this so far, and it seems better to avoid people misusing the
3322 // ArrayAttr for simple types.
3323 auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType());
3324 if (!structType)
3325 return op.emitOpError() << "for array with an array attribute must have a "
3326 "struct element type";
3327
3328 // Shallow verification that leaf attributes are appropriate as struct initial
3329 // value.
3330 size_t numStructElements = structType.getBody().size();
3331 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3332 if (elementsVerified.insert(elementAttr).second) {
3333 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3334 continue;
3335 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3336 if (!subArrayAttr)
3337 return op.emitOpError()
3338 << "nested attribute for struct element at index " << idx
3339 << " must be a zero, or undef, or array attribute";
3340 if (subArrayAttr.size() != numStructElements)
3341 return op.emitOpError()
3342 << "nested array attribute size for struct element at index "
3343 << idx << " must match struct size: " << subArrayAttr.size()
3344 << " vs. " << numStructElements;
3345 }
3346 }
3347
3348 return success();
3349}
3350
3351LogicalResult LLVM::ConstantOp::verify() {
3352 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
3353 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
3354 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
3355 !arrayType.getElementType().isInteger(8)) {
3356 return emitOpError() << "expected array type of "
3357 << sAttr.getValue().size()
3358 << " i8 elements for the string constant";
3359 }
3360 return success();
3361 }
3362 if (auto structType = dyn_cast<LLVMStructType>(getType())) {
3363 auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
3364 if (!arrayAttr)
3365 return emitOpError() << "expected array attribute for struct type";
3366
3367 ArrayRef<Type> elementTypes = structType.getBody();
3368 if (arrayAttr.size() != elementTypes.size()) {
3369 return emitOpError() << "expected array attribute of size "
3370 << elementTypes.size();
3371 }
3372 for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
3373 if (!type.isSignlessIntOrIndexOrFloat()) {
3374 return emitOpError() << "expected struct element types to be floating "
3375 "point type or integer type";
3376 }
3377 if (!isa<FloatAttr, IntegerAttr>(attr)) {
3378 return emitOpError() << "expected element of array attribute to be "
3379 "floating point or integer";
3380 }
3381 if (cast<TypedAttr>(attr).getType() != type)
3382 return emitOpError()
3383 << "struct element at index " << i << " is of wrong type";
3384 }
3385
3386 return success();
3387 }
3388 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3389 return emitOpError() << "does not support target extension type.";
3390
3391 // Check that an attribute whose element type has floating point semantics
3392 // `attributeFloatSemantics` is compatible with a type whose element type
3393 // is `constantElementType`.
3394 //
3395 // Requirement is that either
3396 // 1) They have identical floating point types.
3397 // 2) `constantElementType` is an integer type of the same width as the float
3398 // attribute. This is to support builtin MLIR float types without LLVM
3399 // equivalents, see comments in getLLVMConstant for more details.
3400 auto verifyFloatSemantics =
3401 [this](const llvm::fltSemantics &attributeFloatSemantics,
3402 Type constantElementType) -> LogicalResult {
3403 if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3404 if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
3405 return emitOpError()
3406 << "attribute and type have different float semantics";
3407 }
3408 return success();
3409 }
3410 unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
3411 if (isa<IntegerType>(constantElementType)) {
3412 if (!constantElementType.isInteger(floatWidth))
3413 return emitOpError() << "expected integer type of width " << floatWidth;
3414
3415 return success();
3416 }
3417 return success();
3418 };
3419
3420 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3421 if (isa<IntegerAttr>(getValue())) {
3422 if (!llvm::isa<IntegerType>(getType()))
3423 return emitOpError() << "expected integer type";
3424 } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3425 return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
3426 } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3428 // The exact number of elements of a scalable vector is unknown, so we
3429 // allow only splat attributes.
3430 auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
3431 if (!splatElementsAttr)
3432 return emitOpError()
3433 << "scalable vector type requires a splat attribute";
3434 return success();
3435 }
3436 if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
3437 return emitOpError() << "expected vector or array type";
3438
3439 // The number of elements of the attribute and the type must match.
3440 int64_t attrNumElements = elementsAttr.getNumElements();
3441 if (getNumElements(getType()) != attrNumElements) {
3442 return emitOpError()
3443 << "type and attribute have a different number of elements: "
3444 << getNumElements(getType()) << " vs. " << attrNumElements;
3445 }
3446
3447 Type attrElmType = getElementType(elementsAttr.getType());
3448 Type resultElmType = getElementType(getType());
3449 if (auto floatType = dyn_cast<FloatType>(attrElmType))
3450 return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
3451
3452 if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3453 return emitOpError(
3454 "expected integer element type for integer elements attribute");
3455 }
3456 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3457
3458 // The case where the constant is LLVMStructType has already been handled.
3459 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
3460 if (!arrayType)
3461 return emitOpError()
3462 << "expected array or struct type for array attribute";
3463
3464 // When the attribute is an ArrayAttr, check that its nesting matches the
3465 // corresponding ArrayType or VectorType nesting.
3466 return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
3467 } else {
3468 return emitOpError()
3469 << "only supports integer, float, string or elements attributes";
3470 }
3471
3472 return success();
3473}
3474
3475bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
3476 // The value's type must be the same as the provided type.
3477 auto typedAttr = dyn_cast<TypedAttr>(value);
3478 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
3479 return false;
3480 // The value's type must be an LLVM compatible type.
3481 if (!isCompatibleType(type))
3482 return false;
3483 // TODO: Add support for additional attributes kinds once needed.
3484 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
3485}
3486
3487ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
3488 Type type, Location loc) {
3489 if (isBuildableWith(value, type))
3490 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
3491 return nullptr;
3492}
3493
3494// Constant op constant-folds to its value.
3495OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
3496
3497//===----------------------------------------------------------------------===//
3498// AtomicRMWOp
3499//===----------------------------------------------------------------------===//
3500
3501void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3502 AtomicBinOp binOp, Value ptr, Value val,
3503 AtomicOrdering ordering, StringRef syncscope,
3504 unsigned alignment, bool isVolatile) {
3505 build(builder, state, val.getType(), binOp, ptr, val, ordering,
3506 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3507 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
3508 /*access_groups=*/nullptr,
3509 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3510}
3511
3512LogicalResult AtomicRMWOp::verify() {
3513 auto valType = getVal().getType();
3514 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3515 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax ||
3516 getBinOp() == AtomicBinOp::fminimum ||
3517 getBinOp() == AtomicBinOp::fmaximum) {
3518 if (isCompatibleVectorType(valType)) {
3519 if (isScalableVectorType(valType))
3520 return emitOpError("expected LLVM IR fixed vector type");
3521 Type elemType = llvm::cast<VectorType>(valType).getElementType();
3522 if (!isCompatibleFloatingPointType(elemType))
3523 return emitOpError(
3524 "expected LLVM IR floating point type for vector element");
3525 } else if (!isCompatibleFloatingPointType(valType)) {
3526 return emitOpError("expected LLVM IR floating point type");
3527 }
3528 } else if (getBinOp() == AtomicBinOp::xchg) {
3529 DataLayout dataLayout = DataLayout::closest(*this);
3530 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3531 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
3532 } else {
3533 auto intType = llvm::dyn_cast<IntegerType>(valType);
3534 unsigned intBitWidth = intType ? intType.getWidth() : 0;
3535 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3536 intBitWidth != 64)
3537 return emitOpError("expected LLVM IR integer type");
3538 }
3539
3540 if (static_cast<unsigned>(getOrdering()) <
3541 static_cast<unsigned>(AtomicOrdering::monotonic))
3542 return emitOpError() << "expected at least '"
3543 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3544 << "' ordering";
3545
3546 return success();
3547}
3548
3549//===----------------------------------------------------------------------===//
3550// AtomicCmpXchgOp
3551//===----------------------------------------------------------------------===//
3552
3553/// Returns an LLVM struct type that contains a value type and a boolean type.
3554static LLVMStructType getValAndBoolStructType(Type valType) {
3555 auto boolType = IntegerType::get(valType.getContext(), 1);
3556 return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
3557}
3558
3559void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3560 Value ptr, Value cmp, Value val,
3561 AtomicOrdering successOrdering,
3562 AtomicOrdering failureOrdering, StringRef syncscope,
3563 unsigned alignment, bool isWeak, bool isVolatile) {
3564 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
3565 successOrdering, failureOrdering,
3566 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3567 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
3568 isVolatile, /*access_groups=*/nullptr,
3569 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3570}
3571
3572LogicalResult AtomicCmpXchgOp::verify() {
3573 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
3574 if (!ptrType)
3575 return emitOpError("expected LLVM IR pointer type for operand #0");
3576 auto valType = getVal().getType();
3577 DataLayout dataLayout = DataLayout::closest(*this);
3578 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3579 return emitOpError("unexpected LLVM IR type");
3580 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3581 getFailureOrdering() < AtomicOrdering::monotonic)
3582 return emitOpError("ordering must be at least 'monotonic'");
3583 if (getFailureOrdering() == AtomicOrdering::release ||
3584 getFailureOrdering() == AtomicOrdering::acq_rel)
3585 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
3586 return success();
3587}
3588
3589//===----------------------------------------------------------------------===//
3590// FenceOp
3591//===----------------------------------------------------------------------===//
3592
3593void FenceOp::build(OpBuilder &builder, OperationState &state,
3594 AtomicOrdering ordering, StringRef syncscope) {
3595 build(builder, state, ordering,
3596 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
3597}
3598
3599LogicalResult FenceOp::verify() {
3600 if (getOrdering() == AtomicOrdering::not_atomic ||
3601 getOrdering() == AtomicOrdering::unordered ||
3602 getOrdering() == AtomicOrdering::monotonic)
3603 return emitOpError("can be given only acquire, release, acq_rel, "
3604 "and seq_cst orderings");
3605 return success();
3606}
3607
3608//===----------------------------------------------------------------------===//
3609// Verifier for extension ops
3610//===----------------------------------------------------------------------===//
3611
3612/// Verifies that the given extension operation operates on consistent scalars
3613/// or vectors, and that the target width is larger than the input width.
3614template <class ExtOp>
3615static LogicalResult verifyExtOp(ExtOp op) {
3616 IntegerType inputType, outputType;
3617 if (isCompatibleVectorType(op.getArg().getType())) {
3618 if (!isCompatibleVectorType(op.getResult().getType()))
3619 return op.emitError(
3620 "input type is a vector but output type is an integer");
3621 if (getVectorNumElements(op.getArg().getType()) !=
3622 getVectorNumElements(op.getResult().getType()))
3623 return op.emitError("input and output vectors are of incompatible shape");
3624 // Because this is a CastOp, the element of vectors is guaranteed to be an
3625 // integer.
3626 inputType = cast<IntegerType>(
3627 cast<VectorType>(op.getArg().getType()).getElementType());
3628 outputType = cast<IntegerType>(
3629 cast<VectorType>(op.getResult().getType()).getElementType());
3630 } else {
3631 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3632 // an integer.
3633 inputType = cast<IntegerType>(op.getArg().getType());
3634 outputType = dyn_cast<IntegerType>(op.getResult().getType());
3635 if (!outputType)
3636 return op.emitError(
3637 "input type is an integer but output type is a vector");
3638 }
3639
3640 if (outputType.getWidth() <= inputType.getWidth())
3641 return op.emitError("integer width of the output type is smaller or "
3642 "equal to the integer width of the input type");
3643 return success();
3644}
3645
3646//===----------------------------------------------------------------------===//
3647// ZExtOp
3648//===----------------------------------------------------------------------===//
3649
3650LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
3651
3652OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3653 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3654 if (!arg)
3655 return {};
3656
3657 size_t targetSize = cast<IntegerType>(getType()).getWidth();
3658 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
3659}
3660
3661//===----------------------------------------------------------------------===//
3662// SExtOp
3663//===----------------------------------------------------------------------===//
3664
3665LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
3666
3667//===----------------------------------------------------------------------===//
3668// Folder and verifier for LLVM::BitcastOp
3669//===----------------------------------------------------------------------===//
3670
3671/// Folds a cast op that can be chained.
3672template <typename T>
3674 typename T::FoldAdaptor adaptor) {
3675 // cast(x : T0, T0) -> x
3676 if (castOp.getArg().getType() == castOp.getType())
3677 return castOp.getArg();
3678 if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3679 // cast(cast(x : T0, T1), T0) -> x
3680 if (prev.getArg().getType() == castOp.getType())
3681 return prev.getArg();
3682 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3683 castOp.getArgMutable().set(prev.getArg());
3684 return Value{castOp};
3685 }
3686 return {};
3687}
3688
3689OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3690 return foldChainableCast(*this, adaptor);
3691}
3692
3693LogicalResult LLVM::BitcastOp::verify() {
3694 auto resultType = llvm::dyn_cast<LLVMPointerType>(
3695 extractVectorElementType(getResult().getType()));
3696 auto sourceType = llvm::dyn_cast<LLVMPointerType>(
3697 extractVectorElementType(getArg().getType()));
3698
3699 // If one of the types is a pointer (or vector of pointers), then
3700 // both source and result type have to be pointers.
3701 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3702 return emitOpError("can only cast pointers from and to pointers");
3703
3704 if (!resultType)
3705 return success();
3706
3707 auto isVector = llvm::IsaPred<VectorType>;
3708
3709 // Due to bitcast requiring both operands to be of the same size, it is not
3710 // possible for only one of the two to be a pointer of vectors.
3711 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3712 return emitOpError("cannot cast pointer to vector of pointers");
3713
3714 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3715 return emitOpError("cannot cast vector of pointers to pointer");
3716
3717 // Bitcast cannot cast between pointers of different address spaces.
3718 // 'llvm.addrspacecast' must be used for this purpose instead.
3719 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3720 return emitOpError("cannot cast pointers of different address spaces, "
3721 "use 'llvm.addrspacecast' instead");
3722
3723 return success();
3724}
3725
3726//===----------------------------------------------------------------------===//
3727// Folder for LLVM::AddrSpaceCastOp
3728//===----------------------------------------------------------------------===//
3729
3730OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3731 return foldChainableCast(*this, adaptor);
3732}
3733
3734Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3735
3736//===----------------------------------------------------------------------===//
3737// Folder for LLVM::GEPOp
3738//===----------------------------------------------------------------------===//
3739
3740OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3741 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3742 adaptor.getDynamicIndices());
3743
3744 // gep %x:T, 0 -> %x
3745 if (getBase().getType() == getType() && indices.size() == 1)
3746 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
3747 if (integer.getValue().isZero())
3748 return getBase();
3749
3750 // Canonicalize any dynamic indices of constant value to constant indices.
3751 bool changed = false;
3752 SmallVector<GEPArg> gepArgs;
3753 for (auto iter : llvm::enumerate(indices)) {
3754 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
3755 // Constant indices can only be int32_t, so if integer does not fit we
3756 // are forced to keep it dynamic, despite being a constant.
3757 if (!indices.isDynamicIndex(iter.index()) || !integer ||
3758 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
3759
3760 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3761 if (Value val = llvm::dyn_cast_if_present<Value>(existing))
3762 gepArgs.emplace_back(val);
3763 else
3764 gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt());
3765
3766 continue;
3767 }
3768
3769 changed = true;
3770 gepArgs.emplace_back(integer.getInt());
3771 }
3772 if (changed) {
3773 SmallVector<int32_t> rawConstantIndices;
3774 SmallVector<Value> dynamicIndices;
3775 destructureIndices(getElemType(), gepArgs, rawConstantIndices,
3776 dynamicIndices);
3777
3778 getDynamicIndicesMutable().assign(dynamicIndices);
3779 setRawConstantIndices(rawConstantIndices);
3780 return Value{*this};
3781 }
3782
3783 return {};
3784}
3785
3786Value LLVM::GEPOp::getViewSource() { return getBase(); }
3787
3788//===----------------------------------------------------------------------===//
3789// ShlOp
3790//===----------------------------------------------------------------------===//
3791
3792OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3793 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3794 if (!rhs)
3795 return {};
3796
3797 if (rhs.getValue().getZExtValue() >=
3798 getLhs().getType().getIntOrFloatBitWidth())
3799 return {}; // TODO: Fold into poison.
3800
3801 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3802 if (!lhs)
3803 return {};
3804
3805 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
3806}
3807
3808//===----------------------------------------------------------------------===//
3809// OrOp
3810//===----------------------------------------------------------------------===//
3811
3812OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
3813 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3814 if (!lhs)
3815 return {};
3816
3817 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3818 if (!rhs)
3819 return {};
3820
3821 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
3822}
3823
3824//===----------------------------------------------------------------------===//
3825// CallIntrinsicOp
3826//===----------------------------------------------------------------------===//
3827
3828LogicalResult CallIntrinsicOp::verify() {
3829 if (!getIntrin().starts_with("llvm."))
3830 return emitOpError() << "intrinsic name must start with 'llvm.'";
3831 if (failed(verifyOperandBundles(*this)))
3832 return failure();
3833 return success();
3834}
3835
3836void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3837 mlir::StringAttr intrin, mlir::ValueRange args) {
3838 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3839 FastmathFlagsAttr{},
3840 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3841 /*res_attrs=*/{});
3842}
3843
3844void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3845 mlir::StringAttr intrin, mlir::ValueRange args,
3846 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3847 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3848 fastMathFlags,
3849 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3850 /*res_attrs=*/{});
3851}
3852
3853void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3854 mlir::Type resultType, mlir::StringAttr intrin,
3855 mlir::ValueRange args) {
3856 build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
3857 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3858 /*res_attrs=*/{});
3859}
3860
3861void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3862 mlir::TypeRange resultTypes,
3863 mlir::StringAttr intrin, mlir::ValueRange args,
3864 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3865 build(builder, state, resultTypes, intrin, args, fastMathFlags,
3866 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3867 /*res_attrs=*/{});
3868}
3869
3870ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
3872 StringAttr intrinAttr;
3875 SmallVector<SmallVector<Type>> opBundleOperandTypes;
3876 ArrayAttr opBundleTags;
3877
3878 // Parse intrinsic name.
3880 intrinAttr, parser.getBuilder().getType<NoneType>()))
3881 return failure();
3882 result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
3883 intrinAttr);
3884
3885 if (parser.parseLParen())
3886 return failure();
3887
3888 // Parse the function arguments.
3889 if (parser.parseOperandList(operands))
3890 return mlir::failure();
3891
3892 if (parser.parseRParen())
3893 return mlir::failure();
3894
3895 // Handle bundles.
3896 SMLoc opBundlesLoc = parser.getCurrentLocation();
3897 if (std::optional<ParseResult> result = parseOpBundles(
3898 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
3899 result && failed(*result))
3900 return failure();
3901 if (opBundleTags && !opBundleTags.empty())
3902 result.addAttribute(
3903 CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
3904 opBundleTags);
3905
3906 if (parser.parseOptionalAttrDict(result.attributes))
3907 return mlir::failure();
3908
3910 SmallVector<DictionaryAttr> resultAttrs;
3911 if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
3912 operands, argAttrs, resultAttrs))
3913 return failure();
3915 parser.getBuilder(), result, argAttrs, resultAttrs,
3916 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3917
3918 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
3919 opBundleOperandTypes,
3920 getOpBundleSizesAttrName(result.name)))
3921 return failure();
3922
3923 int32_t numOpBundleOperands = 0;
3924 for (const auto &operands : opBundleOperands)
3925 numOpBundleOperands += operands.size();
3926
3927 result.addAttribute(
3928 CallIntrinsicOp::getOperandSegmentSizeAttr(),
3930 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
3931
3932 return mlir::success();
3933}
3934
3935void CallIntrinsicOp::print(OpAsmPrinter &p) {
3936 p << ' ';
3937 p.printAttributeWithoutType(getIntrinAttr());
3938
3939 OperandRange args = getArgs();
3940 p << "(" << args << ")";
3941
3942 // Operand bundles.
3943 if (!getOpBundleOperands().empty()) {
3944 p << ' ';
3945 printOpBundles(p, *this, getOpBundleOperands(),
3946 getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
3947 }
3948
3949 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
3950 {getOperandSegmentSizesAttrName(),
3951 getOpBundleSizesAttrName(), getIntrinAttrName(),
3952 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
3953 getResAttrsAttrName()});
3954
3955 p << " : ";
3956
3957 // Reconstruct the MLIR function type from operand and result types.
3959 p, args.getTypes(), getArgAttrsAttr(),
3960 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
3961}
3962
3963//===----------------------------------------------------------------------===//
3964// LinkerOptionsOp
3965//===----------------------------------------------------------------------===//
3966
3967LogicalResult LinkerOptionsOp::verify() {
3968 if (mlir::Operation *parentOp = (*this)->getParentOp();
3969 parentOp && !satisfiesLLVMModule(parentOp))
3970 return emitOpError("must appear at the module level");
3971 return success();
3972}
3973
3974//===----------------------------------------------------------------------===//
3975// ModuleFlagsOp
3976//===----------------------------------------------------------------------===//
3977
3978LogicalResult ModuleFlagsOp::verify() {
3979 if (Operation *parentOp = (*this)->getParentOp();
3980 parentOp && !satisfiesLLVMModule(parentOp))
3981 return emitOpError("must appear at the module level");
3982 for (Attribute flag : getFlags())
3983 if (!isa<ModuleFlagAttr>(flag))
3984 return emitOpError("expected a module flag attribute");
3985 return success();
3986}
3987
3988//===----------------------------------------------------------------------===//
3989// InlineAsmOp
3990//===----------------------------------------------------------------------===//
3991
3992void InlineAsmOp::getEffects(
3994 &effects) {
3995 if (getHasSideEffects()) {
3996 effects.emplace_back(MemoryEffects::Write::get());
3997 effects.emplace_back(MemoryEffects::Read::get());
3998 }
3999}
4000
4001//===----------------------------------------------------------------------===//
4002// BlockAddressOp
4003//===----------------------------------------------------------------------===//
4004
4005LogicalResult
4006BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4007 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
4008 getBlockAddr().getFunction());
4009 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
4010
4011 if (!function)
4012 return emitOpError("must reference a function defined by 'llvm.func'");
4013
4014 return success();
4015}
4016
4017LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
4018 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
4019 parentLLVMModule(*this), getBlockAddr().getFunction()));
4020}
4021
4022BlockTagOp BlockAddressOp::getBlockTagOp() {
4023 auto funcOp = dyn_cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
4024 parentLLVMModule(*this), getBlockAddr().getFunction()));
4025 if (!funcOp)
4026 return nullptr;
4027
4028 BlockTagOp blockTagOp = nullptr;
4029 funcOp.walk([&](LLVM::BlockTagOp labelOp) {
4030 if (labelOp.getTag() == getBlockAddr().getTag()) {
4031 blockTagOp = labelOp;
4032 return WalkResult::interrupt();
4033 }
4034 return WalkResult::advance();
4035 });
4036 return blockTagOp;
4037}
4038
4039LogicalResult BlockAddressOp::verify() {
4040 if (!getBlockTagOp())
4041 return emitOpError(
4042 "expects an existing block label target in the referenced function");
4043
4044 return success();
4045}
4046
4047/// Fold a blockaddress operation to a dedicated blockaddress
4048/// attribute.
4049OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
4050
4051//===----------------------------------------------------------------------===//
4052// LLVM::IndirectBrOp
4053//===----------------------------------------------------------------------===//
4054
4055SuccessorOperands IndirectBrOp::getSuccessorOperands(unsigned index) {
4056 assert(index < getNumSuccessors() && "invalid successor index");
4057 return SuccessorOperands(getSuccOperandsMutable()[index]);
4058}
4059
4060void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4061 Value addr, ArrayRef<ValueRange> succOperands,
4062 BlockRange successors) {
4063 odsState.addOperands(addr);
4064 for (ValueRange range : succOperands)
4065 odsState.addOperands(range);
4066 SmallVector<int32_t> rangeSegments;
4067 for (ValueRange range : succOperands)
4068 rangeSegments.push_back(range.size());
4069 odsState.getOrAddProperties<Properties>().indbr_operand_segments =
4070 odsBuilder.getDenseI32ArrayAttr(rangeSegments);
4071 odsState.addSuccessors(successors);
4072}
4073
4075 OpAsmParser &parser, Type &flagType,
4076 SmallVectorImpl<Block *> &succOperandBlocks,
4078 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
4079 if (failed(parser.parseCommaSeparatedList(
4081 [&]() {
4082 Block *destination = nullptr;
4083 SmallVector<OpAsmParser::UnresolvedOperand> operands;
4084 SmallVector<Type> operandTypes;
4085
4086 if (parser.parseSuccessor(destination).failed())
4087 return failure();
4088
4089 if (succeeded(parser.parseOptionalLParen())) {
4090 if (failed(parser.parseOperandList(
4091 operands, OpAsmParser::Delimiter::None)) ||
4092 failed(parser.parseColonTypeList(operandTypes)) ||
4093 failed(parser.parseRParen()))
4094 return failure();
4095 }
4096 succOperandBlocks.push_back(destination);
4097 succOperands.emplace_back(operands);
4098 succOperandsTypes.emplace_back(operandTypes);
4099 return success();
4100 },
4101 "successor blocks")))
4102 return failure();
4103 return success();
4104}
4105
4106static void
4107printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
4108 SuccessorRange succs, OperandRangeRange succOperands,
4109 const TypeRangeRange &succOperandsTypes) {
4110 p << "[";
4111 llvm::interleave(
4112 llvm::zip(succs, succOperands),
4113 [&](auto i) {
4114 p.printNewline();
4115 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
4116 },
4117 [&] { p << ','; });
4118 if (!succOperands.empty())
4119 p.printNewline();
4120 p << "]";
4121}
4122
4123//===----------------------------------------------------------------------===//
4124// SincosOp (intrinsic)
4125//===----------------------------------------------------------------------===//
4126
4127LogicalResult LLVM::SincosOp::verify() {
4128 auto operandType = getOperand().getType();
4129 auto resultType = getResult().getType();
4130 auto resultStructType =
4131 mlir::dyn_cast<mlir::LLVM::LLVMStructType>(resultType);
4132 if (!resultStructType || resultStructType.getBody().size() != 2 ||
4133 resultStructType.getBody()[0] != operandType ||
4134 resultStructType.getBody()[1] != operandType) {
4135 return emitOpError("expected result type to be an homogeneous struct with "
4136 "two elements matching the operand type, but got ")
4137 << resultType;
4138 }
4139 return success();
4140}
4141
4142//===----------------------------------------------------------------------===//
4143// AssumeOp (intrinsic)
4144//===----------------------------------------------------------------------===//
4145
4146void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4147 mlir::Value cond) {
4148 return build(builder, state, cond, /*op_bundle_operands=*/{},
4149 /*op_bundle_tags=*/ArrayAttr{});
4150}
4151
4152void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4153 Value cond, llvm::StringRef tag, ValueRange args) {
4154 return build(builder, state, cond, ArrayRef<ValueRange>(args),
4155 builder.getStrArrayAttr(tag));
4156}
4157
4158void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4159 Value cond, AssumeAlignTag, Value ptr, Value align) {
4160 return build(builder, state, cond, "align", ValueRange{ptr, align});
4161}
4162
4163void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4165 Value ptr2) {
4166 return build(builder, state, cond, "separate_storage",
4167 ValueRange{ptr1, ptr2});
4168}
4169
4170LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
4171
4172//===----------------------------------------------------------------------===//
4173// masked_gather (intrinsic)
4174//===----------------------------------------------------------------------===//
4175
4176LogicalResult LLVM::masked_gather::verify() {
4177 auto ptrsVectorType = getPtrs().getType();
4178 Type expectedPtrsVectorType =
4181 // Vector of pointers type should match result vector type, other than the
4182 // element type.
4183 if (ptrsVectorType != expectedPtrsVectorType)
4184 return emitOpError("expected operand #1 type to be ")
4185 << expectedPtrsVectorType;
4186 return success();
4187}
4188
4189//===----------------------------------------------------------------------===//
4190// masked_scatter (intrinsic)
4191//===----------------------------------------------------------------------===//
4192
4193LogicalResult LLVM::masked_scatter::verify() {
4194 auto ptrsVectorType = getPtrs().getType();
4195 Type expectedPtrsVectorType =
4197 LLVM::getVectorNumElements(getValue().getType()));
4198 // Vector of pointers type should match value vector type, other than the
4199 // element type.
4200 if (ptrsVectorType != expectedPtrsVectorType)
4201 return emitOpError("expected operand #2 type to be ")
4202 << expectedPtrsVectorType;
4203 return success();
4204}
4205
4206//===----------------------------------------------------------------------===//
4207// masked_expandload (intrinsic)
4208//===----------------------------------------------------------------------===//
4209
4210void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
4211 mlir::TypeRange resTys, Value ptr,
4212 Value mask, Value passthru,
4213 uint64_t align) {
4214 ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
4215 build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
4216 /*res_attrs=*/nullptr);
4217}
4218
4219//===----------------------------------------------------------------------===//
4220// masked_compressstore (intrinsic)
4221//===----------------------------------------------------------------------===//
4222
4223void LLVM::masked_compressstore::build(OpBuilder &builder,
4224 OperationState &state, Value value,
4225 Value ptr, Value mask, uint64_t align) {
4226 ArrayAttr argAttrs =
4227 getLLVMAlignParamForCompressExpand(builder, false, align);
4228 build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
4229 /*res_attrs=*/nullptr);
4230}
4231
4232//===----------------------------------------------------------------------===//
4233// InlineAsmOp
4234//===----------------------------------------------------------------------===//
4235
4236LogicalResult InlineAsmOp::verify() {
4237 if (!getTailCallKindAttr())
4238 return success();
4239
4240 if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4241 return emitOpError(
4242 "tail call kind 'musttail' is not supported by this operation");
4243
4244 return success();
4245}
4246
4247//===----------------------------------------------------------------------===//
4248// UDivOp
4249//===----------------------------------------------------------------------===//
4250Speculation::Speculatability UDivOp::getSpeculatability() {
4251 // X / 0 => UB
4252 Value divisor = getRhs();
4253 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
4255
4257}
4258
4259//===----------------------------------------------------------------------===//
4260// SDivOp
4261//===----------------------------------------------------------------------===//
4262Speculation::Speculatability SDivOp::getSpeculatability() {
4263 // This function conservatively assumes that all signed division by -1 are
4264 // not speculatable.
4265 // X / 0 => UB
4266 // INT_MIN / -1 => UB
4267 Value divisor = getRhs();
4268 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
4271
4273}
4274
4275//===----------------------------------------------------------------------===//
4276// LLVMDialect initialization, type parsing, and registration.
4277//===----------------------------------------------------------------------===//
4278
4279void LLVMDialect::initialize() {
4280 registerAttributes();
4281
4282 // clang-format off
4283 addTypes<LLVMVoidType,
4284 LLVMTokenType,
4285 LLVMLabelType,
4286 LLVMMetadataType>();
4287 // clang-format on
4288 registerTypes();
4289
4290 addOperations<
4291#define GET_OP_LIST
4292#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4293
4294 ,
4295#define GET_OP_LIST
4296#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4297
4298 >();
4299
4300 // Support unknown operations because not all LLVM operations are registered.
4301 allowUnknownOperations();
4302 declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
4304}
4305
4306#define GET_OP_CLASSES
4307#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4308
4309#define GET_OP_CLASSES
4310#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4311
4312LogicalResult LLVMDialect::verifyDataLayoutString(
4313 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
4314 llvm::Expected<llvm::DataLayout> maybeDataLayout =
4315 llvm::DataLayout::parse(descr);
4316 if (maybeDataLayout)
4317 return success();
4318
4319 std::string message;
4320 llvm::raw_string_ostream messageStream(message);
4321 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
4322 reportError("invalid data layout descriptor: " + message);
4323 return failure();
4324}
4325
4326/// Verify LLVM dialect attributes.
4327LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
4328 NamedAttribute attr) {
4329 // If the data layout attribute is present, it must use the LLVM data layout
4330 // syntax. Try parsing it and report errors in case of failure. Users of this
4331 // attribute may assume it is well-formed and can pass it to the (asserting)
4332 // llvm::DataLayout constructor.
4333 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
4334 return success();
4335 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
4336 return verifyDataLayoutString(
4337 stringAttr.getValue(),
4338 [op](const Twine &message) { op->emitOpError() << message.str(); });
4339
4340 return op->emitOpError() << "expected '"
4341 << LLVM::LLVMDialect::getDataLayoutAttrName()
4342 << "' to be a string attributes";
4343}
4344
4345LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
4346 Type paramType,
4347 NamedAttribute paramAttr) {
4348 // LLVM attribute may be attached to a result of operation that has not been
4349 // converted to LLVM dialect yet, so the result may have a type with unknown
4350 // representation in LLVM dialect type space. In this case we cannot verify
4351 // whether the attribute may be
4352 bool verifyValueType = isCompatibleType(paramType);
4353 StringAttr name = paramAttr.getName();
4354
4355 auto checkUnitAttrType = [&]() -> LogicalResult {
4356 if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
4357 return op->emitError() << name << " should be a unit attribute";
4358 return success();
4359 };
4360 auto checkTypeAttrType = [&]() -> LogicalResult {
4361 if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
4362 return op->emitError() << name << " should be a type attribute";
4363 return success();
4364 };
4365 auto checkIntegerAttrType = [&]() -> LogicalResult {
4366 if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
4367 return op->emitError() << name << " should be an integer attribute";
4368 return success();
4369 };
4370 auto checkPointerType = [&]() -> LogicalResult {
4371 if (!llvm::isa<LLVMPointerType>(paramType))
4372 return op->emitError()
4373 << name << " attribute attached to non-pointer LLVM type";
4374 return success();
4375 };
4376 auto checkIntegerType = [&]() -> LogicalResult {
4377 if (!llvm::isa<IntegerType>(paramType))
4378 return op->emitError()
4379 << name << " attribute attached to non-integer LLVM type";
4380 return success();
4381 };
4382 auto checkPointerTypeMatches = [&]() -> LogicalResult {
4383 if (failed(checkPointerType()))
4384 return failure();
4385
4386 return success();
4387 };
4388
4389 // Check a unit attribute that is attached to a pointer value.
4390 if (name == LLVMDialect::getNoAliasAttrName() ||
4391 name == LLVMDialect::getReadonlyAttrName() ||
4392 name == LLVMDialect::getReadnoneAttrName() ||
4393 name == LLVMDialect::getWriteOnlyAttrName() ||
4394 name == LLVMDialect::getNestAttrName() ||
4395 name == LLVMDialect::getNoCaptureAttrName() ||
4396 name == LLVMDialect::getNoFreeAttrName() ||
4397 name == LLVMDialect::getNonNullAttrName()) {
4398 if (failed(checkUnitAttrType()))
4399 return failure();
4400 if (verifyValueType && failed(checkPointerType()))
4401 return failure();
4402 return success();
4403 }
4404
4405 // Check a type attribute that is attached to a pointer value.
4406 if (name == LLVMDialect::getStructRetAttrName() ||
4407 name == LLVMDialect::getByValAttrName() ||
4408 name == LLVMDialect::getByRefAttrName() ||
4409 name == LLVMDialect::getElementTypeAttrName() ||
4410 name == LLVMDialect::getInAllocaAttrName() ||
4411 name == LLVMDialect::getPreallocatedAttrName()) {
4412 if (failed(checkTypeAttrType()))
4413 return failure();
4414 if (verifyValueType && failed(checkPointerTypeMatches()))
4415 return failure();
4416 return success();
4417 }
4418
4419 // Check a unit attribute that is attached to an integer value.
4420 if (name == LLVMDialect::getSExtAttrName() ||
4421 name == LLVMDialect::getZExtAttrName()) {
4422 if (failed(checkUnitAttrType()))
4423 return failure();
4424 if (verifyValueType && failed(checkIntegerType()))
4425 return failure();
4426 return success();
4427 }
4428
4429 // Check an integer attribute that is attached to a pointer value.
4430 if (name == LLVMDialect::getAlignAttrName() ||
4431 name == LLVMDialect::getDereferenceableAttrName() ||
4432 name == LLVMDialect::getDereferenceableOrNullAttrName()) {
4433 if (failed(checkIntegerAttrType()))
4434 return failure();
4435 if (verifyValueType && failed(checkPointerType()))
4436 return failure();
4437 return success();
4438 }
4439
4440 // Check an integer attribute that is attached to a pointer value.
4441 if (name == LLVMDialect::getStackAlignmentAttrName()) {
4442 if (failed(checkIntegerAttrType()))
4443 return failure();
4444 return success();
4445 }
4446
4447 // Check a unit attribute that can be attached to arbitrary types.
4448 if (name == LLVMDialect::getNoUndefAttrName() ||
4449 name == LLVMDialect::getInRegAttrName() ||
4450 name == LLVMDialect::getReturnedAttrName())
4451 return checkUnitAttrType();
4452
4453 return success();
4454}
4455
4456/// Verify LLVMIR function argument attributes.
4457LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
4458 unsigned regionIdx,
4459 unsigned argIdx,
4460 NamedAttribute argAttr) {
4461 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4462 if (!funcOp)
4463 return success();
4464 Type argType = funcOp.getArgumentTypes()[argIdx];
4465
4466 return verifyParameterAttribute(op, argType, argAttr);
4467}
4468
4469LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
4470 unsigned regionIdx,
4471 unsigned resIdx,
4472 NamedAttribute resAttr) {
4473 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4474 if (!funcOp)
4475 return success();
4476 Type resType = funcOp.getResultTypes()[resIdx];
4477
4478 // Check to see if this function has a void return with a result attribute
4479 // to it. It isn't clear what semantics we would assign to that.
4480 if (llvm::isa<LLVMVoidType>(resType))
4481 return op->emitError() << "cannot attach result attributes to functions "
4482 "with a void return";
4483
4484 // Check to see if this attribute is allowed as a result attribute. Only
4485 // explicitly forbidden LLVM attributes will cause an error.
4486 auto name = resAttr.getName();
4487 if (name == LLVMDialect::getAllocAlignAttrName() ||
4488 name == LLVMDialect::getAllocatedPointerAttrName() ||
4489 name == LLVMDialect::getByValAttrName() ||
4490 name == LLVMDialect::getByRefAttrName() ||
4491 name == LLVMDialect::getInAllocaAttrName() ||
4492 name == LLVMDialect::getNestAttrName() ||
4493 name == LLVMDialect::getNoCaptureAttrName() ||
4494 name == LLVMDialect::getNoFreeAttrName() ||
4495 name == LLVMDialect::getPreallocatedAttrName() ||
4496 name == LLVMDialect::getReadnoneAttrName() ||
4497 name == LLVMDialect::getReadonlyAttrName() ||
4498 name == LLVMDialect::getReturnedAttrName() ||
4499 name == LLVMDialect::getStackAlignmentAttrName() ||
4500 name == LLVMDialect::getStructRetAttrName() ||
4501 name == LLVMDialect::getWriteOnlyAttrName())
4502 return op->emitError() << name << " is not a valid result attribute";
4503 return verifyParameterAttribute(op, resType, resAttr);
4504}
4505
4506Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
4507 Type type, Location loc) {
4508 // If this was folded from an operation other than llvm.mlir.constant, it
4509 // should be materialized as such. Note that an llvm.mlir.zero may fold into
4510 // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
4511 if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
4512 if (isa<LLVM::LLVMPointerType>(type))
4513 return LLVM::AddressOfOp::create(builder, loc, type, symbol);
4514 if (isa<LLVM::UndefAttr>(value))
4515 return LLVM::UndefOp::create(builder, loc, type);
4516 if (isa<LLVM::PoisonAttr>(value))
4517 return LLVM::PoisonOp::create(builder, loc, type);
4518 if (isa<LLVM::ZeroAttr>(value))
4519 return LLVM::ZeroOp::create(builder, loc, type);
4520 // Otherwise try materializing it as a regular llvm.mlir.constant op.
4521 return LLVM::ConstantOp::materialize(builder, value, type, loc);
4522}
4523
4524//===----------------------------------------------------------------------===//
4525// Utility functions.
4526//===----------------------------------------------------------------------===//
4527
4529 StringRef name, StringRef value,
4530 LLVM::Linkage linkage) {
4531 assert(builder.getInsertionBlock() &&
4532 builder.getInsertionBlock()->getParentOp() &&
4533 "expected builder to point to a block constrained in an op");
4534 auto module =
4535 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
4536 assert(module && "builder points to an op outside of a module");
4537
4538 // Create the global at the entry of the module.
4539 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
4540 MLIRContext *ctx = builder.getContext();
4541 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
4542 auto global = LLVM::GlobalOp::create(
4543 moduleBuilder, loc, type, /*isConstant=*/true, linkage, name,
4544 builder.getStringAttr(value), /*alignment=*/0);
4545
4546 LLVMPointerType ptrType = LLVMPointerType::get(ctx);
4547 // Get the pointer to the first character in the global string.
4548 Value globalPtr =
4549 LLVM::AddressOfOp::create(builder, loc, ptrType, global.getSymNameAttr());
4550 return LLVM::GEPOp::create(builder, loc, ptrType, type, globalPtr,
4551 ArrayRef<GEPArg>{0, 0});
4552}
4553
4558
4560 Operation *module = op->getParentOp();
4561 while (module && !satisfiesLLVMModule(module))
4562 module = module->getParentOp();
4563 assert(module && "unexpected operation outside of a module");
4564 return module;
4565}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static int parseOptionalKeywordAlternative(OpAsmParser &parser, ArrayRef< StringRef > keywords)
static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, bool isExpandLoad, uint64_t alignment=1)
static LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data)
static ParseResult parseGEPIndices(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &indices, DenseI32ArrayAttr &rawConstantIndices)
static LogicalResult verifyOperandBundles(OpType &op)
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result)
static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, TypeRange operandTypes, StringRef tag)
static LogicalResult verifyComdat(Operation *op, std::optional< SymbolRefAttr > attr)
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, ValueRange args)
Constructs a LLVMFunctionType from MLIR results and args.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallOpVarCalleeType(OpTy callOp)
Verify that the parameter and return types of the variadic callee type match the callOp argument and ...
static ParseResult parseOptionalCallFuncPtr(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands)
Parses an optional function pointer operand before the call argument list for indirect calls,...
static bool isZeroAttribute(Attribute value)
static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, OperandRange indices, DenseI32ArrayAttr rawConstantIndices)
static std::optional< ParseResult > parseOpBundles(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand > > &opBundleOperands, SmallVector< SmallVector< Type > > &opBundleOperandTypes, ArrayAttr &opBundleTags)
static LLVMStructType getValAndBoolStructType(Type valType)
Returns an LLVM struct type that contains a value type and a boolean type.
static void printOpBundles(OpAsmPrinter &p, Operation *op, OperandRangeRange opBundleOperands, TypeRangeRange opBundleOperandTypes, std::optional< ArrayAttr > opBundleTags)
static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, Type resType, DenseI32ArrayAttr mask)
Nothing to do when the result type is inferred.
static LogicalResult verifyBlockTags(LLVMFuncOp funcOp)
static Type buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef< Type > inputs, ArrayRef< Type > outputs, function_interface_impl::VariadicFlag variadicFlag)
static auto processFMFAttr(ArrayRef< NamedAttribute > attrs)
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType)
Gets the variadic callee type for a LLVMFunctionType.
static Type getInsertExtractValueElementType(function_ref< InFlightDiagnostic(StringRef)> emitError, Type containerType, ArrayRef< int64_t > position)
Extract the type at position in the LLVM IR aggregate type containerType.
static ParseResult parseOneOpBundle(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand > > &opBundleOperands, SmallVector< SmallVector< Type > > &opBundleOperandTypes, SmallVector< Attribute > &opBundleTags)
static Type getElementType(Type type)
Determine the element type of type.
static void printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static ParseResult resolveOpBundleOperands(OpAsmParser &parser, SMLoc loc, OperationState &state, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand > > opBundleOperands, ArrayRef< SmallVector< Type > > opBundleOperandTypes, StringAttr opBundleSizesAttrName)
static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val)
static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op, LLVM::LLVMArrayType arrayType, ArrayAttr arrayAttr, int dim)
Verifies the constant array represented by arrayAttr matches the provided arrayType.
static ParseResult parseCallTypeAndResolveOperands(OpAsmParser &parser, OperationState &result, bool isDirect, ArrayRef< OpAsmParser::UnresolvedOperand > operands, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses the type of a call operation and resolves the operands if the parsing succeeds.
static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, Operation *op, SymbolTableCollection &symbolTable)
Verifies symbol's use in op to ensure the symbol is a valid and fully defined llvm....
static Type extractVectorElementType(Type type)
Returns the elemental type of any LLVM-compatible vector type or self.
static bool hasScalableVectorType(Type t)
Check if the given type is a scalable vector type or a vector/array type that contains a nested scala...
static SmallVector< Type, 1 > getCallOpResultTypes(LLVMFunctionType calleeType)
Gets the MLIR Op-like result types of a LLVMFunctionType.
static OpFoldResult foldChainableCast(T castOp, typename T::FoldAdaptor adaptor)
Folds a cast op that can be chained.
static void destructureIndices(Type currType, ArrayRef< GEPArg > indices, SmallVectorImpl< int32_t > &rawConstantIndices, SmallVectorImpl< Value > &dynamicIndices)
Destructures the 'indices' parameter into 'rawConstantIndices' and 'dynamicIndices',...
static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser, OperationState &result)
Parse common attributes that might show up in the same order in both GlobalOp and AliasOp.
static Type getI1SameShape(Type type)
Returns a boolean type that has the same shape as type.
static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op)
static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
Returns a scalar or vector boolean attribute of the given type.
static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee)
Verify that an inlinable callsite of a debug-info-bearing function in a debug-info-bearing function h...
static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, Type &resType, DenseI32ArrayAttr mask)
Build the result type of a shuffle vector operation.
static LogicalResult verifyExtOp(ExtOp op)
Verifies that the given extension operation operates on consistent scalars or vectors,...
static constexpr const char kElemTypeAttrName[]
static ParseResult parseInsertExtractValueElementType(AsmParser &parser, Type &valueType, Type containerType, DenseI64ArrayAttr position)
Infer the value type from the container type and position.
static LogicalResult verifyStructIndices(Type baseGEPType, unsigned indexPos, GEPIndicesAdaptor< ValueRange > indices, function_ref< InFlightDiagnostic()> emitOpError)
For the given indices, check if they comply with baseGEPType, especially check against LLVMStructType...
static Attribute extractElementAt(Attribute attr, size_t index)
Extracts the element at the given index from an attribute.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void printInsertExtractValueElementType(AsmPrinter &printer, Operation *op, Type valueType, Type containerType, DenseI64ArrayAttr position)
Nothing to print for an inferred type.
static ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
#define REGISTER_ENUM_TYPE(Ty)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
@ Square
Square brackets surrounding zero or more operands.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalColonTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional colon followed by a type list, which if present must have at least one type.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseCustomAttributeWithFallback(Attribute &result, Type type, function_ref< ParseResult(Attribute &result, Type type)> parseAttribute)=0
Parse a custom attribute with the provided callback, unless the next token is #, in which case the ge...
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printString(StringRef string)
Print the given string as a quoted string, escaping any special or non-printable characters in it.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:98
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:200
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:112
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:324
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:266
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:104
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:306
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class represents a fused location whose metadata is known to be an instance of the given type.
Definition Location.h:149
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Class used for building a 'llvm.getelementptr'.
Definition LLVMDialect.h:71
Class used for convenient access and iteration over GEP indices.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:320
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
This class represents a single result from folding an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
This class represents a specific instance of an effect.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction for a range of TypeRange.
Definition TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
A named class for passing around the variadic flag.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:229
void addBytecodeInterface(LLVMDialect *dialect)
Add the interfaces necessary for encoding the LLVM dialect components in bytecode.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
Operation * parentLLVMModule(Operation *op)
Lookup parent Module satisfying LLVM conditions on the Module Operation.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition LLVMDialect.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout)
Returns true if the given type is supported by atomic operations.
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, TypeRange resultTypes, ArrayAttr resultAttrs, Region *body=nullptr, bool printEmptyResult=true)
Print a function signature for a call or callable operation.
ParseResult parseFunctionSignature(OpAsmParser &parser, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs, bool mustParseEmptyResult=true)
Parses a function signature using parser.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.