MLIR 23.0.0git
LLVMDialect.cpp
Go to the documentation of this file.
1//===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the types and operation details for the LLVM IR dialect in
10// MLIR, and the LLVM IR dialect. It also registers the dialect.
11//
12//===----------------------------------------------------------------------===//
13
17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/Builders.h"
19#include "mlir/IR/BuiltinOps.h"
22#include "mlir/IR/MLIRContext.h"
23#include "mlir/IR/Matchers.h"
26
27#include "llvm/ADT/APFloat.h"
28#include "llvm/ADT/DenseSet.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/IR/DataLayout.h"
31#include "llvm/Support/Error.h"
32
33#include "LLVMDialectBytecode.h"
34
35#include <numeric>
36#include <optional>
37
38using namespace mlir;
39using namespace mlir::LLVM;
40using mlir::LLVM::cconv::getMaxEnumValForCConv;
41using mlir::LLVM::linkage::getMaxEnumValForLinkage;
42using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
43
44#include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
45
46//===----------------------------------------------------------------------===//
47// Attribute Helpers
48//===----------------------------------------------------------------------===//
49
50static constexpr const char kElemTypeAttrName[] = "elem_type";
51
54 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
55 if (attr.getName() == "fastmathFlags") {
56 auto defAttr =
57 FastmathFlagsAttr::get(attr.getValue().getContext(), {});
58 return defAttr != attr.getValue();
59 }
60 return true;
61 }));
62 return filteredAttrs;
63}
64
65/// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
66/// fully defined llvm.func.
67static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
68 Operation *op,
69 SymbolTableCollection &symbolTable) {
70 StringRef name = symbol.getValue();
71 auto func =
72 symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
73 if (!func)
74 return op->emitOpError("'")
75 << name << "' does not reference a valid LLVM function";
76 if (func.isExternal())
77 return op->emitOpError("'") << name << "' does not have a definition";
78 return success();
79}
80
81/// Returns a boolean type that has the same shape as `type`. It supports both
82/// fixed size vectors as well as scalable vectors.
83static Type getI1SameShape(Type type) {
84 Type i1Type = IntegerType::get(type.getContext(), 1);
87 return i1Type;
88}
89
90// Parses one of the keywords provided in the list `keywords` and returns the
91// position of the parsed keyword in the list. If none of the keywords from the
92// list is parsed, returns -1.
94 ArrayRef<StringRef> keywords) {
95 for (const auto &en : llvm::enumerate(keywords)) {
96 if (succeeded(parser.parseOptionalKeyword(en.value())))
97 return en.index();
98 }
99 return -1;
100}
101
102namespace {
103template <typename Ty>
104struct EnumTraits {};
105
106#define REGISTER_ENUM_TYPE(Ty) \
107 template <> \
108 struct EnumTraits<Ty> { \
109 static StringRef stringify(Ty value) { return stringify##Ty(value); } \
110 static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
111 }
112
113REGISTER_ENUM_TYPE(Linkage);
114REGISTER_ENUM_TYPE(UnnamedAddr);
115REGISTER_ENUM_TYPE(CConv);
116REGISTER_ENUM_TYPE(TailCallKind);
117REGISTER_ENUM_TYPE(Visibility);
118} // namespace
119
120/// Parse an enum from the keyword, or default to the provided default value.
121/// The return type is the enum type by default, unless overridden with the
122/// second template argument.
123template <typename EnumTy, typename RetTy = EnumTy>
125 EnumTy defaultValue) {
127 for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
128 names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
129
130 int index = parseOptionalKeywordAlternative(parser, names);
131 if (index == -1)
132 return static_cast<RetTy>(defaultValue);
133 return static_cast<RetTy>(index);
134}
135
136static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val) {
137 p << stringifyLinkage(val.getLinkage());
138}
139
140static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
141 val = LinkageAttr::get(
142 p.getContext(),
143 parseOptionalLLVMKeyword<LLVM::Linkage>(p, LLVM::Linkage::External));
144 return success();
145}
146
148 bool isExpandLoad,
149 uint64_t alignment = 1) {
150 // From
151 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
152 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
153 //
154 // The pointer alignment defaults to 1.
155 if (alignment == 1) {
156 return nullptr;
157 }
158
159 auto emptyDictAttr = builder.getDictionaryAttr({});
160 auto alignmentAttr = builder.getI64IntegerAttr(alignment);
161 auto namedAttr =
162 builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
163 SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
164 auto alignDictAttr = builder.getDictionaryAttr(attrs);
165 // From
166 // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
167 // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
168 //
169 // The align parameter attribute can be provided for [expandload]'s first
170 // argument. The align parameter attribute can be provided for
171 // [compressstore]'s second argument.
172 int pos = isExpandLoad ? 0 : 1;
173 return pos == 0 ? builder.getArrayAttr(
174 {alignDictAttr, emptyDictAttr, emptyDictAttr})
175 : builder.getArrayAttr(
176 {emptyDictAttr, alignDictAttr, emptyDictAttr});
177}
178
179//===----------------------------------------------------------------------===//
180// Operand bundle helpers.
181//===----------------------------------------------------------------------===//
182
184 TypeRange operandTypes, StringRef tag) {
185 p.printString(tag);
186 p << "(";
187
188 if (!operands.empty()) {
189 p.printOperands(operands);
190 p << " : ";
191 llvm::interleaveComma(operandTypes, p);
192 }
193
194 p << ")";
195}
196
198 OperandRangeRange opBundleOperands,
199 TypeRangeRange opBundleOperandTypes,
200 std::optional<ArrayAttr> opBundleTags) {
201 if (opBundleOperands.empty())
202 return;
203 assert(opBundleTags && "expect operand bundle tags");
204
205 p << "[";
206 llvm::interleaveComma(
207 llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
208 [&p](auto bundle) {
209 auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
210 printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
211 bundleTag);
212 });
213 p << "]";
214}
215
216static ParseResult parseOneOpBundle(
217 OpAsmParser &p,
219 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
220 SmallVector<Attribute> &opBundleTags) {
221 SMLoc currentParserLoc = p.getCurrentLocation();
223 SmallVector<Type> types;
224 std::string tag;
225
226 if (p.parseString(&tag))
227 return p.emitError(currentParserLoc, "expect operand bundle tag");
228
229 if (p.parseLParen())
230 return failure();
231
232 if (p.parseOptionalRParen()) {
233 if (p.parseOperandList(operands) || p.parseColon() ||
234 p.parseTypeList(types) || p.parseRParen())
235 return failure();
236 }
237
238 opBundleOperands.push_back(std::move(operands));
239 opBundleOperandTypes.push_back(std::move(types));
240 opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
241
242 return success();
243}
244
245static std::optional<ParseResult> parseOpBundles(
246 OpAsmParser &p,
248 SmallVector<SmallVector<Type>> &opBundleOperandTypes,
249 ArrayAttr &opBundleTags) {
250 if (p.parseOptionalLSquare())
251 return std::nullopt;
252
253 if (succeeded(p.parseOptionalRSquare()))
254 return success();
255
256 SmallVector<Attribute> opBundleTagAttrs;
257 auto bundleParser = [&] {
258 return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
259 opBundleTagAttrs);
260 };
261 if (p.parseCommaSeparatedList(bundleParser))
262 return failure();
263
264 if (p.parseRSquare())
265 return failure();
266
267 opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
268
269 return success();
270}
271
272//===----------------------------------------------------------------------===//
273// Printing, parsing, folding and builder for LLVM::CmpOp.
274//===----------------------------------------------------------------------===//
275
276void ICmpOp::print(OpAsmPrinter &p) {
277 p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
278 << ", " << getOperand(1);
279 p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
280 p << " : " << getLhs().getType();
281}
282
283void FCmpOp::print(OpAsmPrinter &p) {
284 p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
285 << ", " << getOperand(1);
286 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
287 p << " : " << getLhs().getType();
288}
289
290// <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
291// attribute-dict? `:` type
292// <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
293// attribute-dict? `:` type
294template <typename CmpPredicateType>
295static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
296 StringAttr predicateAttr;
298 Type type;
299 SMLoc predicateLoc, trailingTypeLoc;
300 if (parser.getCurrentLocation(&predicateLoc) ||
301 parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
302 parser.parseOperand(lhs) || parser.parseComma() ||
303 parser.parseOperand(rhs) ||
304 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
305 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
306 parser.resolveOperand(lhs, type, result.operands) ||
307 parser.resolveOperand(rhs, type, result.operands))
308 return failure();
309
310 // Replace the string attribute `predicate` with an integer attribute.
311 int64_t predicateValue = 0;
312 if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
313 std::optional<ICmpPredicate> predicate =
314 symbolizeICmpPredicate(predicateAttr.getValue());
315 if (!predicate)
316 return parser.emitError(predicateLoc)
317 << "'" << predicateAttr.getValue()
318 << "' is an incorrect value of the 'predicate' attribute";
319 predicateValue = static_cast<int64_t>(*predicate);
320 } else {
321 std::optional<FCmpPredicate> predicate =
322 symbolizeFCmpPredicate(predicateAttr.getValue());
323 if (!predicate)
324 return parser.emitError(predicateLoc)
325 << "'" << predicateAttr.getValue()
326 << "' is an incorrect value of the 'predicate' attribute";
327 predicateValue = static_cast<int64_t>(*predicate);
328 }
329
330 result.attributes.set("predicate",
331 parser.getBuilder().getI64IntegerAttr(predicateValue));
332
333 // The result type is either i1 or a vector type <? x i1> if the inputs are
334 // vectors.
335 if (!isCompatibleType(type))
336 return parser.emitError(trailingTypeLoc,
337 "expected LLVM dialect-compatible type");
338 result.addTypes(getI1SameShape(type));
339 return success();
340}
341
342ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
343 return parseCmpOp<ICmpPredicate>(parser, result);
344}
345
346ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
347 return parseCmpOp<FCmpPredicate>(parser, result);
348}
349
350/// Returns a scalar or vector boolean attribute of the given type.
351static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
352 auto boolAttr = BoolAttr::get(ctx, value);
353 ShapedType shapedType = dyn_cast<ShapedType>(type);
354 if (!shapedType)
355 return boolAttr;
356 return DenseElementsAttr::get(shapedType, boolAttr);
357}
358
359OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
360 if (getPredicate() != ICmpPredicate::eq &&
361 getPredicate() != ICmpPredicate::ne)
362 return {};
363
364 // cmpi(eq/ne, x, x) -> true/false
365 if (getLhs() == getRhs())
367 getPredicate() == ICmpPredicate::eq);
368
369 // cmpi(eq/ne, alloca, null) -> false/true
370 if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
372 getPredicate() == ICmpPredicate::ne);
373
374 // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
375 if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
376 Value lhs = getLhs();
377 Value rhs = getRhs();
378 getLhsMutable().assign(rhs);
379 getRhsMutable().assign(lhs);
380 return getResult();
381 }
382
383 return {};
384}
385
386//===----------------------------------------------------------------------===//
387// Printing, parsing and verification for LLVM::AllocaOp.
388//===----------------------------------------------------------------------===//
389
390void AllocaOp::print(OpAsmPrinter &p) {
391 auto funcTy =
392 FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
393
394 if (getInalloca())
395 p << " inalloca";
396
397 p << ' ' << getArraySize() << " x " << getElemType();
398 if (getAlignment() && *getAlignment() != 0)
399 p.printOptionalAttrDict((*this)->getAttrs(),
400 {kElemTypeAttrName, getInallocaAttrName()});
401 else
403 (*this)->getAttrs(),
404 {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
405 p << " : " << funcTy;
406}
407
408// <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
409// attribute-dict? `:` type `,` type
410ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
412 Type type, elemType;
413 SMLoc trailingTypeLoc;
414
415 if (succeeded(parser.parseOptionalKeyword("inalloca")))
416 result.addAttribute(getInallocaAttrName(result.name),
417 UnitAttr::get(parser.getContext()));
418
419 if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
420 parser.parseType(elemType) ||
421 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
422 parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
423 return failure();
424
425 std::optional<NamedAttribute> alignmentAttr =
426 result.attributes.getNamed("alignment");
427 if (alignmentAttr.has_value()) {
428 auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
429 if (!alignmentInt)
430 return parser.emitError(parser.getNameLoc(),
431 "expected integer alignment");
432 if (alignmentInt.getValue().isZero())
433 result.attributes.erase("alignment");
434 }
435
436 // Extract the result type from the trailing function type.
437 auto funcType = llvm::dyn_cast<FunctionType>(type);
438 if (!funcType || funcType.getNumInputs() != 1 ||
439 funcType.getNumResults() != 1)
440 return parser.emitError(
441 trailingTypeLoc,
442 "expected trailing function type with one argument and one result");
443
444 if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
445 return failure();
446
447 Type resultType = funcType.getResult(0);
448 if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
449 result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
450
451 result.addTypes({funcType.getResult(0)});
452 return success();
453}
454
455LogicalResult AllocaOp::verify() {
456 // Only certain target extension types can be used in 'alloca'.
457 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
458 targetExtType && !targetExtType.supportsMemOps())
459 return emitOpError()
460 << "this target extension type cannot be used in alloca";
461
462 return success();
463}
464
465//===----------------------------------------------------------------------===//
466// LLVM::BrOp
467//===----------------------------------------------------------------------===//
468
469SuccessorOperands BrOp::getSuccessorOperands(unsigned index) {
470 assert(index == 0 && "invalid successor index");
471 return SuccessorOperands(getDestOperandsMutable());
472}
473
474//===----------------------------------------------------------------------===//
475// LLVM::CondBrOp
476//===----------------------------------------------------------------------===//
477
478SuccessorOperands CondBrOp::getSuccessorOperands(unsigned index) {
479 assert(index < getNumSuccessors() && "invalid successor index");
480 return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
481 : getFalseDestOperandsMutable());
482}
483
484void CondBrOp::build(OpBuilder &builder, OperationState &result,
485 Value condition, Block *trueDest, ValueRange trueOperands,
486 Block *falseDest, ValueRange falseOperands,
487 std::optional<std::pair<uint32_t, uint32_t>> weights) {
488 DenseI32ArrayAttr weightsAttr;
489 if (weights)
490 weightsAttr =
491 builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
492 static_cast<int32_t>(weights->second)});
493
494 build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
495 /*loop_annotation=*/{}, trueDest, falseDest);
496}
497
498//===----------------------------------------------------------------------===//
499// LLVM::SwitchOp
500//===----------------------------------------------------------------------===//
501
502void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
503 Block *defaultDestination, ValueRange defaultOperands,
504 DenseIntElementsAttr caseValues,
505 BlockRange caseDestinations,
506 ArrayRef<ValueRange> caseOperands,
507 ArrayRef<int32_t> branchWeights) {
508 DenseI32ArrayAttr weightsAttr;
509 if (!branchWeights.empty())
510 weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
511
512 build(builder, result, value, defaultOperands, caseOperands, caseValues,
513 weightsAttr, defaultDestination, caseDestinations);
514}
515
516void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
517 Block *defaultDestination, ValueRange defaultOperands,
518 ArrayRef<APInt> caseValues, BlockRange caseDestinations,
519 ArrayRef<ValueRange> caseOperands,
520 ArrayRef<int32_t> branchWeights) {
521 DenseIntElementsAttr caseValuesAttr;
522 if (!caseValues.empty()) {
523 ShapedType caseValueType = VectorType::get(
524 static_cast<int64_t>(caseValues.size()), value.getType());
525 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
526 }
527
528 build(builder, result, value, defaultDestination, defaultOperands,
529 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
530}
531
532void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
533 Block *defaultDestination, ValueRange defaultOperands,
534 ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
535 ArrayRef<ValueRange> caseOperands,
536 ArrayRef<int32_t> branchWeights) {
537 DenseIntElementsAttr caseValuesAttr;
538 if (!caseValues.empty()) {
539 ShapedType caseValueType = VectorType::get(
540 static_cast<int64_t>(caseValues.size()), value.getType());
541 caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
542 }
543
544 build(builder, result, value, defaultDestination, defaultOperands,
545 caseValuesAttr, caseDestinations, caseOperands, branchWeights);
546}
547
548/// <cases> ::= `[` (case (`,` case )* )? `]`
549/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
550static ParseResult parseSwitchOpCases(
551 OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
552 SmallVectorImpl<Block *> &caseDestinations,
554 SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
555 if (failed(parser.parseLSquare()))
556 return failure();
557 if (succeeded(parser.parseOptionalRSquare()))
558 return success();
559 SmallVector<APInt> values;
560 unsigned bitWidth = flagType.getIntOrFloatBitWidth();
561 auto parseCase = [&]() {
562 int64_t value = 0;
563 if (failed(parser.parseInteger(value)))
564 return failure();
565 values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
566
567 Block *destination;
569 SmallVector<Type> operandTypes;
570 if (parser.parseColon() || parser.parseSuccessor(destination))
571 return failure();
572 if (!parser.parseOptionalLParen()) {
574 /*allowResultNumber=*/false) ||
575 parser.parseColonTypeList(operandTypes) || parser.parseRParen())
576 return failure();
577 }
578 caseDestinations.push_back(destination);
579 caseOperands.emplace_back(operands);
580 caseOperandTypes.emplace_back(operandTypes);
581 return success();
582 };
583 if (failed(parser.parseCommaSeparatedList(parseCase)))
584 return failure();
585
586 ShapedType caseValueType =
587 VectorType::get(static_cast<int64_t>(values.size()), flagType);
588 caseValues = DenseIntElementsAttr::get(caseValueType, values);
589 return parser.parseRSquare();
590}
591
592static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
593 DenseIntElementsAttr caseValues,
594 SuccessorRange caseDestinations,
595 OperandRangeRange caseOperands,
596 const TypeRangeRange &caseOperandTypes) {
597 p << '[';
598 p.printNewline();
599 if (!caseValues) {
600 p << ']';
601 return;
602 }
603
604 size_t index = 0;
605 llvm::interleave(
606 llvm::zip(caseValues, caseDestinations),
607 [&](auto i) {
608 p << " ";
609 p << std::get<0>(i);
610 p << ": ";
611 p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
612 },
613 [&] {
614 p << ',';
615 p.printNewline();
616 });
617 p.printNewline();
618 p << ']';
619}
620
621LogicalResult SwitchOp::verify() {
622 if ((!getCaseValues() && !getCaseDestinations().empty()) ||
623 (getCaseValues() &&
624 getCaseValues()->size() !=
625 static_cast<int64_t>(getCaseDestinations().size())))
626 return emitOpError("expects number of case values to match number of "
627 "case destinations");
628 if (getCaseValues() &&
629 getValue().getType() != getCaseValues()->getElementType())
630 return emitError("expects case value type to match condition value type");
631 return success();
632}
633
634SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
635 assert(index < getNumSuccessors() && "invalid successor index");
636 return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
637 : getCaseOperandsMutable(index - 1));
638}
639
640//===----------------------------------------------------------------------===//
641// Code for LLVM::GEPOp.
642//===----------------------------------------------------------------------===//
643
644GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() {
645 return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
646 getDynamicIndices());
647}
648
649/// Returns the elemental type of any LLVM-compatible vector type or self.
651 if (auto vectorType = llvm::dyn_cast<VectorType>(type))
652 return vectorType.getElementType();
653 return type;
654}
655
656/// Destructures the 'indices' parameter into 'rawConstantIndices' and
657/// 'dynamicIndices', encoding the former in the process. In the process,
658/// dynamic indices which are used to index into a structure type are converted
659/// to constant indices when possible. To do this, the GEPs element type should
660/// be passed as first parameter.
662 SmallVectorImpl<int32_t> &rawConstantIndices,
663 SmallVectorImpl<Value> &dynamicIndices) {
664 for (const GEPArg &iter : indices) {
665 // If the thing we are currently indexing into is a struct we must turn
666 // any integer constants into constant indices. If this is not possible
667 // we don't do anything here. The verifier will catch it and emit a proper
668 // error. All other canonicalization is done in the fold method.
669 bool requiresConst = !rawConstantIndices.empty() &&
670 isa_and_nonnull<LLVMStructType>(currType);
671 if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
672 APInt intC;
673 if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
674 intC.isSignedIntN(kGEPConstantBitWidth)) {
675 rawConstantIndices.push_back(intC.getSExtValue());
676 } else {
677 rawConstantIndices.push_back(GEPOp::kDynamicIndex);
678 dynamicIndices.push_back(val);
679 }
680 } else {
681 rawConstantIndices.push_back(cast<GEPConstantIndex>(iter));
682 }
683
684 // Skip for very first iteration of this loop. First index does not index
685 // within the aggregates, but is just a pointer offset.
686 if (rawConstantIndices.size() == 1 || !currType)
687 continue;
688
689 currType = TypeSwitch<Type, Type>(currType)
690 .Case<VectorType, LLVMArrayType>([](auto containerType) {
691 return containerType.getElementType();
692 })
693 .Case([&](LLVMStructType structType) -> Type {
694 int64_t memberIndex = rawConstantIndices.back();
695 if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
696 structType.getBody().size())
697 return structType.getBody()[memberIndex];
698 return nullptr;
699 })
700 .Default(nullptr);
701 }
702}
703
704void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
705 Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
706 GEPNoWrapFlags noWrapFlags,
707 ArrayRef<NamedAttribute> attributes) {
708 SmallVector<int32_t> rawConstantIndices;
709 SmallVector<Value> dynamicIndices;
710 destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
711
712 result.addTypes(resultType);
713 result.addAttributes(attributes);
714 result.getOrAddProperties<Properties>().rawConstantIndices =
715 builder.getDenseI32ArrayAttr(rawConstantIndices);
716 result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
717 result.getOrAddProperties<Properties>().elem_type =
718 TypeAttr::get(elementType);
719 result.addOperands(basePtr);
720 result.addOperands(dynamicIndices);
721}
722
723void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
724 Type elementType, Value basePtr, ValueRange indices,
725 GEPNoWrapFlags noWrapFlags,
726 ArrayRef<NamedAttribute> attributes) {
727 build(builder, result, resultType, elementType, basePtr,
728 SmallVector<GEPArg>(indices), noWrapFlags, attributes);
729}
730
731static ParseResult
734 DenseI32ArrayAttr &rawConstantIndices) {
735 SmallVector<int32_t> constantIndices;
736
737 auto idxParser = [&]() -> ParseResult {
738 int32_t constantIndex;
739 OptionalParseResult parsedInteger =
740 parser.parseOptionalInteger(constantIndex);
741 if (parsedInteger.has_value()) {
742 if (failed(parsedInteger.value()))
743 return failure();
744 constantIndices.push_back(constantIndex);
745 return success();
746 }
747
748 constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
749 return parser.parseOperand(indices.emplace_back());
750 };
751 if (parser.parseCommaSeparatedList(idxParser))
752 return failure();
753
754 rawConstantIndices =
755 DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
756 return success();
757}
758
759static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
761 DenseI32ArrayAttr rawConstantIndices) {
762 llvm::interleaveComma(
763 GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
765 if (Value val = llvm::dyn_cast_if_present<Value>(cst))
766 printer.printOperand(val);
767 else
768 printer << cast<IntegerAttr>(cst).getInt();
769 });
770}
771
772/// For the given `indices`, check if they comply with `baseGEPType`,
773/// especially check against LLVMStructTypes nested within.
774static LogicalResult
775verifyStructIndices(Type baseGEPType, unsigned indexPos,
778 if (indexPos >= indices.size())
779 // Stop searching
780 return success();
781
782 return TypeSwitch<Type, LogicalResult>(baseGEPType)
783 .Case([&](LLVMStructType structType) -> LogicalResult {
784 auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
785 if (!attr)
786 return emitOpError() << "expected index " << indexPos
787 << " indexing a struct to be constant";
788
789 int32_t gepIndex = attr.getInt();
790 ArrayRef<Type> elementTypes = structType.getBody();
791 if (gepIndex < 0 ||
792 static_cast<size_t>(gepIndex) >= elementTypes.size())
793 return emitOpError() << "index " << indexPos
794 << " indexing a struct is out of bounds";
795
796 // Instead of recursively going into every children types, we only
797 // dive into the one indexed by gepIndex.
798 return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
800 })
801 .Case<VectorType, LLVMArrayType>(
802 [&](auto containerType) -> LogicalResult {
803 return verifyStructIndices(containerType.getElementType(),
804 indexPos + 1, indices, emitOpError);
805 })
806 .Default([&](auto otherType) -> LogicalResult {
807 return emitOpError()
808 << "type " << otherType << " cannot be indexed (index #"
809 << indexPos << ")";
810 });
811}
812
813/// Driver function around `verifyStructIndices`.
814static LogicalResult
819
820LogicalResult LLVM::GEPOp::verify() {
821 if (static_cast<size_t>(
822 llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
823 getDynamicIndices().size())
824 return emitOpError("expected as many dynamic indices as specified in '")
825 << getRawConstantIndicesAttrName().getValue() << "'";
826
827 if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
828 return emitOpError("'inbounds_flag' cannot be used directly.");
829
830 return verifyStructIndices(getElemType(), getIndices(),
831 [&] { return emitOpError(); });
832}
833
834//===----------------------------------------------------------------------===//
835// LoadOp
836//===----------------------------------------------------------------------===//
837
838void LoadOp::getEffects(
840 &effects) {
841 effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
842 // Volatile operations can have target-specific read-write effects on
843 // memory besides the one referred to by the pointer operand.
844 // Similarly, atomic operations that are monotonic or stricter cause
845 // synchronization that from a language point-of-view, are arbitrary
846 // read-writes into memory.
847 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
848 getOrdering() != AtomicOrdering::unordered)) {
849 effects.emplace_back(MemoryEffects::Write::get());
850 effects.emplace_back(MemoryEffects::Read::get());
851 }
852}
853
854/// Returns true if the given type is supported by atomic operations. All
855/// integer, float, and pointer types with a power-of-two bitsize and a minimal
856/// size of 8 bits are supported.
858 const DataLayout &dataLayout) {
859 if (!isa<IntegerType, LLVMPointerType>(type))
861 return false;
862
863 llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type);
864 if (bitWidth.isScalable())
865 return false;
866 // Needs to be at least 8 bits and a power of two.
867 return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
868}
869
870/// Verifies the attributes and the type of atomic memory access operations.
871template <typename OpTy>
872static LogicalResult
873verifyAtomicMemOp(OpTy memOp, Type valueType,
874 ArrayRef<AtomicOrdering> unsupportedOrderings) {
875 if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
876 DataLayout dataLayout = DataLayout::closest(memOp);
877 if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
878 return memOp.emitOpError("unsupported type ")
879 << valueType << " for atomic access";
880 if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
881 return memOp.emitOpError("unsupported ordering '")
882 << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
883 if (!memOp.getAlignment())
884 return memOp.emitOpError("expected alignment for atomic access");
885 return success();
886 }
887 if (memOp.getSyncscope())
888 return memOp.emitOpError(
889 "expected syncscope to be null for non-atomic access");
890 return success();
891}
892
893LogicalResult LoadOp::verify() {
894 Type valueType = getResult().getType();
895 return verifyAtomicMemOp(*this, valueType,
896 {AtomicOrdering::release, AtomicOrdering::acq_rel});
897}
898
899void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
900 Value addr, unsigned alignment, bool isVolatile,
901 bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
902 AtomicOrdering ordering, StringRef syncscope) {
903 build(builder, state, type, addr,
904 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
905 isNonTemporal, isInvariant, isInvariantGroup, ordering,
906 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
907 /*dereferenceable=*/nullptr,
908 /*access_groups=*/nullptr,
909 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
910 /*tbaa=*/nullptr);
911}
912
913//===----------------------------------------------------------------------===//
914// StoreOp
915//===----------------------------------------------------------------------===//
916
917void StoreOp::getEffects(
919 &effects) {
920 effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
921 // Volatile operations can have target-specific read-write effects on
922 // memory besides the one referred to by the pointer operand.
923 // Similarly, atomic operations that are monotonic or stricter cause
924 // synchronization that from a language point-of-view, are arbitrary
925 // read-writes into memory.
926 if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
927 getOrdering() != AtomicOrdering::unordered)) {
928 effects.emplace_back(MemoryEffects::Write::get());
929 effects.emplace_back(MemoryEffects::Read::get());
930 }
931}
932
933LogicalResult StoreOp::verify() {
934 Type valueType = getValue().getType();
935 return verifyAtomicMemOp(*this, valueType,
936 {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
937}
938
939void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
940 Value addr, unsigned alignment, bool isVolatile,
941 bool isNonTemporal, bool isInvariantGroup,
942 AtomicOrdering ordering, StringRef syncscope) {
943 build(builder, state, value, addr,
944 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
945 isNonTemporal, isInvariantGroup, ordering,
946 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
947 /*access_groups=*/nullptr,
948 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
949}
950
951//===----------------------------------------------------------------------===//
952// CallOp
953//===----------------------------------------------------------------------===//
954
955/// Gets the MLIR Op-like result types of a LLVMFunctionType.
956static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
957 SmallVector<Type, 1> results;
958 Type resultType = calleeType.getReturnType();
959 if (!isa<LLVM::LLVMVoidType>(resultType))
960 results.push_back(resultType);
961 return results;
962}
963
964/// Gets the variadic callee type for a LLVMFunctionType.
965static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
966 return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
967}
968
969/// Constructs a LLVMFunctionType from MLIR `results` and `args`.
970static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
971 ValueRange args) {
972 Type resultType;
973 if (results.empty())
974 resultType = LLVMVoidType::get(context);
975 else
976 resultType = results.front();
977 return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
978 /*isVarArg=*/false);
979}
980
981void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
982 StringRef callee, ValueRange args) {
983 build(builder, state, results, builder.getStringAttr(callee), args);
984}
985
986void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
987 StringAttr callee, ValueRange args) {
988 build(builder, state, results, SymbolRefAttr::get(callee), args);
989}
990
991void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
992 FlatSymbolRefAttr callee, ValueRange args) {
993 assert(callee && "expected non-null callee in direct call builder");
994 build(builder, state, results,
995 /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
996 /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
997 /*memory_effects=*/nullptr,
998 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
999 /*noreturn=*/nullptr, /*returns_twice=*/nullptr, /*hot=*/nullptr,
1000 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1001 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1002 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1003 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1004 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1005 /*save_reg_params=*/nullptr,
1006 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1007 /*default_func_attrs=*/nullptr,
1008 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1009 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1010 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1011 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1012 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1013 /*inline_hint=*/nullptr);
1014}
1015
1016void CallOp::build(OpBuilder &builder, OperationState &state,
1017 LLVMFunctionType calleeType, StringRef callee,
1018 ValueRange args) {
1019 build(builder, state, calleeType, builder.getStringAttr(callee), args);
1020}
1021
1022void CallOp::build(OpBuilder &builder, OperationState &state,
1023 LLVMFunctionType calleeType, StringAttr callee,
1024 ValueRange args) {
1025 build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
1026}
1027
1028void CallOp::build(OpBuilder &builder, OperationState &state,
1029 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1030 ValueRange args) {
1031 build(builder, state, getCallOpResultTypes(calleeType),
1032 getCallOpVarCalleeType(calleeType), callee, args,
1033 /*fastmathFlags=*/nullptr,
1034 /*CConv=*/nullptr,
1035 /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1036 /*convergent=*/nullptr,
1037 /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1038 /*noreturn=*/nullptr,
1039 /*returns_twice=*/nullptr, /*hot=*/nullptr,
1040 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1041 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1042 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1043 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1044 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1045 /*save_reg_params=*/nullptr,
1046 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1047 /*default_func_attrs=*/nullptr,
1048 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1049 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1050 /*access_groups=*/nullptr,
1051 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1052 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1053 /*inline_hint=*/nullptr);
1054}
1055
1056void CallOp::build(OpBuilder &builder, OperationState &state,
1057 LLVMFunctionType calleeType, ValueRange args) {
1058 build(builder, state, getCallOpResultTypes(calleeType),
1059 getCallOpVarCalleeType(calleeType),
1060 /*callee=*/nullptr, args,
1061 /*fastmathFlags=*/nullptr,
1062 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1063 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1064 /*noreturn=*/nullptr,
1065 /*returns_twice=*/nullptr, /*hot=*/nullptr,
1066 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1067 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1068 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1069 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1070 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1071 /*save_reg_params=*/nullptr,
1072 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1073 /*default_func_attrs=*/nullptr,
1074 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1075 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1076 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1077 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1078 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1079 /*inline_hint=*/nullptr);
1080}
1081
1082void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1083 ValueRange args) {
1084 auto calleeType = func.getFunctionType();
1085 build(builder, state, getCallOpResultTypes(calleeType),
1086 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
1087 /*fastmathFlags=*/nullptr,
1088 /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1089 /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1090 /*noreturn=*/nullptr,
1091 /*returns_twice=*/nullptr, /*hot=*/nullptr,
1092 /*cold=*/nullptr, /*noduplicate=*/nullptr,
1093 /*no_caller_saved_registers=*/nullptr, /*nocallback=*/nullptr,
1094 /*modular_format=*/nullptr, /*nobuiltins=*/nullptr,
1095 /*allocsize=*/nullptr, /*optsize=*/nullptr, /*minsize=*/nullptr,
1096 /*builtin=*/nullptr, /*nobuiltin=*/nullptr,
1097 /*save_reg_params=*/nullptr,
1098 /*zero_call_used_regs=*/nullptr, /*trap_func_name=*/nullptr,
1099 /*default_func_attrs=*/nullptr,
1100 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1101 /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1102 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1103 /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1104 /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1105 /*inline_hint=*/nullptr);
1106}
1107
1108CallInterfaceCallable CallOp::getCallableForCallee() {
1109 // Direct call.
1110 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1111 return calleeAttr;
1112 // Indirect call, callee Value is the first operand.
1113 return getOperand(0);
1114}
1115
1116void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1117 // Direct call.
1118 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1119 auto symRef = cast<SymbolRefAttr>(callee);
1120 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1121 }
1122 // Indirect call, callee Value is the first operand.
1123 return setOperand(0, cast<Value>(callee));
1124}
1125
1126Operation::operand_range CallOp::getArgOperands() {
1127 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1128}
1129
1130MutableOperandRange CallOp::getArgOperandsMutable() {
1131 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1132 getCalleeOperands().size());
1133}
1134
1135/// Verify that an inlinable callsite of a debug-info-bearing function in a
1136/// debug-info-bearing function has a debug location attached to it. This
1137/// mirrors an LLVM IR verifier.
1138static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1139 if (callee.isExternal())
1140 return success();
1141 auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1142 if (!parentFunc)
1143 return success();
1144
1145 auto hasSubprogram = [](Operation *op) {
1146 return op->getLoc()
1147 ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1148 nullptr;
1149 };
1150 if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1151 return success();
1152 bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1153 if (!containsLoc)
1154 return callOp.emitError()
1155 << "inlinable function call in a function with a DISubprogram "
1156 "location must have a debug location";
1157 return success();
1158}
1159
1160/// Verify that the parameter and return types of the variadic callee type match
1161/// the `callOp` argument and result types.
1162template <typename OpTy>
1163static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1164 std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1165 if (!varCalleeType)
1166 return success();
1167
1168 // Verify the variadic callee type is a variadic function type.
1169 if (!varCalleeType->isVarArg())
1170 return callOp.emitOpError(
1171 "expected var_callee_type to be a variadic function type");
1172
1173 // Verify the variadic callee type has at most as many parameters as the call
1174 // has argument operands.
1175 if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1176 return callOp.emitOpError("expected var_callee_type to have at most ")
1177 << callOp.getArgOperands().size() << " parameters";
1178
1179 // Verify the variadic callee type matches the call argument types.
1180 for (auto [paramType, operand] :
1181 llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1182 if (paramType != operand.getType())
1183 return callOp.emitOpError()
1184 << "var_callee_type parameter type mismatch: " << paramType
1185 << " != " << operand.getType();
1186
1187 // Verify the variadic callee type matches the call result type.
1188 if (!callOp.getNumResults()) {
1189 if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1190 return callOp.emitOpError("expected var_callee_type to return void");
1191 } else {
1192 if (callOp.getResult().getType() != varCalleeType->getReturnType())
1193 return callOp.emitOpError("var_callee_type return type mismatch: ")
1194 << varCalleeType->getReturnType()
1195 << " != " << callOp.getResult().getType();
1196 }
1197 return success();
1198}
1199
1200template <typename OpType>
1201static LogicalResult verifyOperandBundles(OpType &op) {
1202 OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1203 std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1204
1205 auto isStringAttr = [](Attribute tagAttr) {
1206 return isa<StringAttr>(tagAttr);
1207 };
1208 if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1209 return op.emitError("operand bundle tag must be a StringAttr");
1210
1211 size_t numOpBundles = opBundleOperands.size();
1212 size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1213 if (numOpBundles != numOpBundleTags)
1214 return op.emitError("expected ")
1215 << numOpBundles << " operand bundle tags, but actually got "
1216 << numOpBundleTags;
1217
1218 return success();
1219}
1220
1221LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
1222
1223LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1225 return failure();
1226
1227 // Type for the callee, we'll get it differently depending if it is a direct
1228 // or indirect call.
1229 Type fnType;
1230
1231 bool isIndirect = false;
1232
1233 // If this is an indirect call, the callee attribute is missing.
1234 FlatSymbolRefAttr calleeName = getCalleeAttr();
1235 if (!calleeName) {
1236 isIndirect = true;
1237 if (!getNumOperands())
1238 return emitOpError(
1239 "must have either a `callee` attribute or at least an operand");
1240 auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1241 if (!ptrType)
1242 return emitOpError("indirect call expects a pointer as callee: ")
1243 << getOperand(0).getType();
1244
1245 return success();
1246 } else {
1247 Operation *callee =
1248 symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1249 if (!callee)
1250 return emitOpError()
1251 << "'" << calleeName.getValue()
1252 << "' does not reference a symbol in the current scope";
1253 if (auto fn = dyn_cast<LLVMFuncOp>(callee)) {
1254 if (failed(verifyCallOpDebugInfo(*this, fn)))
1255 return failure();
1256 fnType = fn.getFunctionType();
1257 } else if (auto ifunc = dyn_cast<IFuncOp>(callee)) {
1258 fnType = ifunc.getIFuncType();
1259 } else if (isa<AliasOp>(callee)) {
1260 // Aliases can alias functions, so calling through an alias is valid.
1261 // The function type is determined by the call's operands and result
1262 // types.
1263 fnType = getCalleeFunctionType();
1264 } else {
1265 return emitOpError()
1266 << "'" << calleeName.getValue()
1267 << "' does not reference a valid LLVM function, IFunc, or alias";
1268 }
1269 }
1270
1271 LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1272 if (!funcType)
1273 return emitOpError("callee does not have a functional type: ") << fnType;
1274
1275 if (funcType.isVarArg() && !getVarCalleeType())
1276 return emitOpError() << "missing var_callee_type attribute for vararg call";
1277
1278 // Verify that the operand and result types match the callee.
1279
1280 if (!funcType.isVarArg() &&
1281 funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1282 return emitOpError() << "incorrect number of operands ("
1283 << (getCalleeOperands().size() - isIndirect)
1284 << ") for callee (expecting: "
1285 << funcType.getNumParams() << ")";
1286
1287 if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1288 return emitOpError() << "incorrect number of operands ("
1289 << (getCalleeOperands().size() - isIndirect)
1290 << ") for varargs callee (expecting at least: "
1291 << funcType.getNumParams() << ")";
1292
1293 for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1294 if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1295 return emitOpError() << "operand type mismatch for operand " << i << ": "
1296 << getOperand(i + isIndirect).getType()
1297 << " != " << funcType.getParamType(i);
1298
1299 if (getNumResults() == 0 &&
1300 !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1301 return emitOpError() << "expected function call to produce a value";
1302
1303 if (getNumResults() != 0 &&
1304 llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1305 return emitOpError()
1306 << "calling function with void result must not produce values";
1307
1308 if (getNumResults() > 1)
1309 return emitOpError()
1310 << "expected LLVM function call to produce 0 or 1 result";
1311
1312 if (getNumResults() && getResult().getType() != funcType.getReturnType())
1313 return emitOpError() << "result type mismatch: " << getResult().getType()
1314 << " != " << funcType.getReturnType();
1315
1316 return success();
1317}
1318
1319void CallOp::print(OpAsmPrinter &p) {
1320 auto callee = getCallee();
1321 bool isDirect = callee.has_value();
1322
1323 p << ' ';
1324
1325 // Print calling convention.
1326 if (getCConv() != LLVM::CConv::C)
1327 p << stringifyCConv(getCConv()) << ' ';
1328
1329 if (getTailCallKind() != LLVM::TailCallKind::None)
1330 p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1331
1332 // Print the direct callee if present as a function attribute, or an indirect
1333 // callee (first operand) otherwise.
1334 if (isDirect)
1335 p.printSymbolName(callee.value());
1336 else
1337 p << getOperand(0);
1338
1339 auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
1340 p << '(' << args << ')';
1341
1342 // Print the variadic callee type if the call is variadic.
1343 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1344 p << " vararg(" << *varCalleeType << ")";
1345
1346 if (!getOpBundleOperands().empty()) {
1347 p << " ";
1348 printOpBundles(p, *this, getOpBundleOperands(),
1349 getOpBundleOperands().getTypes(), getOpBundleTags());
1350 }
1351
1352 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1353 {getCalleeAttrName(), getTailCallKindAttrName(),
1354 getVarCalleeTypeAttrName(), getCConvAttrName(),
1355 getOperandSegmentSizesAttrName(),
1356 getOpBundleSizesAttrName(),
1357 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1358 getResAttrsAttrName()});
1359
1360 p << " : ";
1361 if (!isDirect)
1362 p << getOperand(0).getType() << ", ";
1363
1364 // Reconstruct the MLIR function type from operand and result types.
1366 p, args.getTypes(), getArgAttrsAttr(),
1367 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1368}
1369
1370/// Parses the type of a call operation and resolves the operands if the parsing
1371/// succeeds. Returns failure otherwise.
1373 OpAsmParser &parser, OperationState &result, bool isDirect,
1376 SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1377 SMLoc trailingTypesLoc = parser.getCurrentLocation();
1378 SmallVector<Type> types;
1379 if (parser.parseColon())
1380 return failure();
1381 if (!isDirect) {
1382 types.emplace_back();
1383 if (parser.parseType(types.back()))
1384 return failure();
1385 if (parser.parseOptionalComma())
1386 return parser.emitError(
1387 trailingTypesLoc, "expected indirect call to have 2 trailing types");
1388 }
1389 SmallVector<Type> argTypes;
1390 SmallVector<Type> resTypes;
1391 if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1392 resTypes, resultAttrs)) {
1393 if (isDirect)
1394 return parser.emitError(trailingTypesLoc,
1395 "expected direct call to have 1 trailing types");
1396 return parser.emitError(trailingTypesLoc,
1397 "expected trailing function type");
1398 }
1399
1400 if (resTypes.size() > 1)
1401 return parser.emitError(trailingTypesLoc,
1402 "expected function with 0 or 1 result");
1403 if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
1404 return parser.emitError(trailingTypesLoc,
1405 "expected a non-void result type");
1406
1407 // The head element of the types list matches the callee type for
1408 // indirect calls, while the types list is emtpy for direct calls.
1409 // Append the function input types to resolve the call operation
1410 // operands.
1411 llvm::append_range(types, argTypes);
1412 if (parser.resolveOperands(operands, types, parser.getNameLoc(),
1413 result.operands))
1414 return failure();
1415 if (!resTypes.empty())
1416 result.addTypes(resTypes);
1417
1418 return success();
1419}
1420
1421/// Parses an optional function pointer operand before the call argument list
1422/// for indirect calls, or stops parsing at the function identifier otherwise.
1423static ParseResult parseOptionalCallFuncPtr(
1424 OpAsmParser &parser,
1426 OpAsmParser::UnresolvedOperand funcPtrOperand;
1427 OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
1428 if (parseResult.has_value()) {
1429 if (failed(*parseResult))
1430 return *parseResult;
1431 operands.push_back(funcPtrOperand);
1432 }
1433 return success();
1434}
1435
1436static ParseResult resolveOpBundleOperands(
1437 OpAsmParser &parser, SMLoc loc, OperationState &state,
1439 ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1440 StringAttr opBundleSizesAttrName) {
1441 unsigned opBundleIndex = 0;
1442 for (const auto &[operands, types] :
1443 llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
1444 if (operands.size() != types.size())
1445 return parser.emitError(loc, "expected ")
1446 << operands.size()
1447 << " types for operand bundle operands for operand bundle #"
1448 << opBundleIndex << ", but actually got " << types.size();
1449 if (parser.resolveOperands(operands, types, loc, state.operands))
1450 return failure();
1451 }
1452
1453 SmallVector<int32_t> opBundleSizes;
1454 opBundleSizes.reserve(opBundleOperands.size());
1455 for (const auto &operands : opBundleOperands)
1456 opBundleSizes.push_back(operands.size());
1457
1458 state.addAttribute(
1459 opBundleSizesAttrName,
1460 DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
1461
1462 return success();
1463}
1464
1465// <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1466// `(` ssa-use-list `)`
1467// ( `vararg(` var-callee-type `)` )?
1468// ( `[` op-bundles-list `]` )?
1469// attribute-dict? `:` (type `,`)? function-type
1470ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1471 SymbolRefAttr funcAttr;
1472 TypeAttr varCalleeType;
1475 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1476 ArrayAttr opBundleTags;
1477
1478 // Default to C Calling Convention if no keyword is provided.
1479 result.addAttribute(
1480 getCConvAttrName(result.name),
1481 CConvAttr::get(parser.getContext(),
1482 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
1483
1484 result.addAttribute(
1485 getTailCallKindAttrName(result.name),
1486 TailCallKindAttr::get(parser.getContext(),
1488 parser, LLVM::TailCallKind::None)));
1489
1490 // Parse a function pointer for indirect calls.
1491 if (parseOptionalCallFuncPtr(parser, operands))
1492 return failure();
1493 bool isDirect = operands.empty();
1494
1495 // Parse a function identifier for direct calls.
1496 if (isDirect)
1497 if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1498 return failure();
1499
1500 // Parse the function arguments.
1501 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1502 return failure();
1503
1504 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1505 if (isVarArg) {
1506 StringAttr varCalleeTypeAttrName =
1507 CallOp::getVarCalleeTypeAttrName(result.name);
1508 if (parser.parseLParen().failed() ||
1509 parser
1510 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1511 result.attributes)
1512 .failed() ||
1513 parser.parseRParen().failed())
1514 return failure();
1515 }
1516
1517 SMLoc opBundlesLoc = parser.getCurrentLocation();
1518 if (std::optional<ParseResult> result = parseOpBundles(
1519 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1520 result && failed(*result))
1521 return failure();
1522 if (opBundleTags && !opBundleTags.empty())
1523 result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
1524 opBundleTags);
1525
1526 if (parser.parseOptionalAttrDict(result.attributes))
1527 return failure();
1528
1529 // Parse the trailing type list and resolve the operands.
1531 SmallVector<DictionaryAttr> resultAttrs;
1532 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1533 argAttrs, resultAttrs))
1534 return failure();
1536 parser.getBuilder(), result, argAttrs, resultAttrs,
1537 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1538 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1539 opBundleOperandTypes,
1540 getOpBundleSizesAttrName(result.name)))
1541 return failure();
1542
1543 int32_t numOpBundleOperands = 0;
1544 for (const auto &operands : opBundleOperands)
1545 numOpBundleOperands += operands.size();
1546
1547 result.addAttribute(
1548 CallOp::getOperandSegmentSizeAttr(),
1550 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1551 return success();
1552}
1553
1554LLVMFunctionType CallOp::getCalleeFunctionType() {
1555 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1556 return *varCalleeType;
1557 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1558}
1559
1560///===---------------------------------------------------------------------===//
1561/// LLVM::InvokeOp
1562///===---------------------------------------------------------------------===//
1563
1564void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1565 ValueRange ops, Block *normal, ValueRange normalOps,
1566 Block *unwind, ValueRange unwindOps) {
1567 auto calleeType = func.getFunctionType();
1568 build(builder, state, getCallOpResultTypes(calleeType),
1569 getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1570 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1571 nullptr, nullptr, {}, {}, normal, unwind);
1572}
1573
1574void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1575 FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1576 ValueRange normalOps, Block *unwind,
1577 ValueRange unwindOps) {
1578 build(builder, state, tys,
1579 /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr,
1580 /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {},
1581 normal, unwind);
1582}
1583
1584void InvokeOp::build(OpBuilder &builder, OperationState &state,
1585 LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1586 ValueRange ops, Block *normal, ValueRange normalOps,
1587 Block *unwind, ValueRange unwindOps) {
1588 build(builder, state, getCallOpResultTypes(calleeType),
1589 getCallOpVarCalleeType(calleeType), callee, ops,
1590 /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1591 nullptr, nullptr, {}, {}, normal, unwind);
1592}
1593
1594SuccessorOperands InvokeOp::getSuccessorOperands(unsigned index) {
1595 assert(index < getNumSuccessors() && "invalid successor index");
1596 return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1597 : getUnwindDestOperandsMutable());
1598}
1599
1600CallInterfaceCallable InvokeOp::getCallableForCallee() {
1601 // Direct call.
1602 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1603 return calleeAttr;
1604 // Indirect call, callee Value is the first operand.
1605 return getOperand(0);
1606}
1607
1608void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1609 // Direct call.
1610 if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1611 auto symRef = cast<SymbolRefAttr>(callee);
1612 return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1613 }
1614 // Indirect call, callee Value is the first operand.
1615 return setOperand(0, cast<Value>(callee));
1616}
1617
1618Operation::operand_range InvokeOp::getArgOperands() {
1619 return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1620}
1621
1622MutableOperandRange InvokeOp::getArgOperandsMutable() {
1623 return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1624 getCalleeOperands().size());
1625}
1626
1627LogicalResult InvokeOp::verify() {
1629 return failure();
1630
1631 Block *unwindDest = getUnwindDest();
1632 if (unwindDest->empty())
1633 return emitError("must have at least one operation in unwind destination");
1634
1635 // In unwind destination, first operation must be LandingpadOp
1636 if (!isa<LandingpadOp>(unwindDest->front()))
1637 return emitError("first operation in unwind destination should be a "
1638 "llvm.landingpad operation");
1639
1640 if (failed(verifyOperandBundles(*this)))
1641 return failure();
1642
1643 return success();
1644}
1645
1646void InvokeOp::print(OpAsmPrinter &p) {
1647 auto callee = getCallee();
1648 bool isDirect = callee.has_value();
1649
1650 p << ' ';
1651
1652 // Print calling convention.
1653 if (getCConv() != LLVM::CConv::C)
1654 p << stringifyCConv(getCConv()) << ' ';
1655
1656 // Either function name or pointer
1657 if (isDirect)
1658 p.printSymbolName(callee.value());
1659 else
1660 p << getOperand(0);
1661
1662 p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
1663 p << " to ";
1664 p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1665 p << " unwind ";
1666 p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1667
1668 // Print the variadic callee type if the invoke is variadic.
1669 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1670 p << " vararg(" << *varCalleeType << ")";
1671
1672 if (!getOpBundleOperands().empty()) {
1673 p << " ";
1674 printOpBundles(p, *this, getOpBundleOperands(),
1675 getOpBundleOperands().getTypes(), getOpBundleTags());
1676 }
1677
1678 p.printOptionalAttrDict((*this)->getAttrs(),
1679 {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1680 getCConvAttrName(), getVarCalleeTypeAttrName(),
1681 getOpBundleSizesAttrName(),
1682 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1683 getResAttrsAttrName()});
1684
1685 p << " : ";
1686 if (!isDirect)
1687 p << getOperand(0).getType() << ", ";
1689 p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
1690 getArgAttrsAttr(),
1691 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1692}
1693
1694// <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1695// `(` ssa-use-list `)`
1696// `to` bb-id (`[` ssa-use-and-type-list `]`)?
1697// `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1698// ( `vararg(` var-callee-type `)` )?
1699// ( `[` op-bundles-list `]` )?
1700// attribute-dict? `:` (type `,`)?
1701// function-type-with-argument-attributes
1702ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1704 SymbolRefAttr funcAttr;
1705 TypeAttr varCalleeType;
1707 SmallVector<SmallVector<Type>> opBundleOperandTypes;
1708 ArrayAttr opBundleTags;
1709 Block *normalDest, *unwindDest;
1710 SmallVector<Value, 4> normalOperands, unwindOperands;
1711 Builder &builder = parser.getBuilder();
1712
1713 // Default to C Calling Convention if no keyword is provided.
1714 result.addAttribute(
1715 getCConvAttrName(result.name),
1716 CConvAttr::get(parser.getContext(),
1717 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
1718
1719 // Parse a function pointer for indirect calls.
1720 if (parseOptionalCallFuncPtr(parser, operands))
1721 return failure();
1722 bool isDirect = operands.empty();
1723
1724 // Parse a function identifier for direct calls.
1725 if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1726 return failure();
1727
1728 // Parse the function arguments.
1729 if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1730 parser.parseKeyword("to") ||
1731 parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1732 parser.parseKeyword("unwind") ||
1733 parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1734 return failure();
1735
1736 bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1737 if (isVarArg) {
1738 StringAttr varCalleeTypeAttrName =
1739 InvokeOp::getVarCalleeTypeAttrName(result.name);
1740 if (parser.parseLParen().failed() ||
1741 parser
1742 .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1743 result.attributes)
1744 .failed() ||
1745 parser.parseRParen().failed())
1746 return failure();
1747 }
1748
1749 SMLoc opBundlesLoc = parser.getCurrentLocation();
1750 if (std::optional<ParseResult> result = parseOpBundles(
1751 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1752 result && failed(*result))
1753 return failure();
1754 if (opBundleTags && !opBundleTags.empty())
1755 result.addAttribute(
1756 InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
1757 opBundleTags);
1758
1759 if (parser.parseOptionalAttrDict(result.attributes))
1760 return failure();
1761
1762 // Parse the trailing type list and resolve the function operands.
1764 SmallVector<DictionaryAttr> resultAttrs;
1765 if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1766 argAttrs, resultAttrs))
1767 return failure();
1769 parser.getBuilder(), result, argAttrs, resultAttrs,
1770 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1771
1772 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1773 opBundleOperandTypes,
1774 getOpBundleSizesAttrName(result.name)))
1775 return failure();
1776
1777 result.addSuccessors({normalDest, unwindDest});
1778 result.addOperands(normalOperands);
1779 result.addOperands(unwindOperands);
1780
1781 int32_t numOpBundleOperands = 0;
1782 for (const auto &operands : opBundleOperands)
1783 numOpBundleOperands += operands.size();
1784
1785 result.addAttribute(
1786 InvokeOp::getOperandSegmentSizeAttr(),
1787 builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
1788 static_cast<int32_t>(normalOperands.size()),
1789 static_cast<int32_t>(unwindOperands.size()),
1790 numOpBundleOperands}));
1791 return success();
1792}
1793
1794LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1795 if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1796 return *varCalleeType;
1797 return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1798}
1799
1800///===----------------------------------------------------------------------===//
1801/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1802///===----------------------------------------------------------------------===//
1803
1804LogicalResult LandingpadOp::verify() {
1805 Value value;
1806 if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1807 if (!func.getPersonality())
1808 return emitError(
1809 "llvm.landingpad needs to be in a function with a personality");
1810 }
1811
1812 // Consistency of llvm.landingpad result types is checked in
1813 // LLVMFuncOp::verify().
1814
1815 if (!getCleanup() && getOperands().empty())
1816 return emitError("landingpad instruction expects at least one clause or "
1817 "cleanup attribute");
1818
1819 for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1820 value = getOperand(idx);
1821 bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1822 if (isFilter) {
1823 // FIXME: Verify filter clauses when arrays are appropriately handled
1824 } else {
1825 // catch - global addresses only.
1826 // Bitcast ops should have global addresses as their args.
1827 if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1828 if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1829 continue;
1830 return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1831 << "global addresses expected as operand to "
1832 "bitcast used in clauses for landingpad";
1833 }
1834 // ZeroOp and AddressOfOp allowed
1835 if (value.getDefiningOp<ZeroOp>())
1836 continue;
1837 if (value.getDefiningOp<AddressOfOp>())
1838 continue;
1839 return emitError("clause #")
1840 << idx << " is not a known constant - null, addressof, bitcast";
1841 }
1842 }
1843 return success();
1844}
1845
1846void LandingpadOp::print(OpAsmPrinter &p) {
1847 p << (getCleanup() ? " cleanup " : " ");
1848
1849 // Clauses
1850 for (auto value : getOperands()) {
1851 // Similar to llvm - if clause is an array type then it is filter
1852 // clause else catch clause
1853 bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1854 p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1855 << value.getType() << ") ";
1856 }
1857
1858 p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1859
1860 p << ": " << getType();
1861}
1862
1863// <operation> ::= `llvm.landingpad` `cleanup`?
1864// ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1865ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1866 // Check for cleanup
1867 if (succeeded(parser.parseOptionalKeyword("cleanup")))
1868 result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1869
1870 // Parse clauses with types
1871 while (succeeded(parser.parseOptionalLParen()) &&
1872 (succeeded(parser.parseOptionalKeyword("filter")) ||
1873 succeeded(parser.parseOptionalKeyword("catch")))) {
1875 Type ty;
1876 if (parser.parseOperand(operand) || parser.parseColon() ||
1877 parser.parseType(ty) ||
1878 parser.resolveOperand(operand, ty, result.operands) ||
1879 parser.parseRParen())
1880 return failure();
1881 }
1882
1883 Type type;
1884 if (parser.parseColon() || parser.parseType(type))
1885 return failure();
1886
1887 result.addTypes(type);
1888 return success();
1889}
1890
1891//===----------------------------------------------------------------------===//
1892// ExtractValueOp
1893//===----------------------------------------------------------------------===//
1894
1895/// Extract the type at `position` in the LLVM IR aggregate type
1896/// `containerType`. Each element of `position` is an index into a nested
1897/// aggregate type. Return the resulting type or emit an error.
1899 function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1900 ArrayRef<int64_t> position) {
1901 Type llvmType = containerType;
1902 if (!isCompatibleType(containerType)) {
1903 emitError("expected LLVM IR Dialect type, got ") << containerType;
1904 return {};
1905 }
1906
1907 // Infer the element type from the structure type: iteratively step inside the
1908 // type by taking the element type, indexed by the position attribute for
1909 // structures. Check the position index before accessing, it is supposed to
1910 // be in bounds.
1911 for (int64_t idx : position) {
1912 if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1913 if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1914 emitError("position out of bounds: ") << idx;
1915 return {};
1916 }
1917 llvmType = arrayType.getElementType();
1918 } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1919 if (idx < 0 ||
1920 static_cast<unsigned>(idx) >= structType.getBody().size()) {
1921 emitError("position out of bounds: ") << idx;
1922 return {};
1923 }
1924 llvmType = structType.getBody()[idx];
1925 } else {
1926 emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1927 return {};
1928 }
1929 }
1930 return llvmType;
1931}
1932
1933/// Extract the type at `position` in the wrapped LLVM IR aggregate type
1934/// `containerType`.
1936 ArrayRef<int64_t> position) {
1937 for (int64_t idx : position) {
1938 if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1939 llvmType = structType.getBody()[idx];
1940 else
1941 llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1942 }
1943 return llvmType;
1944}
1945
1946/// Extracts the element at the given index from an attribute. For
1947/// `ElementsAttr`, returns the element at the specified index, or `nullptr` if
1948/// the shaped type does not have rank 1. For `ArrayAttr`, returns the element
1949/// at the specified index. For `ZeroAttr`, `UndefAttr`, and `PoisonAttr`,
1950/// returns the attribute itself unchanged. Returns `nullptr` if the attribute
1951/// is not one of these types or if the index is out of bounds.
1953 if (auto elementsAttr = dyn_cast<ElementsAttr>(attr)) {
1954 ShapedType shapedType = elementsAttr.getShapedType();
1955 if (!shapedType.hasRank() || shapedType.getRank() != 1)
1956 return nullptr;
1957 if (index < static_cast<size_t>(elementsAttr.getNumElements()))
1958 return elementsAttr.getValues<Attribute>()[index];
1959 return nullptr;
1960 }
1961 if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
1962 if (index < arrayAttr.getValue().size())
1963 return arrayAttr[index];
1964 return nullptr;
1965 }
1966 if (isa<ZeroAttr, UndefAttr, PoisonAttr>(attr))
1967 return attr;
1968 return nullptr;
1969}
1970
1971OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1972 if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
1973 SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
1974 newPos.append(getPosition().begin(), getPosition().end());
1975 setPosition(newPos);
1976 getContainerMutable().set(extractValueOp.getContainer());
1977 return getResult();
1978 }
1979
1980 Attribute containerAttr;
1981 if (matchPattern(getContainer(), m_Constant(&containerAttr))) {
1982 for (int64_t pos : getPosition()) {
1983 containerAttr = extractElementAt(containerAttr, pos);
1984 if (!containerAttr)
1985 return nullptr;
1986 }
1987 return containerAttr;
1988 }
1989
1990 Value container = getContainer();
1991 ArrayRef<int64_t> extractPos = getPosition();
1992 while (auto insertValueOp = container.getDefiningOp<InsertValueOp>()) {
1993 ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1994 auto extractPosSize = extractPos.size();
1995 auto insertPosSize = insertPos.size();
1996
1997 // Case 1: Exact match of positions.
1998 if (extractPos == insertPos)
1999 return insertValueOp.getValue();
2000
2001 // Case 2: Insert position is a prefix of extract position. Continue
2002 // traversal with the inserted value. Example:
2003 // ```
2004 // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
2005 // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
2006 // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
2007 // %3 = llvm.insertvalue %2, %foo[0]
2008 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
2009 // %4 = llvm.extractvalue %3[0, 0]
2010 // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
2011 // ```
2012 // In the above example, %4 is folded to %arg1.
2013 if (extractPosSize > insertPosSize &&
2014 extractPos.take_front(insertPosSize) == insertPos) {
2015 container = insertValueOp.getValue();
2016 extractPos = extractPos.drop_front(insertPosSize);
2017 continue;
2018 }
2019
2020 // Case 3: Try to continue the traversal with the container value.
2021
2022 // If extract position is a prefix of insert position, stop propagating back
2023 // as it will miss dependencies. For instance, %3 should not fold to %f0 in
2024 // the following example:
2025 // ```
2026 // %1 = llvm.insertvalue %f0, %0[0, 0] :
2027 // !llvm.array<4 x !llvm.array<4 x f32>>
2028 // %2 = llvm.insertvalue %arr, %1[0] :
2029 // !llvm.array<4 x !llvm.array<4 x f32>>
2030 // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
2031 // ```
2032 if (insertPosSize > extractPosSize &&
2033 extractPos == insertPos.take_front(extractPosSize))
2034 break;
2035 // If neither a prefix, nor the exact position, we can extract out of the
2036 // value being inserted into. Moreover, we can try again if that operand
2037 // is itself an insertvalue expression.
2038 container = insertValueOp.getContainer();
2039 }
2040
2041 // We failed to resolve past this container either because it is not an
2042 // InsertValueOp, or it is an InsertValueOp that partially overlaps with the
2043 // value being extracted. Update to read from this container instead.
2044 if (container == getContainer())
2045 return {};
2046 setPosition(extractPos);
2047 getContainerMutable().assign(container);
2048 return getResult();
2049}
2050
2051LogicalResult ExtractValueOp::verify() {
2052 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
2054 emitError, getContainer().getType(), getPosition());
2055 if (!valueType)
2056 return failure();
2057
2058 if (getRes().getType() != valueType)
2059 return emitOpError() << "Type mismatch: extracting from "
2060 << getContainer().getType() << " should produce "
2061 << valueType << " but this op returns "
2062 << getRes().getType();
2063 return success();
2064}
2065
2066void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
2067 Value container, ArrayRef<int64_t> position) {
2068 build(builder, state,
2069 getInsertExtractValueElementType(container.getType(), position),
2070 container, builder.getAttr<DenseI64ArrayAttr>(position));
2071}
2072
2073//===----------------------------------------------------------------------===//
2074// InsertValueOp
2075//===----------------------------------------------------------------------===//
2076
2077namespace {
2078/// Update any ExtractValueOps using a given InsertValueOp to instead read from
2079/// the closest InsertValueOp in the chain leading up to the current op that
2080/// writes to the same member. This traversal could be done entirely in
2081/// ExtractValueOp::fold, but doing it here significantly speeds things up
2082/// because we can handle several ExtractValueOps with a single traversal.
2083/// For instance, in this example:
2084/// %i0 = llvm.insertvalue %v0, %undef[0]
2085/// %i1 = llvm.insertvalue %v1, %0[1]
2086/// ...
2087/// %i999 = llvm.insertvalue %v999, %998[999]
2088/// %e0 = llvm.extractvalue %i999[0]
2089/// %e1 = llvm.extractvalue %i999[1]
2090/// ...
2091/// %e999 = llvm.extractvalue %i999[999]
2092/// Individually running the folder on each extractvalue would require
2093/// traversing the insertvalue chain 1000 times, but running this pattern on the
2094/// InsertValueOp would allow us to achieve the same result with a single
2095/// traversal. The resulting IR after this pattern will then be:
2096/// %i0 = llvm.insertvalue %v0, %undef[0]
2097/// %i1 = llvm.insertvalue %v1, %0[1]
2098/// ...
2099/// %i999 = llvm.insertvalue %v999, %998[999]
2100/// %e0 = llvm.extractvalue %i0[0]
2101/// %e1 = llvm.extractvalue %i1[1]
2102/// ...
2103/// %e999 = llvm.extractvalue %i999[999]
2104struct ResolveExtractValueSource : public OpRewritePattern<InsertValueOp> {
2106
2107 LogicalResult matchAndRewrite(InsertValueOp insertOp,
2108 PatternRewriter &rewriter) const override {
2109 bool changed = false;
2110 // Map each position in the top-level struct to the ExtractOps that read
2111 // from it. For the example in the doc-comment above this map will be empty
2112 // when we visit ops %i0 - %i998. For %i999, it will contain:
2113 // 0 -> { %e0 }, 1 -> { %e1 }, ... 999-> { %e999 }
2115 auto insertBaseIdx = insertOp.getPosition()[0];
2116 for (auto &use : insertOp->getUses()) {
2117 if (auto extractOp = dyn_cast<ExtractValueOp>(use.getOwner())) {
2118 auto baseIdx = extractOp.getPosition()[0];
2119 // We can skip reads of the member that insertOp writes to since they
2120 // will not be updated.
2121 if (baseIdx == insertBaseIdx)
2122 continue;
2123 posToExtractOps[baseIdx].push_back(extractOp);
2124 }
2125 }
2126 // Walk up the chain of insertions and try to resolve the remaining
2127 // extractions that access the same member.
2128 Value nextContainer = insertOp.getContainer();
2129 while (!posToExtractOps.empty()) {
2130 auto curInsert =
2131 dyn_cast_or_null<InsertValueOp>(nextContainer.getDefiningOp());
2132 if (!curInsert)
2133 break;
2134 nextContainer = curInsert.getContainer();
2135
2136 // Check if any extractions read the member written by this insertion.
2137 auto curInsertBaseIdx = curInsert.getPosition()[0];
2138 auto it = posToExtractOps.find(curInsertBaseIdx);
2139 if (it == posToExtractOps.end())
2140 continue;
2141
2142 // Update the ExtractOps to read from the current insertion.
2143 for (auto &extractOp : it->second) {
2144 rewriter.modifyOpInPlace(extractOp, [&] {
2145 extractOp.getContainerMutable().assign(curInsert);
2146 });
2147 }
2148 // The entry should never be empty if it exists, so if we are at this
2149 // point, set changed to true.
2150 assert(!it->second.empty());
2151 changed |= true;
2152 posToExtractOps.erase(it);
2153 }
2154 // There was no insertion along the chain that wrote the member accessed by
2155 // these extracts. So we can update them to use the top of the chain.
2156 for (auto &[baseIdx, extracts] : posToExtractOps) {
2157 for (auto &extractOp : extracts) {
2158 rewriter.modifyOpInPlace(extractOp, [&] {
2159 extractOp.getContainerMutable().assign(nextContainer);
2160 });
2161 }
2162 assert(!extracts.empty() && "Empty list in map");
2163 changed = true;
2164 }
2165 return success(changed);
2166 }
2167};
2168} // namespace
2169
2170void InsertValueOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2171 MLIRContext *context) {
2172 patterns.add<ResolveExtractValueSource>(context);
2173}
2174
2175/// Infer the value type from the container type and position.
2176static ParseResult
2178 Type containerType,
2179 DenseI64ArrayAttr position) {
2181 [&](StringRef msg) {
2182 return parser.emitError(parser.getCurrentLocation(), msg);
2183 },
2184 containerType, position.asArrayRef());
2185 return success(!!valueType);
2186}
2187
2188/// Nothing to print for an inferred type.
2190 Operation *op, Type valueType,
2191 Type containerType,
2192 DenseI64ArrayAttr position) {}
2193
2194LogicalResult InsertValueOp::verify() {
2195 auto emitError = [this](StringRef msg) { return emitOpError(msg); };
2197 emitError, getContainer().getType(), getPosition());
2198 if (!valueType)
2199 return failure();
2200
2201 if (getValue().getType() != valueType)
2202 return emitOpError() << "Type mismatch: cannot insert "
2203 << getValue().getType() << " into "
2204 << getContainer().getType();
2205
2206 return success();
2207}
2208
2209//===----------------------------------------------------------------------===//
2210// ReturnOp
2211//===----------------------------------------------------------------------===//
2212
2213LogicalResult ReturnOp::verify() {
2214 auto parent = (*this)->getParentOfType<LLVMFuncOp>();
2215 if (!parent)
2216 return success();
2217
2218 Type expectedType = parent.getFunctionType().getReturnType();
2219 if (llvm::isa<LLVMVoidType>(expectedType)) {
2220 if (!getArg())
2221 return success();
2222 InFlightDiagnostic diag = emitOpError("expected no operands");
2223 diag.attachNote(parent->getLoc()) << "when returning from function";
2224 return diag;
2225 }
2226 if (!getArg()) {
2227 if (llvm::isa<LLVMVoidType>(expectedType))
2228 return success();
2229 InFlightDiagnostic diag = emitOpError("expected 1 operand");
2230 diag.attachNote(parent->getLoc()) << "when returning from function";
2231 return diag;
2232 }
2233 if (expectedType != getArg().getType()) {
2234 InFlightDiagnostic diag = emitOpError("mismatching result types");
2235 diag.attachNote(parent->getLoc()) << "when returning from function";
2236 return diag;
2237 }
2238 return success();
2239}
2240
2241//===----------------------------------------------------------------------===//
2242// LLVM::AddressOfOp.
2243//===----------------------------------------------------------------------===//
2244
2245GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2246 return dyn_cast_or_null<GlobalOp>(
2247 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2248}
2249
2250LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2251 return dyn_cast_or_null<LLVMFuncOp>(
2252 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2253}
2254
2255AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
2256 return dyn_cast_or_null<AliasOp>(
2257 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2258}
2259
2260IFuncOp AddressOfOp::getIFunc(SymbolTableCollection &symbolTable) {
2261 return dyn_cast_or_null<IFuncOp>(
2262 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2263}
2264
2265LogicalResult
2266AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2267 Operation *symbol =
2268 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
2269
2270 auto global = dyn_cast_or_null<GlobalOp>(symbol);
2271 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2272 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2273 auto ifunc = dyn_cast_or_null<IFuncOp>(symbol);
2274
2275 if (!global && !function && !alias && !ifunc)
2276 return emitOpError("must reference a global defined by 'llvm.mlir.global', "
2277 "'llvm.mlir.alias' or 'llvm.func' or 'llvm.mlir.ifunc'");
2278
2279 LLVMPointerType type = getType();
2280 if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
2281 (alias && alias.getAddrSpace() != type.getAddressSpace()))
2282 return emitOpError("pointer address space must match address space of the "
2283 "referenced global or alias");
2284
2285 return success();
2286}
2287
2288// AddressOfOp constant-folds to the global symbol name.
2289OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2290 return getGlobalNameAttr();
2291}
2292
2293//===----------------------------------------------------------------------===//
2294// LLVM::DSOLocalEquivalentOp
2295//===----------------------------------------------------------------------===//
2296
2297LLVMFuncOp
2298DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) {
2299 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
2300 parentLLVMModule(*this), getFunctionNameAttr()));
2301}
2302
2303AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) {
2304 return dyn_cast_or_null<AliasOp>(symbolTable.lookupSymbolIn(
2305 parentLLVMModule(*this), getFunctionNameAttr()));
2306}
2307
2308LogicalResult
2309DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2310 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
2311 getFunctionNameAttr());
2312 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2313 auto alias = dyn_cast_or_null<AliasOp>(symbol);
2314
2315 if (!function && !alias)
2316 return emitOpError(
2317 "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'");
2318
2319 if (alias) {
2320 if (alias.getInitializer()
2321 .walk([&](AddressOfOp addrOp) {
2322 if (addrOp.getGlobal(symbolTable))
2323 return WalkResult::interrupt();
2324 return WalkResult::advance();
2325 })
2326 .wasInterrupted())
2327 return emitOpError("must reference an alias to a function");
2328 }
2329
2330 if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) ||
2331 (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak))
2332 return emitOpError(
2333 "target function with 'extern_weak' linkage not allowed");
2334
2335 return success();
2336}
2337
2338/// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent
2339/// attribute.
2340OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) {
2341 return DSOLocalEquivalentAttr::get(getContext(), getFunctionNameAttr());
2342}
2343
2344//===----------------------------------------------------------------------===//
2345// Verifier for LLVM::ComdatOp.
2346//===----------------------------------------------------------------------===//
2347
2348void ComdatOp::build(OpBuilder &builder, OperationState &result,
2349 StringRef symName) {
2350 result.addAttribute(getSymNameAttrName(result.name),
2351 builder.getStringAttr(symName));
2352 Region *body = result.addRegion();
2353 body->emplaceBlock();
2354}
2355
2356LogicalResult ComdatOp::verifyRegions() {
2357 Region &body = getBody();
2358 for (Operation &op : body.getOps())
2359 if (!isa<ComdatSelectorOp>(op))
2360 return op.emitError(
2361 "only comdat selector symbols can appear in a comdat region");
2362
2363 return success();
2364}
2365
2366//===----------------------------------------------------------------------===//
2367// Builder, printer and verifier for LLVM::GlobalOp.
2368//===----------------------------------------------------------------------===//
2369
2370void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2371 bool isConstant, Linkage linkage, StringRef name,
2372 Attribute value, uint64_t alignment, unsigned addrSpace,
2373 bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2375 ArrayRef<Attribute> dbgExprs) {
2376 result.addAttribute(getSymNameAttrName(result.name),
2377 builder.getStringAttr(name));
2378 result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
2379 if (isConstant)
2380 result.addAttribute(getConstantAttrName(result.name),
2381 builder.getUnitAttr());
2382 if (value)
2383 result.addAttribute(getValueAttrName(result.name), value);
2384 if (dsoLocal)
2385 result.addAttribute(getDsoLocalAttrName(result.name),
2386 builder.getUnitAttr());
2387 if (threadLocal)
2388 result.addAttribute(getThreadLocal_AttrName(result.name),
2389 builder.getUnitAttr());
2390 if (comdat)
2391 result.addAttribute(getComdatAttrName(result.name), comdat);
2392
2393 // Only add an alignment attribute if the "alignment" input
2394 // is different from 0. The value must also be a power of two, but
2395 // this is tested in GlobalOp::verify, not here.
2396 if (alignment != 0)
2397 result.addAttribute(getAlignmentAttrName(result.name),
2398 builder.getI64IntegerAttr(alignment));
2399
2400 result.addAttribute(getLinkageAttrName(result.name),
2401 LinkageAttr::get(builder.getContext(), linkage));
2402 if (addrSpace != 0)
2403 result.addAttribute(getAddrSpaceAttrName(result.name),
2404 builder.getI32IntegerAttr(addrSpace));
2405 result.attributes.append(attrs.begin(), attrs.end());
2406
2407 if (!dbgExprs.empty())
2408 result.addAttribute(getDbgExprsAttrName(result.name),
2409 ArrayAttr::get(builder.getContext(), dbgExprs));
2410
2411 result.addRegion();
2412}
2413
2414template <typename OpType>
2415static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op) {
2416 p << ' ' << stringifyLinkage(op.getLinkage()) << ' ';
2417 StringRef visibility = stringifyVisibility(op.getVisibility_());
2418 if (!visibility.empty())
2419 p << visibility << ' ';
2420 if (op.getThreadLocal_())
2421 p << "thread_local ";
2422 if (auto unnamedAddr = op.getUnnamedAddr()) {
2423 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2424 if (!str.empty())
2425 p << str << ' ';
2426 }
2427}
2428
2429void GlobalOp::print(OpAsmPrinter &p) {
2431 if (getConstant())
2432 p << "constant ";
2433 p.printSymbolName(getSymName());
2434 p << '(';
2435 if (auto value = getValueOrNull())
2436 p.printAttribute(value);
2437 p << ')';
2438 if (auto comdat = getComdat())
2439 p << " comdat(" << *comdat << ')';
2440
2441 // Note that the alignment attribute is printed using the
2442 // default syntax here, even though it is an inherent attribute
2443 // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2444 p.printOptionalAttrDict((*this)->getAttrs(),
2445 {SymbolTable::getSymbolAttrName(),
2446 getGlobalTypeAttrName(), getConstantAttrName(),
2447 getValueAttrName(), getLinkageAttrName(),
2448 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2449 getVisibility_AttrName(), getComdatAttrName()});
2450
2451 // Print the trailing type unless it's a string global.
2452 if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
2453 return;
2454 p << " : " << getType();
2455
2456 Region &initializer = getInitializerRegion();
2457 if (!initializer.empty()) {
2458 p << ' ';
2459 p.printRegion(initializer, /*printEntryBlockArgs=*/false);
2460 }
2461}
2462
2463static LogicalResult verifyComdat(Operation *op,
2464 std::optional<SymbolRefAttr> attr) {
2465 if (!attr)
2466 return success();
2467
2468 auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
2469 if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
2470 return op->emitError() << "expected comdat symbol";
2471
2472 return success();
2473}
2474
2475static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2477 // Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
2478 // block to be removed by canonicalizer's region simplify pass, which needs to
2479 // be dialect aware to allow extra constraints to be described.
2480 WalkResult res = funcOp.walk([&](BlockTagOp blockTagOp) {
2481 if (blockTags.contains(blockTagOp.getTag())) {
2482 blockTagOp.emitError()
2483 << "duplicate block tag '" << blockTagOp.getTag().getId()
2484 << "' in the same function: ";
2485 return WalkResult::interrupt();
2486 }
2487 blockTags.insert(blockTagOp.getTag());
2488 return WalkResult::advance();
2489 });
2490
2491 return failure(res.wasInterrupted());
2492}
2493
2494/// Parse common attributes that might show up in the same order in both
2495/// GlobalOp and AliasOp.
2496template <typename OpType>
2497static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser,
2499 MLIRContext *ctx = parser.getContext();
2500 // Parse optional linkage, default to External.
2501 result.addAttribute(
2502 OpType::getLinkageAttrName(result.name),
2503 LLVM::LinkageAttr::get(ctx, parseOptionalLLVMKeyword<Linkage>(
2504 parser, LLVM::Linkage::External)));
2505
2506 // Parse optional visibility, default to Default.
2507 result.addAttribute(OpType::getVisibility_AttrName(result.name),
2510 parser, LLVM::Visibility::Default)));
2511
2512 if (succeeded(parser.parseOptionalKeyword("thread_local")))
2513 result.addAttribute(OpType::getThreadLocal_AttrName(result.name),
2514 parser.getBuilder().getUnitAttr());
2515
2516 // Parse optional UnnamedAddr, default to None.
2517 result.addAttribute(OpType::getUnnamedAddrAttrName(result.name),
2520 parser, LLVM::UnnamedAddr::None)));
2521
2522 return success();
2523}
2524
2525// operation ::= `llvm.mlir.global` linkage? visibility?
2526// (`unnamed_addr` | `local_unnamed_addr`)?
2527// `thread_local`? `constant`? `@` identifier
2528// `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2529// attribute-list? (`:` type)? region?
2530//
2531// The type can be omitted for string attributes, in which case it will be
2532// inferred from the value of the string as [strlen(value) x i8].
2533ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2534 // Call into common parsing between GlobalOp and AliasOp.
2536 return failure();
2537
2538 if (succeeded(parser.parseOptionalKeyword("constant")))
2539 result.addAttribute(getConstantAttrName(result.name),
2540 parser.getBuilder().getUnitAttr());
2541
2542 StringAttr name;
2543 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2544 result.attributes) ||
2545 parser.parseLParen())
2546 return failure();
2547
2548 Attribute value;
2549 if (parser.parseOptionalRParen()) {
2550 if (parser.parseAttribute(value, getValueAttrName(result.name),
2551 result.attributes) ||
2552 parser.parseRParen())
2553 return failure();
2554 }
2555
2556 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2557 SymbolRefAttr comdat;
2558 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2559 parser.parseRParen())
2560 return failure();
2561
2562 result.addAttribute(getComdatAttrName(result.name), comdat);
2563 }
2564
2566 if (parser.parseOptionalAttrDict(result.attributes) ||
2567 parser.parseOptionalColonTypeList(types))
2568 return failure();
2569
2570 if (types.size() > 1)
2571 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2572
2573 Region &initRegion = *result.addRegion();
2574 if (types.empty()) {
2575 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2576 MLIRContext *context = parser.getContext();
2577 auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2578 strAttr.getValue().size());
2579 types.push_back(arrayType);
2580 } else {
2581 return parser.emitError(parser.getNameLoc(),
2582 "type can only be omitted for string globals");
2583 }
2584 } else {
2585 OptionalParseResult parseResult =
2586 parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2587 /*argTypes=*/{});
2588 if (parseResult.has_value() && failed(*parseResult))
2589 return failure();
2590 }
2591
2592 result.addAttribute(getGlobalTypeAttrName(result.name),
2593 TypeAttr::get(types[0]));
2594 return success();
2595}
2596
2597static bool isZeroAttribute(Attribute value) {
2598 if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2599 return intValue.getValue().isZero();
2600 if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2601 return fpValue.getValue().isZero();
2602 if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
2603 return isZeroAttribute(splatValue.getSplatValue<Attribute>());
2604 if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2605 return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2606 if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2607 return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2608 return false;
2609}
2610
2611LogicalResult GlobalOp::verify() {
2612 bool validType = isCompatibleOuterType(getType())
2613 ? !llvm::isa<LLVMVoidType, TokenType, LLVMMetadataType,
2614 LLVMLabelType>(getType())
2615 : llvm::isa<PointerElementTypeInterface>(getType());
2616 if (!validType)
2617 return emitOpError(
2618 "expects type to be a valid element type for an LLVM global");
2619 if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2620 return emitOpError("must appear at the module level");
2621
2622 if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2623 auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2624 IntegerType elementType =
2625 type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2626 if (!elementType || elementType.getWidth() != 8 ||
2627 type.getNumElements() != strAttr.getValue().size())
2628 return emitOpError(
2629 "requires an i8 array type of the length equal to that of the string "
2630 "attribute");
2631 }
2632
2633 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2634 if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2635 return emitOpError()
2636 << "this target extension type cannot be used in a global";
2637
2638 if (Attribute value = getValueOrNull())
2639 return emitOpError() << "global with target extension type can only be "
2640 "initialized with zero-initializer";
2641 }
2642
2643 if (getLinkage() == Linkage::Common) {
2644 if (Attribute value = getValueOrNull()) {
2645 if (!isZeroAttribute(value)) {
2646 return emitOpError()
2647 << "expected zero value for '"
2648 << stringifyLinkage(Linkage::Common) << "' linkage";
2649 }
2650 }
2651 }
2652
2653 if (getLinkage() == Linkage::Appending) {
2654 if (!llvm::isa<LLVMArrayType>(getType())) {
2655 return emitOpError() << "expected array type for '"
2656 << stringifyLinkage(Linkage::Appending)
2657 << "' linkage";
2658 }
2659 }
2660
2661 if (failed(verifyComdat(*this, getComdat())))
2662 return failure();
2663
2664 std::optional<uint64_t> alignAttr = getAlignment();
2665 if (alignAttr.has_value()) {
2666 uint64_t value = alignAttr.value();
2667 if (!llvm::isPowerOf2_64(value))
2668 return emitError() << "alignment attribute is not a power of 2";
2669 }
2670
2671 return success();
2672}
2673
2674LogicalResult GlobalOp::verifyRegions() {
2675 if (Block *b = getInitializerBlock()) {
2676 ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2677 if (ret.operand_type_begin() == ret.operand_type_end())
2678 return emitOpError("initializer region cannot return void");
2679 if (*ret.operand_type_begin() != getType())
2680 return emitOpError("initializer region type ")
2681 << *ret.operand_type_begin() << " does not match global type "
2682 << getType();
2683
2684 for (Operation &op : *b) {
2685 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2686 if (!iface || !iface.hasNoEffect())
2687 return op.emitError()
2688 << "ops with side effects not allowed in global initializers";
2689 }
2690
2691 if (getValueOrNull())
2692 return emitOpError("cannot have both initializer value and region");
2693 }
2694
2695 return success();
2696}
2697
2698//===----------------------------------------------------------------------===//
2699// LLVM::GlobalCtorsOp
2700//===----------------------------------------------------------------------===//
2701
2702static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
2703 if (data.empty())
2704 return success();
2705
2706 if (llvm::all_of(data.getAsRange<Attribute>(), [](Attribute v) {
2707 return isa<FlatSymbolRefAttr, ZeroAttr>(v);
2708 }))
2709 return success();
2710 return op->emitError("data element must be symbol or #llvm.zero");
2711}
2712
2713LogicalResult
2714GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2715 for (Attribute ctor : getCtors()) {
2716 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2717 symbolTable)))
2718 return failure();
2719 }
2720 return success();
2721}
2722
2723LogicalResult GlobalCtorsOp::verify() {
2724 if (checkGlobalXtorData(*this, getData()).failed())
2725 return failure();
2726
2727 if (getCtors().size() == getPriorities().size() &&
2728 getCtors().size() == getData().size())
2729 return success();
2730 return emitError(
2731 "ctors, priorities, and data must have the same number of elements");
2732}
2733
2734//===----------------------------------------------------------------------===//
2735// LLVM::GlobalDtorsOp
2736//===----------------------------------------------------------------------===//
2737
2738LogicalResult
2739GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2740 for (Attribute dtor : getDtors()) {
2741 if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2742 symbolTable)))
2743 return failure();
2744 }
2745 return success();
2746}
2747
2748LogicalResult GlobalDtorsOp::verify() {
2749 if (checkGlobalXtorData(*this, getData()).failed())
2750 return failure();
2751
2752 if (getDtors().size() == getPriorities().size() &&
2753 getDtors().size() == getData().size())
2754 return success();
2755 return emitError(
2756 "dtors, priorities, and data must have the same number of elements");
2757}
2758
2759//===----------------------------------------------------------------------===//
2760// Builder, printer and verifier for LLVM::AliasOp.
2761//===----------------------------------------------------------------------===//
2762
2763void AliasOp::build(OpBuilder &builder, OperationState &result, Type type,
2764 Linkage linkage, StringRef name, bool dsoLocal,
2765 bool threadLocal, ArrayRef<NamedAttribute> attrs) {
2766 result.addAttribute(getSymNameAttrName(result.name),
2767 builder.getStringAttr(name));
2768 result.addAttribute(getAliasTypeAttrName(result.name), TypeAttr::get(type));
2769 if (dsoLocal)
2770 result.addAttribute(getDsoLocalAttrName(result.name),
2771 builder.getUnitAttr());
2772 if (threadLocal)
2773 result.addAttribute(getThreadLocal_AttrName(result.name),
2774 builder.getUnitAttr());
2775
2776 result.addAttribute(getLinkageAttrName(result.name),
2777 LinkageAttr::get(builder.getContext(), linkage));
2778 result.attributes.append(attrs.begin(), attrs.end());
2779
2780 result.addRegion();
2781}
2782
2783void AliasOp::print(OpAsmPrinter &p) {
2785
2786 p.printSymbolName(getSymName());
2787 p.printOptionalAttrDict((*this)->getAttrs(),
2788 {SymbolTable::getSymbolAttrName(),
2789 getAliasTypeAttrName(), getLinkageAttrName(),
2790 getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2791 getVisibility_AttrName()});
2792
2793 // Print the trailing type.
2794 p << " : " << getType() << ' ';
2795 // Print the initializer region.
2796 p.printRegion(getInitializerRegion(), /*printEntryBlockArgs=*/false);
2797}
2798
2799// operation ::= `llvm.mlir.alias` linkage? visibility?
2800// (`unnamed_addr` | `local_unnamed_addr`)?
2801// `thread_local`? `@` identifier
2802// `(` attribute? `)`
2803// attribute-list? `:` type region
2804//
2805ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
2806 // Call into common parsing between GlobalOp and AliasOp.
2808 return failure();
2809
2810 StringAttr name;
2811 if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2812 result.attributes))
2813 return failure();
2814
2816 if (parser.parseOptionalAttrDict(result.attributes) ||
2817 parser.parseOptionalColonTypeList(types))
2818 return failure();
2819
2820 if (types.size() > 1)
2821 return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2822
2823 Region &initRegion = *result.addRegion();
2824 if (parser.parseRegion(initRegion).failed())
2825 return failure();
2826
2827 result.addAttribute(getAliasTypeAttrName(result.name),
2828 TypeAttr::get(types[0]));
2829 return success();
2830}
2831
2832LogicalResult AliasOp::verify() {
2833 bool validType = isCompatibleOuterType(getType())
2834 ? !llvm::isa<LLVMVoidType, TokenType, LLVMMetadataType,
2835 LLVMLabelType>(getType())
2836 : llvm::isa<PointerElementTypeInterface>(getType());
2837 if (!validType)
2838 return emitOpError(
2839 "expects type to be a valid element type for an LLVM global alias");
2840
2841 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2842 switch (getLinkage()) {
2843 case Linkage::External:
2844 case Linkage::Internal:
2845 case Linkage::Private:
2846 case Linkage::Weak:
2847 case Linkage::WeakODR:
2848 case Linkage::Linkonce:
2849 case Linkage::LinkonceODR:
2850 case Linkage::AvailableExternally:
2851 break;
2852 default:
2853 return emitOpError()
2854 << "'" << stringifyLinkage(getLinkage())
2855 << "' linkage not supported in aliases, available options: private, "
2856 "internal, linkonce, weak, linkonce_odr, weak_odr, external or "
2857 "available_externally";
2858 }
2859
2860 return success();
2861}
2862
2863LogicalResult AliasOp::verifyRegions() {
2864 Block &b = getInitializerBlock();
2865 auto ret = cast<ReturnOp>(b.getTerminator());
2866 if (ret.getNumOperands() == 0 ||
2867 !isa<LLVM::LLVMPointerType>(ret.getOperand(0).getType()))
2868 return emitOpError("initializer region must always return a pointer");
2869
2870 for (Operation &op : b) {
2871 auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2872 if (!iface || !iface.hasNoEffect())
2873 return op.emitError()
2874 << "ops with side effects are not allowed in alias initializers";
2875 }
2876
2877 return success();
2878}
2879
2880unsigned AliasOp::getAddrSpace() {
2881 Block &initializer = getInitializerBlock();
2882 auto ret = cast<ReturnOp>(initializer.getTerminator());
2883 auto ptrTy = cast<LLVMPointerType>(ret.getOperand(0).getType());
2884 return ptrTy.getAddressSpace();
2885}
2886
2887//===----------------------------------------------------------------------===//
2888// IFuncOp
2889//===----------------------------------------------------------------------===//
2890
2891void IFuncOp::build(OpBuilder &builder, OperationState &result, StringRef name,
2892 Type iFuncType, StringRef resolverName, Type resolverType,
2893 Linkage linkage, LLVM::Visibility visibility) {
2894 return build(builder, result, name, iFuncType, resolverName, resolverType,
2895 linkage, /*dso_local=*/false, /*address_space=*/0,
2896 UnnamedAddr::None, visibility);
2897}
2898
2899LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2900 Operation *symbol =
2901 symbolTable.lookupSymbolIn(parentLLVMModule(*this), getResolverAttr());
2902 // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2903 auto resolver = dyn_cast<LLVMFuncOp>(symbol);
2904 auto alias = dyn_cast<AliasOp>(symbol);
2905 while (alias) {
2906 Block &initBlock = alias.getInitializerBlock();
2907 auto returnOp = cast<ReturnOp>(initBlock.getTerminator());
2908 auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>();
2909 // FIXME: This is a best effort solution. The AliasOp body might be more
2910 // complex and in that case we bail out with success. To completely match
2911 // the LLVM IR logic it would be necessary to implement proper alias and
2912 // cast stripping.
2913 if (!addrOp)
2914 return success();
2915 resolver = addrOp.getFunction(symbolTable);
2916 alias = addrOp.getAlias(symbolTable);
2917 }
2918 if (!resolver)
2919 return emitOpError("must have a function resolver");
2920 Linkage linkage = resolver.getLinkage();
2921 if (resolver.isExternal() || linkage == Linkage::AvailableExternally)
2922 return emitOpError("resolver must be a definition");
2923 if (!isa<LLVMPointerType>(resolver.getFunctionType().getReturnType()))
2924 return emitOpError("resolver must return a pointer");
2925 auto resolverPtr = dyn_cast<LLVMPointerType>(getResolverType());
2926 if (!resolverPtr || resolverPtr.getAddressSpace() != getAddressSpace())
2927 return emitOpError("resolver has incorrect type");
2928 return success();
2929}
2930
2931LogicalResult IFuncOp::verify() {
2932 switch (getLinkage()) {
2933 case Linkage::External:
2934 case Linkage::Internal:
2935 case Linkage::Private:
2936 case Linkage::Weak:
2937 case Linkage::WeakODR:
2938 case Linkage::Linkonce:
2939 case Linkage::LinkonceODR:
2940 break;
2941 default:
2942 return emitOpError() << "'" << stringifyLinkage(getLinkage())
2943 << "' linkage not supported in ifuncs, available "
2944 "options: private, internal, linkonce, weak, "
2945 "linkonce_odr, weak_odr, or external linkage";
2946 }
2947 return success();
2948}
2949
2950//===----------------------------------------------------------------------===//
2951// ShuffleVectorOp
2952//===----------------------------------------------------------------------===//
2953
2954void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2955 Value v2, DenseI32ArrayAttr mask,
2957 auto containerType = v1.getType();
2958 auto vType = LLVM::getVectorType(
2959 cast<VectorType>(containerType).getElementType(), mask.size(),
2960 LLVM::isScalableVectorType(containerType));
2961 build(builder, state, vType, v1, v2, mask);
2962 state.addAttributes(attrs);
2963}
2964
2965void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2966 Value v2, ArrayRef<int32_t> mask) {
2967 build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2968}
2969
2970/// Build the result type of a shuffle vector operation.
2971static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2972 Type &resType, DenseI32ArrayAttr mask) {
2973 if (!LLVM::isCompatibleVectorType(v1Type))
2974 return parser.emitError(parser.getCurrentLocation(),
2975 "expected an LLVM compatible vector type");
2976 resType =
2977 LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
2978 mask.size(), LLVM::isScalableVectorType(v1Type));
2979 return success();
2980}
2981
2982/// Nothing to do when the result type is inferred.
2983static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2984 Type resType, DenseI32ArrayAttr mask) {}
2985
2986LogicalResult ShuffleVectorOp::verify() {
2987 if (LLVM::isScalableVectorType(getV1().getType()) &&
2988 llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2989 return emitOpError("expected a splat operation for scalable vectors");
2990 return success();
2991}
2992
2993// Folding for shufflevector op when v1 is single element 1D vector
2994// and the mask is a single zero. OpFoldResult will be v1 in this case.
2995OpFoldResult ShuffleVectorOp::fold(FoldAdaptor adaptor) {
2996 // Check if operand 0 is a single element vector.
2997 auto vecType = llvm::dyn_cast<VectorType>(getV1().getType());
2998 if (!vecType || vecType.getRank() != 1 || vecType.getNumElements() != 1)
2999 return {};
3000 // Check if the mask is a single zero.
3001 // Note: The mask is guaranteed to be non-empty.
3002 if (getMask().size() != 1 || getMask()[0] != 0)
3003 return {};
3004 return getV1();
3005}
3006
3007//===----------------------------------------------------------------------===//
3008// Implementations for LLVM::LLVMFuncOp.
3009//===----------------------------------------------------------------------===//
3010
3011// Add the entry block to the function.
3012Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
3013 assert(empty() && "function already has an entry block");
3014 OpBuilder::InsertionGuard g(builder);
3015 Block *entry = builder.createBlock(&getBody());
3016
3017 // FIXME: Allow passing in proper locations for the entry arguments.
3018 LLVMFunctionType type = getFunctionType();
3019 for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
3020 entry->addArgument(type.getParamType(i), getLoc());
3021 return entry;
3022}
3023
3024void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
3025 StringRef name, Type type, LLVM::Linkage linkage,
3026 bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
3028 ArrayRef<DictionaryAttr> argAttrs,
3029 std::optional<uint64_t> functionEntryCount) {
3030 result.addRegion();
3032 builder.getStringAttr(name));
3033 result.addAttribute(getFunctionTypeAttrName(result.name),
3034 TypeAttr::get(type));
3035 result.addAttribute(getLinkageAttrName(result.name),
3036 LinkageAttr::get(builder.getContext(), linkage));
3037 result.addAttribute(getCConvAttrName(result.name),
3038 CConvAttr::get(builder.getContext(), cconv));
3039 result.attributes.append(attrs.begin(), attrs.end());
3040 if (dsoLocal)
3041 result.addAttribute(getDsoLocalAttrName(result.name),
3042 builder.getUnitAttr());
3043 if (comdat)
3044 result.addAttribute(getComdatAttrName(result.name), comdat);
3045 if (functionEntryCount)
3046 result.addAttribute(getFunctionEntryCountAttrName(result.name),
3047 builder.getI64IntegerAttr(functionEntryCount.value()));
3048#ifndef NDEBUG
3049 std::optional<NamedAttribute> duplicate = result.attributes.findDuplicate();
3050 if (duplicate.has_value()) {
3051 llvm::report_fatal_error(
3052 Twine("LLVMFuncOp propagated an attribute that is meant "
3053 "to be constructed by the builder: ") +
3054 duplicate->getName().str());
3055 }
3056#endif
3057 if (argAttrs.empty())
3058 return;
3059
3060 assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
3061 "expected as many argument attribute lists as arguments");
3063 builder, result, argAttrs, /*resultAttrs=*/{},
3064 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3065}
3066
3067// Builds an LLVM function type from the given lists of input and output types.
3068// Returns a null type if any of the types provided are non-LLVM types, or if
3069// there is more than one output type.
3070static Type
3072 ArrayRef<Type> outputs,
3074 Builder &b = parser.getBuilder();
3075 if (outputs.size() > 1) {
3076 parser.emitError(loc, "failed to construct function type: expected zero or "
3077 "one function result");
3078 return {};
3079 }
3080
3081 // Convert inputs to LLVM types, exit early on error.
3082 SmallVector<Type, 4> llvmInputs;
3083 for (auto t : inputs) {
3084 if (!isCompatibleType(t)) {
3085 parser.emitError(loc, "failed to construct function type: expected LLVM "
3086 "type for function arguments");
3087 return {};
3088 }
3089 llvmInputs.push_back(t);
3090 }
3091
3092 // No output is denoted as "void" in LLVM type system.
3093 Type llvmOutput =
3094 outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
3095 if (!isCompatibleType(llvmOutput)) {
3096 parser.emitError(loc, "failed to construct function type: expected LLVM "
3097 "type for function results")
3098 << llvmOutput;
3099 return {};
3100 }
3101 return LLVMFunctionType::get(llvmOutput, llvmInputs,
3102 variadicFlag.isVariadic());
3103}
3104
3105// Parses an LLVM function.
3106//
3107// operation ::= `llvm.func` linkage? cconv? function-signature
3108// (`comdat(` symbol-ref-id `)`)?
3109// function-attributes?
3110// function-body
3111//
3112ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
3113 // Default to external linkage if no keyword is provided.
3114 result.addAttribute(getLinkageAttrName(result.name),
3115 LinkageAttr::get(parser.getContext(),
3117 parser, LLVM::Linkage::External)));
3118
3119 // Parse optional visibility, default to Default.
3120 result.addAttribute(getVisibility_AttrName(result.name),
3123 parser, LLVM::Visibility::Default)));
3124
3125 // Parse optional UnnamedAddr, default to None.
3126 result.addAttribute(getUnnamedAddrAttrName(result.name),
3129 parser, LLVM::UnnamedAddr::None)));
3130
3131 // Default to C Calling Convention if no keyword is provided.
3132 result.addAttribute(
3133 getCConvAttrName(result.name),
3134 CConvAttr::get(parser.getContext(),
3135 parseOptionalLLVMKeyword<CConv>(parser, LLVM::CConv::C)));
3136
3137 StringAttr nameAttr;
3139 SmallVector<DictionaryAttr> resultAttrs;
3140 SmallVector<Type> resultTypes;
3141 bool isVariadic;
3142
3143 auto signatureLocation = parser.getCurrentLocation();
3144 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3145 result.attributes) ||
3147 parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
3148 resultAttrs))
3149 return failure();
3150
3151 SmallVector<Type> argTypes;
3152 for (auto &arg : entryArgs)
3153 argTypes.push_back(arg.type);
3154 auto type =
3155 buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
3157 if (!type)
3158 return failure();
3159 result.addAttribute(getFunctionTypeAttrName(result.name),
3160 TypeAttr::get(type));
3161
3162 if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
3163 int64_t minRange, maxRange;
3164 if (parser.parseLParen() || parser.parseInteger(minRange) ||
3165 parser.parseComma() || parser.parseInteger(maxRange) ||
3166 parser.parseRParen())
3167 return failure();
3168 auto intTy = IntegerType::get(parser.getContext(), 32);
3169 result.addAttribute(
3170 getVscaleRangeAttrName(result.name),
3171 LLVM::VScaleRangeAttr::get(parser.getContext(),
3172 IntegerAttr::get(intTy, minRange),
3173 IntegerAttr::get(intTy, maxRange)));
3174 }
3175 // Parse the optional comdat selector.
3176 if (succeeded(parser.parseOptionalKeyword("comdat"))) {
3177 SymbolRefAttr comdat;
3178 if (parser.parseLParen() || parser.parseAttribute(comdat) ||
3179 parser.parseRParen())
3180 return failure();
3181
3182 result.addAttribute(getComdatAttrName(result.name), comdat);
3183 }
3184
3185 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
3186 return failure();
3188 parser.getBuilder(), result, entryArgs, resultAttrs,
3189 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3190
3191 auto *body = result.addRegion();
3192 OptionalParseResult parseResult =
3193 parser.parseOptionalRegion(*body, entryArgs);
3194 return failure(parseResult.has_value() && failed(*parseResult));
3195}
3196
3197// Print the LLVMFuncOp. Collects argument and result types and passes them to
3198// helper functions. Drops "void" result since it cannot be parsed back. Skips
3199// the external linkage since it is the default value.
3200void LLVMFuncOp::print(OpAsmPrinter &p) {
3201 p << ' ';
3202 if (getLinkage() != LLVM::Linkage::External)
3203 p << stringifyLinkage(getLinkage()) << ' ';
3204 StringRef visibility = stringifyVisibility(getVisibility_());
3205 if (!visibility.empty())
3206 p << visibility << ' ';
3207 if (auto unnamedAddr = getUnnamedAddr()) {
3208 StringRef str = stringifyUnnamedAddr(*unnamedAddr);
3209 if (!str.empty())
3210 p << str << ' ';
3211 }
3212 if (getCConv() != LLVM::CConv::C)
3213 p << stringifyCConv(getCConv()) << ' ';
3214
3215 p.printSymbolName(getName());
3216
3217 LLVMFunctionType fnType = getFunctionType();
3218 SmallVector<Type, 8> argTypes;
3219 SmallVector<Type, 1> resTypes;
3220 argTypes.reserve(fnType.getNumParams());
3221 for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
3222 argTypes.push_back(fnType.getParamType(i));
3223
3224 Type returnType = fnType.getReturnType();
3225 if (!llvm::isa<LLVMVoidType>(returnType))
3226 resTypes.push_back(returnType);
3227
3229 isVarArg(), resTypes);
3230
3231 // Print vscale range if present
3232 if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
3233 p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
3234 << vscale->getMaxRange().getInt() << ')';
3235
3236 // Print the optional comdat selector.
3237 if (auto comdat = getComdat())
3238 p << " comdat(" << *comdat << ')';
3239
3241 p, *this,
3242 {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
3243 getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
3244 getComdatAttrName(), getUnnamedAddrAttrName(),
3245 getVscaleRangeAttrName()});
3246
3247 // Print the body if this is not an external function.
3248 Region &body = getBody();
3249 if (!body.empty()) {
3250 p << ' ';
3251 p.printRegion(body, /*printEntryBlockArgs=*/false,
3252 /*printBlockTerminators=*/true);
3253 }
3254}
3255
3256// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3257// - functions don't have 'common' linkage
3258// - external functions have 'external' or 'extern_weak' linkage;
3259// - vararg is (currently) only supported for external functions;
3260LogicalResult LLVMFuncOp::verify() {
3261 if (getLinkage() == LLVM::Linkage::Common)
3262 return emitOpError() << "functions cannot have '"
3263 << stringifyLinkage(LLVM::Linkage::Common)
3264 << "' linkage";
3265
3266 if (failed(verifyComdat(*this, getComdat())))
3267 return failure();
3268
3269 if (isExternal()) {
3270 if (getLinkage() != LLVM::Linkage::External &&
3271 getLinkage() != LLVM::Linkage::ExternWeak)
3272 return emitOpError() << "external functions must have '"
3273 << stringifyLinkage(LLVM::Linkage::External)
3274 << "' or '"
3275 << stringifyLinkage(LLVM::Linkage::ExternWeak)
3276 << "' linkage";
3277 return success();
3278 }
3279
3280 // In LLVM IR, these attributes are composed by convention, not by design.
3281 if (isNoInline() && isAlwaysInline())
3282 return emitError("no_inline and always_inline attributes are incompatible");
3283
3284 if (isOptimizeNone() && !isNoInline())
3285 return emitOpError("with optimize_none must also be no_inline");
3286
3287 Type landingpadResultTy;
3288 StringRef diagnosticMessage;
3289 bool isLandingpadTypeConsistent =
3290 !walk([&](Operation *op) {
3291 const auto checkType = [&](Type type, StringRef errorMessage) {
3292 if (!landingpadResultTy) {
3293 landingpadResultTy = type;
3294 return WalkResult::advance();
3295 }
3296 if (landingpadResultTy != type) {
3297 diagnosticMessage = errorMessage;
3298 return WalkResult::interrupt();
3299 }
3300 return WalkResult::advance();
3301 };
3303 .Case([&](LandingpadOp landingpad) {
3304 constexpr StringLiteral errorMessage =
3305 "'llvm.landingpad' should have a consistent result type "
3306 "inside a function";
3307 return checkType(landingpad.getType(), errorMessage);
3308 })
3309 .Case([&](ResumeOp resume) {
3310 constexpr StringLiteral errorMessage =
3311 "'llvm.resume' should have a consistent input type inside a "
3312 "function";
3313 return checkType(resume.getValue().getType(), errorMessage);
3314 })
3315 .Default([](auto) { return WalkResult::skip(); });
3316 }).wasInterrupted();
3317 if (!isLandingpadTypeConsistent) {
3318 assert(!diagnosticMessage.empty() &&
3319 "Expecting a non-empty diagnostic message");
3320 return emitError(diagnosticMessage);
3321 }
3322
3323 if (failed(verifyBlockTags(*this)))
3324 return failure();
3325
3326 return success();
3327}
3328
3329/// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3330/// - entry block arguments are of LLVM types.
3331LogicalResult LLVMFuncOp::verifyRegions() {
3332 if (isExternal())
3333 return success();
3334
3335 unsigned numArguments = getFunctionType().getNumParams();
3336 Block &entryBlock = front();
3337 for (unsigned i = 0; i < numArguments; ++i) {
3338 Type argType = entryBlock.getArgument(i).getType();
3339 if (!isCompatibleType(argType))
3340 return emitOpError("entry block argument #")
3341 << i << " is not of LLVM type";
3342 }
3343
3344 return success();
3345}
3346
3347Region *LLVMFuncOp::getCallableRegion() {
3348 if (isExternal())
3349 return nullptr;
3350 return &getBody();
3351}
3352
3353//===----------------------------------------------------------------------===//
3354// UndefOp.
3355//===----------------------------------------------------------------------===//
3356
3357/// Fold an undef operation to a dedicated undef attribute.
3358OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
3359 return LLVM::UndefAttr::get(getContext());
3360}
3361
3362//===----------------------------------------------------------------------===//
3363// PoisonOp.
3364//===----------------------------------------------------------------------===//
3365
3366/// Fold a poison operation to a dedicated poison attribute.
3367OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
3368 return LLVM::PoisonAttr::get(getContext());
3369}
3370
3371//===----------------------------------------------------------------------===//
3372// MetadataAsValueOp.
3373//===----------------------------------------------------------------------===//
3374
3375/// Fold a metadata-as-value operation to its wrapped metadata attribute.
3376OpFoldResult LLVM::MetadataAsValueOp::fold(FoldAdaptor) {
3377 return getMetadataAttr();
3378}
3379
3380//===----------------------------------------------------------------------===//
3381// ZeroOp.
3382//===----------------------------------------------------------------------===//
3383
3384LogicalResult LLVM::ZeroOp::verify() {
3385 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3386 if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
3387 return emitOpError()
3388 << "target extension type does not support zero-initializer";
3389
3390 return success();
3391}
3392
3393/// Fold a zero operation to a builtin zero attribute when possible and fall
3394/// back to a dedicated zero attribute.
3395OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
3397 if (result)
3398 return result;
3399 return LLVM::ZeroAttr::get(getContext());
3400}
3401
3402//===----------------------------------------------------------------------===//
3403// ConstantOp.
3404//===----------------------------------------------------------------------===//
3405
3406/// Compute the total number of elements in the given type, also taking into
3407/// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
3408/// Everything else is treated as a scalar.
3410 if (auto vecType = dyn_cast<VectorType>(t)) {
3411 assert(!vecType.isScalable() &&
3412 "number of elements of a scalable vector type is unknown");
3413 return vecType.getNumElements() * getNumElements(vecType.getElementType());
3414 }
3415 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3416 return arrayType.getNumElements() *
3417 getNumElements(arrayType.getElementType());
3418 return 1;
3419}
3420
3421/// Determine the element type of `type`. Supported types are `VectorType`,
3422/// `TensorType`, and `LLVMArrayType`. Everything else is treated as a scalar.
3424 while (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(type))
3425 type = arrayType.getElementType();
3426 if (auto vecType = dyn_cast<VectorType>(type))
3427 return vecType.getElementType();
3428 if (auto tenType = dyn_cast<TensorType>(type))
3429 return tenType.getElementType();
3430 return type;
3431}
3432
3433/// Check if the given type is a scalable vector type or a vector/array type
3434/// that contains a nested scalable vector type.
3436 if (auto vecType = dyn_cast<VectorType>(t)) {
3437 if (vecType.isScalable())
3438 return true;
3439 return hasScalableVectorType(vecType.getElementType());
3440 }
3441 if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3442 return hasScalableVectorType(arrayType.getElementType());
3443 return false;
3444}
3445
3446/// Verifies the constant array represented by `arrayAttr` matches the provided
3447/// `arrayType`.
3448static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3449 LLVM::LLVMArrayType arrayType,
3450 ArrayAttr arrayAttr, int dim) {
3451 if (arrayType.getNumElements() != arrayAttr.size())
3452 return op.emitOpError()
3453 << "array attribute size does not match array type size in "
3454 "dimension "
3455 << dim << ": " << arrayAttr.size() << " vs. "
3456 << arrayType.getNumElements();
3457
3458 llvm::DenseSet<Attribute> elementsVerified;
3459
3460 // Recursively verify sub-dimensions for multidimensional arrays.
3461 if (auto subArrayType =
3462 dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
3463 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
3464 if (elementsVerified.insert(elementAttr).second) {
3465 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3466 continue;
3467 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3468 if (!subArrayAttr)
3469 return op.emitOpError()
3470 << "nested attribute for sub-array in dimension " << dim
3471 << " at index " << idx
3472 << " must be a zero, or undef, or array attribute";
3473 if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
3474 dim + 1)))
3475 return failure();
3476 }
3477 return success();
3478 }
3479
3480 // Forbid usages of ArrayAttr for simple array types that should use
3481 // DenseElementsAttr instead. Note that there would be a use case for such
3482 // array types when one element value is obtained via a ptr-to-int conversion
3483 // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3484 // user needs this so far, and it seems better to avoid people misusing the
3485 // ArrayAttr for simple types.
3486 Type elementType = arrayType.getElementType();
3487 if (isa<LLVM::LLVMPointerType>(elementType)) {
3488 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3489 if (isa<FlatSymbolRefAttr, LLVM::ZeroAttr, LLVM::UndefAttr,
3490 LLVM::PoisonAttr>(elementAttr))
3491 continue;
3492 return op.emitOpError()
3493 << "pointer array element at index " << idx
3494 << " must be a flat symbol reference, zero, undef, or poison";
3495 }
3496 return success();
3497 }
3498 auto structType = dyn_cast<LLVM::LLVMStructType>(elementType);
3499 if (!structType)
3500 return op.emitOpError() << "for array with an array attribute must have a "
3501 "struct element type";
3502
3503 // Shallow verification that leaf attributes are appropriate as struct initial
3504 // value.
3505 size_t numStructElements = structType.getBody().size();
3506 for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3507 if (elementsVerified.insert(elementAttr).second) {
3508 if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3509 continue;
3510 auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3511 if (!subArrayAttr)
3512 return op.emitOpError()
3513 << "nested attribute for struct element at index " << idx
3514 << " must be a zero, or undef, or array attribute";
3515 if (subArrayAttr.size() != numStructElements)
3516 return op.emitOpError()
3517 << "nested array attribute size for struct element at index "
3518 << idx << " must match struct size: " << subArrayAttr.size()
3519 << " vs. " << numStructElements;
3520 }
3521 }
3522
3523 return success();
3524}
3525
3526LogicalResult LLVM::ConstantOp::verify() {
3527 if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
3528 auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
3529 if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
3530 !arrayType.getElementType().isInteger(8)) {
3531 return emitOpError() << "expected array type of "
3532 << sAttr.getValue().size()
3533 << " i8 elements for the string constant";
3534 }
3535 return success();
3536 }
3537 if (auto structType = dyn_cast<LLVMStructType>(getType())) {
3538 auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
3539 if (!arrayAttr)
3540 return emitOpError() << "expected array attribute for struct type";
3541
3542 ArrayRef<Type> elementTypes = structType.getBody();
3543 if (arrayAttr.size() != elementTypes.size()) {
3544 return emitOpError() << "expected array attribute of size "
3545 << elementTypes.size();
3546 }
3547 for (auto [i, attr, type] : llvm::enumerate(arrayAttr, elementTypes)) {
3548 if (!type.isSignlessIntOrIndexOrFloat()) {
3549 return emitOpError() << "expected struct element types to be floating "
3550 "point type or integer type";
3551 }
3552 if (!isa<FloatAttr, IntegerAttr>(attr)) {
3553 return emitOpError() << "expected element of array attribute to be "
3554 "floating point or integer";
3555 }
3556 if (cast<TypedAttr>(attr).getType() != type)
3557 return emitOpError()
3558 << "struct element at index " << i << " is of wrong type";
3559 }
3560
3561 return success();
3562 }
3563 if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3564 return emitOpError() << "does not support target extension type.";
3565
3566 // Check that an attribute whose element type has floating point semantics
3567 // `attributeFloatSemantics` is compatible with a type whose element type
3568 // is `constantElementType`.
3569 //
3570 // Requirement is that either
3571 // 1) They have identical floating point types.
3572 // 2) `constantElementType` is an integer type of the same width as the float
3573 // attribute. This is to support builtin MLIR float types without LLVM
3574 // equivalents, see comments in getLLVMConstant for more details.
3575 auto verifyFloatSemantics =
3576 [this](const llvm::fltSemantics &attributeFloatSemantics,
3577 Type constantElementType) -> LogicalResult {
3578 if (auto floatType = dyn_cast<FloatType>(constantElementType)) {
3579 if (&floatType.getFloatSemantics() != &attributeFloatSemantics) {
3580 return emitOpError()
3581 << "attribute and type have different float semantics";
3582 }
3583 return success();
3584 }
3585 unsigned floatWidth = APFloat::getSizeInBits(attributeFloatSemantics);
3586 if (isa<IntegerType>(constantElementType)) {
3587 if (!constantElementType.isInteger(floatWidth))
3588 return emitOpError() << "expected integer type of width " << floatWidth;
3589
3590 return success();
3591 }
3592 return success();
3593 };
3594
3595 // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3596 if (isa<IntegerAttr>(getValue())) {
3597 if (!llvm::isa<IntegerType>(getType()))
3598 return emitOpError() << "expected integer type";
3599 } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3600 return verifyFloatSemantics(floatAttr.getValue().getSemantics(), getType());
3601 } else if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3603 // The exact number of elements of a scalable vector is unknown, so we
3604 // allow only splat attributes.
3605 auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
3606 if (!splatElementsAttr)
3607 return emitOpError()
3608 << "scalable vector type requires a splat attribute";
3609 return success();
3610 }
3611 if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
3612 return emitOpError() << "expected vector or array type";
3613
3614 // The number of elements of the attribute and the type must match.
3615 int64_t attrNumElements = elementsAttr.getNumElements();
3616 if (getNumElements(getType()) != attrNumElements) {
3617 return emitOpError()
3618 << "type and attribute have a different number of elements: "
3619 << getNumElements(getType()) << " vs. " << attrNumElements;
3620 }
3621
3622 Type attrElmType = getElementType(elementsAttr.getType());
3623 Type resultElmType = getElementType(getType());
3624 if (auto floatType = dyn_cast<FloatType>(attrElmType))
3625 return verifyFloatSemantics(floatType.getFloatSemantics(), resultElmType);
3626
3627 if (isa<IntegerType>(attrElmType) && !isa<IntegerType>(resultElmType)) {
3628 return emitOpError(
3629 "expected integer element type for integer elements attribute");
3630 }
3631 } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3632
3633 // The case where the constant is LLVMStructType has already been handled.
3634 auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
3635 if (!arrayType)
3636 return emitOpError()
3637 << "expected array or struct type for array attribute";
3638
3639 // When the attribute is an ArrayAttr, check that its nesting matches the
3640 // corresponding ArrayType or VectorType nesting.
3641 return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
3642 } else {
3643 return emitOpError()
3644 << "only supports integer, float, string or elements attributes";
3645 }
3646
3647 return success();
3648}
3649
3650bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
3651 // The value's type must be the same as the provided type.
3652 auto typedAttr = dyn_cast<TypedAttr>(value);
3653 if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
3654 return false;
3655 // The value's type must be an LLVM compatible type.
3656 if (!isCompatibleType(type))
3657 return false;
3658 // TODO: Add support for additional attributes kinds once needed.
3659 return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
3660}
3661
3662ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
3663 Type type, Location loc) {
3664 if (isBuildableWith(value, type))
3665 return LLVM::ConstantOp::create(builder, loc, cast<TypedAttr>(value));
3666 return nullptr;
3667}
3668
3669// Constant op constant-folds to its value.
3670OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
3671
3672//===----------------------------------------------------------------------===//
3673// AtomicRMWOp
3674//===----------------------------------------------------------------------===//
3675
3676void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3677 AtomicBinOp binOp, Value ptr, Value val,
3678 AtomicOrdering ordering, StringRef syncscope,
3679 unsigned alignment, bool isVolatile) {
3680 build(builder, state, val.getType(), binOp, ptr, val, ordering,
3681 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3682 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
3683 /*access_groups=*/nullptr,
3684 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3685}
3686
3687LogicalResult AtomicRMWOp::verify() {
3688 auto valType = getVal().getType();
3689 if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3690 getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax ||
3691 getBinOp() == AtomicBinOp::fminimum ||
3692 getBinOp() == AtomicBinOp::fmaximum ||
3693 getBinOp() == AtomicBinOp::fminimumnum ||
3694 getBinOp() == AtomicBinOp::fmaximumnum) {
3695 if (isCompatibleVectorType(valType)) {
3696 if (isScalableVectorType(valType))
3697 return emitOpError("expected LLVM IR fixed vector type");
3698 Type elemType = llvm::cast<VectorType>(valType).getElementType();
3699 if (!isCompatibleFloatingPointType(elemType))
3700 return emitOpError(
3701 "expected LLVM IR floating point type for vector element");
3702 } else if (!isCompatibleFloatingPointType(valType)) {
3703 return emitOpError("expected LLVM IR floating point type");
3704 }
3705 } else if (getBinOp() == AtomicBinOp::xchg) {
3706 DataLayout dataLayout = DataLayout::closest(*this);
3707 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3708 return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
3709 } else {
3710 auto intType = llvm::dyn_cast<IntegerType>(valType);
3711 unsigned intBitWidth = intType ? intType.getWidth() : 0;
3712 if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3713 intBitWidth != 64)
3714 return emitOpError("expected LLVM IR integer type");
3715 }
3716
3717 if (static_cast<unsigned>(getOrdering()) <
3718 static_cast<unsigned>(AtomicOrdering::monotonic))
3719 return emitOpError() << "expected at least '"
3720 << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3721 << "' ordering";
3722
3723 return success();
3724}
3725
3726//===----------------------------------------------------------------------===//
3727// AtomicCmpXchgOp
3728//===----------------------------------------------------------------------===//
3729
3730/// Returns an LLVM struct type that contains a value type and a boolean type.
3731static LLVMStructType getValAndBoolStructType(Type valType) {
3732 auto boolType = IntegerType::get(valType.getContext(), 1);
3733 return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
3734}
3735
3736void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3737 Value ptr, Value cmp, Value val,
3738 AtomicOrdering successOrdering,
3739 AtomicOrdering failureOrdering, StringRef syncscope,
3740 unsigned alignment, bool isWeak, bool isVolatile) {
3741 build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
3742 successOrdering, failureOrdering,
3743 !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3744 alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
3745 isVolatile, /*access_groups=*/nullptr,
3746 /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3747}
3748
3749LogicalResult AtomicCmpXchgOp::verify() {
3750 auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
3751 if (!ptrType)
3752 return emitOpError("expected LLVM IR pointer type for operand #0");
3753 auto valType = getVal().getType();
3754 DataLayout dataLayout = DataLayout::closest(*this);
3755 if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3756 return emitOpError("unexpected LLVM IR type");
3757 if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3758 getFailureOrdering() < AtomicOrdering::monotonic)
3759 return emitOpError("ordering must be at least 'monotonic'");
3760 if (getFailureOrdering() == AtomicOrdering::release ||
3761 getFailureOrdering() == AtomicOrdering::acq_rel)
3762 return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
3763 return success();
3764}
3765
3766//===----------------------------------------------------------------------===//
3767// FenceOp
3768//===----------------------------------------------------------------------===//
3769
3770void FenceOp::build(OpBuilder &builder, OperationState &state,
3771 AtomicOrdering ordering, StringRef syncscope) {
3772 build(builder, state, ordering,
3773 syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
3774}
3775
3776LogicalResult FenceOp::verify() {
3777 if (getOrdering() == AtomicOrdering::not_atomic ||
3778 getOrdering() == AtomicOrdering::unordered ||
3779 getOrdering() == AtomicOrdering::monotonic)
3780 return emitOpError("can be given only acquire, release, acq_rel, "
3781 "and seq_cst orderings");
3782 return success();
3783}
3784
3785//===----------------------------------------------------------------------===//
3786// Verifier for extension ops
3787//===----------------------------------------------------------------------===//
3788
3789/// Verifies that the given extension operation operates on consistent scalars
3790/// or vectors, and that the target width is larger than the input width.
3791template <class ExtOp>
3792static LogicalResult verifyExtOp(ExtOp op) {
3793 IntegerType inputType, outputType;
3794 if (isCompatibleVectorType(op.getArg().getType())) {
3795 if (!isCompatibleVectorType(op.getResult().getType()))
3796 return op.emitError(
3797 "input type is a vector but output type is an integer");
3798 if (getVectorNumElements(op.getArg().getType()) !=
3799 getVectorNumElements(op.getResult().getType()))
3800 return op.emitError("input and output vectors are of incompatible shape");
3801 // Because this is a CastOp, the element of vectors is guaranteed to be an
3802 // integer.
3803 inputType = cast<IntegerType>(
3804 cast<VectorType>(op.getArg().getType()).getElementType());
3805 outputType = cast<IntegerType>(
3806 cast<VectorType>(op.getResult().getType()).getElementType());
3807 } else {
3808 // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3809 // an integer.
3810 inputType = cast<IntegerType>(op.getArg().getType());
3811 outputType = dyn_cast<IntegerType>(op.getResult().getType());
3812 if (!outputType)
3813 return op.emitError(
3814 "input type is an integer but output type is a vector");
3815 }
3816
3817 if (outputType.getWidth() <= inputType.getWidth())
3818 return op.emitError("integer width of the output type is smaller or "
3819 "equal to the integer width of the input type");
3820 return success();
3821}
3822
3823//===----------------------------------------------------------------------===//
3824// ZExtOp
3825//===----------------------------------------------------------------------===//
3826
3827LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
3828
3829OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3830 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3831 if (!arg)
3832 return {};
3833
3834 size_t targetSize = cast<IntegerType>(getType()).getWidth();
3835 return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
3836}
3837
3838//===----------------------------------------------------------------------===//
3839// SExtOp
3840//===----------------------------------------------------------------------===//
3841
3842LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
3843
3844//===----------------------------------------------------------------------===//
3845// Folder and verifier for LLVM::BitcastOp
3846//===----------------------------------------------------------------------===//
3847
3848/// Folds a cast op that can be chained.
3849template <typename T>
3851 typename T::FoldAdaptor adaptor) {
3852 // cast(x : T0, T0) -> x
3853 if (castOp.getArg().getType() == castOp.getType())
3854 return castOp.getArg();
3855 if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3856 // cast(cast(x : T0, T1), T0) -> x
3857 if (prev.getArg().getType() == castOp.getType())
3858 return prev.getArg();
3859 // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3860 castOp.getArgMutable().set(prev.getArg());
3861 return Value{castOp};
3862 }
3863 return {};
3864}
3865
3866OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3867 return foldChainableCast(*this, adaptor);
3868}
3869
3870LogicalResult LLVM::BitcastOp::verify() {
3871 Type srcElemType = extractVectorElementType(getArg().getType());
3872 Type dstElemType = extractVectorElementType(getResult().getType());
3873
3874 // TODO: 'bitcast' requires result and operand type to be identical in size.
3875 // Byte types may be cast from/to any type pointer constraints.
3876 if (isa<LLVMByteType>(srcElemType) || isa<LLVMByteType>(dstElemType))
3877 return success();
3878
3879 auto resultType = llvm::dyn_cast<LLVMPointerType>(dstElemType);
3880 auto sourceType = llvm::dyn_cast<LLVMPointerType>(srcElemType);
3881
3882 // If one of the types is a pointer (or vector of pointers), then
3883 // both source and result type have to be pointers.
3884 if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3885 return emitOpError("can only cast pointers from and to pointers");
3886
3887 if (!resultType)
3888 return success();
3889
3890 auto isVector = llvm::IsaPred<VectorType>;
3891
3892 // Due to bitcast requiring both operands to be of the same size, it is not
3893 // possible for only one of the two to be a pointer of vectors.
3894 if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3895 return emitOpError("cannot cast pointer to vector of pointers");
3896
3897 if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3898 return emitOpError("cannot cast vector of pointers to pointer");
3899
3900 // Bitcast cannot cast between pointers of different address spaces.
3901 // 'llvm.addrspacecast' must be used for this purpose instead.
3902 if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3903 return emitOpError("cannot cast pointers of different address spaces, "
3904 "use 'llvm.addrspacecast' instead");
3905
3906 return success();
3907}
3908
3909LogicalResult LLVM::PtrToAddrOp::verify() {
3910 auto pointerType =
3911 cast<LLVM::LLVMPointerType>(extractVectorElementType(getArg().getType()));
3912 auto integerType = cast<IntegerType>(extractVectorElementType(getType()));
3913
3914 auto dataLayout = DataLayout::closest(*this);
3915 std::optional<unsigned> width = dataLayout.getTypeIndexBitwidth(pointerType);
3916 assert(width && "pointers always return an index bitwidth");
3917 if (width != integerType.getWidth())
3918 return emitOpError("bit-width of integer result type ")
3919 << integerType << " must match the pointer bitwidth (" << *width
3920 << ") specified in the datalayout";
3921
3922 return success();
3923}
3924
3925//===----------------------------------------------------------------------===//
3926// Folder for LLVM::AddrSpaceCastOp
3927//===----------------------------------------------------------------------===//
3928
3929OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3930 return foldChainableCast(*this, adaptor);
3931}
3932
3933Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3934
3935//===----------------------------------------------------------------------===//
3936// Folder for LLVM::GEPOp
3937//===----------------------------------------------------------------------===//
3938
3939OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3940 GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3941 adaptor.getDynamicIndices());
3942
3943 // gep %x:T, 0 -> %x
3944 if (getBase().getType() == getType() && indices.size() == 1)
3945 if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
3946 if (integer.getValue().isZero())
3947 return getBase();
3948
3949 // Canonicalize any dynamic indices of constant value to constant indices.
3950 bool changed = false;
3951 SmallVector<GEPArg> gepArgs;
3952 for (auto iter : llvm::enumerate(indices)) {
3953 auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
3954 // Constant indices can only be int32_t, so if integer does not fit we
3955 // are forced to keep it dynamic, despite being a constant.
3956 if (!indices.isDynamicIndex(iter.index()) || !integer ||
3957 !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
3958
3959 PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3960 if (Value val = llvm::dyn_cast_if_present<Value>(existing))
3961 gepArgs.emplace_back(val);
3962 else
3963 gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt());
3964
3965 continue;
3966 }
3967
3968 changed = true;
3969 gepArgs.emplace_back(integer.getInt());
3970 }
3971 if (changed) {
3972 SmallVector<int32_t> rawConstantIndices;
3973 SmallVector<Value> dynamicIndices;
3974 destructureIndices(getElemType(), gepArgs, rawConstantIndices,
3975 dynamicIndices);
3976
3977 getDynamicIndicesMutable().assign(dynamicIndices);
3978 setRawConstantIndices(rawConstantIndices);
3979 return Value{*this};
3980 }
3981
3982 return {};
3983}
3984
3985Value LLVM::GEPOp::getViewSource() { return getBase(); }
3986
3987//===----------------------------------------------------------------------===//
3988// ShlOp
3989//===----------------------------------------------------------------------===//
3990
3991OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3992 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3993 if (!rhs)
3994 return {};
3995
3996 if (rhs.getValue().getZExtValue() >=
3997 getLhs().getType().getIntOrFloatBitWidth())
3998 return {}; // TODO: Fold into poison.
3999
4000 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
4001 if (!lhs)
4002 return {};
4003
4004 return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
4005}
4006
4007//===----------------------------------------------------------------------===//
4008// OrOp
4009//===----------------------------------------------------------------------===//
4010
4011OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
4012 auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
4013 if (!lhs)
4014 return {};
4015
4016 auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
4017 if (!rhs)
4018 return {};
4019
4020 return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
4021}
4022
4023//===----------------------------------------------------------------------===//
4024// CallIntrinsicOp
4025//===----------------------------------------------------------------------===//
4026
4027LogicalResult CallIntrinsicOp::verify() {
4028 if (!getIntrin().starts_with("llvm."))
4029 return emitOpError() << "intrinsic name must start with 'llvm.'";
4030 if (failed(verifyOperandBundles(*this)))
4031 return failure();
4032 return success();
4033}
4034
4035void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4036 mlir::StringAttr intrin, mlir::ValueRange args) {
4037 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
4038 FastmathFlagsAttr{},
4039 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4040 /*res_attrs=*/{});
4041}
4042
4043void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4044 mlir::StringAttr intrin, mlir::ValueRange args,
4045 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
4046 build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
4047 fastMathFlags,
4048 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4049 /*res_attrs=*/{});
4050}
4051
4052void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4053 mlir::Type resultType, mlir::StringAttr intrin,
4054 mlir::ValueRange args) {
4055 build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
4056 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4057 /*res_attrs=*/{});
4058}
4059
4060void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
4061 mlir::TypeRange resultTypes,
4062 mlir::StringAttr intrin, mlir::ValueRange args,
4063 mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
4064 build(builder, state, resultTypes, intrin, args, fastMathFlags,
4065 /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
4066 /*res_attrs=*/{});
4067}
4068
4069ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
4071 StringAttr intrinAttr;
4074 SmallVector<SmallVector<Type>> opBundleOperandTypes;
4075 ArrayAttr opBundleTags;
4076
4077 // Parse intrinsic name.
4079 intrinAttr, parser.getBuilder().getType<NoneType>()))
4080 return failure();
4081 result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
4082 intrinAttr);
4083
4084 if (parser.parseLParen())
4085 return failure();
4086
4087 // Parse the function arguments.
4088 if (parser.parseOperandList(operands))
4089 return mlir::failure();
4090
4091 if (parser.parseRParen())
4092 return mlir::failure();
4093
4094 // Handle bundles.
4095 SMLoc opBundlesLoc = parser.getCurrentLocation();
4096 if (std::optional<ParseResult> result = parseOpBundles(
4097 parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
4098 result && failed(*result))
4099 return failure();
4100 if (opBundleTags && !opBundleTags.empty())
4101 result.addAttribute(
4102 CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
4103 opBundleTags);
4104
4105 if (parser.parseOptionalAttrDict(result.attributes))
4106 return mlir::failure();
4107
4109 SmallVector<DictionaryAttr> resultAttrs;
4110 if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
4111 operands, argAttrs, resultAttrs))
4112 return failure();
4114 parser.getBuilder(), result, argAttrs, resultAttrs,
4115 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
4116
4117 if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
4118 opBundleOperandTypes,
4119 getOpBundleSizesAttrName(result.name)))
4120 return failure();
4121
4122 int32_t numOpBundleOperands = 0;
4123 for (const auto &operands : opBundleOperands)
4124 numOpBundleOperands += operands.size();
4125
4126 result.addAttribute(
4127 CallIntrinsicOp::getOperandSegmentSizeAttr(),
4129 {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
4130
4131 return mlir::success();
4132}
4133
4134void CallIntrinsicOp::print(OpAsmPrinter &p) {
4135 p << ' ';
4136 p.printAttributeWithoutType(getIntrinAttr());
4137
4138 OperandRange args = getArgs();
4139 p << "(" << args << ")";
4140
4141 // Operand bundles.
4142 if (!getOpBundleOperands().empty()) {
4143 p << ' ';
4144 printOpBundles(p, *this, getOpBundleOperands(),
4145 getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
4146 }
4147
4148 p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
4149 {getOperandSegmentSizesAttrName(),
4150 getOpBundleSizesAttrName(), getIntrinAttrName(),
4151 getOpBundleTagsAttrName(), getArgAttrsAttrName(),
4152 getResAttrsAttrName()});
4153
4154 p << " : ";
4155
4156 // Reconstruct the MLIR function type from operand and result types.
4158 p, args.getTypes(), getArgAttrsAttr(),
4159 /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
4160}
4161
4162//===----------------------------------------------------------------------===//
4163// LinkerOptionsOp
4164//===----------------------------------------------------------------------===//
4165
4166LogicalResult LinkerOptionsOp::verify() {
4167 if (mlir::Operation *parentOp = (*this)->getParentOp();
4168 parentOp && !satisfiesLLVMModule(parentOp))
4169 return emitOpError("must appear at the module level");
4170 return success();
4171}
4172
4173//===----------------------------------------------------------------------===//
4174// ModuleFlagsOp
4175//===----------------------------------------------------------------------===//
4176
4177LogicalResult ModuleFlagsOp::verify() {
4178 if (Operation *parentOp = (*this)->getParentOp();
4179 parentOp && !satisfiesLLVMModule(parentOp))
4180 return emitOpError("must appear at the module level");
4181
4182 llvm::DenseSet<StringAttr> seenNonRequireKeys;
4183 for (Attribute flag : getFlags()) {
4184 auto moduleFlag = dyn_cast<ModuleFlagAttrInterface>(flag);
4185 if (!moduleFlag)
4186 return emitOpError("expected a module flag attribute");
4188 moduleFlag.getModuleFlagKey(), moduleFlag.getModuleFlagValue(),
4189 [&] { return emitOpError(); })))
4190 return failure();
4191 if (moduleFlag.getModuleFlagBehavior() == ModFlagBehavior::Require)
4192 continue;
4193 StringAttr key = moduleFlag.getModuleFlagKey();
4194 if (!seenNonRequireKeys.insert(key).second)
4195 return emitOpError("expected module flag key '")
4196 << key.getValue() << "' to be unique for non-require flags";
4197 }
4198 return success();
4199}
4200
4201//===----------------------------------------------------------------------===//
4202// InlineAsmOp
4203//===----------------------------------------------------------------------===//
4204
4205void InlineAsmOp::getEffects(
4207 &effects) {
4208 if (getHasSideEffects()) {
4209 effects.emplace_back(MemoryEffects::Write::get());
4210 effects.emplace_back(MemoryEffects::Read::get());
4211 }
4212}
4213
4214//===----------------------------------------------------------------------===//
4215// BlockAddressOp
4216//===----------------------------------------------------------------------===//
4217
4218LogicalResult
4219BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4220 Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
4221 getBlockAddr().getFunction());
4222 auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
4223
4224 if (!function)
4225 return emitOpError("must reference a function defined by 'llvm.func'");
4226
4227 return success();
4228}
4229
4230LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
4231 return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
4232 parentLLVMModule(*this), getBlockAddr().getFunction()));
4233}
4234
4235BlockTagOp BlockAddressOp::getBlockTagOp() {
4237 parentLLVMModule(*this), getBlockAddr().getFunction());
4238 if (!sym)
4239 return nullptr;
4240 auto funcOp = dyn_cast<LLVMFuncOp>(sym);
4241 if (!funcOp)
4242 return nullptr;
4243 BlockTagOp blockTagOp = nullptr;
4244 funcOp.walk([&](LLVM::BlockTagOp labelOp) {
4245 if (labelOp.getTag() == getBlockAddr().getTag()) {
4246 blockTagOp = labelOp;
4247 return WalkResult::interrupt();
4248 }
4249 return WalkResult::advance();
4250 });
4251 return blockTagOp;
4252}
4253
4254LogicalResult BlockAddressOp::verify() {
4255 if (!getBlockTagOp())
4256 return emitOpError(
4257 "expects an existing block label target in the referenced function");
4258
4259 return success();
4260}
4261
4262/// Fold a blockaddress operation to a dedicated blockaddress
4263/// attribute.
4264OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
4265
4266//===----------------------------------------------------------------------===//
4267// LLVM::IndirectBrOp
4268//===----------------------------------------------------------------------===//
4269
4270SuccessorOperands IndirectBrOp::getSuccessorOperands(unsigned index) {
4271 assert(index < getNumSuccessors() && "invalid successor index");
4272 return SuccessorOperands(getSuccOperandsMutable()[index]);
4273}
4274
4275void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4276 Value addr, ArrayRef<ValueRange> succOperands,
4277 BlockRange successors) {
4278 odsState.addOperands(addr);
4279 for (ValueRange range : succOperands)
4280 odsState.addOperands(range);
4281 SmallVector<int32_t> rangeSegments;
4282 for (ValueRange range : succOperands)
4283 rangeSegments.push_back(range.size());
4284 odsState.getOrAddProperties<Properties>().indbr_operand_segments =
4285 odsBuilder.getDenseI32ArrayAttr(rangeSegments);
4286 odsState.addSuccessors(successors);
4287}
4288
4290 OpAsmParser &parser, Type &flagType,
4291 SmallVectorImpl<Block *> &succOperandBlocks,
4293 SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
4294 if (failed(parser.parseCommaSeparatedList(
4296 [&]() {
4297 Block *destination = nullptr;
4298 SmallVector<OpAsmParser::UnresolvedOperand> operands;
4299 SmallVector<Type> operandTypes;
4300
4301 if (parser.parseSuccessor(destination).failed())
4302 return failure();
4303
4304 if (succeeded(parser.parseOptionalLParen())) {
4305 if (failed(parser.parseOperandList(
4306 operands, OpAsmParser::Delimiter::None)) ||
4307 failed(parser.parseColonTypeList(operandTypes)) ||
4308 failed(parser.parseRParen()))
4309 return failure();
4310 }
4311 succOperandBlocks.push_back(destination);
4312 succOperands.emplace_back(operands);
4313 succOperandsTypes.emplace_back(operandTypes);
4314 return success();
4315 },
4316 "successor blocks")))
4317 return failure();
4318 return success();
4319}
4320
4321static void
4322printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
4323 SuccessorRange succs, OperandRangeRange succOperands,
4324 const TypeRangeRange &succOperandsTypes) {
4325 p << "[";
4326 llvm::interleave(
4327 llvm::zip(succs, succOperands),
4328 [&](auto i) {
4329 p.printNewline();
4330 p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
4331 },
4332 [&] { p << ','; });
4333 if (!succOperands.empty())
4334 p.printNewline();
4335 p << "]";
4336}
4337
4338//===----------------------------------------------------------------------===//
4339// SincosOp (intrinsic)
4340//===----------------------------------------------------------------------===//
4341
4342LogicalResult LLVM::SincosOp::verify() {
4343 auto operandType = getOperand().getType();
4344 auto resultType = getResult().getType();
4345 auto resultStructType =
4346 mlir::dyn_cast<mlir::LLVM::LLVMStructType>(resultType);
4347 if (!resultStructType || resultStructType.getBody().size() != 2 ||
4348 resultStructType.getBody()[0] != operandType ||
4349 resultStructType.getBody()[1] != operandType) {
4350 return emitOpError("expected result type to be an homogeneous struct with "
4351 "two elements matching the operand type, but got ")
4352 << resultType;
4353 }
4354 return success();
4355}
4356
4357//===----------------------------------------------------------------------===//
4358// AssumeOp (intrinsic)
4359//===----------------------------------------------------------------------===//
4360
4361void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4362 mlir::Value cond) {
4363 return build(builder, state, cond, /*op_bundle_operands=*/{},
4364 /*op_bundle_tags=*/ArrayAttr{});
4365}
4366
4367void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4368 Value cond, llvm::StringRef tag, ValueRange args) {
4369 return build(builder, state, cond, ArrayRef<ValueRange>(args),
4370 builder.getStrArrayAttr(tag));
4371}
4372
4373void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4374 Value cond, AssumeAlignTag, Value ptr, Value align) {
4375 return build(builder, state, cond, "align", ValueRange{ptr, align});
4376}
4377
4378void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
4380 Value ptr2) {
4381 return build(builder, state, cond, "separate_storage",
4382 ValueRange{ptr1, ptr2});
4383}
4384
4385LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
4386
4387//===----------------------------------------------------------------------===//
4388// masked_gather (intrinsic)
4389//===----------------------------------------------------------------------===//
4390
4391LogicalResult LLVM::masked_gather::verify() {
4392 auto ptrsVectorType = getPtrs().getType();
4393 Type expectedPtrsVectorType =
4396 // Vector of pointers type should match result vector type, other than the
4397 // element type.
4398 if (ptrsVectorType != expectedPtrsVectorType)
4399 return emitOpError("expected operand #1 type to be ")
4400 << expectedPtrsVectorType;
4401 return success();
4402}
4403
4404//===----------------------------------------------------------------------===//
4405// masked_scatter (intrinsic)
4406//===----------------------------------------------------------------------===//
4407
4408LogicalResult LLVM::masked_scatter::verify() {
4409 auto ptrsVectorType = getPtrs().getType();
4410 Type expectedPtrsVectorType =
4412 LLVM::getVectorNumElements(getValue().getType()));
4413 // Vector of pointers type should match value vector type, other than the
4414 // element type.
4415 if (ptrsVectorType != expectedPtrsVectorType)
4416 return emitOpError("expected operand #2 type to be ")
4417 << expectedPtrsVectorType;
4418 return success();
4419}
4420
4421//===----------------------------------------------------------------------===//
4422// masked_expandload (intrinsic)
4423//===----------------------------------------------------------------------===//
4424
4425void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
4426 mlir::TypeRange resTys, Value ptr,
4427 Value mask, Value passthru,
4428 uint64_t align) {
4429 ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
4430 build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
4431 /*res_attrs=*/nullptr);
4432}
4433
4434//===----------------------------------------------------------------------===//
4435// masked_compressstore (intrinsic)
4436//===----------------------------------------------------------------------===//
4437
4438void LLVM::masked_compressstore::build(OpBuilder &builder,
4439 OperationState &state, Value value,
4440 Value ptr, Value mask, uint64_t align) {
4441 ArrayAttr argAttrs =
4442 getLLVMAlignParamForCompressExpand(builder, false, align);
4443 build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
4444 /*res_attrs=*/nullptr);
4445}
4446
4447//===----------------------------------------------------------------------===//
4448// InlineAsmOp
4449//===----------------------------------------------------------------------===//
4450
4451LogicalResult InlineAsmOp::verify() {
4452 if (!getTailCallKindAttr())
4453 return success();
4454
4455 if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4456 return emitOpError(
4457 "tail call kind 'musttail' is not supported by this operation");
4458
4459 return success();
4460}
4461
4462//===----------------------------------------------------------------------===//
4463// UDivOp
4464//===----------------------------------------------------------------------===//
4465Speculation::Speculatability UDivOp::getSpeculatability() {
4466 // X / 0 => UB
4467 Value divisor = getRhs();
4468 if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
4470
4472}
4473
4474//===----------------------------------------------------------------------===//
4475// SDivOp
4476//===----------------------------------------------------------------------===//
4477Speculation::Speculatability SDivOp::getSpeculatability() {
4478 // This function conservatively assumes that all signed division by -1 are
4479 // not speculatable.
4480 // X / 0 => UB
4481 // INT_MIN / -1 => UB
4482 Value divisor = getRhs();
4483 if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
4486
4488}
4489
4490//===----------------------------------------------------------------------===//
4491// LLVMDialect initialization, type parsing, and registration.
4492//===----------------------------------------------------------------------===//
4493
4494void LLVMDialect::initialize() {
4495 registerAttributes();
4496
4497 // clang-format off
4498 addTypes<LLVMVoidType,
4499 LLVMLabelType,
4500 LLVMMetadataType>();
4501 // clang-format on
4502 registerTypes();
4503
4504 addOperations<
4505#define GET_OP_LIST
4506#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4507
4508 ,
4509#define GET_OP_LIST
4510#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4511
4512 >();
4513
4514 // Support unknown operations because not all LLVM operations are registered.
4515 allowUnknownOperations();
4516 declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
4518}
4519
4520#define GET_OP_CLASSES
4521#include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4522
4523#define GET_OP_CLASSES
4524#include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4525
4526LogicalResult LLVMDialect::verifyDataLayoutString(
4527 StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
4528 llvm::Expected<llvm::DataLayout> maybeDataLayout =
4529 llvm::DataLayout::parse(descr);
4530 if (maybeDataLayout)
4531 return success();
4532
4533 std::string message;
4534 llvm::raw_string_ostream messageStream(message);
4535 llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
4536 reportError("invalid data layout descriptor: " + message);
4537 return failure();
4538}
4539
4540/// Verify LLVM dialect attributes.
4541LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
4542 NamedAttribute attr) {
4543 // If the data layout attribute is present, it must use the LLVM data layout
4544 // syntax. Try parsing it and report errors in case of failure. Users of this
4545 // attribute may assume it is well-formed and can pass it to the (asserting)
4546 // llvm::DataLayout constructor.
4547 if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
4548 return success();
4549 if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
4550 return verifyDataLayoutString(
4551 stringAttr.getValue(),
4552 [op](const Twine &message) { op->emitOpError() << message.str(); });
4553
4554 return op->emitOpError() << "expected '"
4555 << LLVM::LLVMDialect::getDataLayoutAttrName()
4556 << "' to be a string attributes";
4557}
4558
4559LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
4560 Type paramType,
4561 NamedAttribute paramAttr) {
4562 // LLVM attribute may be attached to a result of operation that has not been
4563 // converted to LLVM dialect yet, so the result may have a type with unknown
4564 // representation in LLVM dialect type space. In this case we cannot verify
4565 // whether the attribute may be
4566 bool verifyValueType = isCompatibleType(paramType);
4567 StringAttr name = paramAttr.getName();
4568
4569 auto checkUnitAttrType = [&]() -> LogicalResult {
4570 if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
4571 return op->emitError() << name << " should be a unit attribute";
4572 return success();
4573 };
4574 auto checkTypeAttrType = [&]() -> LogicalResult {
4575 if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
4576 return op->emitError() << name << " should be a type attribute";
4577 return success();
4578 };
4579 auto checkIntegerAttrType = [&]() -> LogicalResult {
4580 if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
4581 return op->emitError() << name << " should be an integer attribute";
4582 return success();
4583 };
4584 auto checkPointerType = [&]() -> LogicalResult {
4585 if (!llvm::isa<LLVMPointerType>(paramType))
4586 return op->emitError()
4587 << name << " attribute attached to non-pointer LLVM type";
4588 return success();
4589 };
4590 auto checkIntegerType = [&]() -> LogicalResult {
4591 if (!llvm::isa<IntegerType>(paramType))
4592 return op->emitError()
4593 << name << " attribute attached to non-integer LLVM type";
4594 return success();
4595 };
4596 auto checkPointerTypeMatches = [&]() -> LogicalResult {
4597 if (failed(checkPointerType()))
4598 return failure();
4599
4600 return success();
4601 };
4602
4603 // Check a unit attribute that is attached to a pointer value.
4604 if (name == LLVMDialect::getNoAliasAttrName() ||
4605 name == LLVMDialect::getReadonlyAttrName() ||
4606 name == LLVMDialect::getReadnoneAttrName() ||
4607 name == LLVMDialect::getWriteOnlyAttrName() ||
4608 name == LLVMDialect::getNestAttrName() ||
4609 name == LLVMDialect::getNoCaptureAttrName() ||
4610 name == LLVMDialect::getNoFreeAttrName() ||
4611 name == LLVMDialect::getNonNullAttrName()) {
4612 if (failed(checkUnitAttrType()))
4613 return failure();
4614 if (verifyValueType && failed(checkPointerType()))
4615 return failure();
4616 return success();
4617 }
4618
4619 // Check a type attribute that is attached to a pointer value.
4620 if (name == LLVMDialect::getStructRetAttrName() ||
4621 name == LLVMDialect::getByValAttrName() ||
4622 name == LLVMDialect::getByRefAttrName() ||
4623 name == LLVMDialect::getElementTypeAttrName() ||
4624 name == LLVMDialect::getInAllocaAttrName() ||
4625 name == LLVMDialect::getPreallocatedAttrName()) {
4626 if (failed(checkTypeAttrType()))
4627 return failure();
4628 if (verifyValueType && failed(checkPointerTypeMatches()))
4629 return failure();
4630 return success();
4631 }
4632
4633 // Check a unit attribute that is attached to an integer value.
4634 if (name == LLVMDialect::getSExtAttrName() ||
4635 name == LLVMDialect::getZExtAttrName()) {
4636 if (failed(checkUnitAttrType()))
4637 return failure();
4638 if (verifyValueType && failed(checkIntegerType()))
4639 return failure();
4640 return success();
4641 }
4642
4643 // Check an integer attribute that is attached to a pointer value.
4644 if (name == LLVMDialect::getAlignAttrName() ||
4645 name == LLVMDialect::getDereferenceableAttrName() ||
4646 name == LLVMDialect::getDereferenceableOrNullAttrName()) {
4647 if (failed(checkIntegerAttrType()))
4648 return failure();
4649 if (verifyValueType && failed(checkPointerType()))
4650 return failure();
4651 return success();
4652 }
4653
4654 // Check an integer attribute that is attached to a pointer value.
4655 if (name == LLVMDialect::getStackAlignmentAttrName()) {
4656 if (failed(checkIntegerAttrType()))
4657 return failure();
4658 return success();
4659 }
4660
4661 // Check a unit attribute that can be attached to arbitrary types.
4662 if (name == LLVMDialect::getNoUndefAttrName() ||
4663 name == LLVMDialect::getInRegAttrName() ||
4664 name == LLVMDialect::getReturnedAttrName())
4665 return checkUnitAttrType();
4666
4667 return success();
4668}
4669
4670/// Verify LLVMIR function argument attributes.
4671LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
4672 unsigned regionIdx,
4673 unsigned argIdx,
4674 NamedAttribute argAttr) {
4675 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4676 if (!funcOp)
4677 return success();
4678 Type argType = funcOp.getArgumentTypes()[argIdx];
4679
4680 return verifyParameterAttribute(op, argType, argAttr);
4681}
4682
4683LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
4684 unsigned regionIdx,
4685 unsigned resIdx,
4686 NamedAttribute resAttr) {
4687 auto funcOp = dyn_cast<FunctionOpInterface>(op);
4688 if (!funcOp)
4689 return success();
4690 Type resType = funcOp.getResultTypes()[resIdx];
4691
4692 // Check to see if this function has a void return with a result attribute
4693 // to it. It isn't clear what semantics we would assign to that.
4694 if (llvm::isa<LLVMVoidType>(resType))
4695 return op->emitError() << "cannot attach result attributes to functions "
4696 "with a void return";
4697
4698 // Check to see if this attribute is allowed as a result attribute. Only
4699 // explicitly forbidden LLVM attributes will cause an error.
4700 auto name = resAttr.getName();
4701 if (name == LLVMDialect::getAllocAlignAttrName() ||
4702 name == LLVMDialect::getAllocatedPointerAttrName() ||
4703 name == LLVMDialect::getByValAttrName() ||
4704 name == LLVMDialect::getByRefAttrName() ||
4705 name == LLVMDialect::getInAllocaAttrName() ||
4706 name == LLVMDialect::getNestAttrName() ||
4707 name == LLVMDialect::getNoCaptureAttrName() ||
4708 name == LLVMDialect::getNoFreeAttrName() ||
4709 name == LLVMDialect::getPreallocatedAttrName() ||
4710 name == LLVMDialect::getReadnoneAttrName() ||
4711 name == LLVMDialect::getReadonlyAttrName() ||
4712 name == LLVMDialect::getReturnedAttrName() ||
4713 name == LLVMDialect::getStackAlignmentAttrName() ||
4714 name == LLVMDialect::getStructRetAttrName() ||
4715 name == LLVMDialect::getWriteOnlyAttrName())
4716 return op->emitError() << name << " is not a valid result attribute";
4717 return verifyParameterAttribute(op, resType, resAttr);
4718}
4719
4720Operation *LLVMDialect::materializeConstant(OpBuilder &builder, Attribute value,
4721 Type type, Location loc) {
4722 // If this was folded from an operation other than llvm.mlir.constant, it
4723 // should be materialized as such. Note that an llvm.mlir.zero may fold into
4724 // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
4725 if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
4726 if (isa<LLVM::LLVMPointerType>(type))
4727 return LLVM::AddressOfOp::create(builder, loc, type, symbol);
4728 if (isa<LLVM::UndefAttr>(value))
4729 return LLVM::UndefOp::create(builder, loc, type);
4730 if (isa<LLVM::PoisonAttr>(value))
4731 return LLVM::PoisonOp::create(builder, loc, type);
4732 if (isa<LLVM::ZeroAttr>(value))
4733 return LLVM::ZeroOp::create(builder, loc, type);
4734 if (isa<LLVM::MDStringAttr, LLVM::MDConstantAttr, LLVM::MDFuncAttr,
4735 LLVM::MDNodeAttr>(value))
4736 if (isa<LLVM::LLVMMetadataType>(type))
4737 return LLVM::MetadataAsValueOp::create(builder, loc, type, value);
4738 // Otherwise try materializing it as a regular llvm.mlir.constant op.
4739 return LLVM::ConstantOp::materialize(builder, value, type, loc);
4740}
4741
4742//===----------------------------------------------------------------------===//
4743// Utility functions.
4744//===----------------------------------------------------------------------===//
4745
4747 StringRef name, StringRef value,
4748 LLVM::Linkage linkage) {
4749 assert(builder.getInsertionBlock() &&
4750 builder.getInsertionBlock()->getParentOp() &&
4751 "expected builder to point to a block constrained in an op");
4752 auto module =
4753 builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
4754 assert(module && "builder points to an op outside of a module");
4755
4756 // Create the global at the entry of the module.
4757 OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
4758 MLIRContext *ctx = builder.getContext();
4759 auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
4760 auto global = LLVM::GlobalOp::create(
4761 moduleBuilder, loc, type, /*isConstant=*/true, linkage, name,
4762 builder.getStringAttr(value), /*alignment=*/0);
4763
4764 LLVMPointerType ptrType = LLVMPointerType::get(ctx);
4765 // Get the pointer to the first character in the global string.
4766 Value globalPtr =
4767 LLVM::AddressOfOp::create(builder, loc, ptrType, global.getSymNameAttr());
4768 return LLVM::GEPOp::create(builder, loc, ptrType, type, globalPtr,
4769 ArrayRef<GEPArg>{0, 0});
4770}
4771
4776
4778 Operation *module = op->getParentOp();
4779 while (module && !satisfiesLLVMModule(module))
4780 module = module->getParentOp();
4781 assert(module && "unexpected operation outside of a module");
4782 return module;
4783}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
lhs
static int parseOptionalKeywordAlternative(OpAsmParser &parser, ArrayRef< StringRef > keywords)
static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, bool isExpandLoad, uint64_t alignment=1)
static LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data)
static ParseResult parseGEPIndices(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &indices, DenseI32ArrayAttr &rawConstantIndices)
static LogicalResult verifyOperandBundles(OpType &op)
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result)
static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, TypeRange operandTypes, StringRef tag)
static LogicalResult verifyComdat(Operation *op, std::optional< SymbolRefAttr > attr)
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, ValueRange args)
Constructs a LLVMFunctionType from MLIR results and args.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &caseOperands, SmallVectorImpl< SmallVector< Type > > &caseOperandTypes)
<cases> ::= [ (case (, case )* )?
static LogicalResult verifyCallOpVarCalleeType(OpTy callOp)
Verify that the parameter and return types of the variadic callee type match the callOp argument and ...
static ParseResult parseOptionalCallFuncPtr(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands)
Parses an optional function pointer operand before the call argument list for indirect calls,...
static bool isZeroAttribute(Attribute value)
static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, OperandRange indices, DenseI32ArrayAttr rawConstantIndices)
static std::optional< ParseResult > parseOpBundles(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand > > &opBundleOperands, SmallVector< SmallVector< Type > > &opBundleOperandTypes, ArrayAttr &opBundleTags)
static LLVMStructType getValAndBoolStructType(Type valType)
Returns an LLVM struct type that contains a value type and a boolean type.
static void printOpBundles(OpAsmPrinter &p, Operation *op, OperandRangeRange opBundleOperands, TypeRangeRange opBundleOperandTypes, std::optional< ArrayAttr > opBundleTags)
static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, Type resType, DenseI32ArrayAttr mask)
Nothing to do when the result type is inferred.
static LogicalResult verifyBlockTags(LLVMFuncOp funcOp)
static Type buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef< Type > inputs, ArrayRef< Type > outputs, function_interface_impl::VariadicFlag variadicFlag)
static auto processFMFAttr(ArrayRef< NamedAttribute > attrs)
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType)
Gets the variadic callee type for a LLVMFunctionType.
static Type getInsertExtractValueElementType(function_ref< InFlightDiagnostic(StringRef)> emitError, Type containerType, ArrayRef< int64_t > position)
Extract the type at position in the LLVM IR aggregate type containerType.
static ParseResult parseOneOpBundle(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand > > &opBundleOperands, SmallVector< SmallVector< Type > > &opBundleOperandTypes, SmallVector< Attribute > &opBundleTags)
static Type getElementType(Type type)
Determine the element type of type.
static void printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static ParseResult resolveOpBundleOperands(OpAsmParser &parser, SMLoc loc, OperationState &state, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand > > opBundleOperands, ArrayRef< SmallVector< Type > > opBundleOperandTypes, StringAttr opBundleSizesAttrName)
static void printLLVMLinkage(OpAsmPrinter &p, Operation *, LinkageAttr val)
static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op, LLVM::LLVMArrayType arrayType, ArrayAttr arrayAttr, int dim)
Verifies the constant array represented by arrayAttr matches the provided arrayType.
static ParseResult parseCallTypeAndResolveOperands(OpAsmParser &parser, OperationState &result, bool isDirect, ArrayRef< OpAsmParser::UnresolvedOperand > operands, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses the type of a call operation and resolves the operands if the parsing succeeds.
static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, Operation *op, SymbolTableCollection &symbolTable)
Verifies symbol's use in op to ensure the symbol is a valid and fully defined llvm....
static Type extractVectorElementType(Type type)
Returns the elemental type of any LLVM-compatible vector type or self.
static bool hasScalableVectorType(Type t)
Check if the given type is a scalable vector type or a vector/array type that contains a nested scala...
static SmallVector< Type, 1 > getCallOpResultTypes(LLVMFunctionType calleeType)
Gets the MLIR Op-like result types of a LLVMFunctionType.
static OpFoldResult foldChainableCast(T castOp, typename T::FoldAdaptor adaptor)
Folds a cast op that can be chained.
static void destructureIndices(Type currType, ArrayRef< GEPArg > indices, SmallVectorImpl< int32_t > &rawConstantIndices, SmallVectorImpl< Value > &dynamicIndices)
Destructures the 'indices' parameter into 'rawConstantIndices' and 'dynamicIndices',...
static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser, OperationState &result)
Parse common attributes that might show up in the same order in both GlobalOp and AliasOp.
static Type getI1SameShape(Type type)
Returns a boolean type that has the same shape as type.
static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op)
static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
Returns a scalar or vector boolean attribute of the given type.
static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee)
Verify that an inlinable callsite of a debug-info-bearing function in a debug-info-bearing function h...
static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, Type &resType, DenseI32ArrayAttr mask)
Build the result type of a shuffle vector operation.
static LogicalResult verifyExtOp(ExtOp op)
Verifies that the given extension operation operates on consistent scalars or vectors,...
static constexpr const char kElemTypeAttrName[]
static ParseResult parseInsertExtractValueElementType(AsmParser &parser, Type &valueType, Type containerType, DenseI64ArrayAttr position)
Infer the value type from the container type and position.
static LogicalResult verifyStructIndices(Type baseGEPType, unsigned indexPos, GEPIndicesAdaptor< ValueRange > indices, function_ref< InFlightDiagnostic()> emitOpError)
For the given indices, check if they comply with baseGEPType, especially check against LLVMStructType...
static Attribute extractElementAt(Attribute attr, size_t index)
Extracts the element at the given index from an attribute.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static void printInsertExtractValueElementType(AsmPrinter &printer, Operation *op, Type valueType, Type containerType, DenseI64ArrayAttr position)
Nothing to print for an inferred type.
static ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand > > &succOperands, SmallVectorImpl< SmallVector< Type > > &succOperandsTypes)
#define REGISTER_ENUM_TYPE(Ty)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static std::string diag(const llvm::Value &value)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
@ Square
Square brackets surrounding zero or more operands.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalColonTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional colon followed by a type list, which if present must have at least one type.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseCustomAttributeWithFallback(Attribute &result, Type type, function_ref< ParseResult(Attribute &result, Type type)> parseAttribute)=0
Parse a custom attribute with the provided callback, unless the next token is #, in which case the ge...
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printString(StringRef string)
Print the given string as a quoted string, escaping any special or non-printable characters in it.
virtual void printAttribute(Attribute attr)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
This class provides an abstraction over the different types of ranges over Blocks.
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition Block.cpp:31
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:102
IntegerAttr getI32IntegerAttr(int32_t value)
Definition Builders.cpp:204
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerAttr getI64IntegerAttr(int64_t value)
Definition Builders.cpp:116
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:267
TypedAttr getZeroAttr(Type type)
Definition Builders.cpp:329
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition Builders.cpp:271
MLIRContext * getContext() const
Definition Builders.h:56
DictionaryAttr getDictionaryAttr(ArrayRef< NamedAttribute > value)
Definition Builders.cpp:108
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
ArrayAttr getStrArrayAttr(ArrayRef< StringRef > values)
Definition Builders.cpp:311
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
std::optional< uint64_t > getTypeIndexBitwidth(Type t) const
Returns the bitwidth that should be used when performing index computations for the given pointer-lik...
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class represents a fused location whose metadata is known to be an instance of the given type.
Definition Location.h:149
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Class used for building a 'llvm.getelementptr'.
Definition LLVMDialect.h:71
Class used for convenient access and iteration over GEP indices.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition Builders.h:322
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
This class represents a single result from folding an operation.
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
This class represents a contiguous range of operand ranges, e.g.
Definition ValueRange.h:85
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getTypes() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
OperandRange operand_range
Definition Operation.h:396
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class implements Optional functionality for ParseResult.
ParseResult value() const
Access the internal ParseResult value.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & emplaceBlock()
Definition Region.h:46
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents a specific instance of an effect.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction for a range of TypeRange.
Definition TypeRange.h:107
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition Types.cpp:106
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition WalkResult.h:51
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
A named class for passing around the variadic flag.
The OpAsmOpInterface, see OpAsmInterface.td for more details.
Definition CallGraph.h:227
LogicalResult verifyModuleFlagValue(StringAttr key, Attribute value, function_ref< InFlightDiagnostic()> emitError)
Verifies that a module flag value can be exported to LLVM IR.
void addBytecodeInterface(LLVMDialect *dialect)
Add the interfaces necessary for encoding the LLVM dialect components in bytecode.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
Operation * parentLLVMModule(Operation *op)
Lookup parent Module satisfying LLVM conditions on the Module Operation.
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition LLVMDialect.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout)
Returns true if the given type is supported by atomic operations.
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, TypeRange resultTypes, ArrayAttr resultAttrs, Region *body=nullptr, bool printEmptyResult=true)
Print a function signature for a call or callable operation.
ParseResult parseFunctionSignature(OpAsmParser &parser, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs, bool mustParseEmptyResult=true)
Parses a function signature using parser.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition Visitors.h:102
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:471
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS()
Matches a constant scalar / vector splat / tensor splat integer or a signed integer range that does n...
Definition Matchers.h:462
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU()
Matches a constant scalar / vector splat / tensor splat integer or a unsigned integer range that does...
Definition Matchers.h:455
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) the properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.