MLIR  21.0.0git
LLVMDialect.cpp
Go to the documentation of this file.
1 //===- LLVMDialect.cpp - LLVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the LLVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "TypeDetail.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/MLIRContext.h"
25 #include "mlir/IR/Matchers.h"
28 
29 #include "llvm/ADT/SCCIterator.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/AsmParser/Parser.h"
32 #include "llvm/Bitcode/BitcodeReader.h"
33 #include "llvm/Bitcode/BitcodeWriter.h"
34 #include "llvm/IR/Attributes.h"
35 #include "llvm/IR/Function.h"
36 #include "llvm/IR/Type.h"
37 #include "llvm/Support/Error.h"
38 #include "llvm/Support/Mutex.h"
39 #include "llvm/Support/SourceMgr.h"
40 
41 #include <numeric>
42 #include <optional>
43 
44 using namespace mlir;
45 using namespace mlir::LLVM;
46 using mlir::LLVM::cconv::getMaxEnumValForCConv;
47 using mlir::LLVM::linkage::getMaxEnumValForLinkage;
48 using mlir::LLVM::tailcallkind::getMaxEnumValForTailCallKind;
49 
50 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
51 
52 //===----------------------------------------------------------------------===//
53 // Attribute Helpers
54 //===----------------------------------------------------------------------===//
55 
56 static constexpr const char kElemTypeAttrName[] = "elem_type";
57 
59  SmallVector<NamedAttribute, 8> filteredAttrs(
60  llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
61  if (attr.getName() == "fastmathFlags") {
62  auto defAttr =
63  FastmathFlagsAttr::get(attr.getValue().getContext(), {});
64  return defAttr != attr.getValue();
65  }
66  return true;
67  }));
68  return filteredAttrs;
69 }
70 
71 /// Verifies `symbol`'s use in `op` to ensure the symbol is a valid and
72 /// fully defined llvm.func.
73 static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol,
74  Operation *op,
75  SymbolTableCollection &symbolTable) {
76  StringRef name = symbol.getValue();
77  auto func =
78  symbolTable.lookupNearestSymbolFrom<LLVMFuncOp>(op, symbol.getAttr());
79  if (!func)
80  return op->emitOpError("'")
81  << name << "' does not reference a valid LLVM function";
82  if (func.isExternal())
83  return op->emitOpError("'") << name << "' does not have a definition";
84  return success();
85 }
86 
87 /// Returns a boolean type that has the same shape as `type`. It supports both
88 /// fixed size vectors as well as scalable vectors.
89 static Type getI1SameShape(Type type) {
90  Type i1Type = IntegerType::get(type.getContext(), 1);
92  return LLVM::getVectorType(i1Type, LLVM::getVectorNumElements(type));
93  return i1Type;
94 }
95 
96 // Parses one of the keywords provided in the list `keywords` and returns the
97 // position of the parsed keyword in the list. If none of the keywords from the
98 // list is parsed, returns -1.
100  ArrayRef<StringRef> keywords) {
101  for (const auto &en : llvm::enumerate(keywords)) {
102  if (succeeded(parser.parseOptionalKeyword(en.value())))
103  return en.index();
104  }
105  return -1;
106 }
107 
108 namespace {
109 template <typename Ty>
110 struct EnumTraits {};
111 
112 #define REGISTER_ENUM_TYPE(Ty) \
113  template <> \
114  struct EnumTraits<Ty> { \
115  static StringRef stringify(Ty value) { return stringify##Ty(value); } \
116  static unsigned getMaxEnumVal() { return getMaxEnumValFor##Ty(); } \
117  }
118 
119 REGISTER_ENUM_TYPE(Linkage);
120 REGISTER_ENUM_TYPE(UnnamedAddr);
121 REGISTER_ENUM_TYPE(CConv);
122 REGISTER_ENUM_TYPE(TailCallKind);
124 } // namespace
125 
126 /// Parse an enum from the keyword, or default to the provided default value.
127 /// The return type is the enum type by default, unless overridden with the
128 /// second template argument.
129 template <typename EnumTy, typename RetTy = EnumTy>
131  OperationState &result,
132  EnumTy defaultValue) {
134  for (unsigned i = 0, e = EnumTraits<EnumTy>::getMaxEnumVal(); i <= e; ++i)
135  names.push_back(EnumTraits<EnumTy>::stringify(static_cast<EnumTy>(i)));
136 
137  int index = parseOptionalKeywordAlternative(parser, names);
138  if (index == -1)
139  return static_cast<RetTy>(defaultValue);
140  return static_cast<RetTy>(index);
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // Operand bundle helpers.
145 //===----------------------------------------------------------------------===//
146 
147 static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands,
148  TypeRange operandTypes, StringRef tag) {
149  p.printString(tag);
150  p << "(";
151 
152  if (!operands.empty()) {
153  p.printOperands(operands);
154  p << " : ";
155  llvm::interleaveComma(operandTypes, p);
156  }
157 
158  p << ")";
159 }
160 
162  OperandRangeRange opBundleOperands,
163  TypeRangeRange opBundleOperandTypes,
164  std::optional<ArrayAttr> opBundleTags) {
165  if (opBundleOperands.empty())
166  return;
167  assert(opBundleTags && "expect operand bundle tags");
168 
169  p << "[";
170  llvm::interleaveComma(
171  llvm::zip(opBundleOperands, opBundleOperandTypes, *opBundleTags), p,
172  [&p](auto bundle) {
173  auto bundleTag = cast<StringAttr>(std::get<2>(bundle)).getValue();
174  printOneOpBundle(p, std::get<0>(bundle), std::get<1>(bundle),
175  bundleTag);
176  });
177  p << "]";
178 }
179 
180 static ParseResult parseOneOpBundle(
181  OpAsmParser &p,
183  SmallVector<SmallVector<Type>> &opBundleOperandTypes,
184  SmallVector<Attribute> &opBundleTags) {
185  SMLoc currentParserLoc = p.getCurrentLocation();
187  SmallVector<Type> types;
188  std::string tag;
189 
190  if (p.parseString(&tag))
191  return p.emitError(currentParserLoc, "expect operand bundle tag");
192 
193  if (p.parseLParen())
194  return failure();
195 
196  if (p.parseOptionalRParen()) {
197  if (p.parseOperandList(operands) || p.parseColon() ||
198  p.parseTypeList(types) || p.parseRParen())
199  return failure();
200  }
201 
202  opBundleOperands.push_back(std::move(operands));
203  opBundleOperandTypes.push_back(std::move(types));
204  opBundleTags.push_back(StringAttr::get(p.getContext(), tag));
205 
206  return success();
207 }
208 
209 static std::optional<ParseResult> parseOpBundles(
210  OpAsmParser &p,
212  SmallVector<SmallVector<Type>> &opBundleOperandTypes,
213  ArrayAttr &opBundleTags) {
214  if (p.parseOptionalLSquare())
215  return std::nullopt;
216 
217  if (succeeded(p.parseOptionalRSquare()))
218  return success();
219 
220  SmallVector<Attribute> opBundleTagAttrs;
221  auto bundleParser = [&] {
222  return parseOneOpBundle(p, opBundleOperands, opBundleOperandTypes,
223  opBundleTagAttrs);
224  };
225  if (p.parseCommaSeparatedList(bundleParser))
226  return failure();
227 
228  if (p.parseRSquare())
229  return failure();
230 
231  opBundleTags = ArrayAttr::get(p.getContext(), opBundleTagAttrs);
232 
233  return success();
234 }
235 
236 //===----------------------------------------------------------------------===//
237 // Printing, parsing, folding and builder for LLVM::CmpOp.
238 //===----------------------------------------------------------------------===//
239 
240 void ICmpOp::print(OpAsmPrinter &p) {
241  p << " \"" << stringifyICmpPredicate(getPredicate()) << "\" " << getOperand(0)
242  << ", " << getOperand(1);
243  p.printOptionalAttrDict((*this)->getAttrs(), {"predicate"});
244  p << " : " << getLhs().getType();
245 }
246 
247 void FCmpOp::print(OpAsmPrinter &p) {
248  p << " \"" << stringifyFCmpPredicate(getPredicate()) << "\" " << getOperand(0)
249  << ", " << getOperand(1);
250  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()), {"predicate"});
251  p << " : " << getLhs().getType();
252 }
253 
254 // <operation> ::= `llvm.icmp` string-literal ssa-use `,` ssa-use
255 // attribute-dict? `:` type
256 // <operation> ::= `llvm.fcmp` string-literal ssa-use `,` ssa-use
257 // attribute-dict? `:` type
258 template <typename CmpPredicateType>
259 static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result) {
260  StringAttr predicateAttr;
262  Type type;
263  SMLoc predicateLoc, trailingTypeLoc;
264  if (parser.getCurrentLocation(&predicateLoc) ||
265  parser.parseAttribute(predicateAttr, "predicate", result.attributes) ||
266  parser.parseOperand(lhs) || parser.parseComma() ||
267  parser.parseOperand(rhs) ||
268  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
269  parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type) ||
270  parser.resolveOperand(lhs, type, result.operands) ||
271  parser.resolveOperand(rhs, type, result.operands))
272  return failure();
273 
274  // Replace the string attribute `predicate` with an integer attribute.
275  int64_t predicateValue = 0;
276  if (std::is_same<CmpPredicateType, ICmpPredicate>()) {
277  std::optional<ICmpPredicate> predicate =
278  symbolizeICmpPredicate(predicateAttr.getValue());
279  if (!predicate)
280  return parser.emitError(predicateLoc)
281  << "'" << predicateAttr.getValue()
282  << "' is an incorrect value of the 'predicate' attribute";
283  predicateValue = static_cast<int64_t>(*predicate);
284  } else {
285  std::optional<FCmpPredicate> predicate =
286  symbolizeFCmpPredicate(predicateAttr.getValue());
287  if (!predicate)
288  return parser.emitError(predicateLoc)
289  << "'" << predicateAttr.getValue()
290  << "' is an incorrect value of the 'predicate' attribute";
291  predicateValue = static_cast<int64_t>(*predicate);
292  }
293 
294  result.attributes.set("predicate",
295  parser.getBuilder().getI64IntegerAttr(predicateValue));
296 
297  // The result type is either i1 or a vector type <? x i1> if the inputs are
298  // vectors.
299  if (!isCompatibleType(type))
300  return parser.emitError(trailingTypeLoc,
301  "expected LLVM dialect-compatible type");
302  result.addTypes(getI1SameShape(type));
303  return success();
304 }
305 
306 ParseResult ICmpOp::parse(OpAsmParser &parser, OperationState &result) {
307  return parseCmpOp<ICmpPredicate>(parser, result);
308 }
309 
310 ParseResult FCmpOp::parse(OpAsmParser &parser, OperationState &result) {
311  return parseCmpOp<FCmpPredicate>(parser, result);
312 }
313 
314 /// Returns a scalar or vector boolean attribute of the given type.
315 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
316  auto boolAttr = BoolAttr::get(ctx, value);
317  ShapedType shapedType = dyn_cast<ShapedType>(type);
318  if (!shapedType)
319  return boolAttr;
320  return DenseElementsAttr::get(shapedType, boolAttr);
321 }
322 
323 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
324  if (getPredicate() != ICmpPredicate::eq &&
325  getPredicate() != ICmpPredicate::ne)
326  return {};
327 
328  // cmpi(eq/ne, x, x) -> true/false
329  if (getLhs() == getRhs())
330  return getBoolAttribute(getType(), getContext(),
331  getPredicate() == ICmpPredicate::eq);
332 
333  // cmpi(eq/ne, alloca, null) -> false/true
334  if (getLhs().getDefiningOp<AllocaOp>() && getRhs().getDefiningOp<ZeroOp>())
335  return getBoolAttribute(getType(), getContext(),
336  getPredicate() == ICmpPredicate::ne);
337 
338  // cmpi(eq/ne, null, alloca) -> cmpi(eq/ne, alloca, null)
339  if (getLhs().getDefiningOp<ZeroOp>() && getRhs().getDefiningOp<AllocaOp>()) {
340  Value lhs = getLhs();
341  Value rhs = getRhs();
342  getLhsMutable().assign(rhs);
343  getRhsMutable().assign(lhs);
344  return getResult();
345  }
346 
347  return {};
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // Printing, parsing and verification for LLVM::AllocaOp.
352 //===----------------------------------------------------------------------===//
353 
354 void AllocaOp::print(OpAsmPrinter &p) {
355  auto funcTy =
356  FunctionType::get(getContext(), {getArraySize().getType()}, {getType()});
357 
358  if (getInalloca())
359  p << " inalloca";
360 
361  p << ' ' << getArraySize() << " x " << getElemType();
362  if (getAlignment() && *getAlignment() != 0)
363  p.printOptionalAttrDict((*this)->getAttrs(),
364  {kElemTypeAttrName, getInallocaAttrName()});
365  else
367  (*this)->getAttrs(),
368  {getAlignmentAttrName(), kElemTypeAttrName, getInallocaAttrName()});
369  p << " : " << funcTy;
370 }
371 
372 // <operation> ::= `llvm.alloca` `inalloca`? ssa-use `x` type
373 // attribute-dict? `:` type `,` type
374 ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
376  Type type, elemType;
377  SMLoc trailingTypeLoc;
378 
379  if (succeeded(parser.parseOptionalKeyword("inalloca")))
380  result.addAttribute(getInallocaAttrName(result.name),
381  UnitAttr::get(parser.getContext()));
382 
383  if (parser.parseOperand(arraySize) || parser.parseKeyword("x") ||
384  parser.parseType(elemType) ||
385  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
386  parser.getCurrentLocation(&trailingTypeLoc) || parser.parseType(type))
387  return failure();
388 
389  std::optional<NamedAttribute> alignmentAttr =
390  result.attributes.getNamed("alignment");
391  if (alignmentAttr.has_value()) {
392  auto alignmentInt = llvm::dyn_cast<IntegerAttr>(alignmentAttr->getValue());
393  if (!alignmentInt)
394  return parser.emitError(parser.getNameLoc(),
395  "expected integer alignment");
396  if (alignmentInt.getValue().isZero())
397  result.attributes.erase("alignment");
398  }
399 
400  // Extract the result type from the trailing function type.
401  auto funcType = llvm::dyn_cast<FunctionType>(type);
402  if (!funcType || funcType.getNumInputs() != 1 ||
403  funcType.getNumResults() != 1)
404  return parser.emitError(
405  trailingTypeLoc,
406  "expected trailing function type with one argument and one result");
407 
408  if (parser.resolveOperand(arraySize, funcType.getInput(0), result.operands))
409  return failure();
410 
411  Type resultType = funcType.getResult(0);
412  if (auto ptrResultType = llvm::dyn_cast<LLVMPointerType>(resultType))
413  result.addAttribute(kElemTypeAttrName, TypeAttr::get(elemType));
414 
415  result.addTypes({funcType.getResult(0)});
416  return success();
417 }
418 
419 LogicalResult AllocaOp::verify() {
420  // Only certain target extension types can be used in 'alloca'.
421  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getElemType());
422  targetExtType && !targetExtType.supportsMemOps())
423  return emitOpError()
424  << "this target extension type cannot be used in alloca";
425 
426  return success();
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // LLVM::BrOp
431 //===----------------------------------------------------------------------===//
432 
434  assert(index == 0 && "invalid successor index");
435  return SuccessorOperands(getDestOperandsMutable());
436 }
437 
438 //===----------------------------------------------------------------------===//
439 // LLVM::CondBrOp
440 //===----------------------------------------------------------------------===//
441 
443  assert(index < getNumSuccessors() && "invalid successor index");
444  return SuccessorOperands(index == 0 ? getTrueDestOperandsMutable()
445  : getFalseDestOperandsMutable());
446 }
447 
448 void CondBrOp::build(OpBuilder &builder, OperationState &result,
449  Value condition, Block *trueDest, ValueRange trueOperands,
450  Block *falseDest, ValueRange falseOperands,
451  std::optional<std::pair<uint32_t, uint32_t>> weights) {
452  DenseI32ArrayAttr weightsAttr;
453  if (weights)
454  weightsAttr =
455  builder.getDenseI32ArrayAttr({static_cast<int32_t>(weights->first),
456  static_cast<int32_t>(weights->second)});
457 
458  build(builder, result, condition, trueOperands, falseOperands, weightsAttr,
459  /*loop_annotation=*/{}, trueDest, falseDest);
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // LLVM::SwitchOp
464 //===----------------------------------------------------------------------===//
465 
466 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
467  Block *defaultDestination, ValueRange defaultOperands,
468  DenseIntElementsAttr caseValues,
469  BlockRange caseDestinations,
470  ArrayRef<ValueRange> caseOperands,
471  ArrayRef<int32_t> branchWeights) {
472  DenseI32ArrayAttr weightsAttr;
473  if (!branchWeights.empty())
474  weightsAttr = builder.getDenseI32ArrayAttr(branchWeights);
475 
476  build(builder, result, value, defaultOperands, caseOperands, caseValues,
477  weightsAttr, defaultDestination, caseDestinations);
478 }
479 
480 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
481  Block *defaultDestination, ValueRange defaultOperands,
482  ArrayRef<APInt> caseValues, BlockRange caseDestinations,
483  ArrayRef<ValueRange> caseOperands,
484  ArrayRef<int32_t> branchWeights) {
485  DenseIntElementsAttr caseValuesAttr;
486  if (!caseValues.empty()) {
487  ShapedType caseValueType = VectorType::get(
488  static_cast<int64_t>(caseValues.size()), value.getType());
489  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
490  }
491 
492  build(builder, result, value, defaultDestination, defaultOperands,
493  caseValuesAttr, caseDestinations, caseOperands, branchWeights);
494 }
495 
496 void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
497  Block *defaultDestination, ValueRange defaultOperands,
498  ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
499  ArrayRef<ValueRange> caseOperands,
500  ArrayRef<int32_t> branchWeights) {
501  DenseIntElementsAttr caseValuesAttr;
502  if (!caseValues.empty()) {
503  ShapedType caseValueType = VectorType::get(
504  static_cast<int64_t>(caseValues.size()), value.getType());
505  caseValuesAttr = DenseIntElementsAttr::get(caseValueType, caseValues);
506  }
507 
508  build(builder, result, value, defaultDestination, defaultOperands,
509  caseValuesAttr, caseDestinations, caseOperands, branchWeights);
510 }
511 
512 /// <cases> ::= `[` (case (`,` case )* )? `]`
513 /// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
514 static ParseResult parseSwitchOpCases(
515  OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues,
516  SmallVectorImpl<Block *> &caseDestinations,
518  SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
519  if (failed(parser.parseLSquare()))
520  return failure();
521  if (succeeded(parser.parseOptionalRSquare()))
522  return success();
523  SmallVector<APInt> values;
524  unsigned bitWidth = flagType.getIntOrFloatBitWidth();
525  auto parseCase = [&]() {
526  int64_t value = 0;
527  if (failed(parser.parseInteger(value)))
528  return failure();
529  values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
530 
531  Block *destination;
533  SmallVector<Type> operandTypes;
534  if (parser.parseColon() || parser.parseSuccessor(destination))
535  return failure();
536  if (!parser.parseOptionalLParen()) {
537  if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
538  /*allowResultNumber=*/false) ||
539  parser.parseColonTypeList(operandTypes) || parser.parseRParen())
540  return failure();
541  }
542  caseDestinations.push_back(destination);
543  caseOperands.emplace_back(operands);
544  caseOperandTypes.emplace_back(operandTypes);
545  return success();
546  };
547  if (failed(parser.parseCommaSeparatedList(parseCase)))
548  return failure();
549 
550  ShapedType caseValueType =
551  VectorType::get(static_cast<int64_t>(values.size()), flagType);
552  caseValues = DenseIntElementsAttr::get(caseValueType, values);
553  return parser.parseRSquare();
554 }
555 
556 static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
557  DenseIntElementsAttr caseValues,
558  SuccessorRange caseDestinations,
559  OperandRangeRange caseOperands,
560  const TypeRangeRange &caseOperandTypes) {
561  p << '[';
562  p.printNewline();
563  if (!caseValues) {
564  p << ']';
565  return;
566  }
567 
568  size_t index = 0;
569  llvm::interleave(
570  llvm::zip(caseValues, caseDestinations),
571  [&](auto i) {
572  p << " ";
573  p << std::get<0>(i);
574  p << ": ";
575  p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
576  },
577  [&] {
578  p << ',';
579  p.printNewline();
580  });
581  p.printNewline();
582  p << ']';
583 }
584 
585 LogicalResult SwitchOp::verify() {
586  if ((!getCaseValues() && !getCaseDestinations().empty()) ||
587  (getCaseValues() &&
588  getCaseValues()->size() !=
589  static_cast<int64_t>(getCaseDestinations().size())))
590  return emitOpError("expects number of case values to match number of "
591  "case destinations");
592  if (getCaseValues() &&
593  getValue().getType() != getCaseValues()->getElementType())
594  return emitError("expects case value type to match condition value type");
595  return success();
596 }
597 
599  assert(index < getNumSuccessors() && "invalid successor index");
600  return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
601  : getCaseOperandsMutable(index - 1));
602 }
603 
604 //===----------------------------------------------------------------------===//
605 // Code for LLVM::GEPOp.
606 //===----------------------------------------------------------------------===//
607 
608 constexpr int32_t GEPOp::kDynamicIndex;
609 
611  return GEPIndicesAdaptor<ValueRange>(getRawConstantIndicesAttr(),
612  getDynamicIndices());
613 }
614 
615 /// Returns the elemental type of any LLVM-compatible vector type or self.
617  if (auto vectorType = llvm::dyn_cast<VectorType>(type))
618  return vectorType.getElementType();
619  return type;
620 }
621 
622 /// Destructures the 'indices' parameter into 'rawConstantIndices' and
623 /// 'dynamicIndices', encoding the former in the process. In the process,
624 /// dynamic indices which are used to index into a structure type are converted
625 /// to constant indices when possible. To do this, the GEPs element type should
626 /// be passed as first parameter.
627 static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
628  SmallVectorImpl<int32_t> &rawConstantIndices,
629  SmallVectorImpl<Value> &dynamicIndices) {
630  for (const GEPArg &iter : indices) {
631  // If the thing we are currently indexing into is a struct we must turn
632  // any integer constants into constant indices. If this is not possible
633  // we don't do anything here. The verifier will catch it and emit a proper
634  // error. All other canonicalization is done in the fold method.
635  bool requiresConst = !rawConstantIndices.empty() &&
636  isa_and_nonnull<LLVMStructType>(currType);
637  if (Value val = llvm::dyn_cast_if_present<Value>(iter)) {
638  APInt intC;
639  if (requiresConst && matchPattern(val, m_ConstantInt(&intC)) &&
640  intC.isSignedIntN(kGEPConstantBitWidth)) {
641  rawConstantIndices.push_back(intC.getSExtValue());
642  } else {
643  rawConstantIndices.push_back(GEPOp::kDynamicIndex);
644  dynamicIndices.push_back(val);
645  }
646  } else {
647  rawConstantIndices.push_back(cast<GEPConstantIndex>(iter));
648  }
649 
650  // Skip for very first iteration of this loop. First index does not index
651  // within the aggregates, but is just a pointer offset.
652  if (rawConstantIndices.size() == 1 || !currType)
653  continue;
654 
655  currType = TypeSwitch<Type, Type>(currType)
656  .Case<VectorType, LLVMArrayType>([](auto containerType) {
657  return containerType.getElementType();
658  })
659  .Case([&](LLVMStructType structType) -> Type {
660  int64_t memberIndex = rawConstantIndices.back();
661  if (memberIndex >= 0 && static_cast<size_t>(memberIndex) <
662  structType.getBody().size())
663  return structType.getBody()[memberIndex];
664  return nullptr;
665  })
666  .Default(Type(nullptr));
667  }
668 }
669 
670 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
671  Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
672  GEPNoWrapFlags noWrapFlags,
673  ArrayRef<NamedAttribute> attributes) {
674  SmallVector<int32_t> rawConstantIndices;
675  SmallVector<Value> dynamicIndices;
676  destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
677 
678  result.addTypes(resultType);
679  result.addAttributes(attributes);
680  result.getOrAddProperties<Properties>().rawConstantIndices =
681  builder.getDenseI32ArrayAttr(rawConstantIndices);
682  result.getOrAddProperties<Properties>().noWrapFlags = noWrapFlags;
683  result.getOrAddProperties<Properties>().elem_type =
684  TypeAttr::get(elementType);
685  result.addOperands(basePtr);
686  result.addOperands(dynamicIndices);
687 }
688 
689 void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
690  Type elementType, Value basePtr, ValueRange indices,
691  GEPNoWrapFlags noWrapFlags,
692  ArrayRef<NamedAttribute> attributes) {
693  build(builder, result, resultType, elementType, basePtr,
694  SmallVector<GEPArg>(indices), noWrapFlags, attributes);
695 }
696 
697 static ParseResult
700  DenseI32ArrayAttr &rawConstantIndices) {
701  SmallVector<int32_t> constantIndices;
702 
703  auto idxParser = [&]() -> ParseResult {
704  int32_t constantIndex;
705  OptionalParseResult parsedInteger =
707  if (parsedInteger.has_value()) {
708  if (failed(parsedInteger.value()))
709  return failure();
710  constantIndices.push_back(constantIndex);
711  return success();
712  }
713 
714  constantIndices.push_back(LLVM::GEPOp::kDynamicIndex);
715  return parser.parseOperand(indices.emplace_back());
716  };
717  if (parser.parseCommaSeparatedList(idxParser))
718  return failure();
719 
720  rawConstantIndices =
721  DenseI32ArrayAttr::get(parser.getContext(), constantIndices);
722  return success();
723 }
724 
725 static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
726  OperandRange indices,
727  DenseI32ArrayAttr rawConstantIndices) {
728  llvm::interleaveComma(
729  GEPIndicesAdaptor<OperandRange>(rawConstantIndices, indices), printer,
731  if (Value val = llvm::dyn_cast_if_present<Value>(cst))
732  printer.printOperand(val);
733  else
734  printer << cast<IntegerAttr>(cst).getInt();
735  });
736 }
737 
738 /// For the given `indices`, check if they comply with `baseGEPType`,
739 /// especially check against LLVMStructTypes nested within.
740 static LogicalResult
741 verifyStructIndices(Type baseGEPType, unsigned indexPos,
743  function_ref<InFlightDiagnostic()> emitOpError) {
744  if (indexPos >= indices.size())
745  // Stop searching
746  return success();
747 
748  return TypeSwitch<Type, LogicalResult>(baseGEPType)
749  .Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
750  auto attr = dyn_cast<IntegerAttr>(indices[indexPos]);
751  if (!attr)
752  return emitOpError() << "expected index " << indexPos
753  << " indexing a struct to be constant";
754 
755  int32_t gepIndex = attr.getInt();
756  ArrayRef<Type> elementTypes = structType.getBody();
757  if (gepIndex < 0 ||
758  static_cast<size_t>(gepIndex) >= elementTypes.size())
759  return emitOpError() << "index " << indexPos
760  << " indexing a struct is out of bounds";
761 
762  // Instead of recursively going into every children types, we only
763  // dive into the one indexed by gepIndex.
764  return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
765  indices, emitOpError);
766  })
767  .Case<VectorType, LLVMArrayType>(
768  [&](auto containerType) -> LogicalResult {
769  return verifyStructIndices(containerType.getElementType(),
770  indexPos + 1, indices, emitOpError);
771  })
772  .Default([&](auto otherType) -> LogicalResult {
773  return emitOpError()
774  << "type " << otherType << " cannot be indexed (index #"
775  << indexPos << ")";
776  });
777 }
778 
779 /// Driver function around `verifyStructIndices`.
780 static LogicalResult
782  function_ref<InFlightDiagnostic()> emitOpError) {
783  return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
784 }
785 
786 LogicalResult LLVM::GEPOp::verify() {
787  if (static_cast<size_t>(
788  llvm::count(getRawConstantIndices(), kDynamicIndex)) !=
789  getDynamicIndices().size())
790  return emitOpError("expected as many dynamic indices as specified in '")
791  << getRawConstantIndicesAttrName().getValue() << "'";
792 
793  if (getNoWrapFlags() == GEPNoWrapFlags::inboundsFlag)
794  return emitOpError("'inbounds_flag' cannot be used directly.");
795 
796  return verifyStructIndices(getElemType(), getIndices(),
797  [&] { return emitOpError(); });
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // LoadOp
802 //===----------------------------------------------------------------------===//
803 
804 void LoadOp::getEffects(
806  &effects) {
807  effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
808  // Volatile operations can have target-specific read-write effects on
809  // memory besides the one referred to by the pointer operand.
810  // Similarly, atomic operations that are monotonic or stricter cause
811  // synchronization that from a language point-of-view, are arbitrary
812  // read-writes into memory.
813  if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
814  getOrdering() != AtomicOrdering::unordered)) {
815  effects.emplace_back(MemoryEffects::Write::get());
816  effects.emplace_back(MemoryEffects::Read::get());
817  }
818 }
819 
820 /// Returns true if the given type is supported by atomic operations. All
821 /// integer, float, and pointer types with a power-of-two bitsize and a minimal
822 /// size of 8 bits are supported.
824  const DataLayout &dataLayout) {
825  if (!isa<IntegerType, LLVMPointerType>(type))
827  return false;
828 
829  llvm::TypeSize bitWidth = dataLayout.getTypeSizeInBits(type);
830  if (bitWidth.isScalable())
831  return false;
832  // Needs to be at least 8 bits and a power of two.
833  return bitWidth >= 8 && (bitWidth & (bitWidth - 1)) == 0;
834 }
835 
836 /// Verifies the attributes and the type of atomic memory access operations.
837 template <typename OpTy>
838 LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType,
839  ArrayRef<AtomicOrdering> unsupportedOrderings) {
840  if (memOp.getOrdering() != AtomicOrdering::not_atomic) {
841  DataLayout dataLayout = DataLayout::closest(memOp);
842  if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout))
843  return memOp.emitOpError("unsupported type ")
844  << valueType << " for atomic access";
845  if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering()))
846  return memOp.emitOpError("unsupported ordering '")
847  << stringifyAtomicOrdering(memOp.getOrdering()) << "'";
848  if (!memOp.getAlignment())
849  return memOp.emitOpError("expected alignment for atomic access");
850  return success();
851  }
852  if (memOp.getSyncscope())
853  return memOp.emitOpError(
854  "expected syncscope to be null for non-atomic access");
855  return success();
856 }
857 
858 LogicalResult LoadOp::verify() {
859  Type valueType = getResult().getType();
860  return verifyAtomicMemOp(*this, valueType,
861  {AtomicOrdering::release, AtomicOrdering::acq_rel});
862 }
863 
864 void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
865  Value addr, unsigned alignment, bool isVolatile,
866  bool isNonTemporal, bool isInvariant, bool isInvariantGroup,
867  AtomicOrdering ordering, StringRef syncscope) {
868  build(builder, state, type, addr,
869  alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
870  isNonTemporal, isInvariant, isInvariantGroup, ordering,
871  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
872  /*dereferenceable=*/nullptr,
873  /*access_groups=*/nullptr,
874  /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr,
875  /*tbaa=*/nullptr);
876 }
877 
878 //===----------------------------------------------------------------------===//
879 // StoreOp
880 //===----------------------------------------------------------------------===//
881 
882 void StoreOp::getEffects(
884  &effects) {
885  effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
886  // Volatile operations can have target-specific read-write effects on
887  // memory besides the one referred to by the pointer operand.
888  // Similarly, atomic operations that are monotonic or stricter cause
889  // synchronization that from a language point-of-view, are arbitrary
890  // read-writes into memory.
891  if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic &&
892  getOrdering() != AtomicOrdering::unordered)) {
893  effects.emplace_back(MemoryEffects::Write::get());
894  effects.emplace_back(MemoryEffects::Read::get());
895  }
896 }
897 
898 LogicalResult StoreOp::verify() {
899  Type valueType = getValue().getType();
900  return verifyAtomicMemOp(*this, valueType,
901  {AtomicOrdering::acquire, AtomicOrdering::acq_rel});
902 }
903 
904 void StoreOp::build(OpBuilder &builder, OperationState &state, Value value,
905  Value addr, unsigned alignment, bool isVolatile,
906  bool isNonTemporal, bool isInvariantGroup,
907  AtomicOrdering ordering, StringRef syncscope) {
908  build(builder, state, value, addr,
909  alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
910  isNonTemporal, isInvariantGroup, ordering,
911  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope),
912  /*access_groups=*/nullptr,
913  /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
914 }
915 
916 //===----------------------------------------------------------------------===//
917 // CallOp
918 //===----------------------------------------------------------------------===//
919 
920 /// Gets the MLIR Op-like result types of a LLVMFunctionType.
921 static SmallVector<Type, 1> getCallOpResultTypes(LLVMFunctionType calleeType) {
922  SmallVector<Type, 1> results;
923  Type resultType = calleeType.getReturnType();
924  if (!isa<LLVM::LLVMVoidType>(resultType))
925  results.push_back(resultType);
926  return results;
927 }
928 
929 /// Gets the variadic callee type for a LLVMFunctionType.
930 static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType) {
931  return calleeType.isVarArg() ? TypeAttr::get(calleeType) : nullptr;
932 }
933 
934 /// Constructs a LLVMFunctionType from MLIR `results` and `args`.
935 static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results,
936  ValueRange args) {
937  Type resultType;
938  if (results.empty())
939  resultType = LLVMVoidType::get(context);
940  else
941  resultType = results.front();
942  return LLVMFunctionType::get(resultType, llvm::to_vector(args.getTypes()),
943  /*isVarArg=*/false);
944 }
945 
946 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
947  StringRef callee, ValueRange args) {
948  build(builder, state, results, builder.getStringAttr(callee), args);
949 }
950 
951 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
952  StringAttr callee, ValueRange args) {
953  build(builder, state, results, SymbolRefAttr::get(callee), args);
954 }
955 
956 void CallOp::build(OpBuilder &builder, OperationState &state, TypeRange results,
957  FlatSymbolRefAttr callee, ValueRange args) {
958  assert(callee && "expected non-null callee in direct call builder");
959  build(builder, state, results,
960  /*var_callee_type=*/nullptr, callee, args, /*fastmathFlags=*/nullptr,
961  /*CConv=*/nullptr, /*TailCallKind=*/nullptr,
962  /*memory_effects=*/nullptr,
963  /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
964  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
965  /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
966  /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
967  /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
968  /*no_inline=*/nullptr, /*always_inline=*/nullptr,
969  /*inline_hint=*/nullptr);
970 }
971 
972 void CallOp::build(OpBuilder &builder, OperationState &state,
973  LLVMFunctionType calleeType, StringRef callee,
974  ValueRange args) {
975  build(builder, state, calleeType, builder.getStringAttr(callee), args);
976 }
977 
978 void CallOp::build(OpBuilder &builder, OperationState &state,
979  LLVMFunctionType calleeType, StringAttr callee,
980  ValueRange args) {
981  build(builder, state, calleeType, SymbolRefAttr::get(callee), args);
982 }
983 
984 void CallOp::build(OpBuilder &builder, OperationState &state,
985  LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
986  ValueRange args) {
987  build(builder, state, getCallOpResultTypes(calleeType),
988  getCallOpVarCalleeType(calleeType), callee, args,
989  /*fastmathFlags=*/nullptr,
990  /*CConv=*/nullptr,
991  /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
992  /*convergent=*/nullptr,
993  /*no_unwind=*/nullptr, /*will_return=*/nullptr,
994  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
995  /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
996  /*access_groups=*/nullptr,
997  /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
998  /*no_inline=*/nullptr, /*always_inline=*/nullptr,
999  /*inline_hint=*/nullptr);
1000 }
1001 
1002 void CallOp::build(OpBuilder &builder, OperationState &state,
1003  LLVMFunctionType calleeType, ValueRange args) {
1004  build(builder, state, getCallOpResultTypes(calleeType),
1005  getCallOpVarCalleeType(calleeType),
1006  /*callee=*/nullptr, args,
1007  /*fastmathFlags=*/nullptr,
1008  /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1009  /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1010  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1011  /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1012  /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1013  /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1014  /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1015  /*inline_hint=*/nullptr);
1016 }
1017 
1018 void CallOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1019  ValueRange args) {
1020  auto calleeType = func.getFunctionType();
1021  build(builder, state, getCallOpResultTypes(calleeType),
1022  getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), args,
1023  /*fastmathFlags=*/nullptr,
1024  /*CConv=*/nullptr, /*TailCallKind=*/nullptr, /*memory_effects=*/nullptr,
1025  /*convergent=*/nullptr, /*no_unwind=*/nullptr, /*will_return=*/nullptr,
1026  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{},
1027  /*access_groups=*/nullptr, /*alias_scopes=*/nullptr,
1028  /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
1029  /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr,
1030  /*no_inline=*/nullptr, /*always_inline=*/nullptr,
1031  /*inline_hint=*/nullptr);
1032 }
1033 
1034 CallInterfaceCallable CallOp::getCallableForCallee() {
1035  // Direct call.
1036  if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1037  return calleeAttr;
1038  // Indirect call, callee Value is the first operand.
1039  return getOperand(0);
1040 }
1041 
1042 void CallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1043  // Direct call.
1044  if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1045  auto symRef = cast<SymbolRefAttr>(callee);
1046  return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1047  }
1048  // Indirect call, callee Value is the first operand.
1049  return setOperand(0, cast<Value>(callee));
1050 }
1051 
1052 Operation::operand_range CallOp::getArgOperands() {
1053  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1054 }
1055 
1056 MutableOperandRange CallOp::getArgOperandsMutable() {
1057  return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1058  getCalleeOperands().size());
1059 }
1060 
1061 /// Verify that an inlinable callsite of a debug-info-bearing function in a
1062 /// debug-info-bearing function has a debug location attached to it. This
1063 /// mirrors an LLVM IR verifier.
1064 static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) {
1065  if (callee.isExternal())
1066  return success();
1067  auto parentFunc = callOp->getParentOfType<FunctionOpInterface>();
1068  if (!parentFunc)
1069  return success();
1070 
1071  auto hasSubprogram = [](Operation *op) {
1072  return op->getLoc()
1073  ->findInstanceOf<FusedLocWith<LLVM::DISubprogramAttr>>() !=
1074  nullptr;
1075  };
1076  if (!hasSubprogram(parentFunc) || !hasSubprogram(callee))
1077  return success();
1078  bool containsLoc = !isa<UnknownLoc>(callOp->getLoc());
1079  if (!containsLoc)
1080  return callOp.emitError()
1081  << "inlinable function call in a function with a DISubprogram "
1082  "location must have a debug location";
1083  return success();
1084 }
1085 
1086 /// Verify that the parameter and return types of the variadic callee type match
1087 /// the `callOp` argument and result types.
1088 template <typename OpTy>
1089 LogicalResult verifyCallOpVarCalleeType(OpTy callOp) {
1090  std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType();
1091  if (!varCalleeType)
1092  return success();
1093 
1094  // Verify the variadic callee type is a variadic function type.
1095  if (!varCalleeType->isVarArg())
1096  return callOp.emitOpError(
1097  "expected var_callee_type to be a variadic function type");
1098 
1099  // Verify the variadic callee type has at most as many parameters as the call
1100  // has argument operands.
1101  if (varCalleeType->getNumParams() > callOp.getArgOperands().size())
1102  return callOp.emitOpError("expected var_callee_type to have at most ")
1103  << callOp.getArgOperands().size() << " parameters";
1104 
1105  // Verify the variadic callee type matches the call argument types.
1106  for (auto [paramType, operand] :
1107  llvm::zip(varCalleeType->getParams(), callOp.getArgOperands()))
1108  if (paramType != operand.getType())
1109  return callOp.emitOpError()
1110  << "var_callee_type parameter type mismatch: " << paramType
1111  << " != " << operand.getType();
1112 
1113  // Verify the variadic callee type matches the call result type.
1114  if (!callOp.getNumResults()) {
1115  if (!isa<LLVMVoidType>(varCalleeType->getReturnType()))
1116  return callOp.emitOpError("expected var_callee_type to return void");
1117  } else {
1118  if (callOp.getResult().getType() != varCalleeType->getReturnType())
1119  return callOp.emitOpError("var_callee_type return type mismatch: ")
1120  << varCalleeType->getReturnType()
1121  << " != " << callOp.getResult().getType();
1122  }
1123  return success();
1124 }
1125 
1126 template <typename OpType>
1127 static LogicalResult verifyOperandBundles(OpType &op) {
1128  OperandRangeRange opBundleOperands = op.getOpBundleOperands();
1129  std::optional<ArrayAttr> opBundleTags = op.getOpBundleTags();
1130 
1131  auto isStringAttr = [](Attribute tagAttr) {
1132  return isa<StringAttr>(tagAttr);
1133  };
1134  if (opBundleTags && !llvm::all_of(*opBundleTags, isStringAttr))
1135  return op.emitError("operand bundle tag must be a StringAttr");
1136 
1137  size_t numOpBundles = opBundleOperands.size();
1138  size_t numOpBundleTags = opBundleTags ? opBundleTags->size() : 0;
1139  if (numOpBundles != numOpBundleTags)
1140  return op.emitError("expected ")
1141  << numOpBundles << " operand bundle tags, but actually got "
1142  << numOpBundleTags;
1143 
1144  return success();
1145 }
1146 
1147 LogicalResult CallOp::verify() { return verifyOperandBundles(*this); }
1148 
1149 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1150  if (failed(verifyCallOpVarCalleeType(*this)))
1151  return failure();
1152 
1153  // Type for the callee, we'll get it differently depending if it is a direct
1154  // or indirect call.
1155  Type fnType;
1156 
1157  bool isIndirect = false;
1158 
1159  // If this is an indirect call, the callee attribute is missing.
1160  FlatSymbolRefAttr calleeName = getCalleeAttr();
1161  if (!calleeName) {
1162  isIndirect = true;
1163  if (!getNumOperands())
1164  return emitOpError(
1165  "must have either a `callee` attribute or at least an operand");
1166  auto ptrType = llvm::dyn_cast<LLVMPointerType>(getOperand(0).getType());
1167  if (!ptrType)
1168  return emitOpError("indirect call expects a pointer as callee: ")
1169  << getOperand(0).getType();
1170 
1171  return success();
1172  } else {
1173  Operation *callee =
1174  symbolTable.lookupNearestSymbolFrom(*this, calleeName.getAttr());
1175  if (!callee)
1176  return emitOpError()
1177  << "'" << calleeName.getValue()
1178  << "' does not reference a symbol in the current scope";
1179  auto fn = dyn_cast<LLVMFuncOp>(callee);
1180  if (!fn)
1181  return emitOpError() << "'" << calleeName.getValue()
1182  << "' does not reference a valid LLVM function";
1183 
1184  if (failed(verifyCallOpDebugInfo(*this, fn)))
1185  return failure();
1186  fnType = fn.getFunctionType();
1187  }
1188 
1189  LLVMFunctionType funcType = llvm::dyn_cast<LLVMFunctionType>(fnType);
1190  if (!funcType)
1191  return emitOpError("callee does not have a functional type: ") << fnType;
1192 
1193  if (funcType.isVarArg() && !getVarCalleeType())
1194  return emitOpError() << "missing var_callee_type attribute for vararg call";
1195 
1196  // Verify that the operand and result types match the callee.
1197 
1198  if (!funcType.isVarArg() &&
1199  funcType.getNumParams() != (getCalleeOperands().size() - isIndirect))
1200  return emitOpError() << "incorrect number of operands ("
1201  << (getCalleeOperands().size() - isIndirect)
1202  << ") for callee (expecting: "
1203  << funcType.getNumParams() << ")";
1204 
1205  if (funcType.getNumParams() > (getCalleeOperands().size() - isIndirect))
1206  return emitOpError() << "incorrect number of operands ("
1207  << (getCalleeOperands().size() - isIndirect)
1208  << ") for varargs callee (expecting at least: "
1209  << funcType.getNumParams() << ")";
1210 
1211  for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
1212  if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
1213  return emitOpError() << "operand type mismatch for operand " << i << ": "
1214  << getOperand(i + isIndirect).getType()
1215  << " != " << funcType.getParamType(i);
1216 
1217  if (getNumResults() == 0 &&
1218  !llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1219  return emitOpError() << "expected function call to produce a value";
1220 
1221  if (getNumResults() != 0 &&
1222  llvm::isa<LLVM::LLVMVoidType>(funcType.getReturnType()))
1223  return emitOpError()
1224  << "calling function with void result must not produce values";
1225 
1226  if (getNumResults() > 1)
1227  return emitOpError()
1228  << "expected LLVM function call to produce 0 or 1 result";
1229 
1230  if (getNumResults() && getResult().getType() != funcType.getReturnType())
1231  return emitOpError() << "result type mismatch: " << getResult().getType()
1232  << " != " << funcType.getReturnType();
1233 
1234  return success();
1235 }
1236 
1237 void CallOp::print(OpAsmPrinter &p) {
1238  auto callee = getCallee();
1239  bool isDirect = callee.has_value();
1240 
1241  p << ' ';
1242 
1243  // Print calling convention.
1244  if (getCConv() != LLVM::CConv::C)
1245  p << stringifyCConv(getCConv()) << ' ';
1246 
1247  if (getTailCallKind() != LLVM::TailCallKind::None)
1248  p << tailcallkind::stringifyTailCallKind(getTailCallKind()) << ' ';
1249 
1250  // Print the direct callee if present as a function attribute, or an indirect
1251  // callee (first operand) otherwise.
1252  if (isDirect)
1253  p.printSymbolName(callee.value());
1254  else
1255  p << getOperand(0);
1256 
1257  auto args = getCalleeOperands().drop_front(isDirect ? 0 : 1);
1258  p << '(' << args << ')';
1259 
1260  // Print the variadic callee type if the call is variadic.
1261  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1262  p << " vararg(" << *varCalleeType << ")";
1263 
1264  if (!getOpBundleOperands().empty()) {
1265  p << " ";
1266  printOpBundles(p, *this, getOpBundleOperands(),
1267  getOpBundleOperands().getTypes(), getOpBundleTags());
1268  }
1269 
1270  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
1271  {getCalleeAttrName(), getTailCallKindAttrName(),
1272  getVarCalleeTypeAttrName(), getCConvAttrName(),
1273  getOperandSegmentSizesAttrName(),
1274  getOpBundleSizesAttrName(),
1275  getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1276  getResAttrsAttrName()});
1277 
1278  p << " : ";
1279  if (!isDirect)
1280  p << getOperand(0).getType() << ", ";
1281 
1282  // Reconstruct the MLIR function type from operand and result types.
1284  p, args.getTypes(), getArgAttrsAttr(),
1285  /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1286 }
1287 
1288 /// Parses the type of a call operation and resolves the operands if the parsing
1289 /// succeeds. Returns failure otherwise.
1291  OpAsmParser &parser, OperationState &result, bool isDirect,
1294  SmallVectorImpl<DictionaryAttr> &resultAttrs) {
1295  SMLoc trailingTypesLoc = parser.getCurrentLocation();
1296  SmallVector<Type> types;
1297  if (parser.parseColon())
1298  return failure();
1299  if (!isDirect) {
1300  types.emplace_back();
1301  if (parser.parseType(types.back()))
1302  return failure();
1303  if (parser.parseOptionalComma())
1304  return parser.emitError(
1305  trailingTypesLoc, "expected indirect call to have 2 trailing types");
1306  }
1307  SmallVector<Type> argTypes;
1308  SmallVector<Type> resTypes;
1309  if (call_interface_impl::parseFunctionSignature(parser, argTypes, argAttrs,
1310  resTypes, resultAttrs)) {
1311  if (isDirect)
1312  return parser.emitError(trailingTypesLoc,
1313  "expected direct call to have 1 trailing types");
1314  return parser.emitError(trailingTypesLoc,
1315  "expected trailing function type");
1316  }
1317 
1318  if (resTypes.size() > 1)
1319  return parser.emitError(trailingTypesLoc,
1320  "expected function with 0 or 1 result");
1321  if (resTypes.size() == 1 && llvm::isa<LLVM::LLVMVoidType>(resTypes[0]))
1322  return parser.emitError(trailingTypesLoc,
1323  "expected a non-void result type");
1324 
1325  // The head element of the types list matches the callee type for
1326  // indirect calls, while the types list is emtpy for direct calls.
1327  // Append the function input types to resolve the call operation
1328  // operands.
1329  llvm::append_range(types, argTypes);
1330  if (parser.resolveOperands(operands, types, parser.getNameLoc(),
1331  result.operands))
1332  return failure();
1333  if (resTypes.size() != 0)
1334  result.addTypes(resTypes);
1335 
1336  return success();
1337 }
1338 
1339 /// Parses an optional function pointer operand before the call argument list
1340 /// for indirect calls, or stops parsing at the function identifier otherwise.
1341 static ParseResult parseOptionalCallFuncPtr(
1342  OpAsmParser &parser,
1344  OpAsmParser::UnresolvedOperand funcPtrOperand;
1345  OptionalParseResult parseResult = parser.parseOptionalOperand(funcPtrOperand);
1346  if (parseResult.has_value()) {
1347  if (failed(*parseResult))
1348  return *parseResult;
1349  operands.push_back(funcPtrOperand);
1350  }
1351  return success();
1352 }
1353 
1354 static ParseResult resolveOpBundleOperands(
1355  OpAsmParser &parser, SMLoc loc, OperationState &state,
1357  ArrayRef<SmallVector<Type>> opBundleOperandTypes,
1358  StringAttr opBundleSizesAttrName) {
1359  unsigned opBundleIndex = 0;
1360  for (const auto &[operands, types] :
1361  llvm::zip_equal(opBundleOperands, opBundleOperandTypes)) {
1362  if (operands.size() != types.size())
1363  return parser.emitError(loc, "expected ")
1364  << operands.size()
1365  << " types for operand bundle operands for operand bundle #"
1366  << opBundleIndex << ", but actually got " << types.size();
1367  if (parser.resolveOperands(operands, types, loc, state.operands))
1368  return failure();
1369  }
1370 
1371  SmallVector<int32_t> opBundleSizes;
1372  opBundleSizes.reserve(opBundleOperands.size());
1373  for (const auto &operands : opBundleOperands)
1374  opBundleSizes.push_back(operands.size());
1375 
1376  state.addAttribute(
1377  opBundleSizesAttrName,
1378  DenseI32ArrayAttr::get(parser.getContext(), opBundleSizes));
1379 
1380  return success();
1381 }
1382 
1383 // <operation> ::= `llvm.call` (cconv)? (tailcallkind)? (function-id | ssa-use)
1384 // `(` ssa-use-list `)`
1385 // ( `vararg(` var-callee-type `)` )?
1386 // ( `[` op-bundles-list `]` )?
1387 // attribute-dict? `:` (type `,`)? function-type
1388 ParseResult CallOp::parse(OpAsmParser &parser, OperationState &result) {
1389  SymbolRefAttr funcAttr;
1390  TypeAttr varCalleeType;
1393  SmallVector<SmallVector<Type>> opBundleOperandTypes;
1394  ArrayAttr opBundleTags;
1395 
1396  // Default to C Calling Convention if no keyword is provided.
1397  result.addAttribute(
1398  getCConvAttrName(result.name),
1399  CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1400  parser, result, LLVM::CConv::C)));
1401 
1402  result.addAttribute(
1403  getTailCallKindAttrName(result.name),
1405  parseOptionalLLVMKeyword<TailCallKind>(
1406  parser, result, LLVM::TailCallKind::None)));
1407 
1408  // Parse a function pointer for indirect calls.
1409  if (parseOptionalCallFuncPtr(parser, operands))
1410  return failure();
1411  bool isDirect = operands.empty();
1412 
1413  // Parse a function identifier for direct calls.
1414  if (isDirect)
1415  if (parser.parseAttribute(funcAttr, "callee", result.attributes))
1416  return failure();
1417 
1418  // Parse the function arguments.
1419  if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren))
1420  return failure();
1421 
1422  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1423  if (isVarArg) {
1424  StringAttr varCalleeTypeAttrName =
1425  CallOp::getVarCalleeTypeAttrName(result.name);
1426  if (parser.parseLParen().failed() ||
1427  parser
1428  .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1429  result.attributes)
1430  .failed() ||
1431  parser.parseRParen().failed())
1432  return failure();
1433  }
1434 
1435  SMLoc opBundlesLoc = parser.getCurrentLocation();
1436  if (std::optional<ParseResult> result = parseOpBundles(
1437  parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1438  result && failed(*result))
1439  return failure();
1440  if (opBundleTags && !opBundleTags.empty())
1441  result.addAttribute(CallOp::getOpBundleTagsAttrName(result.name).getValue(),
1442  opBundleTags);
1443 
1444  if (parser.parseOptionalAttrDict(result.attributes))
1445  return failure();
1446 
1447  // Parse the trailing type list and resolve the operands.
1448  SmallVector<DictionaryAttr> argAttrs;
1449  SmallVector<DictionaryAttr> resultAttrs;
1450  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1451  argAttrs, resultAttrs))
1452  return failure();
1454  parser.getBuilder(), result, argAttrs, resultAttrs,
1455  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1456  if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1457  opBundleOperandTypes,
1458  getOpBundleSizesAttrName(result.name)))
1459  return failure();
1460 
1461  int32_t numOpBundleOperands = 0;
1462  for (const auto &operands : opBundleOperands)
1463  numOpBundleOperands += operands.size();
1464 
1465  result.addAttribute(
1466  CallOp::getOperandSegmentSizeAttr(),
1468  {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
1469  return success();
1470 }
1471 
1472 LLVMFunctionType CallOp::getCalleeFunctionType() {
1473  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1474  return *varCalleeType;
1475  return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1476 }
1477 
1478 ///===---------------------------------------------------------------------===//
1479 /// LLVM::InvokeOp
1480 ///===---------------------------------------------------------------------===//
1481 
1482 void InvokeOp::build(OpBuilder &builder, OperationState &state, LLVMFuncOp func,
1483  ValueRange ops, Block *normal, ValueRange normalOps,
1484  Block *unwind, ValueRange unwindOps) {
1485  auto calleeType = func.getFunctionType();
1486  build(builder, state, getCallOpResultTypes(calleeType),
1487  getCallOpVarCalleeType(calleeType), SymbolRefAttr::get(func), ops,
1488  /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1489  nullptr, nullptr, {}, {}, normal, unwind);
1490 }
1491 
1492 void InvokeOp::build(OpBuilder &builder, OperationState &state, TypeRange tys,
1493  FlatSymbolRefAttr callee, ValueRange ops, Block *normal,
1494  ValueRange normalOps, Block *unwind,
1495  ValueRange unwindOps) {
1496  build(builder, state, tys,
1497  /*var_callee_type=*/nullptr, callee, ops, /*arg_attrs=*/nullptr,
1498  /*res_attrs=*/nullptr, normalOps, unwindOps, nullptr, nullptr, {}, {},
1499  normal, unwind);
1500 }
1501 
1502 void InvokeOp::build(OpBuilder &builder, OperationState &state,
1503  LLVMFunctionType calleeType, FlatSymbolRefAttr callee,
1504  ValueRange ops, Block *normal, ValueRange normalOps,
1505  Block *unwind, ValueRange unwindOps) {
1506  build(builder, state, getCallOpResultTypes(calleeType),
1507  getCallOpVarCalleeType(calleeType), callee, ops,
1508  /*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, normalOps, unwindOps,
1509  nullptr, nullptr, {}, {}, normal, unwind);
1510 }
1511 
1513  assert(index < getNumSuccessors() && "invalid successor index");
1514  return SuccessorOperands(index == 0 ? getNormalDestOperandsMutable()
1515  : getUnwindDestOperandsMutable());
1516 }
1517 
1518 CallInterfaceCallable InvokeOp::getCallableForCallee() {
1519  // Direct call.
1520  if (FlatSymbolRefAttr calleeAttr = getCalleeAttr())
1521  return calleeAttr;
1522  // Indirect call, callee Value is the first operand.
1523  return getOperand(0);
1524 }
1525 
1526 void InvokeOp::setCalleeFromCallable(CallInterfaceCallable callee) {
1527  // Direct call.
1528  if (FlatSymbolRefAttr calleeAttr = getCalleeAttr()) {
1529  auto symRef = cast<SymbolRefAttr>(callee);
1530  return setCalleeAttr(cast<FlatSymbolRefAttr>(symRef));
1531  }
1532  // Indirect call, callee Value is the first operand.
1533  return setOperand(0, cast<Value>(callee));
1534 }
1535 
1536 Operation::operand_range InvokeOp::getArgOperands() {
1537  return getCalleeOperands().drop_front(getCallee().has_value() ? 0 : 1);
1538 }
1539 
1540 MutableOperandRange InvokeOp::getArgOperandsMutable() {
1541  return MutableOperandRange(*this, getCallee().has_value() ? 0 : 1,
1542  getCalleeOperands().size());
1543 }
1544 
1545 LogicalResult InvokeOp::verify() {
1546  if (failed(verifyCallOpVarCalleeType(*this)))
1547  return failure();
1548 
1549  Block *unwindDest = getUnwindDest();
1550  if (unwindDest->empty())
1551  return emitError("must have at least one operation in unwind destination");
1552 
1553  // In unwind destination, first operation must be LandingpadOp
1554  if (!isa<LandingpadOp>(unwindDest->front()))
1555  return emitError("first operation in unwind destination should be a "
1556  "llvm.landingpad operation");
1557 
1558  if (failed(verifyOperandBundles(*this)))
1559  return failure();
1560 
1561  return success();
1562 }
1563 
1564 void InvokeOp::print(OpAsmPrinter &p) {
1565  auto callee = getCallee();
1566  bool isDirect = callee.has_value();
1567 
1568  p << ' ';
1569 
1570  // Print calling convention.
1571  if (getCConv() != LLVM::CConv::C)
1572  p << stringifyCConv(getCConv()) << ' ';
1573 
1574  // Either function name or pointer
1575  if (isDirect)
1576  p.printSymbolName(callee.value());
1577  else
1578  p << getOperand(0);
1579 
1580  p << '(' << getCalleeOperands().drop_front(isDirect ? 0 : 1) << ')';
1581  p << " to ";
1582  p.printSuccessorAndUseList(getNormalDest(), getNormalDestOperands());
1583  p << " unwind ";
1584  p.printSuccessorAndUseList(getUnwindDest(), getUnwindDestOperands());
1585 
1586  // Print the variadic callee type if the invoke is variadic.
1587  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1588  p << " vararg(" << *varCalleeType << ")";
1589 
1590  if (!getOpBundleOperands().empty()) {
1591  p << " ";
1592  printOpBundles(p, *this, getOpBundleOperands(),
1593  getOpBundleOperands().getTypes(), getOpBundleTags());
1594  }
1595 
1596  p.printOptionalAttrDict((*this)->getAttrs(),
1597  {getCalleeAttrName(), getOperandSegmentSizeAttr(),
1598  getCConvAttrName(), getVarCalleeTypeAttrName(),
1599  getOpBundleSizesAttrName(),
1600  getOpBundleTagsAttrName(), getArgAttrsAttrName(),
1601  getResAttrsAttrName()});
1602 
1603  p << " : ";
1604  if (!isDirect)
1605  p << getOperand(0).getType() << ", ";
1607  p, getCalleeOperands().drop_front(isDirect ? 0 : 1).getTypes(),
1608  getArgAttrsAttr(),
1609  /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
1610 }
1611 
1612 // <operation> ::= `llvm.invoke` (cconv)? (function-id | ssa-use)
1613 // `(` ssa-use-list `)`
1614 // `to` bb-id (`[` ssa-use-and-type-list `]`)?
1615 // `unwind` bb-id (`[` ssa-use-and-type-list `]`)?
1616 // ( `vararg(` var-callee-type `)` )?
1617 // ( `[` op-bundles-list `]` )?
1618 // attribute-dict? `:` (type `,`)?
1619 // function-type-with-argument-attributes
1620 ParseResult InvokeOp::parse(OpAsmParser &parser, OperationState &result) {
1622  SymbolRefAttr funcAttr;
1623  TypeAttr varCalleeType;
1625  SmallVector<SmallVector<Type>> opBundleOperandTypes;
1626  ArrayAttr opBundleTags;
1627  Block *normalDest, *unwindDest;
1628  SmallVector<Value, 4> normalOperands, unwindOperands;
1629  Builder &builder = parser.getBuilder();
1630 
1631  // Default to C Calling Convention if no keyword is provided.
1632  result.addAttribute(
1633  getCConvAttrName(result.name),
1634  CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
1635  parser, result, LLVM::CConv::C)));
1636 
1637  // Parse a function pointer for indirect calls.
1638  if (parseOptionalCallFuncPtr(parser, operands))
1639  return failure();
1640  bool isDirect = operands.empty();
1641 
1642  // Parse a function identifier for direct calls.
1643  if (isDirect && parser.parseAttribute(funcAttr, "callee", result.attributes))
1644  return failure();
1645 
1646  // Parse the function arguments.
1647  if (parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
1648  parser.parseKeyword("to") ||
1649  parser.parseSuccessorAndUseList(normalDest, normalOperands) ||
1650  parser.parseKeyword("unwind") ||
1651  parser.parseSuccessorAndUseList(unwindDest, unwindOperands))
1652  return failure();
1653 
1654  bool isVarArg = parser.parseOptionalKeyword("vararg").succeeded();
1655  if (isVarArg) {
1656  StringAttr varCalleeTypeAttrName =
1657  InvokeOp::getVarCalleeTypeAttrName(result.name);
1658  if (parser.parseLParen().failed() ||
1659  parser
1660  .parseAttribute(varCalleeType, varCalleeTypeAttrName,
1661  result.attributes)
1662  .failed() ||
1663  parser.parseRParen().failed())
1664  return failure();
1665  }
1666 
1667  SMLoc opBundlesLoc = parser.getCurrentLocation();
1668  if (std::optional<ParseResult> result = parseOpBundles(
1669  parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
1670  result && failed(*result))
1671  return failure();
1672  if (opBundleTags && !opBundleTags.empty())
1673  result.addAttribute(
1674  InvokeOp::getOpBundleTagsAttrName(result.name).getValue(),
1675  opBundleTags);
1676 
1677  if (parser.parseOptionalAttrDict(result.attributes))
1678  return failure();
1679 
1680  // Parse the trailing type list and resolve the function operands.
1681  SmallVector<DictionaryAttr> argAttrs;
1682  SmallVector<DictionaryAttr> resultAttrs;
1683  if (parseCallTypeAndResolveOperands(parser, result, isDirect, operands,
1684  argAttrs, resultAttrs))
1685  return failure();
1687  parser.getBuilder(), result, argAttrs, resultAttrs,
1688  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1689 
1690  if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
1691  opBundleOperandTypes,
1692  getOpBundleSizesAttrName(result.name)))
1693  return failure();
1694 
1695  result.addSuccessors({normalDest, unwindDest});
1696  result.addOperands(normalOperands);
1697  result.addOperands(unwindOperands);
1698 
1699  int32_t numOpBundleOperands = 0;
1700  for (const auto &operands : opBundleOperands)
1701  numOpBundleOperands += operands.size();
1702 
1703  result.addAttribute(
1704  InvokeOp::getOperandSegmentSizeAttr(),
1705  builder.getDenseI32ArrayAttr({static_cast<int32_t>(operands.size()),
1706  static_cast<int32_t>(normalOperands.size()),
1707  static_cast<int32_t>(unwindOperands.size()),
1708  numOpBundleOperands}));
1709  return success();
1710 }
1711 
1712 LLVMFunctionType InvokeOp::getCalleeFunctionType() {
1713  if (std::optional<LLVMFunctionType> varCalleeType = getVarCalleeType())
1714  return *varCalleeType;
1715  return getLLVMFuncType(getContext(), getResultTypes(), getArgOperands());
1716 }
1717 
1718 ///===----------------------------------------------------------------------===//
1719 /// Verifying/Printing/Parsing for LLVM::LandingpadOp.
1720 ///===----------------------------------------------------------------------===//
1721 
1722 LogicalResult LandingpadOp::verify() {
1723  Value value;
1724  if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
1725  if (!func.getPersonality())
1726  return emitError(
1727  "llvm.landingpad needs to be in a function with a personality");
1728  }
1729 
1730  // Consistency of llvm.landingpad result types is checked in
1731  // LLVMFuncOp::verify().
1732 
1733  if (!getCleanup() && getOperands().empty())
1734  return emitError("landingpad instruction expects at least one clause or "
1735  "cleanup attribute");
1736 
1737  for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
1738  value = getOperand(idx);
1739  bool isFilter = llvm::isa<LLVMArrayType>(value.getType());
1740  if (isFilter) {
1741  // FIXME: Verify filter clauses when arrays are appropriately handled
1742  } else {
1743  // catch - global addresses only.
1744  // Bitcast ops should have global addresses as their args.
1745  if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
1746  if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
1747  continue;
1748  return emitError("constant clauses expected").attachNote(bcOp.getLoc())
1749  << "global addresses expected as operand to "
1750  "bitcast used in clauses for landingpad";
1751  }
1752  // ZeroOp and AddressOfOp allowed
1753  if (value.getDefiningOp<ZeroOp>())
1754  continue;
1755  if (value.getDefiningOp<AddressOfOp>())
1756  continue;
1757  return emitError("clause #")
1758  << idx << " is not a known constant - null, addressof, bitcast";
1759  }
1760  }
1761  return success();
1762 }
1763 
1765  p << (getCleanup() ? " cleanup " : " ");
1766 
1767  // Clauses
1768  for (auto value : getOperands()) {
1769  // Similar to llvm - if clause is an array type then it is filter
1770  // clause else catch clause
1771  bool isArrayTy = llvm::isa<LLVMArrayType>(value.getType());
1772  p << '(' << (isArrayTy ? "filter " : "catch ") << value << " : "
1773  << value.getType() << ") ";
1774  }
1775 
1776  p.printOptionalAttrDict((*this)->getAttrs(), {"cleanup"});
1777 
1778  p << ": " << getType();
1779 }
1780 
1781 // <operation> ::= `llvm.landingpad` `cleanup`?
1782 // ((`catch` | `filter`) operand-type ssa-use)* attribute-dict?
1783 ParseResult LandingpadOp::parse(OpAsmParser &parser, OperationState &result) {
1784  // Check for cleanup
1785  if (succeeded(parser.parseOptionalKeyword("cleanup")))
1786  result.addAttribute("cleanup", parser.getBuilder().getUnitAttr());
1787 
1788  // Parse clauses with types
1789  while (succeeded(parser.parseOptionalLParen()) &&
1790  (succeeded(parser.parseOptionalKeyword("filter")) ||
1791  succeeded(parser.parseOptionalKeyword("catch")))) {
1793  Type ty;
1794  if (parser.parseOperand(operand) || parser.parseColon() ||
1795  parser.parseType(ty) ||
1796  parser.resolveOperand(operand, ty, result.operands) ||
1797  parser.parseRParen())
1798  return failure();
1799  }
1800 
1801  Type type;
1802  if (parser.parseColon() || parser.parseType(type))
1803  return failure();
1804 
1805  result.addTypes(type);
1806  return success();
1807 }
1808 
1809 //===----------------------------------------------------------------------===//
1810 // ExtractValueOp
1811 //===----------------------------------------------------------------------===//
1812 
1813 /// Extract the type at `position` in the LLVM IR aggregate type
1814 /// `containerType`. Each element of `position` is an index into a nested
1815 /// aggregate type. Return the resulting type or emit an error.
1817  function_ref<InFlightDiagnostic(StringRef)> emitError, Type containerType,
1818  ArrayRef<int64_t> position) {
1819  Type llvmType = containerType;
1820  if (!isCompatibleType(containerType)) {
1821  emitError("expected LLVM IR Dialect type, got ") << containerType;
1822  return {};
1823  }
1824 
1825  // Infer the element type from the structure type: iteratively step inside the
1826  // type by taking the element type, indexed by the position attribute for
1827  // structures. Check the position index before accessing, it is supposed to
1828  // be in bounds.
1829  for (int64_t idx : position) {
1830  if (auto arrayType = llvm::dyn_cast<LLVMArrayType>(llvmType)) {
1831  if (idx < 0 || static_cast<unsigned>(idx) >= arrayType.getNumElements()) {
1832  emitError("position out of bounds: ") << idx;
1833  return {};
1834  }
1835  llvmType = arrayType.getElementType();
1836  } else if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType)) {
1837  if (idx < 0 ||
1838  static_cast<unsigned>(idx) >= structType.getBody().size()) {
1839  emitError("position out of bounds: ") << idx;
1840  return {};
1841  }
1842  llvmType = structType.getBody()[idx];
1843  } else {
1844  emitError("expected LLVM IR structure/array type, got: ") << llvmType;
1845  return {};
1846  }
1847  }
1848  return llvmType;
1849 }
1850 
1851 /// Extract the type at `position` in the wrapped LLVM IR aggregate type
1852 /// `containerType`.
1854  ArrayRef<int64_t> position) {
1855  for (int64_t idx : position) {
1856  if (auto structType = llvm::dyn_cast<LLVMStructType>(llvmType))
1857  llvmType = structType.getBody()[idx];
1858  else
1859  llvmType = llvm::cast<LLVMArrayType>(llvmType).getElementType();
1860  }
1861  return llvmType;
1862 }
1863 
1864 OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1865  if (auto extractValueOp = getContainer().getDefiningOp<ExtractValueOp>()) {
1866  SmallVector<int64_t, 4> newPos(extractValueOp.getPosition());
1867  newPos.append(getPosition().begin(), getPosition().end());
1868  setPosition(newPos);
1869  getContainerMutable().set(extractValueOp.getContainer());
1870  return getResult();
1871  }
1872 
1873  {
1874  DenseElementsAttr constval;
1875  matchPattern(getContainer(), m_Constant(&constval));
1876  if (constval && constval.getElementType() == getType()) {
1877  if (isa<SplatElementsAttr>(constval))
1878  return constval.getSplatValue<Attribute>();
1879  if (getPosition().size() == 1)
1880  return constval.getValues<Attribute>()[getPosition()[0]];
1881  }
1882  }
1883 
1884  auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
1885  OpFoldResult result = {};
1886  ArrayRef<int64_t> extractPos = getPosition();
1887  bool switchedToInsertedValue = false;
1888  while (insertValueOp) {
1889  ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1890  auto extractPosSize = extractPos.size();
1891  auto insertPosSize = insertPos.size();
1892 
1893  // Case 1: Exact match of positions.
1894  if (extractPos == insertPos)
1895  return insertValueOp.getValue();
1896 
1897  // Case 2: Insert position is a prefix of extract position. Continue
1898  // traversal with the inserted value. Example:
1899  // ```
1900  // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1901  // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1902  // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1903  // %3 = llvm.insertvalue %2, %foo[0]
1904  // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1905  // %4 = llvm.extractvalue %3[0, 0]
1906  // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1907  // ```
1908  // In the above example, %4 is folded to %arg1.
1909  if (extractPosSize > insertPosSize &&
1910  extractPos.take_front(insertPosSize) == insertPos) {
1911  insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
1912  extractPos = extractPos.drop_front(insertPosSize);
1913  switchedToInsertedValue = true;
1914  continue;
1915  }
1916 
1917  // Case 3: Try to continue the traversal with the container value.
1918  unsigned min = std::min(extractPosSize, insertPosSize);
1919 
1920  // If one is fully prefix of the other, stop propagating back as it will
1921  // miss dependencies. For instance, %3 should not fold to %f0 in the
1922  // following example:
1923  // ```
1924  // %1 = llvm.insertvalue %f0, %0[0, 0] :
1925  // !llvm.array<4 x !llvm.array<4 x f32>>
1926  // %2 = llvm.insertvalue %arr, %1[0] :
1927  // !llvm.array<4 x !llvm.array<4 x f32>>
1928  // %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1929  // ```
1930  if (extractPos.take_front(min) == insertPos.take_front(min))
1931  return result;
1932  // If neither a prefix, nor the exact position, we can extract out of the
1933  // value being inserted into. Moreover, we can try again if that operand
1934  // is itself an insertvalue expression.
1935  if (!switchedToInsertedValue) {
1936  // Do not swap out the container operand if we decided earlier to
1937  // continue the traversal with the inserted value (Case 2).
1938  getContainerMutable().assign(insertValueOp.getContainer());
1939  result = getResult();
1940  }
1941  insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
1942  }
1943  return result;
1944 }
1945 
1946 LogicalResult ExtractValueOp::verify() {
1947  auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1949  emitError, getContainer().getType(), getPosition());
1950  if (!valueType)
1951  return failure();
1952 
1953  if (getRes().getType() != valueType)
1954  return emitOpError() << "Type mismatch: extracting from "
1955  << getContainer().getType() << " should produce "
1956  << valueType << " but this op returns "
1957  << getRes().getType();
1958  return success();
1959 }
1960 
1961 void ExtractValueOp::build(OpBuilder &builder, OperationState &state,
1962  Value container, ArrayRef<int64_t> position) {
1963  build(builder, state,
1964  getInsertExtractValueElementType(container.getType(), position),
1965  container, builder.getAttr<DenseI64ArrayAttr>(position));
1966 }
1967 
1968 //===----------------------------------------------------------------------===//
1969 // InsertValueOp
1970 //===----------------------------------------------------------------------===//
1971 
1972 /// Infer the value type from the container type and position.
1973 static ParseResult
1975  Type containerType,
1976  DenseI64ArrayAttr position) {
1978  [&](StringRef msg) {
1979  return parser.emitError(parser.getCurrentLocation(), msg);
1980  },
1981  containerType, position.asArrayRef());
1982  return success(!!valueType);
1983 }
1984 
1985 /// Nothing to print for an inferred type.
1987  Operation *op, Type valueType,
1988  Type containerType,
1989  DenseI64ArrayAttr position) {}
1990 
1991 LogicalResult InsertValueOp::verify() {
1992  auto emitError = [this](StringRef msg) { return emitOpError(msg); };
1994  emitError, getContainer().getType(), getPosition());
1995  if (!valueType)
1996  return failure();
1997 
1998  if (getValue().getType() != valueType)
1999  return emitOpError() << "Type mismatch: cannot insert "
2000  << getValue().getType() << " into "
2001  << getContainer().getType();
2002 
2003  return success();
2004 }
2005 
2006 //===----------------------------------------------------------------------===//
2007 // ReturnOp
2008 //===----------------------------------------------------------------------===//
2009 
2010 LogicalResult ReturnOp::verify() {
2011  auto parent = (*this)->getParentOfType<LLVMFuncOp>();
2012  if (!parent)
2013  return success();
2014 
2015  Type expectedType = parent.getFunctionType().getReturnType();
2016  if (llvm::isa<LLVMVoidType>(expectedType)) {
2017  if (!getArg())
2018  return success();
2019  InFlightDiagnostic diag = emitOpError("expected no operands");
2020  diag.attachNote(parent->getLoc()) << "when returning from function";
2021  return diag;
2022  }
2023  if (!getArg()) {
2024  if (llvm::isa<LLVMVoidType>(expectedType))
2025  return success();
2026  InFlightDiagnostic diag = emitOpError("expected 1 operand");
2027  diag.attachNote(parent->getLoc()) << "when returning from function";
2028  return diag;
2029  }
2030  if (expectedType != getArg().getType()) {
2031  InFlightDiagnostic diag = emitOpError("mismatching result types");
2032  diag.attachNote(parent->getLoc()) << "when returning from function";
2033  return diag;
2034  }
2035  return success();
2036 }
2037 
2038 //===----------------------------------------------------------------------===//
2039 // LLVM::AddressOfOp.
2040 //===----------------------------------------------------------------------===//
2041 
2043  Operation *module = op->getParentOp();
2044  while (module && !satisfiesLLVMModule(module))
2045  module = module->getParentOp();
2046  assert(module && "unexpected operation outside of a module");
2047  return module;
2048 }
2049 
2050 GlobalOp AddressOfOp::getGlobal(SymbolTableCollection &symbolTable) {
2051  return dyn_cast_or_null<GlobalOp>(
2052  symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2053 }
2054 
2055 LLVMFuncOp AddressOfOp::getFunction(SymbolTableCollection &symbolTable) {
2056  return dyn_cast_or_null<LLVMFuncOp>(
2057  symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2058 }
2059 
2060 AliasOp AddressOfOp::getAlias(SymbolTableCollection &symbolTable) {
2061  return dyn_cast_or_null<AliasOp>(
2062  symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr()));
2063 }
2064 
2065 LogicalResult
2066 AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2067  Operation *symbol =
2068  symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
2069 
2070  auto global = dyn_cast_or_null<GlobalOp>(symbol);
2071  auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2072  auto alias = dyn_cast_or_null<AliasOp>(symbol);
2073 
2074  if (!global && !function && !alias)
2075  return emitOpError("must reference a global defined by 'llvm.mlir.global', "
2076  "'llvm.mlir.alias' or 'llvm.func'");
2077 
2078  LLVMPointerType type = getType();
2079  if ((global && global.getAddrSpace() != type.getAddressSpace()) ||
2080  (alias && alias.getAddrSpace() != type.getAddressSpace()))
2081  return emitOpError("pointer address space must match address space of the "
2082  "referenced global or alias");
2083 
2084  return success();
2085 }
2086 
2087 // AddressOfOp constant-folds to the global symbol name.
2088 OpFoldResult LLVM::AddressOfOp::fold(FoldAdaptor) {
2089  return getGlobalNameAttr();
2090 }
2091 
2092 //===----------------------------------------------------------------------===//
2093 // LLVM::DSOLocalEquivalentOp
2094 //===----------------------------------------------------------------------===//
2095 
2096 LLVMFuncOp
2097 DSOLocalEquivalentOp::getFunction(SymbolTableCollection &symbolTable) {
2098  return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
2099  parentLLVMModule(*this), getFunctionNameAttr()));
2100 }
2101 
2102 AliasOp DSOLocalEquivalentOp::getAlias(SymbolTableCollection &symbolTable) {
2103  return dyn_cast_or_null<AliasOp>(symbolTable.lookupSymbolIn(
2104  parentLLVMModule(*this), getFunctionNameAttr()));
2105 }
2106 
2107 LogicalResult
2108 DSOLocalEquivalentOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2109  Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
2110  getFunctionNameAttr());
2111  auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
2112  auto alias = dyn_cast_or_null<AliasOp>(symbol);
2113 
2114  if (!function && !alias)
2115  return emitOpError(
2116  "must reference a global defined by 'llvm.func' or 'llvm.mlir.alias'");
2117 
2118  if (alias) {
2119  if (alias.getInitializer()
2120  .walk([&](AddressOfOp addrOp) {
2121  if (addrOp.getGlobal(symbolTable))
2122  return WalkResult::interrupt();
2123  return WalkResult::advance();
2124  })
2125  .wasInterrupted())
2126  return emitOpError("must reference an alias to a function");
2127  }
2128 
2129  if ((function && function.getLinkage() == LLVM::Linkage::ExternWeak) ||
2130  (alias && alias.getLinkage() == LLVM::Linkage::ExternWeak))
2131  return emitOpError(
2132  "target function with 'extern_weak' linkage not allowed");
2133 
2134  return success();
2135 }
2136 
2137 /// Fold a dso_local_equivalent operation to a dedicated dso_local_equivalent
2138 /// attribute.
2139 OpFoldResult DSOLocalEquivalentOp::fold(FoldAdaptor) {
2140  return DSOLocalEquivalentAttr::get(getContext(), getFunctionNameAttr());
2141 }
2142 
2143 //===----------------------------------------------------------------------===//
2144 // Verifier for LLVM::ComdatOp.
2145 //===----------------------------------------------------------------------===//
2146 
2147 void ComdatOp::build(OpBuilder &builder, OperationState &result,
2148  StringRef symName) {
2149  result.addAttribute(getSymNameAttrName(result.name),
2150  builder.getStringAttr(symName));
2151  Region *body = result.addRegion();
2152  body->emplaceBlock();
2153 }
2154 
2155 LogicalResult ComdatOp::verifyRegions() {
2156  Region &body = getBody();
2157  for (Operation &op : body.getOps())
2158  if (!isa<ComdatSelectorOp>(op))
2159  return op.emitError(
2160  "only comdat selector symbols can appear in a comdat region");
2161 
2162  return success();
2163 }
2164 
2165 //===----------------------------------------------------------------------===//
2166 // Builder, printer and verifier for LLVM::GlobalOp.
2167 //===----------------------------------------------------------------------===//
2168 
2169 void GlobalOp::build(OpBuilder &builder, OperationState &result, Type type,
2170  bool isConstant, Linkage linkage, StringRef name,
2171  Attribute value, uint64_t alignment, unsigned addrSpace,
2172  bool dsoLocal, bool threadLocal, SymbolRefAttr comdat,
2174  ArrayRef<Attribute> dbgExprs) {
2175  result.addAttribute(getSymNameAttrName(result.name),
2176  builder.getStringAttr(name));
2177  result.addAttribute(getGlobalTypeAttrName(result.name), TypeAttr::get(type));
2178  if (isConstant)
2179  result.addAttribute(getConstantAttrName(result.name),
2180  builder.getUnitAttr());
2181  if (value)
2182  result.addAttribute(getValueAttrName(result.name), value);
2183  if (dsoLocal)
2184  result.addAttribute(getDsoLocalAttrName(result.name),
2185  builder.getUnitAttr());
2186  if (threadLocal)
2187  result.addAttribute(getThreadLocal_AttrName(result.name),
2188  builder.getUnitAttr());
2189  if (comdat)
2190  result.addAttribute(getComdatAttrName(result.name), comdat);
2191 
2192  // Only add an alignment attribute if the "alignment" input
2193  // is different from 0. The value must also be a power of two, but
2194  // this is tested in GlobalOp::verify, not here.
2195  if (alignment != 0)
2196  result.addAttribute(getAlignmentAttrName(result.name),
2197  builder.getI64IntegerAttr(alignment));
2198 
2199  result.addAttribute(getLinkageAttrName(result.name),
2200  LinkageAttr::get(builder.getContext(), linkage));
2201  if (addrSpace != 0)
2202  result.addAttribute(getAddrSpaceAttrName(result.name),
2203  builder.getI32IntegerAttr(addrSpace));
2204  result.attributes.append(attrs.begin(), attrs.end());
2205 
2206  if (!dbgExprs.empty())
2207  result.addAttribute(getDbgExprsAttrName(result.name),
2208  ArrayAttr::get(builder.getContext(), dbgExprs));
2209 
2210  result.addRegion();
2211 }
2212 
2213 template <typename OpType>
2214 static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op) {
2215  p << ' ' << stringifyLinkage(op.getLinkage()) << ' ';
2216  StringRef visibility = stringifyVisibility(op.getVisibility_());
2217  if (!visibility.empty())
2218  p << visibility << ' ';
2219  if (op.getThreadLocal_())
2220  p << "thread_local ";
2221  if (auto unnamedAddr = op.getUnnamedAddr()) {
2222  StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2223  if (!str.empty())
2224  p << str << ' ';
2225  }
2226 }
2227 
2228 void GlobalOp::print(OpAsmPrinter &p) {
2229  printCommonGlobalAndAlias<GlobalOp>(p, *this);
2230  if (getConstant())
2231  p << "constant ";
2232  p.printSymbolName(getSymName());
2233  p << '(';
2234  if (auto value = getValueOrNull())
2235  p.printAttribute(value);
2236  p << ')';
2237  if (auto comdat = getComdat())
2238  p << " comdat(" << *comdat << ')';
2239 
2240  // Note that the alignment attribute is printed using the
2241  // default syntax here, even though it is an inherent attribute
2242  // (as defined in https://mlir.llvm.org/docs/LangRef/#attributes)
2243  p.printOptionalAttrDict((*this)->getAttrs(),
2244  {SymbolTable::getSymbolAttrName(),
2245  getGlobalTypeAttrName(), getConstantAttrName(),
2246  getValueAttrName(), getLinkageAttrName(),
2247  getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2248  getVisibility_AttrName(), getComdatAttrName(),
2249  getUnnamedAddrAttrName()});
2250 
2251  // Print the trailing type unless it's a string global.
2252  if (llvm::dyn_cast_or_null<StringAttr>(getValueOrNull()))
2253  return;
2254  p << " : " << getType();
2255 
2256  Region &initializer = getInitializerRegion();
2257  if (!initializer.empty()) {
2258  p << ' ';
2259  p.printRegion(initializer, /*printEntryBlockArgs=*/false);
2260  }
2261 }
2262 
2263 static LogicalResult verifyComdat(Operation *op,
2264  std::optional<SymbolRefAttr> attr) {
2265  if (!attr)
2266  return success();
2267 
2268  auto *comdatSelector = SymbolTable::lookupNearestSymbolFrom(op, *attr);
2269  if (!isa_and_nonnull<ComdatSelectorOp>(comdatSelector))
2270  return op->emitError() << "expected comdat symbol";
2271 
2272  return success();
2273 }
2274 
2275 static LogicalResult verifyBlockTags(LLVMFuncOp funcOp) {
2276  llvm::DenseSet<BlockTagAttr> blockTags;
2277  // Note that presence of `BlockTagOp`s currently can't prevent an unrecheable
2278  // block to be removed by canonicalizer's region simplify pass, which needs to
2279  // be dialect aware to allow extra constraints to be described.
2280  WalkResult res = funcOp.walk([&](BlockTagOp blockTagOp) {
2281  if (blockTags.contains(blockTagOp.getTag())) {
2282  blockTagOp.emitError()
2283  << "duplicate block tag '" << blockTagOp.getTag().getId()
2284  << "' in the same function: ";
2285  return WalkResult::interrupt();
2286  }
2287  blockTags.insert(blockTagOp.getTag());
2288  return WalkResult::advance();
2289  });
2290 
2291  return failure(res.wasInterrupted());
2292 }
2293 
2294 /// Parse common attributes that might show up in the same order in both
2295 /// GlobalOp and AliasOp.
2296 template <typename OpType>
2297 static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser,
2298  OperationState &result) {
2299  MLIRContext *ctx = parser.getContext();
2300  // Parse optional linkage, default to External.
2301  result.addAttribute(OpType::getLinkageAttrName(result.name),
2303  ctx, parseOptionalLLVMKeyword<Linkage>(
2304  parser, result, LLVM::Linkage::External)));
2305 
2306  // Parse optional visibility, default to Default.
2307  result.addAttribute(OpType::getVisibility_AttrName(result.name),
2308  parser.getBuilder().getI64IntegerAttr(
2309  parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2310  parser, result, LLVM::Visibility::Default)));
2311 
2312  if (succeeded(parser.parseOptionalKeyword("thread_local")))
2313  result.addAttribute(OpType::getThreadLocal_AttrName(result.name),
2314  parser.getBuilder().getUnitAttr());
2315 
2316  // Parse optional UnnamedAddr, default to None.
2317  result.addAttribute(OpType::getUnnamedAddrAttrName(result.name),
2318  parser.getBuilder().getI64IntegerAttr(
2319  parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2320  parser, result, LLVM::UnnamedAddr::None)));
2321 
2322  return success();
2323 }
2324 
2325 // operation ::= `llvm.mlir.global` linkage? visibility?
2326 // (`unnamed_addr` | `local_unnamed_addr`)?
2327 // `thread_local`? `constant`? `@` identifier
2328 // `(` attribute? `)` (`comdat(` symbol-ref-id `)`)?
2329 // attribute-list? (`:` type)? region?
2330 //
2331 // The type can be omitted for string attributes, in which case it will be
2332 // inferred from the value of the string as [strlen(value) x i8].
2333 ParseResult GlobalOp::parse(OpAsmParser &parser, OperationState &result) {
2334  // Call into common parsing between GlobalOp and AliasOp.
2335  if (parseCommonGlobalAndAlias<GlobalOp>(parser, result).failed())
2336  return failure();
2337 
2338  if (succeeded(parser.parseOptionalKeyword("constant")))
2339  result.addAttribute(getConstantAttrName(result.name),
2340  parser.getBuilder().getUnitAttr());
2341 
2342  StringAttr name;
2343  if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2344  result.attributes) ||
2345  parser.parseLParen())
2346  return failure();
2347 
2348  Attribute value;
2349  if (parser.parseOptionalRParen()) {
2350  if (parser.parseAttribute(value, getValueAttrName(result.name),
2351  result.attributes) ||
2352  parser.parseRParen())
2353  return failure();
2354  }
2355 
2356  if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2357  SymbolRefAttr comdat;
2358  if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2359  parser.parseRParen())
2360  return failure();
2361 
2362  result.addAttribute(getComdatAttrName(result.name), comdat);
2363  }
2364 
2365  SmallVector<Type, 1> types;
2366  if (parser.parseOptionalAttrDict(result.attributes) ||
2367  parser.parseOptionalColonTypeList(types))
2368  return failure();
2369 
2370  if (types.size() > 1)
2371  return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2372 
2373  Region &initRegion = *result.addRegion();
2374  if (types.empty()) {
2375  if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(value)) {
2376  MLIRContext *context = parser.getContext();
2377  auto arrayType = LLVM::LLVMArrayType::get(IntegerType::get(context, 8),
2378  strAttr.getValue().size());
2379  types.push_back(arrayType);
2380  } else {
2381  return parser.emitError(parser.getNameLoc(),
2382  "type can only be omitted for string globals");
2383  }
2384  } else {
2385  OptionalParseResult parseResult =
2386  parser.parseOptionalRegion(initRegion, /*arguments=*/{},
2387  /*argTypes=*/{});
2388  if (parseResult.has_value() && failed(*parseResult))
2389  return failure();
2390  }
2391 
2392  result.addAttribute(getGlobalTypeAttrName(result.name),
2393  TypeAttr::get(types[0]));
2394  return success();
2395 }
2396 
2397 static bool isZeroAttribute(Attribute value) {
2398  if (auto intValue = llvm::dyn_cast<IntegerAttr>(value))
2399  return intValue.getValue().isZero();
2400  if (auto fpValue = llvm::dyn_cast<FloatAttr>(value))
2401  return fpValue.getValue().isZero();
2402  if (auto splatValue = llvm::dyn_cast<SplatElementsAttr>(value))
2403  return isZeroAttribute(splatValue.getSplatValue<Attribute>());
2404  if (auto elementsValue = llvm::dyn_cast<ElementsAttr>(value))
2405  return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
2406  if (auto arrayValue = llvm::dyn_cast<ArrayAttr>(value))
2407  return llvm::all_of(arrayValue.getValue(), isZeroAttribute);
2408  return false;
2409 }
2410 
2411 LogicalResult GlobalOp::verify() {
2412  bool validType = isCompatibleOuterType(getType())
2413  ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2414  LLVMMetadataType, LLVMLabelType>(getType())
2415  : llvm::isa<PointerElementTypeInterface>(getType());
2416  if (!validType)
2417  return emitOpError(
2418  "expects type to be a valid element type for an LLVM global");
2419  if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
2420  return emitOpError("must appear at the module level");
2421 
2422  if (auto strAttr = llvm::dyn_cast_or_null<StringAttr>(getValueOrNull())) {
2423  auto type = llvm::dyn_cast<LLVMArrayType>(getType());
2424  IntegerType elementType =
2425  type ? llvm::dyn_cast<IntegerType>(type.getElementType()) : nullptr;
2426  if (!elementType || elementType.getWidth() != 8 ||
2427  type.getNumElements() != strAttr.getValue().size())
2428  return emitOpError(
2429  "requires an i8 array type of the length equal to that of the string "
2430  "attribute");
2431  }
2432 
2433  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
2434  if (!targetExtType.hasProperty(LLVMTargetExtType::CanBeGlobal))
2435  return emitOpError()
2436  << "this target extension type cannot be used in a global";
2437 
2438  if (Attribute value = getValueOrNull())
2439  return emitOpError() << "global with target extension type can only be "
2440  "initialized with zero-initializer";
2441  }
2442 
2443  if (getLinkage() == Linkage::Common) {
2444  if (Attribute value = getValueOrNull()) {
2445  if (!isZeroAttribute(value)) {
2446  return emitOpError()
2447  << "expected zero value for '"
2448  << stringifyLinkage(Linkage::Common) << "' linkage";
2449  }
2450  }
2451  }
2452 
2453  if (getLinkage() == Linkage::Appending) {
2454  if (!llvm::isa<LLVMArrayType>(getType())) {
2455  return emitOpError() << "expected array type for '"
2456  << stringifyLinkage(Linkage::Appending)
2457  << "' linkage";
2458  }
2459  }
2460 
2461  if (failed(verifyComdat(*this, getComdat())))
2462  return failure();
2463 
2464  std::optional<uint64_t> alignAttr = getAlignment();
2465  if (alignAttr.has_value()) {
2466  uint64_t value = alignAttr.value();
2467  if (!llvm::isPowerOf2_64(value))
2468  return emitError() << "alignment attribute is not a power of 2";
2469  }
2470 
2471  return success();
2472 }
2473 
2474 LogicalResult GlobalOp::verifyRegions() {
2475  if (Block *b = getInitializerBlock()) {
2476  ReturnOp ret = cast<ReturnOp>(b->getTerminator());
2477  if (ret.operand_type_begin() == ret.operand_type_end())
2478  return emitOpError("initializer region cannot return void");
2479  if (*ret.operand_type_begin() != getType())
2480  return emitOpError("initializer region type ")
2481  << *ret.operand_type_begin() << " does not match global type "
2482  << getType();
2483 
2484  for (Operation &op : *b) {
2485  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2486  if (!iface || !iface.hasNoEffect())
2487  return op.emitError()
2488  << "ops with side effects not allowed in global initializers";
2489  }
2490 
2491  if (getValueOrNull())
2492  return emitOpError("cannot have both initializer value and region");
2493  }
2494 
2495  return success();
2496 }
2497 
2498 //===----------------------------------------------------------------------===//
2499 // LLVM::GlobalCtorsOp
2500 //===----------------------------------------------------------------------===//
2501 
2502 LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) {
2503  if (data.empty())
2504  return success();
2505 
2506  if (llvm::all_of(data.getAsRange<Attribute>(), [](Attribute v) {
2507  return isa<FlatSymbolRefAttr, ZeroAttr>(v);
2508  }))
2509  return success();
2510  return op->emitError("data element must be symbol or #llvm.zero");
2511 }
2512 
2513 LogicalResult
2514 GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2515  for (Attribute ctor : getCtors()) {
2516  if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(ctor), *this,
2517  symbolTable)))
2518  return failure();
2519  }
2520  return success();
2521 }
2522 
2523 LogicalResult GlobalCtorsOp::verify() {
2524  if (checkGlobalXtorData(*this, getData()).failed())
2525  return failure();
2526 
2527  if (getCtors().size() == getPriorities().size() &&
2528  getCtors().size() == getData().size())
2529  return success();
2530  return emitError(
2531  "ctors, priorities, and data must have the same number of elements");
2532 }
2533 
2534 //===----------------------------------------------------------------------===//
2535 // LLVM::GlobalDtorsOp
2536 //===----------------------------------------------------------------------===//
2537 
2538 LogicalResult
2539 GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
2540  for (Attribute dtor : getDtors()) {
2541  if (failed(verifySymbolAttrUse(llvm::cast<FlatSymbolRefAttr>(dtor), *this,
2542  symbolTable)))
2543  return failure();
2544  }
2545  return success();
2546 }
2547 
2548 LogicalResult GlobalDtorsOp::verify() {
2549  if (checkGlobalXtorData(*this, getData()).failed())
2550  return failure();
2551 
2552  if (getDtors().size() == getPriorities().size() &&
2553  getDtors().size() == getData().size())
2554  return success();
2555  return emitError(
2556  "dtors, priorities, and data must have the same number of elements");
2557 }
2558 
2559 //===----------------------------------------------------------------------===//
2560 // Builder, printer and verifier for LLVM::AliasOp.
2561 //===----------------------------------------------------------------------===//
2562 
2563 void AliasOp::build(OpBuilder &builder, OperationState &result, Type type,
2564  Linkage linkage, StringRef name, bool dsoLocal,
2565  bool threadLocal, ArrayRef<NamedAttribute> attrs) {
2566  result.addAttribute(getSymNameAttrName(result.name),
2567  builder.getStringAttr(name));
2568  result.addAttribute(getAliasTypeAttrName(result.name), TypeAttr::get(type));
2569  if (dsoLocal)
2570  result.addAttribute(getDsoLocalAttrName(result.name),
2571  builder.getUnitAttr());
2572  if (threadLocal)
2573  result.addAttribute(getThreadLocal_AttrName(result.name),
2574  builder.getUnitAttr());
2575 
2576  result.addAttribute(getLinkageAttrName(result.name),
2577  LinkageAttr::get(builder.getContext(), linkage));
2578  result.attributes.append(attrs.begin(), attrs.end());
2579 
2580  result.addRegion();
2581 }
2582 
2583 void AliasOp::print(OpAsmPrinter &p) {
2584  printCommonGlobalAndAlias<AliasOp>(p, *this);
2585 
2586  p.printSymbolName(getSymName());
2587  p.printOptionalAttrDict((*this)->getAttrs(),
2588  {SymbolTable::getSymbolAttrName(),
2589  getAliasTypeAttrName(), getLinkageAttrName(),
2590  getUnnamedAddrAttrName(), getThreadLocal_AttrName(),
2591  getVisibility_AttrName(), getUnnamedAddrAttrName()});
2592 
2593  // Print the trailing type.
2594  p << " : " << getType() << ' ';
2595  // Print the initializer region.
2596  p.printRegion(getInitializerRegion(), /*printEntryBlockArgs=*/false);
2597 }
2598 
2599 // operation ::= `llvm.mlir.alias` linkage? visibility?
2600 // (`unnamed_addr` | `local_unnamed_addr`)?
2601 // `thread_local`? `@` identifier
2602 // `(` attribute? `)`
2603 // attribute-list? `:` type region
2604 //
2605 ParseResult AliasOp::parse(OpAsmParser &parser, OperationState &result) {
2606  // Call into common parsing between GlobalOp and AliasOp.
2607  if (parseCommonGlobalAndAlias<AliasOp>(parser, result).failed())
2608  return failure();
2609 
2610  StringAttr name;
2611  if (parser.parseSymbolName(name, getSymNameAttrName(result.name),
2612  result.attributes))
2613  return failure();
2614 
2615  SmallVector<Type, 1> types;
2616  if (parser.parseOptionalAttrDict(result.attributes) ||
2617  parser.parseOptionalColonTypeList(types))
2618  return failure();
2619 
2620  if (types.size() > 1)
2621  return parser.emitError(parser.getNameLoc(), "expected zero or one type");
2622 
2623  Region &initRegion = *result.addRegion();
2624  if (parser.parseRegion(initRegion).failed())
2625  return failure();
2626 
2627  result.addAttribute(getAliasTypeAttrName(result.name),
2628  TypeAttr::get(types[0]));
2629  return success();
2630 }
2631 
2632 LogicalResult AliasOp::verify() {
2633  bool validType = isCompatibleOuterType(getType())
2634  ? !llvm::isa<LLVMVoidType, LLVMTokenType,
2635  LLVMMetadataType, LLVMLabelType>(getType())
2636  : llvm::isa<PointerElementTypeInterface>(getType());
2637  if (!validType)
2638  return emitOpError(
2639  "expects type to be a valid element type for an LLVM global alias");
2640 
2641  // This matches LLVM IR verification logic, see llvm/lib/IR/Verifier.cpp
2642  switch (getLinkage()) {
2643  case Linkage::External:
2644  case Linkage::Internal:
2645  case Linkage::Private:
2646  case Linkage::Weak:
2647  case Linkage::WeakODR:
2648  case Linkage::Linkonce:
2649  case Linkage::LinkonceODR:
2650  case Linkage::AvailableExternally:
2651  break;
2652  default:
2653  return emitOpError()
2654  << "'" << stringifyLinkage(getLinkage())
2655  << "' linkage not supported in aliases, available options: private, "
2656  "internal, linkonce, weak, linkonce_odr, weak_odr, external or "
2657  "available_externally";
2658  }
2659 
2660  return success();
2661 }
2662 
2663 LogicalResult AliasOp::verifyRegions() {
2664  Block &b = getInitializerBlock();
2665  auto ret = cast<ReturnOp>(b.getTerminator());
2666  if (ret.getNumOperands() == 0 ||
2667  !isa<LLVM::LLVMPointerType>(ret.getOperand(0).getType()))
2668  return emitOpError("initializer region must always return a pointer");
2669 
2670  for (Operation &op : b) {
2671  auto iface = dyn_cast<MemoryEffectOpInterface>(op);
2672  if (!iface || !iface.hasNoEffect())
2673  return op.emitError()
2674  << "ops with side effects are not allowed in alias initializers";
2675  }
2676 
2677  return success();
2678 }
2679 
2680 unsigned AliasOp::getAddrSpace() {
2681  Block &initializer = getInitializerBlock();
2682  auto ret = cast<ReturnOp>(initializer.getTerminator());
2683  auto ptrTy = cast<LLVMPointerType>(ret.getOperand(0).getType());
2684  return ptrTy.getAddressSpace();
2685 }
2686 
2687 //===----------------------------------------------------------------------===//
2688 // ShuffleVectorOp
2689 //===----------------------------------------------------------------------===//
2690 
2691 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2692  Value v2, DenseI32ArrayAttr mask,
2693  ArrayRef<NamedAttribute> attrs) {
2694  auto containerType = v1.getType();
2695  auto vType = LLVM::getVectorType(
2696  cast<VectorType>(containerType).getElementType(), mask.size(),
2697  LLVM::isScalableVectorType(containerType));
2698  build(builder, state, vType, v1, v2, mask);
2699  state.addAttributes(attrs);
2700 }
2701 
2702 void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
2703  Value v2, ArrayRef<int32_t> mask) {
2704  build(builder, state, v1, v2, builder.getDenseI32ArrayAttr(mask));
2705 }
2706 
2707 /// Build the result type of a shuffle vector operation.
2708 static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
2709  Type &resType, DenseI32ArrayAttr mask) {
2710  if (!LLVM::isCompatibleVectorType(v1Type))
2711  return parser.emitError(parser.getCurrentLocation(),
2712  "expected an LLVM compatible vector type");
2713  resType =
2714  LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
2715  mask.size(), LLVM::isScalableVectorType(v1Type));
2716  return success();
2717 }
2718 
2719 /// Nothing to do when the result type is inferred.
2720 static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type,
2721  Type resType, DenseI32ArrayAttr mask) {}
2722 
2723 LogicalResult ShuffleVectorOp::verify() {
2724  if (LLVM::isScalableVectorType(getV1().getType()) &&
2725  llvm::any_of(getMask(), [](int32_t v) { return v != 0; }))
2726  return emitOpError("expected a splat operation for scalable vectors");
2727  return success();
2728 }
2729 
2730 //===----------------------------------------------------------------------===//
2731 // Implementations for LLVM::LLVMFuncOp.
2732 //===----------------------------------------------------------------------===//
2733 
2734 // Add the entry block to the function.
2735 Block *LLVMFuncOp::addEntryBlock(OpBuilder &builder) {
2736  assert(empty() && "function already has an entry block");
2737  OpBuilder::InsertionGuard g(builder);
2738  Block *entry = builder.createBlock(&getBody());
2739 
2740  // FIXME: Allow passing in proper locations for the entry arguments.
2741  LLVMFunctionType type = getFunctionType();
2742  for (unsigned i = 0, e = type.getNumParams(); i < e; ++i)
2743  entry->addArgument(type.getParamType(i), getLoc());
2744  return entry;
2745 }
2746 
2747 void LLVMFuncOp::build(OpBuilder &builder, OperationState &result,
2748  StringRef name, Type type, LLVM::Linkage linkage,
2749  bool dsoLocal, CConv cconv, SymbolRefAttr comdat,
2751  ArrayRef<DictionaryAttr> argAttrs,
2752  std::optional<uint64_t> functionEntryCount) {
2753  result.addRegion();
2755  builder.getStringAttr(name));
2756  result.addAttribute(getFunctionTypeAttrName(result.name),
2757  TypeAttr::get(type));
2758  result.addAttribute(getLinkageAttrName(result.name),
2759  LinkageAttr::get(builder.getContext(), linkage));
2760  result.addAttribute(getCConvAttrName(result.name),
2761  CConvAttr::get(builder.getContext(), cconv));
2762  result.attributes.append(attrs.begin(), attrs.end());
2763  if (dsoLocal)
2764  result.addAttribute(getDsoLocalAttrName(result.name),
2765  builder.getUnitAttr());
2766  if (comdat)
2767  result.addAttribute(getComdatAttrName(result.name), comdat);
2768  if (functionEntryCount)
2769  result.addAttribute(getFunctionEntryCountAttrName(result.name),
2770  builder.getI64IntegerAttr(functionEntryCount.value()));
2771  if (argAttrs.empty())
2772  return;
2773 
2774  assert(llvm::cast<LLVMFunctionType>(type).getNumParams() == argAttrs.size() &&
2775  "expected as many argument attribute lists as arguments");
2777  builder, result, argAttrs, /*resultAttrs=*/std::nullopt,
2778  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2779 }
2780 
2781 // Builds an LLVM function type from the given lists of input and output types.
2782 // Returns a null type if any of the types provided are non-LLVM types, or if
2783 // there is more than one output type.
2784 static Type
2786  ArrayRef<Type> outputs,
2788  Builder &b = parser.getBuilder();
2789  if (outputs.size() > 1) {
2790  parser.emitError(loc, "failed to construct function type: expected zero or "
2791  "one function result");
2792  return {};
2793  }
2794 
2795  // Convert inputs to LLVM types, exit early on error.
2796  SmallVector<Type, 4> llvmInputs;
2797  for (auto t : inputs) {
2798  if (!isCompatibleType(t)) {
2799  parser.emitError(loc, "failed to construct function type: expected LLVM "
2800  "type for function arguments");
2801  return {};
2802  }
2803  llvmInputs.push_back(t);
2804  }
2805 
2806  // No output is denoted as "void" in LLVM type system.
2807  Type llvmOutput =
2808  outputs.empty() ? LLVMVoidType::get(b.getContext()) : outputs.front();
2809  if (!isCompatibleType(llvmOutput)) {
2810  parser.emitError(loc, "failed to construct function type: expected LLVM "
2811  "type for function results")
2812  << llvmOutput;
2813  return {};
2814  }
2815  return LLVMFunctionType::get(llvmOutput, llvmInputs,
2816  variadicFlag.isVariadic());
2817 }
2818 
2819 // Parses an LLVM function.
2820 //
2821 // operation ::= `llvm.func` linkage? cconv? function-signature
2822 // (`comdat(` symbol-ref-id `)`)?
2823 // function-attributes?
2824 // function-body
2825 //
2826 ParseResult LLVMFuncOp::parse(OpAsmParser &parser, OperationState &result) {
2827  // Default to external linkage if no keyword is provided.
2828  result.addAttribute(
2829  getLinkageAttrName(result.name),
2830  LinkageAttr::get(parser.getContext(),
2831  parseOptionalLLVMKeyword<Linkage>(
2832  parser, result, LLVM::Linkage::External)));
2833 
2834  // Parse optional visibility, default to Default.
2835  result.addAttribute(getVisibility_AttrName(result.name),
2836  parser.getBuilder().getI64IntegerAttr(
2837  parseOptionalLLVMKeyword<LLVM::Visibility, int64_t>(
2838  parser, result, LLVM::Visibility::Default)));
2839 
2840  // Parse optional UnnamedAddr, default to None.
2841  result.addAttribute(getUnnamedAddrAttrName(result.name),
2842  parser.getBuilder().getI64IntegerAttr(
2843  parseOptionalLLVMKeyword<UnnamedAddr, int64_t>(
2844  parser, result, LLVM::UnnamedAddr::None)));
2845 
2846  // Default to C Calling Convention if no keyword is provided.
2847  result.addAttribute(
2848  getCConvAttrName(result.name),
2849  CConvAttr::get(parser.getContext(), parseOptionalLLVMKeyword<CConv>(
2850  parser, result, LLVM::CConv::C)));
2851 
2852  StringAttr nameAttr;
2854  SmallVector<DictionaryAttr> resultAttrs;
2855  SmallVector<Type> resultTypes;
2856  bool isVariadic;
2857 
2858  auto signatureLocation = parser.getCurrentLocation();
2859  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2860  result.attributes) ||
2862  parser, /*allowVariadic=*/true, entryArgs, isVariadic, resultTypes,
2863  resultAttrs))
2864  return failure();
2865 
2866  SmallVector<Type> argTypes;
2867  for (auto &arg : entryArgs)
2868  argTypes.push_back(arg.type);
2869  auto type =
2870  buildLLVMFunctionType(parser, signatureLocation, argTypes, resultTypes,
2872  if (!type)
2873  return failure();
2874  result.addAttribute(getFunctionTypeAttrName(result.name),
2875  TypeAttr::get(type));
2876 
2877  if (succeeded(parser.parseOptionalKeyword("vscale_range"))) {
2878  int64_t minRange, maxRange;
2879  if (parser.parseLParen() || parser.parseInteger(minRange) ||
2880  parser.parseComma() || parser.parseInteger(maxRange) ||
2881  parser.parseRParen())
2882  return failure();
2883  auto intTy = IntegerType::get(parser.getContext(), 32);
2884  result.addAttribute(
2885  getVscaleRangeAttrName(result.name),
2887  IntegerAttr::get(intTy, minRange),
2888  IntegerAttr::get(intTy, maxRange)));
2889  }
2890  // Parse the optional comdat selector.
2891  if (succeeded(parser.parseOptionalKeyword("comdat"))) {
2892  SymbolRefAttr comdat;
2893  if (parser.parseLParen() || parser.parseAttribute(comdat) ||
2894  parser.parseRParen())
2895  return failure();
2896 
2897  result.addAttribute(getComdatAttrName(result.name), comdat);
2898  }
2899 
2900  if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
2901  return failure();
2903  parser.getBuilder(), result, entryArgs, resultAttrs,
2904  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
2905 
2906  auto *body = result.addRegion();
2907  OptionalParseResult parseResult =
2908  parser.parseOptionalRegion(*body, entryArgs);
2909  return failure(parseResult.has_value() && failed(*parseResult));
2910 }
2911 
2912 // Print the LLVMFuncOp. Collects argument and result types and passes them to
2913 // helper functions. Drops "void" result since it cannot be parsed back. Skips
2914 // the external linkage since it is the default value.
2916  p << ' ';
2917  if (getLinkage() != LLVM::Linkage::External)
2918  p << stringifyLinkage(getLinkage()) << ' ';
2919  StringRef visibility = stringifyVisibility(getVisibility_());
2920  if (!visibility.empty())
2921  p << visibility << ' ';
2922  if (auto unnamedAddr = getUnnamedAddr()) {
2923  StringRef str = stringifyUnnamedAddr(*unnamedAddr);
2924  if (!str.empty())
2925  p << str << ' ';
2926  }
2927  if (getCConv() != LLVM::CConv::C)
2928  p << stringifyCConv(getCConv()) << ' ';
2929 
2930  p.printSymbolName(getName());
2931 
2932  LLVMFunctionType fnType = getFunctionType();
2933  SmallVector<Type, 8> argTypes;
2934  SmallVector<Type, 1> resTypes;
2935  argTypes.reserve(fnType.getNumParams());
2936  for (unsigned i = 0, e = fnType.getNumParams(); i < e; ++i)
2937  argTypes.push_back(fnType.getParamType(i));
2938 
2939  Type returnType = fnType.getReturnType();
2940  if (!llvm::isa<LLVMVoidType>(returnType))
2941  resTypes.push_back(returnType);
2942 
2944  isVarArg(), resTypes);
2945 
2946  // Print vscale range if present
2947  if (std::optional<VScaleRangeAttr> vscale = getVscaleRange())
2948  p << " vscale_range(" << vscale->getMinRange().getInt() << ", "
2949  << vscale->getMaxRange().getInt() << ')';
2950 
2951  // Print the optional comdat selector.
2952  if (auto comdat = getComdat())
2953  p << " comdat(" << *comdat << ')';
2954 
2956  p, *this,
2957  {getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2958  getLinkageAttrName(), getCConvAttrName(), getVisibility_AttrName(),
2959  getComdatAttrName(), getUnnamedAddrAttrName(),
2960  getVscaleRangeAttrName()});
2961 
2962  // Print the body if this is not an external function.
2963  Region &body = getBody();
2964  if (!body.empty()) {
2965  p << ' ';
2966  p.printRegion(body, /*printEntryBlockArgs=*/false,
2967  /*printBlockTerminators=*/true);
2968  }
2969 }
2970 
2971 // Verifies LLVM- and implementation-specific properties of the LLVM func Op:
2972 // - functions don't have 'common' linkage
2973 // - external functions have 'external' or 'extern_weak' linkage;
2974 // - vararg is (currently) only supported for external functions;
2975 LogicalResult LLVMFuncOp::verify() {
2976  if (getLinkage() == LLVM::Linkage::Common)
2977  return emitOpError() << "functions cannot have '"
2978  << stringifyLinkage(LLVM::Linkage::Common)
2979  << "' linkage";
2980 
2981  if (failed(verifyComdat(*this, getComdat())))
2982  return failure();
2983 
2984  if (isExternal()) {
2985  if (getLinkage() != LLVM::Linkage::External &&
2986  getLinkage() != LLVM::Linkage::ExternWeak)
2987  return emitOpError() << "external functions must have '"
2988  << stringifyLinkage(LLVM::Linkage::External)
2989  << "' or '"
2990  << stringifyLinkage(LLVM::Linkage::ExternWeak)
2991  << "' linkage";
2992  return success();
2993  }
2994 
2995  // In LLVM IR, these attributes are composed by convention, not by design.
2996  if (isNoInline() && isAlwaysInline())
2997  return emitError("no_inline and always_inline attributes are incompatible");
2998 
2999  if (isOptimizeNone() && !isNoInline())
3000  return emitOpError("with optimize_none must also be no_inline");
3001 
3002  Type landingpadResultTy;
3003  StringRef diagnosticMessage;
3004  bool isLandingpadTypeConsistent =
3005  !walk([&](Operation *op) {
3006  const auto checkType = [&](Type type, StringRef errorMessage) {
3007  if (!landingpadResultTy) {
3008  landingpadResultTy = type;
3009  return WalkResult::advance();
3010  }
3011  if (landingpadResultTy != type) {
3012  diagnosticMessage = errorMessage;
3013  return WalkResult::interrupt();
3014  }
3015  return WalkResult::advance();
3016  };
3018  .Case<LandingpadOp>([&](auto landingpad) {
3019  constexpr StringLiteral errorMessage =
3020  "'llvm.landingpad' should have a consistent result type "
3021  "inside a function";
3022  return checkType(landingpad.getType(), errorMessage);
3023  })
3024  .Case<ResumeOp>([&](auto resume) {
3025  constexpr StringLiteral errorMessage =
3026  "'llvm.resume' should have a consistent input type inside a "
3027  "function";
3028  return checkType(resume.getValue().getType(), errorMessage);
3029  })
3030  .Default([](auto) { return WalkResult::skip(); });
3031  }).wasInterrupted();
3032  if (!isLandingpadTypeConsistent) {
3033  assert(!diagnosticMessage.empty() &&
3034  "Expecting a non-empty diagnostic message");
3035  return emitError(diagnosticMessage);
3036  }
3037 
3038  if (failed(verifyBlockTags(*this)))
3039  return failure();
3040 
3041  return success();
3042 }
3043 
3044 /// Verifies LLVM- and implementation-specific properties of the LLVM func Op:
3045 /// - entry block arguments are of LLVM types.
3046 LogicalResult LLVMFuncOp::verifyRegions() {
3047  if (isExternal())
3048  return success();
3049 
3050  unsigned numArguments = getFunctionType().getNumParams();
3051  Block &entryBlock = front();
3052  for (unsigned i = 0; i < numArguments; ++i) {
3053  Type argType = entryBlock.getArgument(i).getType();
3054  if (!isCompatibleType(argType))
3055  return emitOpError("entry block argument #")
3056  << i << " is not of LLVM type";
3057  }
3058 
3059  return success();
3060 }
3061 
3062 Region *LLVMFuncOp::getCallableRegion() {
3063  if (isExternal())
3064  return nullptr;
3065  return &getBody();
3066 }
3067 
3068 //===----------------------------------------------------------------------===//
3069 // UndefOp.
3070 //===----------------------------------------------------------------------===//
3071 
3072 /// Fold an undef operation to a dedicated undef attribute.
3073 OpFoldResult LLVM::UndefOp::fold(FoldAdaptor) {
3074  return LLVM::UndefAttr::get(getContext());
3075 }
3076 
3077 //===----------------------------------------------------------------------===//
3078 // PoisonOp.
3079 //===----------------------------------------------------------------------===//
3080 
3081 /// Fold a poison operation to a dedicated poison attribute.
3082 OpFoldResult LLVM::PoisonOp::fold(FoldAdaptor) {
3084 }
3085 
3086 //===----------------------------------------------------------------------===//
3087 // ZeroOp.
3088 //===----------------------------------------------------------------------===//
3089 
3090 LogicalResult LLVM::ZeroOp::verify() {
3091  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType()))
3092  if (!targetExtType.hasProperty(LLVM::LLVMTargetExtType::HasZeroInit))
3093  return emitOpError()
3094  << "target extension type does not support zero-initializer";
3095 
3096  return success();
3097 }
3098 
3099 /// Fold a zero operation to a builtin zero attribute when possible and fall
3100 /// back to a dedicated zero attribute.
3101 OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) {
3103  if (result)
3104  return result;
3105  return LLVM::ZeroAttr::get(getContext());
3106 }
3107 
3108 //===----------------------------------------------------------------------===//
3109 // ConstantOp.
3110 //===----------------------------------------------------------------------===//
3111 
3112 /// Compute the total number of elements in the given type, also taking into
3113 /// account nested types. Supported types are `VectorType` and `LLVMArrayType`.
3114 /// Everything else is treated as a scalar.
3115 static int64_t getNumElements(Type t) {
3116  if (auto vecType = dyn_cast<VectorType>(t)) {
3117  assert(!vecType.isScalable() &&
3118  "number of elements of a scalable vector type is unknown");
3119  return vecType.getNumElements() * getNumElements(vecType.getElementType());
3120  }
3121  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3122  return arrayType.getNumElements() *
3123  getNumElements(arrayType.getElementType());
3124  return 1;
3125 }
3126 
3127 /// Check if the given type is a scalable vector type or a vector/array type
3128 /// that contains a nested scalable vector type.
3129 static bool hasScalableVectorType(Type t) {
3130  if (auto vecType = dyn_cast<VectorType>(t)) {
3131  if (vecType.isScalable())
3132  return true;
3133  return hasScalableVectorType(vecType.getElementType());
3134  }
3135  if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t))
3136  return hasScalableVectorType(arrayType.getElementType());
3137  return false;
3138 }
3139 
3140 /// Verifies the constant array represented by `arrayAttr` matches the provided
3141 /// `arrayType`.
3142 static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op,
3143  LLVM::LLVMArrayType arrayType,
3144  ArrayAttr arrayAttr, int dim) {
3145  if (arrayType.getNumElements() != arrayAttr.size())
3146  return op.emitOpError()
3147  << "array attribute size does not match array type size in "
3148  "dimension "
3149  << dim << ": " << arrayAttr.size() << " vs. "
3150  << arrayType.getNumElements();
3151 
3152  llvm::DenseSet<Attribute> elementsVerified;
3153 
3154  // Recursively verify sub-dimensions for multidimensional arrays.
3155  if (auto subArrayType =
3156  dyn_cast<LLVM::LLVMArrayType>(arrayType.getElementType())) {
3157  for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr))
3158  if (elementsVerified.insert(elementAttr).second) {
3159  if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3160  continue;
3161  auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3162  if (!subArrayAttr)
3163  return op.emitOpError()
3164  << "nested attribute for sub-array in dimension " << dim
3165  << " at index " << idx
3166  << " must be a zero, or undef, or array attribute";
3167  if (failed(verifyStructArrayConstant(op, subArrayType, subArrayAttr,
3168  dim + 1)))
3169  return failure();
3170  }
3171  return success();
3172  }
3173 
3174  // Forbid usages of ArrayAttr for simple array types that should use
3175  // DenseElementsAttr instead. Note that there would be a use case for such
3176  // array types when one element value is obtained via a ptr-to-int conversion
3177  // from a symbol and cannot be represented in a DenseElementsAttr, but no MLIR
3178  // user needs this so far, and it seems better to avoid people misusing the
3179  // ArrayAttr for simple types.
3180  auto structType = dyn_cast<LLVM::LLVMStructType>(arrayType.getElementType());
3181  if (!structType)
3182  return op.emitOpError() << "for array with an array attribute must have a "
3183  "struct element type";
3184 
3185  // Shallow verification that leaf attributes are appropriate as struct initial
3186  // value.
3187  size_t numStructElements = structType.getBody().size();
3188  for (auto [idx, elementAttr] : llvm::enumerate(arrayAttr)) {
3189  if (elementsVerified.insert(elementAttr).second) {
3190  if (isa<LLVM::ZeroAttr, LLVM::UndefAttr>(elementAttr))
3191  continue;
3192  auto subArrayAttr = dyn_cast<ArrayAttr>(elementAttr);
3193  if (!subArrayAttr)
3194  return op.emitOpError()
3195  << "nested attribute for struct element at index " << idx
3196  << " must be a zero, or undef, or array attribute";
3197  if (subArrayAttr.size() != numStructElements)
3198  return op.emitOpError()
3199  << "nested array attribute size for struct element at index "
3200  << idx << " must match struct size: " << subArrayAttr.size()
3201  << " vs. " << numStructElements;
3202  }
3203  }
3204 
3205  return success();
3206 }
3207 
3208 LogicalResult LLVM::ConstantOp::verify() {
3209  if (StringAttr sAttr = llvm::dyn_cast<StringAttr>(getValue())) {
3210  auto arrayType = llvm::dyn_cast<LLVMArrayType>(getType());
3211  if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
3212  !arrayType.getElementType().isInteger(8)) {
3213  return emitOpError() << "expected array type of "
3214  << sAttr.getValue().size()
3215  << " i8 elements for the string constant";
3216  }
3217  return success();
3218  }
3219  if (auto structType = dyn_cast<LLVMStructType>(getType())) {
3220  auto arrayAttr = dyn_cast<ArrayAttr>(getValue());
3221  if (!arrayAttr) {
3222  return emitOpError() << "expected array attribute for a struct constant";
3223  }
3224 
3225  ArrayRef<Type> elementTypes = structType.getBody();
3226  if (arrayAttr.size() != elementTypes.size()) {
3227  return emitOpError() << "expected array attribute of size "
3228  << elementTypes.size();
3229  }
3230  for (auto elementTy : elementTypes) {
3231  if (!isa<IntegerType, FloatType, LLVMPPCFP128Type>(elementTy)) {
3232  return emitOpError() << "expected struct element types to be floating "
3233  "point type or integer type";
3234  }
3235  }
3236 
3237  for (size_t i = 0; i < elementTypes.size(); ++i) {
3238  Attribute element = arrayAttr[i];
3239  if (!isa<IntegerAttr, FloatAttr>(element)) {
3240  return emitOpError()
3241  << "expected struct element attribute types to be floating "
3242  "point type or integer type";
3243  }
3244  auto elementType = cast<TypedAttr>(element).getType();
3245  if (elementType != elementTypes[i]) {
3246  return emitOpError()
3247  << "struct element at index " << i << " is of wrong type";
3248  }
3249  }
3250 
3251  return success();
3252  }
3253  if (auto targetExtType = dyn_cast<LLVMTargetExtType>(getType())) {
3254  return emitOpError() << "does not support target extension type.";
3255  }
3256 
3257  // Verification of IntegerAttr, FloatAttr, ElementsAttr, ArrayAttr.
3258  if (auto intAttr = dyn_cast<IntegerAttr>(getValue())) {
3259  if (!llvm::isa<IntegerType>(getType()))
3260  return emitOpError() << "expected integer type";
3261  } else if (auto floatAttr = dyn_cast<FloatAttr>(getValue())) {
3262  const llvm::fltSemantics &sem = floatAttr.getValue().getSemantics();
3263  unsigned floatWidth = APFloat::getSizeInBits(sem);
3264  if (auto floatTy = dyn_cast<FloatType>(getType())) {
3265  if (floatTy.getWidth() != floatWidth) {
3266  return emitOpError() << "expected float type of width " << floatWidth;
3267  }
3268  }
3269  // See the comment for getLLVMConstant for more details about why 8-bit
3270  // floats can be represented by integers.
3271  if (isa<IntegerType>(getType()) && !getType().isInteger(floatWidth)) {
3272  return emitOpError() << "expected integer type of width " << floatWidth;
3273  }
3274  } else if (isa<ElementsAttr>(getValue())) {
3275  if (hasScalableVectorType(getType())) {
3276  // The exact number of elements of a scalable vector is unknown, so we
3277  // allow only splat attributes.
3278  auto splatElementsAttr = dyn_cast<SplatElementsAttr>(getValue());
3279  if (!splatElementsAttr)
3280  return emitOpError()
3281  << "scalable vector type requires a splat attribute";
3282  return success();
3283  }
3284  if (!isa<VectorType, LLVM::LLVMArrayType>(getType()))
3285  return emitOpError() << "expected vector or array type";
3286  // The number of elements of the attribute and the type must match.
3287  if (auto elementsAttr = dyn_cast<ElementsAttr>(getValue())) {
3288  int64_t attrNumElements = elementsAttr.getNumElements();
3289  if (getNumElements(getType()) != attrNumElements)
3290  return emitOpError()
3291  << "type and attribute have a different number of elements: "
3292  << getNumElements(getType()) << " vs. " << attrNumElements;
3293  }
3294  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
3295  auto arrayType = dyn_cast<LLVM::LLVMArrayType>(getType());
3296  if (!arrayType)
3297  return emitOpError() << "expected array type";
3298  // When the attribute is an ArrayAttr, check that its nesting matches the
3299  // corresponding ArrayType or VectorType nesting.
3300  return verifyStructArrayConstant(*this, arrayType, arrayAttr, /*dim=*/0);
3301  } else {
3302  return emitOpError()
3303  << "only supports integer, float, string or elements attributes";
3304  }
3305 
3306  return success();
3307 }
3308 
3309 bool LLVM::ConstantOp::isBuildableWith(Attribute value, Type type) {
3310  // The value's type must be the same as the provided type.
3311  auto typedAttr = dyn_cast<TypedAttr>(value);
3312  if (!typedAttr || typedAttr.getType() != type || !isCompatibleType(type))
3313  return false;
3314  // The value's type must be an LLVM compatible type.
3315  if (!isCompatibleType(type))
3316  return false;
3317  // TODO: Add support for additional attributes kinds once needed.
3318  return isa<IntegerAttr, FloatAttr, ElementsAttr>(value);
3319 }
3320 
3321 ConstantOp LLVM::ConstantOp::materialize(OpBuilder &builder, Attribute value,
3322  Type type, Location loc) {
3323  if (isBuildableWith(value, type))
3324  return builder.create<LLVM::ConstantOp>(loc, cast<TypedAttr>(value));
3325  return nullptr;
3326 }
3327 
3328 // Constant op constant-folds to its value.
3329 OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); }
3330 
3331 //===----------------------------------------------------------------------===//
3332 // AtomicRMWOp
3333 //===----------------------------------------------------------------------===//
3334 
3335 void AtomicRMWOp::build(OpBuilder &builder, OperationState &state,
3336  AtomicBinOp binOp, Value ptr, Value val,
3337  AtomicOrdering ordering, StringRef syncscope,
3338  unsigned alignment, bool isVolatile) {
3339  build(builder, state, val.getType(), binOp, ptr, val, ordering,
3340  !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3341  alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isVolatile,
3342  /*access_groups=*/nullptr,
3343  /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3344 }
3345 
3346 LogicalResult AtomicRMWOp::verify() {
3347  auto valType = getVal().getType();
3348  if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
3349  getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax ||
3350  getBinOp() == AtomicBinOp::fminimum ||
3351  getBinOp() == AtomicBinOp::fmaximum) {
3352  if (isCompatibleVectorType(valType)) {
3353  if (isScalableVectorType(valType))
3354  return emitOpError("expected LLVM IR fixed vector type");
3355  Type elemType = llvm::cast<VectorType>(valType).getElementType();
3356  if (!isCompatibleFloatingPointType(elemType))
3357  return emitOpError(
3358  "expected LLVM IR floating point type for vector element");
3359  } else if (!isCompatibleFloatingPointType(valType)) {
3360  return emitOpError("expected LLVM IR floating point type");
3361  }
3362  } else if (getBinOp() == AtomicBinOp::xchg) {
3363  DataLayout dataLayout = DataLayout::closest(*this);
3364  if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3365  return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
3366  } else {
3367  auto intType = llvm::dyn_cast<IntegerType>(valType);
3368  unsigned intBitWidth = intType ? intType.getWidth() : 0;
3369  if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
3370  intBitWidth != 64)
3371  return emitOpError("expected LLVM IR integer type");
3372  }
3373 
3374  if (static_cast<unsigned>(getOrdering()) <
3375  static_cast<unsigned>(AtomicOrdering::monotonic))
3376  return emitOpError() << "expected at least '"
3377  << stringifyAtomicOrdering(AtomicOrdering::monotonic)
3378  << "' ordering";
3379 
3380  return success();
3381 }
3382 
3383 //===----------------------------------------------------------------------===//
3384 // AtomicCmpXchgOp
3385 //===----------------------------------------------------------------------===//
3386 
3387 /// Returns an LLVM struct type that contains a value type and a boolean type.
3388 static LLVMStructType getValAndBoolStructType(Type valType) {
3389  auto boolType = IntegerType::get(valType.getContext(), 1);
3390  return LLVMStructType::getLiteral(valType.getContext(), {valType, boolType});
3391 }
3392 
3393 void AtomicCmpXchgOp::build(OpBuilder &builder, OperationState &state,
3394  Value ptr, Value cmp, Value val,
3395  AtomicOrdering successOrdering,
3396  AtomicOrdering failureOrdering, StringRef syncscope,
3397  unsigned alignment, bool isWeak, bool isVolatile) {
3398  build(builder, state, getValAndBoolStructType(val.getType()), ptr, cmp, val,
3399  successOrdering, failureOrdering,
3400  !syncscope.empty() ? builder.getStringAttr(syncscope) : nullptr,
3401  alignment ? builder.getI64IntegerAttr(alignment) : nullptr, isWeak,
3402  isVolatile, /*access_groups=*/nullptr,
3403  /*alias_scopes=*/nullptr, /*noalias_scopes=*/nullptr, /*tbaa=*/nullptr);
3404 }
3405 
3406 LogicalResult AtomicCmpXchgOp::verify() {
3407  auto ptrType = llvm::cast<LLVM::LLVMPointerType>(getPtr().getType());
3408  if (!ptrType)
3409  return emitOpError("expected LLVM IR pointer type for operand #0");
3410  auto valType = getVal().getType();
3411  DataLayout dataLayout = DataLayout::closest(*this);
3412  if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
3413  return emitOpError("unexpected LLVM IR type");
3414  if (getSuccessOrdering() < AtomicOrdering::monotonic ||
3415  getFailureOrdering() < AtomicOrdering::monotonic)
3416  return emitOpError("ordering must be at least 'monotonic'");
3417  if (getFailureOrdering() == AtomicOrdering::release ||
3418  getFailureOrdering() == AtomicOrdering::acq_rel)
3419  return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
3420  return success();
3421 }
3422 
3423 //===----------------------------------------------------------------------===//
3424 // FenceOp
3425 //===----------------------------------------------------------------------===//
3426 
3427 void FenceOp::build(OpBuilder &builder, OperationState &state,
3428  AtomicOrdering ordering, StringRef syncscope) {
3429  build(builder, state, ordering,
3430  syncscope.empty() ? nullptr : builder.getStringAttr(syncscope));
3431 }
3432 
3433 LogicalResult FenceOp::verify() {
3434  if (getOrdering() == AtomicOrdering::not_atomic ||
3435  getOrdering() == AtomicOrdering::unordered ||
3436  getOrdering() == AtomicOrdering::monotonic)
3437  return emitOpError("can be given only acquire, release, acq_rel, "
3438  "and seq_cst orderings");
3439  return success();
3440 }
3441 
3442 //===----------------------------------------------------------------------===//
3443 // Verifier for extension ops
3444 //===----------------------------------------------------------------------===//
3445 
3446 /// Verifies that the given extension operation operates on consistent scalars
3447 /// or vectors, and that the target width is larger than the input width.
3448 template <class ExtOp>
3449 static LogicalResult verifyExtOp(ExtOp op) {
3450  IntegerType inputType, outputType;
3451  if (isCompatibleVectorType(op.getArg().getType())) {
3452  if (!isCompatibleVectorType(op.getResult().getType()))
3453  return op.emitError(
3454  "input type is a vector but output type is an integer");
3455  if (getVectorNumElements(op.getArg().getType()) !=
3456  getVectorNumElements(op.getResult().getType()))
3457  return op.emitError("input and output vectors are of incompatible shape");
3458  // Because this is a CastOp, the element of vectors is guaranteed to be an
3459  // integer.
3460  inputType = cast<IntegerType>(
3461  cast<VectorType>(op.getArg().getType()).getElementType());
3462  outputType = cast<IntegerType>(
3463  cast<VectorType>(op.getResult().getType()).getElementType());
3464  } else {
3465  // Because this is a CastOp and arg is not a vector, arg is guaranteed to be
3466  // an integer.
3467  inputType = cast<IntegerType>(op.getArg().getType());
3468  outputType = dyn_cast<IntegerType>(op.getResult().getType());
3469  if (!outputType)
3470  return op.emitError(
3471  "input type is an integer but output type is a vector");
3472  }
3473 
3474  if (outputType.getWidth() <= inputType.getWidth())
3475  return op.emitError("integer width of the output type is smaller or "
3476  "equal to the integer width of the input type");
3477  return success();
3478 }
3479 
3480 //===----------------------------------------------------------------------===//
3481 // ZExtOp
3482 //===----------------------------------------------------------------------===//
3483 
3484 LogicalResult ZExtOp::verify() { return verifyExtOp<ZExtOp>(*this); }
3485 
3486 OpFoldResult LLVM::ZExtOp::fold(FoldAdaptor adaptor) {
3487  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
3488  if (!arg)
3489  return {};
3490 
3491  size_t targetSize = cast<IntegerType>(getType()).getWidth();
3492  return IntegerAttr::get(getType(), arg.getValue().zext(targetSize));
3493 }
3494 
3495 //===----------------------------------------------------------------------===//
3496 // SExtOp
3497 //===----------------------------------------------------------------------===//
3498 
3499 LogicalResult SExtOp::verify() { return verifyExtOp<SExtOp>(*this); }
3500 
3501 //===----------------------------------------------------------------------===//
3502 // Folder and verifier for LLVM::BitcastOp
3503 //===----------------------------------------------------------------------===//
3504 
3505 /// Folds a cast op that can be chained.
3506 template <typename T>
3508  typename T::FoldAdaptor adaptor) {
3509  // cast(x : T0, T0) -> x
3510  if (castOp.getArg().getType() == castOp.getType())
3511  return castOp.getArg();
3512  if (auto prev = castOp.getArg().template getDefiningOp<T>()) {
3513  // cast(cast(x : T0, T1), T0) -> x
3514  if (prev.getArg().getType() == castOp.getType())
3515  return prev.getArg();
3516  // cast(cast(x : T0, T1), T2) -> cast(x: T0, T2)
3517  castOp.getArgMutable().set(prev.getArg());
3518  return Value{castOp};
3519  }
3520  return {};
3521 }
3522 
3523 OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) {
3524  return foldChainableCast(*this, adaptor);
3525 }
3526 
3527 LogicalResult LLVM::BitcastOp::verify() {
3528  auto resultType = llvm::dyn_cast<LLVMPointerType>(
3529  extractVectorElementType(getResult().getType()));
3530  auto sourceType = llvm::dyn_cast<LLVMPointerType>(
3531  extractVectorElementType(getArg().getType()));
3532 
3533  // If one of the types is a pointer (or vector of pointers), then
3534  // both source and result type have to be pointers.
3535  if (static_cast<bool>(resultType) != static_cast<bool>(sourceType))
3536  return emitOpError("can only cast pointers from and to pointers");
3537 
3538  if (!resultType)
3539  return success();
3540 
3541  auto isVector = llvm::IsaPred<VectorType>;
3542 
3543  // Due to bitcast requiring both operands to be of the same size, it is not
3544  // possible for only one of the two to be a pointer of vectors.
3545  if (isVector(getResult().getType()) && !isVector(getArg().getType()))
3546  return emitOpError("cannot cast pointer to vector of pointers");
3547 
3548  if (!isVector(getResult().getType()) && isVector(getArg().getType()))
3549  return emitOpError("cannot cast vector of pointers to pointer");
3550 
3551  // Bitcast cannot cast between pointers of different address spaces.
3552  // 'llvm.addrspacecast' must be used for this purpose instead.
3553  if (resultType.getAddressSpace() != sourceType.getAddressSpace())
3554  return emitOpError("cannot cast pointers of different address spaces, "
3555  "use 'llvm.addrspacecast' instead");
3556 
3557  return success();
3558 }
3559 
3560 //===----------------------------------------------------------------------===//
3561 // Folder for LLVM::AddrSpaceCastOp
3562 //===----------------------------------------------------------------------===//
3563 
3564 OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) {
3565  return foldChainableCast(*this, adaptor);
3566 }
3567 
3568 Value LLVM::AddrSpaceCastOp::getViewSource() { return getArg(); }
3569 
3570 //===----------------------------------------------------------------------===//
3571 // Folder for LLVM::GEPOp
3572 //===----------------------------------------------------------------------===//
3573 
3574 OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) {
3575  GEPIndicesAdaptor<ArrayRef<Attribute>> indices(getRawConstantIndicesAttr(),
3576  adaptor.getDynamicIndices());
3577 
3578  // gep %x:T, 0 -> %x
3579  if (getBase().getType() == getType() && indices.size() == 1)
3580  if (auto integer = llvm::dyn_cast_or_null<IntegerAttr>(indices[0]))
3581  if (integer.getValue().isZero())
3582  return getBase();
3583 
3584  // Canonicalize any dynamic indices of constant value to constant indices.
3585  bool changed = false;
3586  SmallVector<GEPArg> gepArgs;
3587  for (auto iter : llvm::enumerate(indices)) {
3588  auto integer = llvm::dyn_cast_or_null<IntegerAttr>(iter.value());
3589  // Constant indices can only be int32_t, so if integer does not fit we
3590  // are forced to keep it dynamic, despite being a constant.
3591  if (!indices.isDynamicIndex(iter.index()) || !integer ||
3592  !integer.getValue().isSignedIntN(kGEPConstantBitWidth)) {
3593 
3594  PointerUnion<IntegerAttr, Value> existing = getIndices()[iter.index()];
3595  if (Value val = llvm::dyn_cast_if_present<Value>(existing))
3596  gepArgs.emplace_back(val);
3597  else
3598  gepArgs.emplace_back(cast<IntegerAttr>(existing).getInt());
3599 
3600  continue;
3601  }
3602 
3603  changed = true;
3604  gepArgs.emplace_back(integer.getInt());
3605  }
3606  if (changed) {
3607  SmallVector<int32_t> rawConstantIndices;
3608  SmallVector<Value> dynamicIndices;
3609  destructureIndices(getElemType(), gepArgs, rawConstantIndices,
3610  dynamicIndices);
3611 
3612  getDynamicIndicesMutable().assign(dynamicIndices);
3613  setRawConstantIndices(rawConstantIndices);
3614  return Value{*this};
3615  }
3616 
3617  return {};
3618 }
3619 
3620 Value LLVM::GEPOp::getViewSource() { return getBase(); }
3621 
3622 //===----------------------------------------------------------------------===//
3623 // ShlOp
3624 //===----------------------------------------------------------------------===//
3625 
3626 OpFoldResult LLVM::ShlOp::fold(FoldAdaptor adaptor) {
3627  auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3628  if (!rhs)
3629  return {};
3630 
3631  if (rhs.getValue().getZExtValue() >=
3632  getLhs().getType().getIntOrFloatBitWidth())
3633  return {}; // TODO: Fold into poison.
3634 
3635  auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3636  if (!lhs)
3637  return {};
3638 
3639  return IntegerAttr::get(getType(), lhs.getValue().shl(rhs.getValue()));
3640 }
3641 
3642 //===----------------------------------------------------------------------===//
3643 // OrOp
3644 //===----------------------------------------------------------------------===//
3645 
3646 OpFoldResult LLVM::OrOp::fold(FoldAdaptor adaptor) {
3647  auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
3648  if (!lhs)
3649  return {};
3650 
3651  auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());
3652  if (!rhs)
3653  return {};
3654 
3655  return IntegerAttr::get(getType(), lhs.getValue() | rhs.getValue());
3656 }
3657 
3658 //===----------------------------------------------------------------------===//
3659 // CallIntrinsicOp
3660 //===----------------------------------------------------------------------===//
3661 
3662 LogicalResult CallIntrinsicOp::verify() {
3663  if (!getIntrin().starts_with("llvm."))
3664  return emitOpError() << "intrinsic name must start with 'llvm.'";
3665  if (failed(verifyOperandBundles(*this)))
3666  return failure();
3667  return success();
3668 }
3669 
3670 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3671  mlir::StringAttr intrin, mlir::ValueRange args) {
3672  build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3673  FastmathFlagsAttr{},
3674  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3675  /*res_attrs=*/{});
3676 }
3677 
3678 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3679  mlir::StringAttr intrin, mlir::ValueRange args,
3680  mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3681  build(builder, state, /*resultTypes=*/TypeRange{}, intrin, args,
3682  fastMathFlags,
3683  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3684  /*res_attrs=*/{});
3685 }
3686 
3687 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3688  mlir::Type resultType, mlir::StringAttr intrin,
3689  mlir::ValueRange args) {
3690  build(builder, state, {resultType}, intrin, args, FastmathFlagsAttr{},
3691  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3692  /*res_attrs=*/{});
3693 }
3694 
3695 void CallIntrinsicOp::build(OpBuilder &builder, OperationState &state,
3696  mlir::TypeRange resultTypes,
3697  mlir::StringAttr intrin, mlir::ValueRange args,
3698  mlir::LLVM::FastmathFlagsAttr fastMathFlags) {
3699  build(builder, state, resultTypes, intrin, args, fastMathFlags,
3700  /*op_bundle_operands=*/{}, /*op_bundle_tags=*/{}, /*arg_attrs=*/{},
3701  /*res_attrs=*/{});
3702 }
3703 
3704 ParseResult CallIntrinsicOp::parse(OpAsmParser &parser,
3705  OperationState &result) {
3706  StringAttr intrinAttr;
3709  SmallVector<SmallVector<Type>> opBundleOperandTypes;
3710  ArrayAttr opBundleTags;
3711 
3712  // Parse intrinsic name.
3714  intrinAttr, parser.getBuilder().getType<NoneType>()))
3715  return failure();
3716  result.addAttribute(CallIntrinsicOp::getIntrinAttrName(result.name),
3717  intrinAttr);
3718 
3719  if (parser.parseLParen())
3720  return failure();
3721 
3722  // Parse the function arguments.
3723  if (parser.parseOperandList(operands))
3724  return mlir::failure();
3725 
3726  if (parser.parseRParen())
3727  return mlir::failure();
3728 
3729  // Handle bundles.
3730  SMLoc opBundlesLoc = parser.getCurrentLocation();
3731  if (std::optional<ParseResult> result = parseOpBundles(
3732  parser, opBundleOperands, opBundleOperandTypes, opBundleTags);
3733  result && failed(*result))
3734  return failure();
3735  if (opBundleTags && !opBundleTags.empty())
3736  result.addAttribute(
3737  CallIntrinsicOp::getOpBundleTagsAttrName(result.name).getValue(),
3738  opBundleTags);
3739 
3740  if (parser.parseOptionalAttrDict(result.attributes))
3741  return mlir::failure();
3742 
3743  SmallVector<DictionaryAttr> argAttrs;
3744  SmallVector<DictionaryAttr> resultAttrs;
3745  if (parseCallTypeAndResolveOperands(parser, result, /*isDirect=*/true,
3746  operands, argAttrs, resultAttrs))
3747  return failure();
3749  parser.getBuilder(), result, argAttrs, resultAttrs,
3750  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
3751 
3752  if (resolveOpBundleOperands(parser, opBundlesLoc, result, opBundleOperands,
3753  opBundleOperandTypes,
3754  getOpBundleSizesAttrName(result.name)))
3755  return failure();
3756 
3757  int32_t numOpBundleOperands = 0;
3758  for (const auto &operands : opBundleOperands)
3759  numOpBundleOperands += operands.size();
3760 
3761  result.addAttribute(
3762  CallIntrinsicOp::getOperandSegmentSizeAttr(),
3764  {static_cast<int32_t>(operands.size()), numOpBundleOperands}));
3765 
3766  return mlir::success();
3767 }
3768 
3770  p << ' ';
3771  p.printAttributeWithoutType(getIntrinAttr());
3772 
3773  OperandRange args = getArgs();
3774  p << "(" << args << ")";
3775 
3776  // Operand bundles.
3777  if (!getOpBundleOperands().empty()) {
3778  p << ' ';
3779  printOpBundles(p, *this, getOpBundleOperands(),
3780  getOpBundleOperands().getTypes(), getOpBundleTagsAttr());
3781  }
3782 
3783  p.printOptionalAttrDict(processFMFAttr((*this)->getAttrs()),
3784  {getOperandSegmentSizesAttrName(),
3785  getOpBundleSizesAttrName(), getIntrinAttrName(),
3786  getOpBundleTagsAttrName(), getArgAttrsAttrName(),
3787  getResAttrsAttrName()});
3788 
3789  p << " : ";
3790 
3791  // Reconstruct the MLIR function type from operand and result types.
3793  p, args.getTypes(), getArgAttrsAttr(),
3794  /*isVariadic=*/false, getResultTypes(), getResAttrsAttr());
3795 }
3796 
3797 //===----------------------------------------------------------------------===//
3798 // LinkerOptionsOp
3799 //===----------------------------------------------------------------------===//
3800 
3801 LogicalResult LinkerOptionsOp::verify() {
3802  if (mlir::Operation *parentOp = (*this)->getParentOp();
3803  parentOp && !satisfiesLLVMModule(parentOp))
3804  return emitOpError("must appear at the module level");
3805  return success();
3806 }
3807 
3808 //===----------------------------------------------------------------------===//
3809 // ModuleFlagsOp
3810 //===----------------------------------------------------------------------===//
3811 
3812 LogicalResult ModuleFlagsOp::verify() {
3813  if (Operation *parentOp = (*this)->getParentOp();
3814  parentOp && !satisfiesLLVMModule(parentOp))
3815  return emitOpError("must appear at the module level");
3816  for (Attribute flag : getFlags())
3817  if (!isa<ModuleFlagAttr>(flag))
3818  return emitOpError("expected a module flag attribute");
3819  return success();
3820 }
3821 
3822 //===----------------------------------------------------------------------===//
3823 // InlineAsmOp
3824 //===----------------------------------------------------------------------===//
3825 
3826 void InlineAsmOp::getEffects(
3828  &effects) {
3829  if (getHasSideEffects()) {
3830  effects.emplace_back(MemoryEffects::Write::get());
3831  effects.emplace_back(MemoryEffects::Read::get());
3832  }
3833 }
3834 
3835 //===----------------------------------------------------------------------===//
3836 // BlockAddressOp
3837 //===----------------------------------------------------------------------===//
3838 
3839 LogicalResult
3840 BlockAddressOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3841  Operation *symbol = symbolTable.lookupSymbolIn(parentLLVMModule(*this),
3842  getBlockAddr().getFunction());
3843  auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
3844 
3845  if (!function)
3846  return emitOpError("must reference a function defined by 'llvm.func'");
3847 
3848  return success();
3849 }
3850 
3851 LLVMFuncOp BlockAddressOp::getFunction(SymbolTableCollection &symbolTable) {
3852  return dyn_cast_or_null<LLVMFuncOp>(symbolTable.lookupSymbolIn(
3853  parentLLVMModule(*this), getBlockAddr().getFunction()));
3854 }
3855 
3856 BlockTagOp BlockAddressOp::getBlockTagOp() {
3857  auto funcOp = dyn_cast<LLVMFuncOp>(mlir::SymbolTable::lookupNearestSymbolFrom(
3858  parentLLVMModule(*this), getBlockAddr().getFunction()));
3859  if (!funcOp)
3860  return nullptr;
3861 
3862  BlockTagOp blockTagOp = nullptr;
3863  funcOp.walk([&](LLVM::BlockTagOp labelOp) {
3864  if (labelOp.getTag() == getBlockAddr().getTag()) {
3865  blockTagOp = labelOp;
3866  return WalkResult::interrupt();
3867  }
3868  return WalkResult::advance();
3869  });
3870  return blockTagOp;
3871 }
3872 
3873 LogicalResult BlockAddressOp::verify() {
3874  if (!getBlockTagOp())
3875  return emitOpError(
3876  "expects an existing block label target in the referenced function");
3877 
3878  return success();
3879 }
3880 
3881 /// Fold a blockaddress operation to a dedicated blockaddress
3882 /// attribute.
3883 OpFoldResult BlockAddressOp::fold(FoldAdaptor) { return getBlockAddr(); }
3884 
3885 //===----------------------------------------------------------------------===//
3886 // LLVM::IndirectBrOp
3887 //===----------------------------------------------------------------------===//
3888 
3890  assert(index < getNumSuccessors() && "invalid successor index");
3891  return SuccessorOperands(getSuccOperandsMutable()[index]);
3892 }
3893 
3894 void IndirectBrOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3895  Value addr, ArrayRef<ValueRange> succOperands,
3896  BlockRange successors) {
3897  odsState.addOperands(addr);
3898  for (ValueRange range : succOperands)
3899  odsState.addOperands(range);
3900  SmallVector<int32_t> rangeSegments;
3901  for (ValueRange range : succOperands)
3902  rangeSegments.push_back(range.size());
3903  odsState.getOrAddProperties<Properties>().indbr_operand_segments =
3904  odsBuilder.getDenseI32ArrayAttr(rangeSegments);
3905  odsState.addSuccessors(successors);
3906 }
3907 
3908 static ParseResult parseIndirectBrOpSucessors(
3909  OpAsmParser &parser, Type &flagType,
3910  SmallVectorImpl<Block *> &succOperandBlocks,
3912  SmallVectorImpl<SmallVector<Type>> &succOperandsTypes) {
3913  if (failed(parser.parseCommaSeparatedList(
3915  [&]() {
3916  Block *destination = nullptr;
3917  SmallVector<OpAsmParser::UnresolvedOperand> operands;
3918  SmallVector<Type> operandTypes;
3919 
3920  if (parser.parseSuccessor(destination).failed())
3921  return failure();
3922 
3923  if (succeeded(parser.parseOptionalLParen())) {
3924  if (failed(parser.parseOperandList(
3925  operands, OpAsmParser::Delimiter::None)) ||
3926  failed(parser.parseColonTypeList(operandTypes)) ||
3927  failed(parser.parseRParen()))
3928  return failure();
3929  }
3930  succOperandBlocks.push_back(destination);
3931  succOperands.emplace_back(operands);
3932  succOperandsTypes.emplace_back(operandTypes);
3933  return success();
3934  },
3935  "successor blocks")))
3936  return failure();
3937  return success();
3938 }
3939 
3940 static void
3941 printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType,
3942  SuccessorRange succs, OperandRangeRange succOperands,
3943  const TypeRangeRange &succOperandsTypes) {
3944  p << "[";
3945  llvm::interleave(
3946  llvm::zip(succs, succOperands),
3947  [&](auto i) {
3948  p.printNewline();
3949  p.printSuccessorAndUseList(std::get<0>(i), std::get<1>(i));
3950  },
3951  [&] { p << ','; });
3952  if (!succOperands.empty())
3953  p.printNewline();
3954  p << "]";
3955 }
3956 
3957 //===----------------------------------------------------------------------===//
3958 // AssumeOp (intrinsic)
3959 //===----------------------------------------------------------------------===//
3960 
3961 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3962  mlir::Value cond) {
3963  return build(builder, state, cond, /*op_bundle_operands=*/{},
3964  /*op_bundle_tags=*/ArrayAttr{});
3965 }
3966 
3967 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3968  Value cond,
3969  ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) {
3970  SmallVector<ValueRange> opBundleOperands;
3971  SmallVector<Attribute> opBundleTags;
3972  opBundleOperands.reserve(opBundles.size());
3973  opBundleTags.reserve(opBundles.size());
3974 
3975  for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) {
3976  opBundleOperands.emplace_back(bundle.inputs());
3977  opBundleTags.push_back(
3978  StringAttr::get(builder.getContext(), bundle.getTag()));
3979  }
3980 
3981  auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags);
3982  return build(builder, state, cond, opBundleOperands, opBundleTagsAttr);
3983 }
3984 
3985 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3986  Value cond, llvm::StringRef tag, ValueRange args) {
3987  llvm::OperandBundleDefT<Value> opBundle(
3988  tag.str(), SmallVector<Value>(args.begin(), args.end()));
3989  return build(builder, state, cond, opBundle);
3990 }
3991 
3992 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3993  Value cond, AssumeAlignTag, Value ptr, Value align) {
3994  return build(builder, state, cond, "align", ValueRange{ptr, align});
3995 }
3996 
3997 void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state,
3998  Value cond, AssumeSeparateStorageTag, Value ptr1,
3999  Value ptr2) {
4000  return build(builder, state, cond, "separate_storage",
4001  ValueRange{ptr1, ptr2});
4002 }
4003 
4004 LogicalResult LLVM::AssumeOp::verify() { return verifyOperandBundles(*this); }
4005 
4006 //===----------------------------------------------------------------------===//
4007 // masked_gather (intrinsic)
4008 //===----------------------------------------------------------------------===//
4009 
4010 LogicalResult LLVM::masked_gather::verify() {
4011  auto ptrsVectorType = getPtrs().getType();
4012  Type expectedPtrsVectorType =
4014  LLVM::getVectorNumElements(getRes().getType()));
4015  // Vector of pointers type should match result vector type, other than the
4016  // element type.
4017  if (ptrsVectorType != expectedPtrsVectorType)
4018  return emitOpError("expected operand #1 type to be ")
4019  << expectedPtrsVectorType;
4020  return success();
4021 }
4022 
4023 //===----------------------------------------------------------------------===//
4024 // masked_scatter (intrinsic)
4025 //===----------------------------------------------------------------------===//
4026 
4027 LogicalResult LLVM::masked_scatter::verify() {
4028  auto ptrsVectorType = getPtrs().getType();
4029  Type expectedPtrsVectorType =
4031  LLVM::getVectorNumElements(getValue().getType()));
4032  // Vector of pointers type should match value vector type, other than the
4033  // element type.
4034  if (ptrsVectorType != expectedPtrsVectorType)
4035  return emitOpError("expected operand #2 type to be ")
4036  << expectedPtrsVectorType;
4037  return success();
4038 }
4039 
4040 //===----------------------------------------------------------------------===//
4041 // InlineAsmOp
4042 //===----------------------------------------------------------------------===//
4043 
4044 LogicalResult InlineAsmOp::verify() {
4045  if (!getTailCallKindAttr())
4046  return success();
4047 
4048  if (getTailCallKindAttr().getTailCallKind() == TailCallKind::MustTail)
4049  return emitOpError(
4050  "tail call kind 'musttail' is not supported by this operation");
4051 
4052  return success();
4053 }
4054 
4055 //===----------------------------------------------------------------------===//
4056 // LLVMDialect initialization, type parsing, and registration.
4057 //===----------------------------------------------------------------------===//
4058 
4059 void LLVMDialect::initialize() {
4060  registerAttributes();
4061 
4062  // clang-format off
4063  addTypes<LLVMVoidType,
4064  LLVMTokenType,
4065  LLVMLabelType,
4066  LLVMMetadataType>();
4067  // clang-format on
4068  registerTypes();
4069 
4070  addOperations<
4071 #define GET_OP_LIST
4072 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4073  ,
4074 #define GET_OP_LIST
4075 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4076  >();
4077 
4078  // Support unknown operations because not all LLVM operations are registered.
4079  allowUnknownOperations();
4080  declarePromisedInterface<DialectInlinerInterface, LLVMDialect>();
4081 }
4082 
4083 #define GET_OP_CLASSES
4084 #include "mlir/Dialect/LLVMIR/LLVMOps.cpp.inc"
4085 
4086 #define GET_OP_CLASSES
4087 #include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.cpp.inc"
4088 
4089 LogicalResult LLVMDialect::verifyDataLayoutString(
4090  StringRef descr, llvm::function_ref<void(const Twine &)> reportError) {
4091  llvm::Expected<llvm::DataLayout> maybeDataLayout =
4092  llvm::DataLayout::parse(descr);
4093  if (maybeDataLayout)
4094  return success();
4095 
4096  std::string message;
4097  llvm::raw_string_ostream messageStream(message);
4098  llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream);
4099  reportError("invalid data layout descriptor: " + message);
4100  return failure();
4101 }
4102 
4103 /// Verify LLVM dialect attributes.
4104 LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op,
4105  NamedAttribute attr) {
4106  // If the data layout attribute is present, it must use the LLVM data layout
4107  // syntax. Try parsing it and report errors in case of failure. Users of this
4108  // attribute may assume it is well-formed and can pass it to the (asserting)
4109  // llvm::DataLayout constructor.
4110  if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName())
4111  return success();
4112  if (auto stringAttr = llvm::dyn_cast<StringAttr>(attr.getValue()))
4113  return verifyDataLayoutString(
4114  stringAttr.getValue(),
4115  [op](const Twine &message) { op->emitOpError() << message.str(); });
4116 
4117  return op->emitOpError() << "expected '"
4118  << LLVM::LLVMDialect::getDataLayoutAttrName()
4119  << "' to be a string attributes";
4120 }
4121 
4122 LogicalResult LLVMDialect::verifyParameterAttribute(Operation *op,
4123  Type paramType,
4124  NamedAttribute paramAttr) {
4125  // LLVM attribute may be attached to a result of operation that has not been
4126  // converted to LLVM dialect yet, so the result may have a type with unknown
4127  // representation in LLVM dialect type space. In this case we cannot verify
4128  // whether the attribute may be
4129  bool verifyValueType = isCompatibleType(paramType);
4130  StringAttr name = paramAttr.getName();
4131 
4132  auto checkUnitAttrType = [&]() -> LogicalResult {
4133  if (!llvm::isa<UnitAttr>(paramAttr.getValue()))
4134  return op->emitError() << name << " should be a unit attribute";
4135  return success();
4136  };
4137  auto checkTypeAttrType = [&]() -> LogicalResult {
4138  if (!llvm::isa<TypeAttr>(paramAttr.getValue()))
4139  return op->emitError() << name << " should be a type attribute";
4140  return success();
4141  };
4142  auto checkIntegerAttrType = [&]() -> LogicalResult {
4143  if (!llvm::isa<IntegerAttr>(paramAttr.getValue()))
4144  return op->emitError() << name << " should be an integer attribute";
4145  return success();
4146  };
4147  auto checkPointerType = [&]() -> LogicalResult {
4148  if (!llvm::isa<LLVMPointerType>(paramType))
4149  return op->emitError()
4150  << name << " attribute attached to non-pointer LLVM type";
4151  return success();
4152  };
4153  auto checkIntegerType = [&]() -> LogicalResult {
4154  if (!llvm::isa<IntegerType>(paramType))
4155  return op->emitError()
4156  << name << " attribute attached to non-integer LLVM type";
4157  return success();
4158  };
4159  auto checkPointerTypeMatches = [&]() -> LogicalResult {
4160  if (failed(checkPointerType()))
4161  return failure();
4162 
4163  return success();
4164  };
4165 
4166  // Check a unit attribute that is attached to a pointer value.
4167  if (name == LLVMDialect::getNoAliasAttrName() ||
4168  name == LLVMDialect::getReadonlyAttrName() ||
4169  name == LLVMDialect::getReadnoneAttrName() ||
4170  name == LLVMDialect::getWriteOnlyAttrName() ||
4171  name == LLVMDialect::getNestAttrName() ||
4172  name == LLVMDialect::getNoCaptureAttrName() ||
4173  name == LLVMDialect::getNoFreeAttrName() ||
4174  name == LLVMDialect::getNonNullAttrName()) {
4175  if (failed(checkUnitAttrType()))
4176  return failure();
4177  if (verifyValueType && failed(checkPointerType()))
4178  return failure();
4179  return success();
4180  }
4181 
4182  // Check a type attribute that is attached to a pointer value.
4183  if (name == LLVMDialect::getStructRetAttrName() ||
4184  name == LLVMDialect::getByValAttrName() ||
4185  name == LLVMDialect::getByRefAttrName() ||
4186  name == LLVMDialect::getElementTypeAttrName() ||
4187  name == LLVMDialect::getInAllocaAttrName() ||
4188  name == LLVMDialect::getPreallocatedAttrName()) {
4189  if (failed(checkTypeAttrType()))
4190  return failure();
4191  if (verifyValueType && failed(checkPointerTypeMatches()))
4192  return failure();
4193  return success();
4194  }
4195 
4196  // Check a unit attribute that is attached to an integer value.
4197  if (name == LLVMDialect::getSExtAttrName() ||
4198  name == LLVMDialect::getZExtAttrName()) {
4199  if (failed(checkUnitAttrType()))
4200  return failure();
4201  if (verifyValueType && failed(checkIntegerType()))
4202  return failure();
4203  return success();
4204  }
4205 
4206  // Check an integer attribute that is attached to a pointer value.
4207  if (name == LLVMDialect::getAlignAttrName() ||
4208  name == LLVMDialect::getDereferenceableAttrName() ||
4209  name == LLVMDialect::getDereferenceableOrNullAttrName()) {
4210  if (failed(checkIntegerAttrType()))
4211  return failure();
4212  if (verifyValueType && failed(checkPointerType()))
4213  return failure();
4214  return success();
4215  }
4216 
4217  // Check an integer attribute that is attached to a pointer value.
4218  if (name == LLVMDialect::getStackAlignmentAttrName()) {
4219  if (failed(checkIntegerAttrType()))
4220  return failure();
4221  return success();
4222  }
4223 
4224  // Check a unit attribute that can be attached to arbitrary types.
4225  if (name == LLVMDialect::getNoUndefAttrName() ||
4226  name == LLVMDialect::getInRegAttrName() ||
4227  name == LLVMDialect::getReturnedAttrName())
4228  return checkUnitAttrType();
4229 
4230  return success();
4231 }
4232 
4233 /// Verify LLVMIR function argument attributes.
4234 LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op,
4235  unsigned regionIdx,
4236  unsigned argIdx,
4237  NamedAttribute argAttr) {
4238  auto funcOp = dyn_cast<FunctionOpInterface>(op);
4239  if (!funcOp)
4240  return success();
4241  Type argType = funcOp.getArgumentTypes()[argIdx];
4242 
4243  return verifyParameterAttribute(op, argType, argAttr);
4244 }
4245 
4246 LogicalResult LLVMDialect::verifyRegionResultAttribute(Operation *op,
4247  unsigned regionIdx,
4248  unsigned resIdx,
4249  NamedAttribute resAttr) {
4250  auto funcOp = dyn_cast<FunctionOpInterface>(op);
4251  if (!funcOp)
4252  return success();
4253  Type resType = funcOp.getResultTypes()[resIdx];
4254 
4255  // Check to see if this function has a void return with a result attribute
4256  // to it. It isn't clear what semantics we would assign to that.
4257  if (llvm::isa<LLVMVoidType>(resType))
4258  return op->emitError() << "cannot attach result attributes to functions "
4259  "with a void return";
4260 
4261  // Check to see if this attribute is allowed as a result attribute. Only
4262  // explicitly forbidden LLVM attributes will cause an error.
4263  auto name = resAttr.getName();
4264  if (name == LLVMDialect::getAllocAlignAttrName() ||
4265  name == LLVMDialect::getAllocatedPointerAttrName() ||
4266  name == LLVMDialect::getByValAttrName() ||
4267  name == LLVMDialect::getByRefAttrName() ||
4268  name == LLVMDialect::getInAllocaAttrName() ||
4269  name == LLVMDialect::getNestAttrName() ||
4270  name == LLVMDialect::getNoCaptureAttrName() ||
4271  name == LLVMDialect::getNoFreeAttrName() ||
4272  name == LLVMDialect::getPreallocatedAttrName() ||
4273  name == LLVMDialect::getReadnoneAttrName() ||
4274  name == LLVMDialect::getReadonlyAttrName() ||
4275  name == LLVMDialect::getReturnedAttrName() ||
4276  name == LLVMDialect::getStackAlignmentAttrName() ||
4277  name == LLVMDialect::getStructRetAttrName() ||
4278  name == LLVMDialect::getWriteOnlyAttrName())
4279  return op->emitError() << name << " is not a valid result attribute";
4280  return verifyParameterAttribute(op, resType, resAttr);
4281 }
4282 
4284  Type type, Location loc) {
4285  // If this was folded from an operation other than llvm.mlir.constant, it
4286  // should be materialized as such. Note that an llvm.mlir.zero may fold into
4287  // a builtin zero attribute and thus will materialize as a llvm.mlir.constant.
4288  if (auto symbol = dyn_cast<FlatSymbolRefAttr>(value))
4289  if (isa<LLVM::LLVMPointerType>(type))
4290  return builder.create<LLVM::AddressOfOp>(loc, type, symbol);
4291  if (isa<LLVM::UndefAttr>(value))
4292  return builder.create<LLVM::UndefOp>(loc, type);
4293  if (isa<LLVM::PoisonAttr>(value))
4294  return builder.create<LLVM::PoisonOp>(loc, type);
4295  if (isa<LLVM::ZeroAttr>(value))
4296  return builder.create<LLVM::ZeroOp>(loc, type);
4297  // Otherwise try materializing it as a regular llvm.mlir.constant op.
4298  return LLVM::ConstantOp::materialize(builder, value, type, loc);
4299 }
4300 
4301 //===----------------------------------------------------------------------===//
4302 // Utility functions.
4303 //===----------------------------------------------------------------------===//
4304 
4306  StringRef name, StringRef value,
4307  LLVM::Linkage linkage) {
4308  assert(builder.getInsertionBlock() &&
4309  builder.getInsertionBlock()->getParentOp() &&
4310  "expected builder to point to a block constrained in an op");
4311  auto module =
4312  builder.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
4313  assert(module && "builder points to an op outside of a module");
4314 
4315  // Create the global at the entry of the module.
4316  OpBuilder moduleBuilder(module.getBodyRegion(), builder.getListener());
4317  MLIRContext *ctx = builder.getContext();
4318  auto type = LLVM::LLVMArrayType::get(IntegerType::get(ctx, 8), value.size());
4319  auto global = moduleBuilder.create<LLVM::GlobalOp>(
4320  loc, type, /*isConstant=*/true, linkage, name,
4321  builder.getStringAttr(value), /*alignment=*/0);
4322 
4323  LLVMPointerType ptrType = LLVMPointerType::get(ctx);
4324  // Get the pointer to the first character in the global string.
4325  Value globalPtr =
4326  builder.create<LLVM::AddressOfOp>(loc, ptrType, global.getSymNameAttr());
4327  return builder.create<LLVM::GEPOp>(loc, ptrType, type, globalPtr,
4328  ArrayRef<GEPArg>{0, 0});
4329 }
4330 
4332  return op->hasTrait<OpTrait::SymbolTable>() &&
4334 }
static OperandRange getSuccessorOperands(Block *block, unsigned successorIndex)
Return the operand range used to transfer operands from block to its successor with the given index.
Definition: CFGToSCF.cpp:142
static Value getBase(Value v)
Looks through known "view-like" ops to find the base memref.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static int parseOptionalKeywordAlternative(OpAsmParser &parser, ArrayRef< StringRef > keywords)
Definition: LLVMDialect.cpp:99
LogicalResult verifyCallOpVarCalleeType(OpTy callOp)
Verify that the parameter and return types of the variadic callee type match the callOp argument and ...
LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data)
static ParseResult parseGEPIndices(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &indices, DenseI32ArrayAttr &rawConstantIndices)
static LogicalResult verifyOperandBundles(OpType &op)
static ParseResult parseCmpOp(OpAsmParser &parser, OperationState &result)
LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, ArrayRef< AtomicOrdering > unsupportedOrderings)
Verifies the attributes and the type of atomic memory access operations.
static void printOneOpBundle(OpAsmPrinter &p, OperandRange operands, TypeRange operandTypes, StringRef tag)
static LogicalResult verifyComdat(Operation *op, std::optional< SymbolRefAttr > attr)
static RetTy parseOptionalLLVMKeyword(OpAsmParser &parser, OperationState &result, EnumTy defaultValue)
Parse an enum from the keyword, or default to the provided default value.
static LLVMFunctionType getLLVMFuncType(MLIRContext *context, TypeRange results, ValueRange args)
Constructs a LLVMFunctionType from MLIR results and args.
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, OperandRangeRange caseOperands, const TypeRangeRange &caseOperandTypes)
static ParseResult resolveOpBundleOperands(OpAsmParser &parser, SMLoc loc, OperationState &state, ArrayRef< SmallVector< OpAsmParser::UnresolvedOperand >> opBundleOperands, ArrayRef< SmallVector< Type >> opBundleOperandTypes, StringAttr opBundleSizesAttrName)
static ParseResult parseOptionalCallFuncPtr(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands)
Parses an optional function pointer operand before the call argument list for indirect calls,...
static bool isZeroAttribute(Attribute value)
static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, OperandRange indices, DenseI32ArrayAttr rawConstantIndices)
static LLVMStructType getValAndBoolStructType(Type valType)
Returns an LLVM struct type that contains a value type and a boolean type.
static void printOpBundles(OpAsmPrinter &p, Operation *op, OperandRangeRange opBundleOperands, TypeRangeRange opBundleOperandTypes, std::optional< ArrayAttr > opBundleTags)
static void printShuffleType(AsmPrinter &printer, Operation *op, Type v1Type, Type resType, DenseI32ArrayAttr mask)
Nothing to do when the result type is inferred.
static LogicalResult verifyBlockTags(LLVMFuncOp funcOp)
static Type buildLLVMFunctionType(OpAsmParser &parser, SMLoc loc, ArrayRef< Type > inputs, ArrayRef< Type > outputs, function_interface_impl::VariadicFlag variadicFlag)
static auto processFMFAttr(ArrayRef< NamedAttribute > attrs)
Definition: LLVMDialect.cpp:58
static std::optional< ParseResult > parseOpBundles(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand >> &opBundleOperands, SmallVector< SmallVector< Type >> &opBundleOperandTypes, ArrayAttr &opBundleTags)
static Operation * parentLLVMModule(Operation *op)
static ParseResult parseSwitchOpCases(OpAsmParser &parser, Type flagType, DenseIntElementsAttr &caseValues, SmallVectorImpl< Block * > &caseDestinations, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &caseOperands, SmallVectorImpl< SmallVector< Type >> &caseOperandTypes)
<cases> ::= [ (case (, case )* )? ] <case> ::= integer : bb-id (( ssa-use-and-type-list ))?
static TypeAttr getCallOpVarCalleeType(LLVMFunctionType calleeType)
Gets the variadic callee type for a LLVMFunctionType.
static Type getInsertExtractValueElementType(function_ref< InFlightDiagnostic(StringRef)> emitError, Type containerType, ArrayRef< int64_t > position)
Extract the type at position in the LLVM IR aggregate type containerType.
static ParseResult parseOneOpBundle(OpAsmParser &p, SmallVector< SmallVector< OpAsmParser::UnresolvedOperand >> &opBundleOperands, SmallVector< SmallVector< Type >> &opBundleOperandTypes, SmallVector< Attribute > &opBundleTags)
static ParseResult parseIndirectBrOpSucessors(OpAsmParser &parser, Type &flagType, SmallVectorImpl< Block * > &succOperandBlocks, SmallVectorImpl< SmallVector< OpAsmParser::UnresolvedOperand >> &succOperands, SmallVectorImpl< SmallVector< Type >> &succOperandsTypes)
static void printIndirectBrOpSucessors(OpAsmPrinter &p, IndirectBrOp op, Type flagType, SuccessorRange succs, OperandRangeRange succOperands, const TypeRangeRange &succOperandsTypes)
static LogicalResult verifyStructArrayConstant(LLVM::ConstantOp op, LLVM::LLVMArrayType arrayType, ArrayAttr arrayAttr, int dim)
Verifies the constant array represented by arrayAttr matches the provided arrayType.
static ParseResult parseCallTypeAndResolveOperands(OpAsmParser &parser, OperationState &result, bool isDirect, ArrayRef< OpAsmParser::UnresolvedOperand > operands, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses the type of a call operation and resolves the operands if the parsing succeeds.
static LogicalResult verifySymbolAttrUse(FlatSymbolRefAttr symbol, Operation *op, SymbolTableCollection &symbolTable)
Verifies symbol's use in op to ensure the symbol is a valid and fully defined llvm....
Definition: LLVMDialect.cpp:73
static Type extractVectorElementType(Type type)
Returns the elemental type of any LLVM-compatible vector type or self.
static bool hasScalableVectorType(Type t)
Check if the given type is a scalable vector type or a vector/array type that contains a nested scala...
static OpFoldResult foldChainableCast(T castOp, typename T::FoldAdaptor adaptor)
Folds a cast op that can be chained.
static void destructureIndices(Type currType, ArrayRef< GEPArg > indices, SmallVectorImpl< int32_t > &rawConstantIndices, SmallVectorImpl< Value > &dynamicIndices)
Destructures the 'indices' parameter into 'rawConstantIndices' and 'dynamicIndices',...
static ParseResult parseCommonGlobalAndAlias(OpAsmParser &parser, OperationState &result)
Parse common attributes that might show up in the same order in both GlobalOp and AliasOp.
static SmallVector< Type, 1 > getCallOpResultTypes(LLVMFunctionType calleeType)
Gets the MLIR Op-like result types of a LLVMFunctionType.
static Type getI1SameShape(Type type)
Returns a boolean type that has the same shape as type.
Definition: LLVMDialect.cpp:89
static void printCommonGlobalAndAlias(OpAsmPrinter &p, OpType op)
static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value)
Returns a scalar or vector boolean attribute of the given type.
static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee)
Verify that an inlinable callsite of a debug-info-bearing function in a debug-info-bearing function h...
static ParseResult parseShuffleType(AsmParser &parser, Type v1Type, Type &resType, DenseI32ArrayAttr mask)
Build the result type of a shuffle vector operation.
static LogicalResult verifyExtOp(ExtOp op)
Verifies that the given extension operation operates on consistent scalars or vectors,...
static constexpr const char kElemTypeAttrName[]
Definition: LLVMDialect.cpp:56
static ParseResult parseInsertExtractValueElementType(AsmParser &parser, Type &valueType, Type containerType, DenseI64ArrayAttr position)
Infer the value type from the container type and position.
static LogicalResult verifyStructIndices(Type baseGEPType, unsigned indexPos, GEPIndicesAdaptor< ValueRange > indices, function_ref< InFlightDiagnostic()> emitOpError)
For the given indices, check if they comply with baseGEPType, especially check against LLVMStructType...
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static bool isTypeCompatibleWithAtomicOp(Type type, const DataLayout &dataLayout)
Returns true if the given type is supported by atomic operations.
static void printInsertExtractValueElementType(AsmPrinter &printer, Operation *op, Type valueType, Type containerType, DenseI64ArrayAttr position)
Nothing to print for an inferred type.
#define REGISTER_ENUM_TYPE(Ty)
@ None
static std::string diag(const llvm::Value &value)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static VectorType getVectorType(Type scalarTy, const VectorizationStrategy *strategy)
Returns the vector type resulting from applying the provided vectorization strategy on the scalar typ...
This base class exposes generic asm parser hooks, usable across the various derived parsers.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
@ Square
Square brackets surrounding zero or more operands.
virtual OptionalParseResult parseOptionalInteger(APInt &result)=0
Parse an optional integer value from the stream.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseOptionalColonTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional colon followed by a type list, which if present must have at least one type.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalRParen()=0
Parse a ) token if present.
virtual ParseResult parseCustomAttributeWithFallback(Attribute &result, Type type, function_ref< ParseResult(Attribute &result, Type type)> parseAttribute)=0
Parse a custom attribute with the provided callback, unless the next token is #, in which case the ge...
ParseResult parseString(std::string *string)
Parse a quoted string token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalRSquare()=0
Parse a ] token if present.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printString(StringRef string)
Print the given string as a quoted string, escaping any special or non-printable characters in it.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides an abstraction over the different types of ranges over Blocks.
Definition: BlockSupport.h:106
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
Operation & front()
Definition: Block.h:153
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
UnitAttr getUnitAttr()
Definition: Builders.cpp:96
IntegerAttr getI32IntegerAttr(int32_t value)
Definition: Builders.cpp:198
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:161
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:110
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:89
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:260
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
MLIRContext * getContext() const
Definition: Builders.h:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:96
The main mechanism for performing data layout queries.
static DataLayout closest(Operation *op)
Returns the layout of the closest parent operation carrying layout info.
llvm::TypeSize getTypeSizeInBits(Type t) const
Returns the size in bits of the given type in the current scope.
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
auto getValues() const
Return the held element values as a range of the given type.
Type getElementType() const
Return the element type of this DenseElementsAttr.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A symbol reference with a reference path containing a single element.
StringRef getValue() const
Returns the name of the held symbol reference.
StringAttr getAttr() const
Returns the name of the held symbol reference as a StringAttr.
This class represents a fused location whose metadata is known to be an instance of the given type.
Definition: Location.h:149
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
Definition: Diagnostics.h:352
Class used for building a 'llvm.getelementptr'.
Definition: LLVMDialect.h:76
Class used for convenient access and iteration over GEP indices.
Definition: LLVMDialect.h:123
bool isDynamicIndex(size_t index) const
Returns whether the GEP index at the given position is a dynamic index.
Definition: LLVMDialect.h:152
size_t size() const
Returns the amount of indices of the GEPOp.
Definition: LLVMDialect.h:157
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class provides a mutable adaptor for a range of operands.
Definition: ValueRange.h:118
std::optional< NamedAttribute > getNamed(StringRef name) const
Return the specified named attribute if present, std::nullopt otherwise.
Attribute erase(StringAttr name)
Erase the attribute with the given name from the list.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseSuccessor(Block *&dest)=0
Parse a single operation successor.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
virtual ParseResult parseSuccessorAndUseList(Block *&dest, SmallVectorImpl< Value > &operands)=0
Parse a single operation successor and its operand list.
virtual OptionalParseResult parseOptionalRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region if present.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printSuccessorAndUseList(Block *successor, ValueRange succOperands)=0
Print the successor and its operands.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Listener * getListener() const
Returns the current listener of this builder, or nullptr if this builder doesn't have a listener.
Definition: Builders.h:318
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:428
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition: Builders.h:440
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class provides the API for ops that are known to be isolated from above.
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:451
This class represents a contiguous range of operand ranges, e.g.
Definition: ValueRange.h:84
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
type_range getTypes() const
Definition: ValueRange.cpp:28
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
bool empty()
Definition: Region.h:60
Block & emplaceBlock()
Definition: Region.h:46
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
This class models how operands are forwarded to block arguments in control flow.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
virtual Operation * lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol)
Look up a symbol with the specified name within the specified symbol table operation,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction for a range of TypeRange.
Definition: TypeRange.h:95
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A utility result that is used to signal how to proceed with an ongoing walk:
Definition: Visitors.h:33
static WalkResult skip()
Definition: Visitors.h:52
static WalkResult advance()
Definition: Visitors.h:51
bool wasInterrupted() const
Returns true if the walk was interrupted.
Definition: Visitors.h:55
static WalkResult interrupt()
Definition: Visitors.h:50
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
A named class for passing around the variadic flag.
Value createGlobalString(Location loc, OpBuilder &builder, StringRef name, StringRef value, Linkage linkage)
Create an LLVM global containing the string "value" at the module containing surrounding the insertio...
Type getVectorType(Type elementType, unsigned numElements, bool isScalable=false)
Creates an LLVM dialect-compatible vector type with the given element type and length.
Definition: LLVMTypes.cpp:841
bool isScalableVectorType(Type vectorType)
Returns whether a vector type is scalable or not.
Definition: LLVMTypes.cpp:835
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:814
bool isCompatibleOuterType(Type type)
Returns true if the given outer type is compatible with the LLVM dialect without checking its potenti...
Definition: LLVMTypes.cpp:705
bool satisfiesLLVMModule(Operation *op)
LLVM requires some operations to be inside of a Module operation.
constexpr int kGEPConstantBitWidth
Bit-width of a 'GEPConstantIndex' within GEPArg.
Definition: LLVMDialect.h:67
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:796
bool isCompatibleFloatingPointType(Type type)
Returns true if the given type is a floating-point type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:809
llvm::ElementCount getVectorNumElements(Type type)
Returns the element count of any LLVM-compatible vector type.
Definition: LLVMTypes.cpp:827
void printFunctionSignature(OpAsmPrinter &p, TypeRange argTypes, ArrayAttr argAttrs, bool isVariadic, TypeRange resultTypes, ArrayAttr resultAttrs, Region *body=nullptr, bool printEmptyResult=true)
Print a function signature for a call or callable operation.
ParseResult parseFunctionSignature(OpAsmParser &parser, SmallVectorImpl< Type > &argTypes, SmallVectorImpl< DictionaryAttr > &argAttrs, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs, bool mustParseEmptyResult=true)
Parses a function signature using parser.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
ParseResult parseFunctionSignatureWithArguments(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl< OpAsmParser::Argument > &arguments, bool &isVariadic, SmallVectorImpl< Type > &resultTypes, SmallVectorImpl< DictionaryAttr > &resultAttrs)
Parses a function signature using parser.
void printFunctionAttributes(OpAsmPrinter &p, Operation *op, ArrayRef< StringRef > elided={})
Prints the list of function prefixed with the "attributes" keyword.
void printFunctionSignature(OpAsmPrinter &p, FunctionOpInterface op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes)
Prints the signature of the function-like operation op.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
Definition: CodegenUtils.h:331
Visibility
This enum describes C++ inheritance visibility.
Definition: Class.h:408
std::string stringify(T &&t)
Generically convert a value to a std::string.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition: Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition: Matchers.h:527
const FrozenRewritePatternSet GreedyRewriteConfig bool * changed
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition: Matchers.h:369
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
A callable is either a symbol, or an SSA value, that is referenced by a call-like operation.
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
T & getOrAddProperties()
Get (or create) a properties of the provided type to be set on the operation on creation.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addSuccessors(Block *successor)
Adds a successor to the operation sate. successor must not be null.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.