MLIR  20.0.0git
AsmPrinter.cpp
Go to the documentation of this file.
1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinDialect.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/Dialect.h"
26 #include "mlir/IR/IntegerSet.h"
27 #include "mlir/IR/MLIRContext.h"
29 #include "mlir/IR/Operation.h"
30 #include "mlir/IR/Verifier.h"
31 #include "llvm/ADT/APFloat.h"
32 #include "llvm/ADT/ArrayRef.h"
33 #include "llvm/ADT/DenseMap.h"
34 #include "llvm/ADT/MapVector.h"
35 #include "llvm/ADT/STLExtras.h"
36 #include "llvm/ADT/ScopeExit.h"
37 #include "llvm/ADT/ScopedHashTable.h"
38 #include "llvm/ADT/SetVector.h"
39 #include "llvm/ADT/SmallString.h"
40 #include "llvm/ADT/StringExtras.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/CommandLine.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/Endian.h"
46 #include "llvm/Support/ManagedStatic.h"
47 #include "llvm/Support/Regex.h"
48 #include "llvm/Support/SaveAndRestore.h"
49 #include "llvm/Support/Threading.h"
50 #include "llvm/Support/raw_ostream.h"
51 #include <type_traits>
52 
53 #include <optional>
54 #include <tuple>
55 
56 using namespace mlir;
57 using namespace mlir::detail;
58 
59 #define DEBUG_TYPE "mlir-asm-printer"
60 
61 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
62 
63 void OperationName::dump() const { print(llvm::errs()); }
64 
65 //===--------------------------------------------------------------------===//
66 // AsmParser
67 //===--------------------------------------------------------------------===//
68 
69 AsmParser::~AsmParser() = default;
71 OpAsmParser::~OpAsmParser() = default;
72 
73 MLIRContext *AsmParser::getContext() const { return getBuilder().getContext(); }
74 
75 /// Parse a type list.
76 /// This is out-of-line to work-around
77 /// https://github.com/llvm/llvm-project/issues/62918
80  [&]() { return parseType(result.emplace_back()); });
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // DialectAsmPrinter
85 //===----------------------------------------------------------------------===//
86 
88 
89 //===----------------------------------------------------------------------===//
90 // OpAsmPrinter
91 //===----------------------------------------------------------------------===//
92 
93 OpAsmPrinter::~OpAsmPrinter() = default;
94 
96  auto &os = getStream();
97  os << '(';
98  llvm::interleaveComma(op->getOperands(), os, [&](Value operand) {
99  // Print the types of null values as <<NULL TYPE>>.
100  *this << (operand ? operand.getType() : Type());
101  });
102  os << ") -> ";
103 
104  // Print the result list. We don't parenthesize single result types unless
105  // it is a function (avoiding a grammar ambiguity).
106  bool wrapped = op->getNumResults() != 1;
107  if (!wrapped && op->getResult(0).getType() &&
108  llvm::isa<FunctionType>(op->getResult(0).getType()))
109  wrapped = true;
110 
111  if (wrapped)
112  os << '(';
113 
114  llvm::interleaveComma(op->getResults(), os, [&](const OpResult &result) {
115  // Print the types of null values as <<NULL TYPE>>.
116  *this << (result ? result.getType() : Type());
117  });
118 
119  if (wrapped)
120  os << ')';
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Operation OpAsm interface.
125 //===----------------------------------------------------------------------===//
126 
127 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
128 #include "mlir/IR/OpAsmInterface.cpp.inc"
129 
130 LogicalResult
132  return entry.emitError() << "unknown 'resource' key '" << entry.getKey()
133  << "' for dialect '" << getDialect()->getNamespace()
134  << "'";
135 }
136 
137 //===----------------------------------------------------------------------===//
138 // OpPrintingFlags
139 //===----------------------------------------------------------------------===//
140 
141 namespace {
142 /// This struct contains command line options that can be used to initialize
143 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
144 /// for global command line options.
145 struct AsmPrinterOptions {
146  llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
147  "mlir-print-elementsattrs-with-hex-if-larger",
148  llvm::cl::desc(
149  "Print DenseElementsAttrs with a hex string that have "
150  "more elements than the given upper limit (use -1 to disable)")};
151 
152  llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
153  "mlir-elide-elementsattrs-if-larger",
154  llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
155  "more elements than the given upper limit")};
156 
157  llvm::cl::opt<unsigned> elideResourceStringsIfLarger{
158  "mlir-elide-resource-strings-if-larger",
159  llvm::cl::desc(
160  "Elide printing value of resources if string is too long in chars.")};
161 
162  llvm::cl::opt<bool> printDebugInfoOpt{
163  "mlir-print-debuginfo", llvm::cl::init(false),
164  llvm::cl::desc("Print debug info in MLIR output")};
165 
166  llvm::cl::opt<bool> printPrettyDebugInfoOpt{
167  "mlir-pretty-debuginfo", llvm::cl::init(false),
168  llvm::cl::desc("Print pretty debug info in MLIR output")};
169 
170  // Use the generic op output form in the operation printer even if the custom
171  // form is defined.
172  llvm::cl::opt<bool> printGenericOpFormOpt{
173  "mlir-print-op-generic", llvm::cl::init(false),
174  llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
175 
176  llvm::cl::opt<bool> assumeVerifiedOpt{
177  "mlir-print-assume-verified", llvm::cl::init(false),
178  llvm::cl::desc("Skip op verification when using custom printers"),
179  llvm::cl::Hidden};
180 
181  llvm::cl::opt<bool> printLocalScopeOpt{
182  "mlir-print-local-scope", llvm::cl::init(false),
183  llvm::cl::desc("Print with local scope and inline information (eliding "
184  "aliases for attributes, types, and locations")};
185 
186  llvm::cl::opt<bool> skipRegionsOpt{
187  "mlir-print-skip-regions", llvm::cl::init(false),
188  llvm::cl::desc("Skip regions when printing ops.")};
189 
190  llvm::cl::opt<bool> printValueUsers{
191  "mlir-print-value-users", llvm::cl::init(false),
192  llvm::cl::desc(
193  "Print users of operation results and block arguments as a comment")};
194 
195  llvm::cl::opt<bool> printUniqueSSAIDs{
196  "mlir-print-unique-ssa-ids", llvm::cl::init(false),
197  llvm::cl::desc("Print unique SSA ID numbers for values, block arguments "
198  "and naming conflicts across all regions")};
199 
200  llvm::cl::opt<bool> useNameLocAsPrefix{
201  "mlir-use-nameloc-as-prefix", llvm::cl::init(false),
202  llvm::cl::desc("Print SSA IDs using NameLocs as prefixes")};
203 };
204 } // namespace
205 
206 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
207 
208 /// Register a set of useful command-line options that can be used to configure
209 /// various flags within the AsmPrinter.
211  // Make sure that the options struct has been initialized.
212  *clOptions;
213 }
214 
215 /// Initialize the printing flags with default supplied by the cl::opts above.
217  : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
218  printGenericOpFormFlag(false), skipRegionsFlag(false),
219  assumeVerifiedFlag(false), printLocalScope(false),
220  printValueUsersFlag(false), printUniqueSSAIDsFlag(false),
221  useNameLocAsPrefix(false) {
222  // Initialize based upon command line options, if they are available.
223  if (!clOptions.isConstructed())
224  return;
225  if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
226  elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
227  if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences())
228  elementsAttrHexElementLimit =
229  clOptions->printElementsAttrWithHexIfLarger.getValue();
230  if (clOptions->elideResourceStringsIfLarger.getNumOccurrences())
231  resourceStringCharLimit = clOptions->elideResourceStringsIfLarger;
232  printDebugInfoFlag = clOptions->printDebugInfoOpt;
233  printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
234  printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
235  assumeVerifiedFlag = clOptions->assumeVerifiedOpt;
236  printLocalScope = clOptions->printLocalScopeOpt;
237  skipRegionsFlag = clOptions->skipRegionsOpt;
238  printValueUsersFlag = clOptions->printValueUsers;
239  printUniqueSSAIDsFlag = clOptions->printUniqueSSAIDs;
240  useNameLocAsPrefix = clOptions->useNameLocAsPrefix;
241 }
242 
243 /// Enable the elision of large elements attributes, by printing a '...'
244 /// instead of the element data, when the number of elements is greater than
245 /// `largeElementLimit`. Note: The IR generated with this option is not
246 /// parsable.
248 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
249  elementsAttrElementLimit = largeElementLimit;
250  return *this;
251 }
252 
255  elementsAttrHexElementLimit = largeElementLimit;
256  return *this;
257 }
258 
260 OpPrintingFlags::elideLargeResourceString(int64_t largeResourceLimit) {
261  resourceStringCharLimit = largeResourceLimit;
262  return *this;
263 }
264 
265 /// Enable printing of debug information. If 'prettyForm' is set to true,
266 /// debug information is printed in a more readable 'pretty' form.
268  bool prettyForm) {
269  printDebugInfoFlag = enable;
270  printDebugInfoPrettyFormFlag = prettyForm;
271  return *this;
272 }
273 
274 /// Always print operations in the generic form.
276  printGenericOpFormFlag = enable;
277  return *this;
278 }
279 
280 /// Always skip Regions.
282  skipRegionsFlag = skip;
283  return *this;
284 }
285 
286 /// Do not verify the operation when using custom operation printers.
288  assumeVerifiedFlag = enable;
289  return *this;
290 }
291 
292 /// Use local scope when printing the operation. This allows for using the
293 /// printer in a more localized and thread-safe setting, but may not necessarily
294 /// be identical of what the IR will look like when dumping the full module.
296  printLocalScope = enable;
297  return *this;
298 }
299 
300 /// Print users of values as comments.
302  printValueUsersFlag = enable;
303  return *this;
304 }
305 
306 /// Print unique SSA ID numbers for values, block arguments and naming conflicts
307 /// across all regions
309  printUniqueSSAIDsFlag = enable;
310  return *this;
311 }
312 
313 /// Return if the given ElementsAttr should be elided.
314 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
315  return elementsAttrElementLimit &&
316  *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
317  !llvm::isa<SplatElementsAttr>(attr);
318 }
319 
320 /// Return if the given ElementsAttr should be printed as hex string.
321 bool OpPrintingFlags::shouldPrintElementsAttrWithHex(ElementsAttr attr) const {
322  // -1 is used to disable hex printing.
323  return (elementsAttrHexElementLimit != -1) &&
324  (elementsAttrHexElementLimit < int64_t(attr.getNumElements())) &&
325  !llvm::isa<SplatElementsAttr>(attr);
326 }
327 
328 /// Return the size limit for printing large ElementsAttr.
329 std::optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
330  return elementsAttrElementLimit;
331 }
332 
333 /// Return the size limit for printing large ElementsAttr as hex string.
335  return elementsAttrHexElementLimit;
336 }
337 
338 /// Return the size limit for printing large ElementsAttr.
339 std::optional<uint64_t> OpPrintingFlags::getLargeResourceStringLimit() const {
340  return resourceStringCharLimit;
341 }
342 
343 /// Return if debug information should be printed.
345  return printDebugInfoFlag;
346 }
347 
348 /// Return if debug information should be printed in the pretty form.
350  return printDebugInfoPrettyFormFlag;
351 }
352 
353 /// Return if operations should be printed in the generic form.
355  return printGenericOpFormFlag;
356 }
357 
358 /// Return if Region should be skipped.
359 bool OpPrintingFlags::shouldSkipRegions() const { return skipRegionsFlag; }
360 
361 /// Return if operation verification should be skipped.
363  return assumeVerifiedFlag;
364 }
365 
366 /// Return if the printer should use local scope when dumping the IR.
367 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
368 
369 /// Return if the printer should print users of values.
371  return printValueUsersFlag;
372 }
373 
374 /// Return if the printer should use unique IDs.
376  return printUniqueSSAIDsFlag || shouldPrintGenericOpForm();
377 }
378 
379 /// Return if the printer should use NameLocs as prefixes when printing SSA IDs.
381  return useNameLocAsPrefix;
382 }
383 
384 //===----------------------------------------------------------------------===//
385 // NewLineCounter
386 //===----------------------------------------------------------------------===//
387 
388 namespace {
389 /// This class is a simple formatter that emits a new line when inputted into a
390 /// stream, that enables counting the number of newlines emitted. This class
391 /// should be used whenever emitting newlines in the printer.
392 struct NewLineCounter {
393  unsigned curLine = 1;
394 };
395 
396 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
397  ++newLine.curLine;
398  return os << '\n';
399 }
400 } // namespace
401 
402 //===----------------------------------------------------------------------===//
403 // AsmPrinter::Impl
404 //===----------------------------------------------------------------------===//
405 
406 namespace mlir {
408 public:
409  Impl(raw_ostream &os, AsmStateImpl &state);
410  explicit Impl(Impl &other) : Impl(other.os, other.state) {}
411 
412  /// Returns the output stream of the printer.
413  raw_ostream &getStream() { return os; }
414 
415  template <typename Container, typename UnaryFunctor>
416  inline void interleaveComma(const Container &c, UnaryFunctor eachFn) const {
417  llvm::interleaveComma(c, os, eachFn);
418  }
419 
420  /// This enum describes the different kinds of elision for the type of an
421  /// attribute when printing it.
422  enum class AttrTypeElision {
423  /// The type must not be elided,
424  Never,
425  /// The type may be elided when it matches the default used in the parser
426  /// (for example i64 is the default for integer attributes).
427  May,
428  /// The type must be elided.
429  Must
430  };
431 
432  /// Print the given attribute or an alias.
433  void printAttribute(Attribute attr,
435  /// Print the given attribute without considering an alias.
436  void printAttributeImpl(Attribute attr,
438 
439  /// Print the alias for the given attribute, return failure if no alias could
440  /// be printed.
441  LogicalResult printAlias(Attribute attr);
442 
443  /// Print the given type or an alias.
444  void printType(Type type);
445  /// Print the given type.
446  void printTypeImpl(Type type);
447 
448  /// Print the alias for the given type, return failure if no alias could
449  /// be printed.
450  LogicalResult printAlias(Type type);
451 
452  /// Print the given location to the stream. If `allowAlias` is true, this
453  /// allows for the internal location to use an attribute alias.
454  void printLocation(LocationAttr loc, bool allowAlias = false);
455 
456  /// Print a reference to the given resource that is owned by the given
457  /// dialect.
458  void printResourceHandle(const AsmDialectResourceHandle &resource);
459 
460  void printAffineMap(AffineMap map);
461  void
463  function_ref<void(unsigned, bool)> printValueName = nullptr);
464  void printAffineConstraint(AffineExpr expr, bool isEq);
465  void printIntegerSet(IntegerSet set);
466 
467  LogicalResult pushCyclicPrinting(const void *opaquePointer);
468 
469  void popCyclicPrinting();
470 
472 
473 protected:
475  ArrayRef<StringRef> elidedAttrs = {},
476  bool withKeyword = false);
478  void printTrailingLocation(Location loc, bool allowAlias = true);
479  void printLocationInternal(LocationAttr loc, bool pretty = false,
480  bool isTopLevel = false);
481 
482  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
483  /// used instead of individual elements when the elements attr is large.
484  void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
485 
486  /// Print a dense string elements attribute.
487  void printDenseStringElementsAttr(DenseStringElementsAttr attr);
488 
489  /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
490  /// used instead of individual elements when the elements attr is large.
492  bool allowHex);
493 
494  /// Print a dense array attribute.
495  void printDenseArrayAttr(DenseArrayAttr attr);
496 
497  void printDialectAttribute(Attribute attr);
498  void printDialectType(Type type);
499 
500  /// Print an escaped string, wrapped with "".
501  void printEscapedString(StringRef str);
502 
503  /// Print a hex string, wrapped with "".
504  void printHexString(StringRef str);
505  void printHexString(ArrayRef<char> data);
506 
507  /// This enum is used to represent the binding strength of the enclosing
508  /// context that an AffineExprStorage is being printed in, so we can
509  /// intelligently produce parens.
510  enum class BindingStrength {
511  Weak, // + and -
512  Strong, // All other binary operators.
513  };
515  AffineExpr expr, BindingStrength enclosingTightness,
516  function_ref<void(unsigned, bool)> printValueName = nullptr);
517 
518  /// The output stream for the printer.
519  raw_ostream &os;
520 
521  /// An underlying assembly printer state.
523 
524  /// A set of flags to control the printer's behavior.
526 
527  /// A tracker for the number of new lines emitted during printing.
528  NewLineCounter newLine;
529 };
530 } // namespace mlir
531 
532 //===----------------------------------------------------------------------===//
533 // AliasInitializer
534 //===----------------------------------------------------------------------===//
535 
536 namespace {
537 /// This class represents a specific instance of a symbol Alias.
538 class SymbolAlias {
539 public:
540  SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
541  bool isDeferrable)
542  : name(name), suffixIndex(suffixIndex), isType(isType),
543  isDeferrable(isDeferrable) {}
544 
545  /// Print this alias to the given stream.
546  void print(raw_ostream &os) const {
547  os << (isType ? "!" : "#") << name;
548  if (suffixIndex)
549  os << suffixIndex;
550  }
551 
552  /// Returns true if this is a type alias.
553  bool isTypeAlias() const { return isType; }
554 
555  /// Returns true if this alias supports deferred resolution when parsing.
556  bool canBeDeferred() const { return isDeferrable; }
557 
558 private:
559  /// The main name of the alias.
560  StringRef name;
561  /// The suffix index of the alias.
562  uint32_t suffixIndex : 30;
563  /// A flag indicating whether this alias is for a type.
564  bool isType : 1;
565  /// A flag indicating whether this alias may be deferred or not.
566  bool isDeferrable : 1;
567 
568 public:
569  /// Used to avoid printing incomplete aliases for recursive types.
570  bool isPrinted = false;
571 };
572 
573 /// This class represents a utility that initializes the set of attribute and
574 /// type aliases, without the need to store the extra information within the
575 /// main AliasState class or pass it around via function arguments.
576 class AliasInitializer {
577 public:
578  AliasInitializer(
580  llvm::BumpPtrAllocator &aliasAllocator)
581  : interfaces(interfaces), aliasAllocator(aliasAllocator),
582  aliasOS(aliasBuffer) {}
583 
584  void initialize(Operation *op, const OpPrintingFlags &printerFlags,
585  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
586 
587  /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
588  /// set to true if the originator of this attribute can resolve the alias
589  /// after parsing has completed (e.g. in the case of operation locations).
590  /// `elideType` indicates if the type of the attribute should be skipped when
591  /// looking for nested aliases. Returns the maximum alias depth of the
592  /// attribute, and the alias index of this attribute.
593  std::pair<size_t, size_t> visit(Attribute attr, bool canBeDeferred = false,
594  bool elideType = false) {
595  return visitImpl(attr, aliases, canBeDeferred, elideType);
596  }
597 
598  /// Visit the given type to see if it has an alias. `canBeDeferred` is
599  /// set to true if the originator of this attribute can resolve the alias
600  /// after parsing has completed. Returns the maximum alias depth of the type,
601  /// and the alias index of this type.
602  std::pair<size_t, size_t> visit(Type type, bool canBeDeferred = false) {
603  return visitImpl(type, aliases, canBeDeferred);
604  }
605 
606 private:
607  struct InProgressAliasInfo {
608  InProgressAliasInfo()
609  : aliasDepth(0), isType(false), canBeDeferred(false) {}
610  InProgressAliasInfo(StringRef alias)
611  : alias(alias), aliasDepth(1), isType(false), canBeDeferred(false) {}
612 
613  bool operator<(const InProgressAliasInfo &rhs) const {
614  // Order first by depth, then by attr/type kind, and then by name.
615  if (aliasDepth != rhs.aliasDepth)
616  return aliasDepth < rhs.aliasDepth;
617  if (isType != rhs.isType)
618  return isType;
619  return alias < rhs.alias;
620  }
621 
622  /// The alias for the attribute or type, or std::nullopt if the value has no
623  /// alias.
624  std::optional<StringRef> alias;
625  /// The alias depth of this attribute or type, i.e. an indication of the
626  /// relative ordering of when to print this alias.
627  unsigned aliasDepth : 30;
628  /// If this alias represents a type or an attribute.
629  bool isType : 1;
630  /// If this alias can be deferred or not.
631  bool canBeDeferred : 1;
632  /// Indices for child aliases.
633  SmallVector<size_t> childIndices;
634  };
635 
636  /// Visit the given attribute or type to see if it has an alias.
637  /// `canBeDeferred` is set to true if the originator of this value can resolve
638  /// the alias after parsing has completed (e.g. in the case of operation
639  /// locations). Returns the maximum alias depth of the value, and its alias
640  /// index.
641  template <typename T, typename... PrintArgs>
642  std::pair<size_t, size_t>
643  visitImpl(T value,
644  llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
645  bool canBeDeferred, PrintArgs &&...printArgs);
646 
647  /// Mark the given alias as non-deferrable.
648  void markAliasNonDeferrable(size_t aliasIndex);
649 
650  /// Try to generate an alias for the provided symbol. If an alias is
651  /// generated, the provided alias mapping and reverse mapping are updated.
652  template <typename T>
653  void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
654 
655  /// Given a collection of aliases and symbols, initialize a mapping from a
656  /// symbol to a given alias.
657  static void initializeAliases(
658  llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
659  llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
660 
661  /// The set of asm interfaces within the context.
663 
664  /// An allocator used for alias names.
665  llvm::BumpPtrAllocator &aliasAllocator;
666 
667  /// The set of built aliases.
668  llvm::MapVector<const void *, InProgressAliasInfo> aliases;
669 
670  /// Storage and stream used when generating an alias.
671  SmallString<32> aliasBuffer;
672  llvm::raw_svector_ostream aliasOS;
673 };
674 
675 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
676 /// and merely collects the attributes and types that *would* be printed in a
677 /// normal print invocation so that we can generate proper aliases. This allows
678 /// for us to generate aliases only for the attributes and types that would be
679 /// in the output, and trims down unnecessary output.
680 class DummyAliasOperationPrinter : private OpAsmPrinter {
681 public:
682  explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
683  AliasInitializer &initializer)
684  : printerFlags(printerFlags), initializer(initializer) {}
685 
686  /// Prints the entire operation with the custom assembly form, if available,
687  /// or the generic assembly form, otherwise.
688  void printCustomOrGenericOp(Operation *op) override {
689  // Visit the operation location.
690  if (printerFlags.shouldPrintDebugInfo())
691  initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
692 
693  // If requested, always print the generic form.
694  if (!printerFlags.shouldPrintGenericOpForm()) {
695  op->getName().printAssembly(op, *this, /*defaultDialect=*/"");
696  return;
697  }
698 
699  // Otherwise print with the generic assembly form.
700  printGenericOp(op);
701  }
702 
703 private:
704  /// Print the given operation in the generic form.
705  void printGenericOp(Operation *op, bool printOpName = true) override {
706  // Consider nested operations for aliases.
707  if (!printerFlags.shouldSkipRegions()) {
708  for (Region &region : op->getRegions())
709  printRegion(region, /*printEntryBlockArgs=*/true,
710  /*printBlockTerminators=*/true);
711  }
712 
713  // Visit all the types used in the operation.
714  for (Type type : op->getOperandTypes())
715  printType(type);
716  for (Type type : op->getResultTypes())
717  printType(type);
718 
719  // Consider the attributes of the operation for aliases.
720  for (const NamedAttribute &attr : op->getAttrs())
721  printAttribute(attr.getValue());
722  }
723 
724  /// Print the given block. If 'printBlockArgs' is false, the arguments of the
725  /// block are not printed. If 'printBlockTerminator' is false, the terminator
726  /// operation of the block is not printed.
727  void print(Block *block, bool printBlockArgs = true,
728  bool printBlockTerminator = true) {
729  // Consider the types of the block arguments for aliases if 'printBlockArgs'
730  // is set to true.
731  if (printBlockArgs) {
732  for (BlockArgument arg : block->getArguments()) {
733  printType(arg.getType());
734 
735  // Visit the argument location.
736  if (printerFlags.shouldPrintDebugInfo())
737  // TODO: Allow deferring argument locations.
738  initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
739  }
740  }
741 
742  // Consider the operations within this block, ignoring the terminator if
743  // requested.
744  bool hasTerminator =
745  !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
746  auto range = llvm::make_range(
747  block->begin(),
748  std::prev(block->end(),
749  (!hasTerminator || printBlockTerminator) ? 0 : 1));
750  for (Operation &op : range)
751  printCustomOrGenericOp(&op);
752  }
753 
754  /// Print the given region.
755  void printRegion(Region &region, bool printEntryBlockArgs,
756  bool printBlockTerminators,
757  bool printEmptyBlock = false) override {
758  if (region.empty())
759  return;
760  if (printerFlags.shouldSkipRegions()) {
761  os << "{...}";
762  return;
763  }
764 
765  auto *entryBlock = &region.front();
766  print(entryBlock, printEntryBlockArgs, printBlockTerminators);
767  for (Block &b : llvm::drop_begin(region, 1))
768  print(&b);
769  }
770 
771  void printRegionArgument(BlockArgument arg, ArrayRef<NamedAttribute> argAttrs,
772  bool omitType) override {
773  printType(arg.getType());
774  // Visit the argument location.
775  if (printerFlags.shouldPrintDebugInfo())
776  // TODO: Allow deferring argument locations.
777  initializer.visit(arg.getLoc(), /*canBeDeferred=*/false);
778  }
779 
780  /// Consider the given type to be printed for an alias.
781  void printType(Type type) override { initializer.visit(type); }
782 
783  /// Consider the given attribute to be printed for an alias.
784  void printAttribute(Attribute attr) override { initializer.visit(attr); }
785  void printAttributeWithoutType(Attribute attr) override {
786  printAttribute(attr);
787  }
788  LogicalResult printAlias(Attribute attr) override {
789  initializer.visit(attr);
790  return success();
791  }
792  LogicalResult printAlias(Type type) override {
793  initializer.visit(type);
794  return success();
795  }
796 
797  /// Consider the given location to be printed for an alias.
798  void printOptionalLocationSpecifier(Location loc) override {
799  printAttribute(loc);
800  }
801 
802  /// Print the given set of attributes with names not included within
803  /// 'elidedAttrs'.
804  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
805  ArrayRef<StringRef> elidedAttrs = {}) override {
806  if (attrs.empty())
807  return;
808  if (elidedAttrs.empty()) {
809  for (const NamedAttribute &attr : attrs)
810  printAttribute(attr.getValue());
811  return;
812  }
813  llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
814  elidedAttrs.end());
815  for (const NamedAttribute &attr : attrs)
816  if (!elidedAttrsSet.contains(attr.getName().strref()))
817  printAttribute(attr.getValue());
818  }
819  void printOptionalAttrDictWithKeyword(
821  ArrayRef<StringRef> elidedAttrs = {}) override {
822  printOptionalAttrDict(attrs, elidedAttrs);
823  }
824 
825  /// Return a null stream as the output stream, this will ignore any data fed
826  /// to it.
827  raw_ostream &getStream() const override { return os; }
828 
829  /// The following are hooks of `OpAsmPrinter` that are not necessary for
830  /// determining potential aliases.
831  void printFloat(const APFloat &) override {}
832  void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
833  void printAffineExprOfSSAIds(AffineExpr, ValueRange, ValueRange) override {}
834  void printNewline() override {}
835  void increaseIndent() override {}
836  void decreaseIndent() override {}
837  void printOperand(Value) override {}
838  void printOperand(Value, raw_ostream &os) override {
839  // Users expect the output string to have at least the prefixed % to signal
840  // a value name. To maintain this invariant, emit a name even if it is
841  // guaranteed to go unused.
842  os << "%";
843  }
844  void printKeywordOrString(StringRef) override {}
845  void printString(StringRef) override {}
846  void printResourceHandle(const AsmDialectResourceHandle &) override {}
847  void printSymbolName(StringRef) override {}
848  void printSuccessor(Block *) override {}
849  void printSuccessorAndUseList(Block *, ValueRange) override {}
850  void shadowRegionArgs(Region &, ValueRange) override {}
851 
852  /// The printer flags to use when determining potential aliases.
853  const OpPrintingFlags &printerFlags;
854 
855  /// The initializer to use when identifying aliases.
856  AliasInitializer &initializer;
857 
858  /// A dummy output stream.
859  mutable llvm::raw_null_ostream os;
860 };
861 
862 class DummyAliasDialectAsmPrinter : public DialectAsmPrinter {
863 public:
864  explicit DummyAliasDialectAsmPrinter(AliasInitializer &initializer,
865  bool canBeDeferred,
866  SmallVectorImpl<size_t> &childIndices)
867  : initializer(initializer), canBeDeferred(canBeDeferred),
868  childIndices(childIndices) {}
869 
870  /// Print the given attribute/type, visiting any nested aliases that would be
871  /// generated as part of printing. Returns the maximum alias depth found while
872  /// printing the given value.
873  template <typename T, typename... PrintArgs>
874  size_t printAndVisitNestedAliases(T value, PrintArgs &&...printArgs) {
875  printAndVisitNestedAliasesImpl(value, printArgs...);
876  return maxAliasDepth;
877  }
878 
879 private:
880  /// Print the given attribute/type, visiting any nested aliases that would be
881  /// generated as part of printing.
882  void printAndVisitNestedAliasesImpl(Attribute attr, bool elideType) {
883  if (!isa<BuiltinDialect>(attr.getDialect())) {
884  attr.getDialect().printAttribute(attr, *this);
885 
886  // Process the builtin attributes.
887  } else if (llvm::isa<AffineMapAttr, DenseArrayAttr, FloatAttr, IntegerAttr,
888  IntegerSetAttr, UnitAttr>(attr)) {
889  return;
890  } else if (auto distinctAttr = dyn_cast<DistinctAttr>(attr)) {
891  printAttribute(distinctAttr.getReferencedAttr());
892  } else if (auto dictAttr = dyn_cast<DictionaryAttr>(attr)) {
893  for (const NamedAttribute &nestedAttr : dictAttr.getValue()) {
894  printAttribute(nestedAttr.getName());
895  printAttribute(nestedAttr.getValue());
896  }
897  } else if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
898  for (Attribute nestedAttr : arrayAttr.getValue())
899  printAttribute(nestedAttr);
900  } else if (auto typeAttr = dyn_cast<TypeAttr>(attr)) {
901  printType(typeAttr.getValue());
902  } else if (auto locAttr = dyn_cast<OpaqueLoc>(attr)) {
903  printAttribute(locAttr.getFallbackLocation());
904  } else if (auto locAttr = dyn_cast<NameLoc>(attr)) {
905  if (!isa<UnknownLoc>(locAttr.getChildLoc()))
906  printAttribute(locAttr.getChildLoc());
907  } else if (auto locAttr = dyn_cast<CallSiteLoc>(attr)) {
908  printAttribute(locAttr.getCallee());
909  printAttribute(locAttr.getCaller());
910  } else if (auto locAttr = dyn_cast<FusedLoc>(attr)) {
911  if (Attribute metadata = locAttr.getMetadata())
912  printAttribute(metadata);
913  for (Location nestedLoc : locAttr.getLocations())
914  printAttribute(nestedLoc);
915  }
916 
917  // Don't print the type if we must elide it, or if it is a None type.
918  if (!elideType) {
919  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
920  Type attrType = typedAttr.getType();
921  if (!llvm::isa<NoneType>(attrType))
922  printType(attrType);
923  }
924  }
925  }
926  void printAndVisitNestedAliasesImpl(Type type) {
927  if (!isa<BuiltinDialect>(type.getDialect()))
928  return type.getDialect().printType(type, *this);
929 
930  // Only visit the layout of memref if it isn't the identity.
931  if (auto memrefTy = llvm::dyn_cast<MemRefType>(type)) {
932  printType(memrefTy.getElementType());
933  MemRefLayoutAttrInterface layout = memrefTy.getLayout();
934  if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity())
935  printAttribute(memrefTy.getLayout());
936  if (memrefTy.getMemorySpace())
937  printAttribute(memrefTy.getMemorySpace());
938  return;
939  }
940 
941  // For most builtin types, we can simply walk the sub elements.
942  auto visitFn = [&](auto element) {
943  if (element)
944  (void)printAlias(element);
945  };
946  type.walkImmediateSubElements(visitFn, visitFn);
947  }
948 
949  /// Consider the given type to be printed for an alias.
950  void printType(Type type) override {
951  recordAliasResult(initializer.visit(type, canBeDeferred));
952  }
953 
954  /// Consider the given attribute to be printed for an alias.
955  void printAttribute(Attribute attr) override {
956  recordAliasResult(initializer.visit(attr, canBeDeferred));
957  }
958  void printAttributeWithoutType(Attribute attr) override {
959  recordAliasResult(
960  initializer.visit(attr, canBeDeferred, /*elideType=*/true));
961  }
962  LogicalResult printAlias(Attribute attr) override {
963  printAttribute(attr);
964  return success();
965  }
966  LogicalResult printAlias(Type type) override {
967  printType(type);
968  return success();
969  }
970 
971  /// Record the alias result of a child element.
972  void recordAliasResult(std::pair<size_t, size_t> aliasDepthAndIndex) {
973  childIndices.push_back(aliasDepthAndIndex.second);
974  if (aliasDepthAndIndex.first > maxAliasDepth)
975  maxAliasDepth = aliasDepthAndIndex.first;
976  }
977 
978  /// Return a null stream as the output stream, this will ignore any data fed
979  /// to it.
980  raw_ostream &getStream() const override { return os; }
981 
982  /// The following are hooks of `DialectAsmPrinter` that are not necessary for
983  /// determining potential aliases.
984  void printFloat(const APFloat &) override {}
985  void printKeywordOrString(StringRef) override {}
986  void printString(StringRef) override {}
987  void printSymbolName(StringRef) override {}
988  void printResourceHandle(const AsmDialectResourceHandle &) override {}
989 
990  LogicalResult pushCyclicPrinting(const void *opaquePointer) override {
991  return success(cyclicPrintingStack.insert(opaquePointer));
992  }
993 
994  void popCyclicPrinting() override { cyclicPrintingStack.pop_back(); }
995 
996  /// Stack of potentially cyclic mutable attributes or type currently being
997  /// printed.
998  SetVector<const void *> cyclicPrintingStack;
999 
1000  /// The initializer to use when identifying aliases.
1001  AliasInitializer &initializer;
1002 
1003  /// If the aliases visited by this printer can be deferred.
1004  bool canBeDeferred;
1005 
1006  /// The indices of child aliases.
1007  SmallVectorImpl<size_t> &childIndices;
1008 
1009  /// The maximum alias depth found by the printer.
1010  size_t maxAliasDepth = 0;
1011 
1012  /// A dummy output stream.
1013  mutable llvm::raw_null_ostream os;
1014 };
1015 } // namespace
1016 
1017 /// Sanitize the given name such that it can be used as a valid identifier. If
1018 /// the string needs to be modified in any way, the provided buffer is used to
1019 /// store the new copy,
1020 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
1021  StringRef allowedPunctChars = "$._-",
1022  bool allowTrailingDigit = true) {
1023  assert(!name.empty() && "Shouldn't have an empty name here");
1024 
1025  auto validChar = [&](char ch) {
1026  return llvm::isAlnum(ch) || allowedPunctChars.contains(ch);
1027  };
1028 
1029  auto copyNameToBuffer = [&] {
1030  for (char ch : name) {
1031  if (validChar(ch))
1032  buffer.push_back(ch);
1033  else if (ch == ' ')
1034  buffer.push_back('_');
1035  else
1036  buffer.append(llvm::utohexstr((unsigned char)ch));
1037  }
1038  };
1039 
1040  // Check to see if this name is valid. If it starts with a digit, then it
1041  // could conflict with the autogenerated numeric ID's, so add an underscore
1042  // prefix to avoid problems.
1043  if (isdigit(name[0]) || (!validChar(name[0]) && name[0] != ' ')) {
1044  buffer.push_back('_');
1045  copyNameToBuffer();
1046  return buffer;
1047  }
1048 
1049  // If the name ends with a trailing digit, add a '_' to avoid potential
1050  // conflicts with autogenerated ID's.
1051  if (!allowTrailingDigit && isdigit(name.back())) {
1052  copyNameToBuffer();
1053  buffer.push_back('_');
1054  return buffer;
1055  }
1056 
1057  // Check to see that the name consists of only valid identifier characters.
1058  for (char ch : name) {
1059  if (!validChar(ch)) {
1060  copyNameToBuffer();
1061  return buffer;
1062  }
1063  }
1064 
1065  // If there are no invalid characters, return the original name.
1066  return name;
1067 }
1068 
1069 /// Given a collection of aliases and symbols, initialize a mapping from a
1070 /// symbol to a given alias.
1071 void AliasInitializer::initializeAliases(
1072  llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
1073  llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
1075  unprocessedAliases = visitedSymbols.takeVector();
1076  llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
1077  return lhs.second < rhs.second;
1078  });
1079 
1080  llvm::StringMap<unsigned> nameCounts;
1081  for (auto &[symbol, aliasInfo] : unprocessedAliases) {
1082  if (!aliasInfo.alias)
1083  continue;
1084  StringRef alias = *aliasInfo.alias;
1085  unsigned nameIndex = nameCounts[alias]++;
1086  symbolToAlias.insert(
1087  {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
1088  aliasInfo.canBeDeferred)});
1089  }
1090 }
1091 
1092 void AliasInitializer::initialize(
1093  Operation *op, const OpPrintingFlags &printerFlags,
1094  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
1095  // Use a dummy printer when walking the IR so that we can collect the
1096  // attributes/types that will actually be used during printing when
1097  // considering aliases.
1098  DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
1099  aliasPrinter.printCustomOrGenericOp(op);
1100 
1101  // Initialize the aliases.
1102  initializeAliases(aliases, attrTypeToAlias);
1103 }
1104 
1105 template <typename T, typename... PrintArgs>
1106 std::pair<size_t, size_t> AliasInitializer::visitImpl(
1107  T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
1108  bool canBeDeferred, PrintArgs &&...printArgs) {
1109  auto [it, inserted] =
1110  aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
1111  size_t aliasIndex = std::distance(aliases.begin(), it);
1112  if (!inserted) {
1113  // Make sure that the alias isn't deferred if we don't permit it.
1114  if (!canBeDeferred)
1115  markAliasNonDeferrable(aliasIndex);
1116  return {static_cast<size_t>(it->second.aliasDepth), aliasIndex};
1117  }
1118 
1119  // Try to generate an alias for this value.
1120  generateAlias(value, it->second, canBeDeferred);
1121  it->second.isType = std::is_base_of_v<Type, T>;
1122  it->second.canBeDeferred = canBeDeferred;
1123 
1124  // Print the value, capturing any nested elements that require aliases.
1125  SmallVector<size_t> childAliases;
1126  DummyAliasDialectAsmPrinter printer(*this, canBeDeferred, childAliases);
1127  size_t maxAliasDepth =
1128  printer.printAndVisitNestedAliases(value, printArgs...);
1129 
1130  // Make sure to recompute `it` in case the map was reallocated.
1131  it = std::next(aliases.begin(), aliasIndex);
1132 
1133  // If we had sub elements, update to account for the depth.
1134  it->second.childIndices = std::move(childAliases);
1135  if (maxAliasDepth)
1136  it->second.aliasDepth = maxAliasDepth + 1;
1137 
1138  // Propagate the alias depth of the value.
1139  return {(size_t)it->second.aliasDepth, aliasIndex};
1140 }
1141 
1142 void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) {
1143  auto *it = std::next(aliases.begin(), aliasIndex);
1144 
1145  // If already marked non-deferrable stop the recursion.
1146  // All children should already be marked non-deferrable as well.
1147  if (!it->second.canBeDeferred)
1148  return;
1149 
1150  it->second.canBeDeferred = false;
1151 
1152  // Propagate the non-deferrable flag to any child aliases.
1153  for (size_t childIndex : it->second.childIndices)
1154  markAliasNonDeferrable(childIndex);
1155 }
1156 
1157 template <typename T>
1158 void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
1159  bool canBeDeferred) {
1160  SmallString<32> nameBuffer;
1161  for (const auto &interface : interfaces) {
1163  interface.getAlias(symbol, aliasOS);
1165  continue;
1166  nameBuffer = std::move(aliasBuffer);
1167  assert(!nameBuffer.empty() && "expected valid alias name");
1169  break;
1170  }
1171 
1172  if (nameBuffer.empty())
1173  return;
1174 
1175  SmallString<16> tempBuffer;
1176  StringRef name =
1177  sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
1178  /*allowTrailingDigit=*/false);
1179  name = name.copy(aliasAllocator);
1180  alias = InProgressAliasInfo(name);
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // AliasState
1185 //===----------------------------------------------------------------------===//
1186 
1187 namespace {
1188 /// This class manages the state for type and attribute aliases.
1189 class AliasState {
1190 public:
1191  // Initialize the internal aliases.
1192  void
1193  initialize(Operation *op, const OpPrintingFlags &printerFlags,
1195 
1196  /// Get an alias for the given attribute if it has one and print it in `os`.
1197  /// Returns success if an alias was printed, failure otherwise.
1198  LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
1199 
1200  /// Get an alias for the given type if it has one and print it in `os`.
1201  /// Returns success if an alias was printed, failure otherwise.
1202  LogicalResult getAlias(Type ty, raw_ostream &os) const;
1203 
1204  /// Print all of the referenced aliases that can not be resolved in a deferred
1205  /// manner.
1206  void printNonDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1207  printAliases(p, newLine, /*isDeferred=*/false);
1208  }
1209 
1210  /// Print all of the referenced aliases that support deferred resolution.
1211  void printDeferredAliases(AsmPrinter::Impl &p, NewLineCounter &newLine) {
1212  printAliases(p, newLine, /*isDeferred=*/true);
1213  }
1214 
1215 private:
1216  /// Print all of the referenced aliases that support the provided resolution
1217  /// behavior.
1218  void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1219  bool isDeferred);
1220 
1221  /// Mapping between attribute/type and alias.
1222  llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
1223 
1224  /// An allocator used for alias names.
1225  llvm::BumpPtrAllocator aliasAllocator;
1226 };
1227 } // namespace
1228 
1229 void AliasState::initialize(
1230  Operation *op, const OpPrintingFlags &printerFlags,
1232  AliasInitializer initializer(interfaces, aliasAllocator);
1233  initializer.initialize(op, printerFlags, attrTypeToAlias);
1234 }
1235 
1236 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
1237  const auto *it = attrTypeToAlias.find(attr.getAsOpaquePointer());
1238  if (it == attrTypeToAlias.end())
1239  return failure();
1240  it->second.print(os);
1241  return success();
1242 }
1243 
1244 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
1245  const auto *it = attrTypeToAlias.find(ty.getAsOpaquePointer());
1246  if (it == attrTypeToAlias.end())
1247  return failure();
1248  if (!it->second.isPrinted)
1249  return failure();
1250 
1251  it->second.print(os);
1252  return success();
1253 }
1254 
1255 void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
1256  bool isDeferred) {
1257  auto filterFn = [=](const auto &aliasIt) {
1258  return aliasIt.second.canBeDeferred() == isDeferred;
1259  };
1260  for (auto &[opaqueSymbol, alias] :
1261  llvm::make_filter_range(attrTypeToAlias, filterFn)) {
1262  alias.print(p.getStream());
1263  p.getStream() << " = ";
1264 
1265  if (alias.isTypeAlias()) {
1266  Type type = Type::getFromOpaquePointer(opaqueSymbol);
1267  p.printTypeImpl(type);
1268  alias.isPrinted = true;
1269  } else {
1270  // TODO: Support nested aliases in mutable attributes.
1271  Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
1272  if (attr.hasTrait<AttributeTrait::IsMutable>())
1273  p.getStream() << attr;
1274  else
1275  p.printAttributeImpl(attr);
1276  }
1277 
1278  p.getStream() << newLine;
1279  }
1280 }
1281 
1282 //===----------------------------------------------------------------------===//
1283 // SSANameState
1284 //===----------------------------------------------------------------------===//
1285 
1286 namespace {
1287 /// Info about block printing: a number which is its position in the visitation
1288 /// order, and a name that is used to print reference to it, e.g. ^bb42.
1289 struct BlockInfo {
1290  int ordering;
1291  StringRef name;
1292 };
1293 
1294 /// This class manages the state of SSA value names.
1295 class SSANameState {
1296 public:
1297  /// A sentinel value used for values with names set.
1298  enum : unsigned { NameSentinel = ~0U };
1299 
1300  SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
1301  SSANameState() = default;
1302 
1303  /// Print the SSA identifier for the given value to 'stream'. If
1304  /// 'printResultNo' is true, it also presents the result number ('#' number)
1305  /// of this value.
1306  void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
1307 
1308  /// Print the operation identifier.
1309  void printOperationID(Operation *op, raw_ostream &stream) const;
1310 
1311  /// Return the result indices for each of the result groups registered by this
1312  /// operation, or empty if none exist.
1313  ArrayRef<int> getOpResultGroups(Operation *op);
1314 
1315  /// Get the info for the given block.
1316  BlockInfo getBlockInfo(Block *block);
1317 
1318  /// Renumber the arguments for the specified region to the same names as the
1319  /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
1320  /// details.
1321  void shadowRegionArgs(Region &region, ValueRange namesToUse);
1322 
1323 private:
1324  /// Number the SSA values within the given IR unit.
1325  void numberValuesInRegion(Region &region);
1326  void numberValuesInBlock(Block &block);
1327  void numberValuesInOp(Operation &op);
1328 
1329  /// Given a result of an operation 'result', find the result group head
1330  /// 'lookupValue' and the result of 'result' within that group in
1331  /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
1332  /// has more than 1 result.
1333  void getResultIDAndNumber(OpResult result, Value &lookupValue,
1334  std::optional<int> &lookupResultNo) const;
1335 
1336  /// Set a special value name for the given value.
1337  void setValueName(Value value, StringRef name);
1338 
1339  /// Uniques the given value name within the printer. If the given name
1340  /// conflicts, it is automatically renamed.
1341  StringRef uniqueValueName(StringRef name);
1342 
1343  /// This is the value ID for each SSA value. If this returns NameSentinel,
1344  /// then the valueID has an entry in valueNames.
1345  DenseMap<Value, unsigned> valueIDs;
1346  DenseMap<Value, StringRef> valueNames;
1347 
1348  /// When printing users of values, an operation without a result might
1349  /// be the user. This map holds ids for such operations.
1350  DenseMap<Operation *, unsigned> operationIDs;
1351 
1352  /// This is a map of operations that contain multiple named result groups,
1353  /// i.e. there may be multiple names for the results of the operation. The
1354  /// value of this map are the result numbers that start a result group.
1356 
1357  /// This maps blocks to there visitation number in the current region as well
1358  /// as the string representing their name.
1359  DenseMap<Block *, BlockInfo> blockNames;
1360 
1361  /// This keeps track of all of the non-numeric names that are in flight,
1362  /// allowing us to check for duplicates.
1363  /// Note: the value of the map is unused.
1364  llvm::ScopedHashTable<StringRef, char> usedNames;
1365  llvm::BumpPtrAllocator usedNameAllocator;
1366 
1367  /// This is the next value ID to assign in numbering.
1368  unsigned nextValueID = 0;
1369  /// This is the next ID to assign to a region entry block argument.
1370  unsigned nextArgumentID = 0;
1371  /// This is the next ID to assign when a name conflict is detected.
1372  unsigned nextConflictID = 0;
1373 
1374  /// These are the printing flags. They control, eg., whether to print in
1375  /// generic form.
1376  OpPrintingFlags printerFlags;
1377 };
1378 } // namespace
1379 
1380 SSANameState::SSANameState(Operation *op, const OpPrintingFlags &printerFlags)
1381  : printerFlags(printerFlags) {
1382  llvm::SaveAndRestore valueIDSaver(nextValueID);
1383  llvm::SaveAndRestore argumentIDSaver(nextArgumentID);
1384  llvm::SaveAndRestore conflictIDSaver(nextConflictID);
1385 
1386  // The naming context includes `nextValueID`, `nextArgumentID`,
1387  // `nextConflictID` and `usedNames` scoped HashTable. This information is
1388  // carried from the parent region.
1389  using UsedNamesScopeTy = llvm::ScopedHashTable<StringRef, char>::ScopeTy;
1390  using NamingContext =
1391  std::tuple<Region *, unsigned, unsigned, unsigned, UsedNamesScopeTy *>;
1392 
1393  // Allocator for UsedNamesScopeTy
1394  llvm::BumpPtrAllocator allocator;
1395 
1396  // Add a scope for the top level operation.
1397  auto *topLevelNamesScope =
1398  new (allocator.Allocate<UsedNamesScopeTy>()) UsedNamesScopeTy(usedNames);
1399 
1400  SmallVector<NamingContext, 8> nameContext;
1401  for (Region &region : op->getRegions())
1402  nameContext.push_back(std::make_tuple(&region, nextValueID, nextArgumentID,
1403  nextConflictID, topLevelNamesScope));
1404 
1405  numberValuesInOp(*op);
1406 
1407  while (!nameContext.empty()) {
1408  Region *region;
1409  UsedNamesScopeTy *parentScope;
1410 
1411  if (printerFlags.shouldPrintUniqueSSAIDs())
1412  // To print unique SSA IDs, ignore saved ID counts from parent regions
1413  std::tie(region, std::ignore, std::ignore, std::ignore, parentScope) =
1414  nameContext.pop_back_val();
1415  else
1416  std::tie(region, nextValueID, nextArgumentID, nextConflictID,
1417  parentScope) = nameContext.pop_back_val();
1418 
1419  // When we switch from one subtree to another, pop the scopes(needless)
1420  // until the parent scope.
1421  while (usedNames.getCurScope() != parentScope) {
1422  usedNames.getCurScope()->~UsedNamesScopeTy();
1423  assert((usedNames.getCurScope() != nullptr || parentScope == nullptr) &&
1424  "top level parentScope must be a nullptr");
1425  }
1426 
1427  // Add a scope for the current region.
1428  auto *curNamesScope = new (allocator.Allocate<UsedNamesScopeTy>())
1429  UsedNamesScopeTy(usedNames);
1430 
1431  numberValuesInRegion(*region);
1432 
1433  for (Operation &op : region->getOps())
1434  for (Region &region : op.getRegions())
1435  nameContext.push_back(std::make_tuple(&region, nextValueID,
1436  nextArgumentID, nextConflictID,
1437  curNamesScope));
1438  }
1439 
1440  // Manually remove all the scopes.
1441  while (usedNames.getCurScope() != nullptr)
1442  usedNames.getCurScope()->~UsedNamesScopeTy();
1443 }
1444 
1445 void SSANameState::printValueID(Value value, bool printResultNo,
1446  raw_ostream &stream) const {
1447  if (!value) {
1448  stream << "<<NULL VALUE>>";
1449  return;
1450  }
1451 
1452  std::optional<int> resultNo;
1453  auto lookupValue = value;
1454 
1455  // If this is an operation result, collect the head lookup value of the result
1456  // group and the result number of 'result' within that group.
1457  if (OpResult result = dyn_cast<OpResult>(value))
1458  getResultIDAndNumber(result, lookupValue, resultNo);
1459 
1460  auto it = valueIDs.find(lookupValue);
1461  if (it == valueIDs.end()) {
1462  stream << "<<UNKNOWN SSA VALUE>>";
1463  return;
1464  }
1465 
1466  stream << '%';
1467  if (it->second != NameSentinel) {
1468  stream << it->second;
1469  } else {
1470  auto nameIt = valueNames.find(lookupValue);
1471  assert(nameIt != valueNames.end() && "Didn't have a name entry?");
1472  stream << nameIt->second;
1473  }
1474 
1475  if (resultNo && printResultNo)
1476  stream << '#' << *resultNo;
1477 }
1478 
1479 void SSANameState::printOperationID(Operation *op, raw_ostream &stream) const {
1480  auto it = operationIDs.find(op);
1481  if (it == operationIDs.end()) {
1482  stream << "<<UNKNOWN OPERATION>>";
1483  } else {
1484  stream << '%' << it->second;
1485  }
1486 }
1487 
1488 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
1489  auto it = opResultGroups.find(op);
1490  return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
1491 }
1492 
1493 BlockInfo SSANameState::getBlockInfo(Block *block) {
1494  auto it = blockNames.find(block);
1495  BlockInfo invalidBlock{-1, "INVALIDBLOCK"};
1496  return it != blockNames.end() ? it->second : invalidBlock;
1497 }
1498 
1499 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
1500  assert(!region.empty() && "cannot shadow arguments of an empty region");
1501  assert(region.getNumArguments() == namesToUse.size() &&
1502  "incorrect number of names passed in");
1503  assert(region.getParentOp()->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
1504  "only KnownIsolatedFromAbove ops can shadow names");
1505 
1506  SmallVector<char, 16> nameStr;
1507  for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
1508  auto nameToUse = namesToUse[i];
1509  if (nameToUse == nullptr)
1510  continue;
1511  auto nameToReplace = region.getArgument(i);
1512 
1513  nameStr.clear();
1514  llvm::raw_svector_ostream nameStream(nameStr);
1515  printValueID(nameToUse, /*printResultNo=*/true, nameStream);
1516 
1517  // Entry block arguments should already have a pretty "arg" name.
1518  assert(valueIDs[nameToReplace] == NameSentinel);
1519 
1520  // Use the name without the leading %.
1521  auto name = StringRef(nameStream.str()).drop_front();
1522 
1523  // Overwrite the name.
1524  valueNames[nameToReplace] = name.copy(usedNameAllocator);
1525  }
1526 }
1527 
1528 namespace {
1529 /// Try to get value name from value's location, fallback to `name`.
1530 StringRef maybeGetValueNameFromLoc(Value value, StringRef name) {
1531  if (auto maybeNameLoc = value.getLoc()->findInstanceOf<NameLoc>())
1532  return maybeNameLoc.getName();
1533  return name;
1534 }
1535 } // namespace
1536 
1537 void SSANameState::numberValuesInRegion(Region &region) {
1538  auto setBlockArgNameFn = [&](Value arg, StringRef name) {
1539  assert(!valueIDs.count(arg) && "arg numbered multiple times");
1540  assert(llvm::cast<BlockArgument>(arg).getOwner()->getParent() == &region &&
1541  "arg not defined in current region");
1542  if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1543  name = maybeGetValueNameFromLoc(arg, name);
1544  setValueName(arg, name);
1545  };
1546 
1547  if (!printerFlags.shouldPrintGenericOpForm()) {
1548  if (Operation *op = region.getParentOp()) {
1549  if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
1550  asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
1551  }
1552  }
1553 
1554  // Number the values within this region in a breadth-first order.
1555  unsigned nextBlockID = 0;
1556  for (auto &block : region) {
1557  // Each block gets a unique ID, and all of the operations within it get
1558  // numbered as well.
1559  auto blockInfoIt = blockNames.insert({&block, {-1, ""}});
1560  if (blockInfoIt.second) {
1561  // This block hasn't been named through `getAsmBlockArgumentNames`, use
1562  // default `^bbNNN` format.
1563  std::string name;
1564  llvm::raw_string_ostream(name) << "^bb" << nextBlockID;
1565  blockInfoIt.first->second.name = StringRef(name).copy(usedNameAllocator);
1566  }
1567  blockInfoIt.first->second.ordering = nextBlockID++;
1568 
1569  numberValuesInBlock(block);
1570  }
1571 }
1572 
1573 void SSANameState::numberValuesInBlock(Block &block) {
1574  // Number the block arguments. We give entry block arguments a special name
1575  // 'arg'.
1576  bool isEntryBlock = block.isEntryBlock();
1577  SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
1578  llvm::raw_svector_ostream specialName(specialNameBuffer);
1579  for (auto arg : block.getArguments()) {
1580  if (valueIDs.count(arg))
1581  continue;
1582  if (isEntryBlock) {
1583  specialNameBuffer.resize(strlen("arg"));
1584  specialName << nextArgumentID++;
1585  }
1586  StringRef specialNameStr = specialName.str();
1587  if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1588  specialNameStr = maybeGetValueNameFromLoc(arg, specialNameStr);
1589  setValueName(arg, specialNameStr);
1590  }
1591 
1592  // Number the operations in this block.
1593  for (auto &op : block)
1594  numberValuesInOp(op);
1595 }
1596 
1597 void SSANameState::numberValuesInOp(Operation &op) {
1598  // Function used to set the special result names for the operation.
1599  SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
1600  auto setResultNameFn = [&](Value result, StringRef name) {
1601  assert(!valueIDs.count(result) && "result numbered multiple times");
1602  assert(result.getDefiningOp() == &op && "result not defined by 'op'");
1603  if (LLVM_UNLIKELY(printerFlags.shouldUseNameLocAsPrefix()))
1604  name = maybeGetValueNameFromLoc(result, name);
1605  setValueName(result, name);
1606 
1607  // Record the result number for groups not anchored at 0.
1608  if (int resultNo = llvm::cast<OpResult>(result).getResultNumber())
1609  resultGroups.push_back(resultNo);
1610  };
1611  // Operations can customize the printing of block names in OpAsmOpInterface.
1612  auto setBlockNameFn = [&](Block *block, StringRef name) {
1613  assert(block->getParentOp() == &op &&
1614  "getAsmBlockArgumentNames callback invoked on a block not directly "
1615  "nested under the current operation");
1616  assert(!blockNames.count(block) && "block numbered multiple times");
1617  SmallString<16> tmpBuffer{"^"};
1618  name = sanitizeIdentifier(name, tmpBuffer);
1619  if (name.data() != tmpBuffer.data()) {
1620  tmpBuffer.append(name);
1621  name = tmpBuffer.str();
1622  }
1623  name = name.copy(usedNameAllocator);
1624  blockNames[block] = {-1, name};
1625  };
1626 
1627  if (!printerFlags.shouldPrintGenericOpForm()) {
1628  if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op)) {
1629  asmInterface.getAsmBlockNames(setBlockNameFn);
1630  asmInterface.getAsmResultNames(setResultNameFn);
1631  }
1632  }
1633 
1634  unsigned numResults = op.getNumResults();
1635  if (numResults == 0) {
1636  // If value users should be printed, operations with no result need an id.
1637  if (printerFlags.shouldPrintValueUsers()) {
1638  if (operationIDs.try_emplace(&op, nextValueID).second)
1639  ++nextValueID;
1640  }
1641  return;
1642  }
1643  Value resultBegin = op.getResult(0);
1644 
1645  if (printerFlags.shouldUseNameLocAsPrefix() && !valueIDs.count(resultBegin)) {
1646  if (auto nameLoc = resultBegin.getLoc()->findInstanceOf<NameLoc>()) {
1647  setValueName(resultBegin, nameLoc.getName());
1648  }
1649  }
1650 
1651  // If the first result wasn't numbered, give it a default number.
1652  if (valueIDs.try_emplace(resultBegin, nextValueID).second)
1653  ++nextValueID;
1654 
1655  // If this operation has multiple result groups, mark it.
1656  if (resultGroups.size() != 1) {
1657  llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
1658  opResultGroups.try_emplace(&op, std::move(resultGroups));
1659  }
1660 }
1661 
1662 void SSANameState::getResultIDAndNumber(
1663  OpResult result, Value &lookupValue,
1664  std::optional<int> &lookupResultNo) const {
1665  Operation *owner = result.getOwner();
1666  if (owner->getNumResults() == 1)
1667  return;
1668  int resultNo = result.getResultNumber();
1669 
1670  // If this operation has multiple result groups, we will need to find the
1671  // one corresponding to this result.
1672  auto resultGroupIt = opResultGroups.find(owner);
1673  if (resultGroupIt == opResultGroups.end()) {
1674  // If not, just use the first result.
1675  lookupResultNo = resultNo;
1676  lookupValue = owner->getResult(0);
1677  return;
1678  }
1679 
1680  // Find the correct index using a binary search, as the groups are ordered.
1681  ArrayRef<int> resultGroups = resultGroupIt->second;
1682  const auto *it = llvm::upper_bound(resultGroups, resultNo);
1683  int groupResultNo = 0, groupSize = 0;
1684 
1685  // If there are no smaller elements, the last result group is the lookup.
1686  if (it == resultGroups.end()) {
1687  groupResultNo = resultGroups.back();
1688  groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1689  } else {
1690  // Otherwise, the previous element is the lookup.
1691  groupResultNo = *std::prev(it);
1692  groupSize = *it - groupResultNo;
1693  }
1694 
1695  // We only record the result number for a group of size greater than 1.
1696  if (groupSize != 1)
1697  lookupResultNo = resultNo - groupResultNo;
1698  lookupValue = owner->getResult(groupResultNo);
1699 }
1700 
1701 void SSANameState::setValueName(Value value, StringRef name) {
1702  // If the name is empty, the value uses the default numbering.
1703  if (name.empty()) {
1704  valueIDs[value] = nextValueID++;
1705  return;
1706  }
1707 
1708  valueIDs[value] = NameSentinel;
1709  valueNames[value] = uniqueValueName(name);
1710 }
1711 
1712 StringRef SSANameState::uniqueValueName(StringRef name) {
1713  SmallString<16> tmpBuffer;
1714  name = sanitizeIdentifier(name, tmpBuffer);
1715 
1716  // Check to see if this name is already unique.
1717  if (!usedNames.count(name)) {
1718  name = name.copy(usedNameAllocator);
1719  } else {
1720  // Otherwise, we had a conflict - probe until we find a unique name. This
1721  // is guaranteed to terminate (and usually in a single iteration) because it
1722  // generates new names by incrementing nextConflictID.
1723  SmallString<64> probeName(name);
1724  probeName.push_back('_');
1725  while (true) {
1726  probeName += llvm::utostr(nextConflictID++);
1727  if (!usedNames.count(probeName)) {
1728  name = probeName.str().copy(usedNameAllocator);
1729  break;
1730  }
1731  probeName.resize(name.size() + 1);
1732  }
1733  }
1734 
1735  usedNames.insert(name, char());
1736  return name;
1737 }
1738 
1739 //===----------------------------------------------------------------------===//
1740 // DistinctState
1741 //===----------------------------------------------------------------------===//
1742 
1743 namespace {
1744 /// This class manages the state for distinct attributes.
1745 class DistinctState {
1746 public:
1747  /// Returns a unique identifier for the given distinct attribute.
1748  uint64_t getId(DistinctAttr distinctAttr);
1749 
1750 private:
1751  uint64_t distinctCounter = 0;
1752  DenseMap<DistinctAttr, uint64_t> distinctAttrMap;
1753 };
1754 } // namespace
1755 
1756 uint64_t DistinctState::getId(DistinctAttr distinctAttr) {
1757  auto [it, inserted] =
1758  distinctAttrMap.try_emplace(distinctAttr, distinctCounter);
1759  if (inserted)
1760  distinctCounter++;
1761  return it->getSecond();
1762 }
1763 
1764 //===----------------------------------------------------------------------===//
1765 // Resources
1766 //===----------------------------------------------------------------------===//
1767 
1768 AsmParsedResourceEntry::~AsmParsedResourceEntry() = default;
1769 AsmResourceBuilder::~AsmResourceBuilder() = default;
1770 AsmResourceParser::~AsmResourceParser() = default;
1771 AsmResourcePrinter::~AsmResourcePrinter() = default;
1772 
1774  switch (kind) {
1775  case AsmResourceEntryKind::Blob:
1776  return "blob";
1777  case AsmResourceEntryKind::Bool:
1778  return "bool";
1779  case AsmResourceEntryKind::String:
1780  return "string";
1781  }
1782  llvm_unreachable("unknown AsmResourceEntryKind");
1783 }
1784 
1785 AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
1786  std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
1787  if (!collection)
1788  collection = std::make_unique<ResourceCollection>(key);
1789  return *collection;
1790 }
1791 
1792 std::vector<std::unique_ptr<AsmResourcePrinter>>
1794  std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
1795  for (auto &it : keyToResources) {
1796  ResourceCollection *collection = it.second.get();
1797  auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
1798  return collection->buildResources(op, builder);
1799  };
1800  printers.emplace_back(
1801  AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
1802  }
1803  return printers;
1804 }
1805 
1806 LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
1807  AsmParsedResourceEntry &entry) {
1808  switch (entry.getKind()) {
1810  FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
1811  if (failed(blob))
1812  return failure();
1813  resources.emplace_back(entry.getKey(), std::move(*blob));
1814  return success();
1815  }
1817  FailureOr<bool> value = entry.parseAsBool();
1818  if (failed(value))
1819  return failure();
1820  resources.emplace_back(entry.getKey(), *value);
1821  break;
1822  }
1824  FailureOr<std::string> str = entry.parseAsString();
1825  if (failed(str))
1826  return failure();
1827  resources.emplace_back(entry.getKey(), std::move(*str));
1828  break;
1829  }
1830  }
1831  return success();
1832 }
1833 
1834 void FallbackAsmResourceMap::ResourceCollection::buildResources(
1835  Operation *op, AsmResourceBuilder &builder) const {
1836  for (const auto &entry : resources) {
1837  if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
1838  builder.buildBlob(entry.key, *value);
1839  else if (const auto *value = std::get_if<bool>(&entry.value))
1840  builder.buildBool(entry.key, *value);
1841  else if (const auto *value = std::get_if<std::string>(&entry.value))
1842  builder.buildString(entry.key, *value);
1843  else
1844  llvm_unreachable("unknown AsmResourceEntryKind");
1845  }
1846 }
1847 
1848 //===----------------------------------------------------------------------===//
1849 // AsmState
1850 //===----------------------------------------------------------------------===//
1851 
1852 namespace mlir {
1853 namespace detail {
1855 public:
1856  explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
1857  AsmState::LocationMap *locationMap)
1858  : interfaces(op->getContext()), nameState(op, printerFlags),
1859  printerFlags(printerFlags), locationMap(locationMap) {}
1860  explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
1861  AsmState::LocationMap *locationMap)
1862  : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
1863 
1864  /// Initialize the alias state to enable the printing of aliases.
1866  aliasState.initialize(op, printerFlags, interfaces);
1867  }
1868 
1869  /// Get the state used for aliases.
1870  AliasState &getAliasState() { return aliasState; }
1871 
1872  /// Get the state used for SSA names.
1873  SSANameState &getSSANameState() { return nameState; }
1874 
1875  /// Get the state used for distinct attribute identifiers.
1876  DistinctState &getDistinctState() { return distinctState; }
1877 
1878  /// Return the dialects within the context that implement
1879  /// OpAsmDialectInterface.
1881  return interfaces;
1882  }
1883 
1884  /// Return the non-dialect resource printers.
1886  return llvm::make_pointee_range(externalResourcePrinters);
1887  }
1888 
1889  /// Get the printer flags.
1890  const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
1891 
1892  /// Register the location, line and column, within the buffer that the given
1893  /// operation was printed at.
1894  void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1895  if (locationMap)
1896  (*locationMap)[op] = std::make_pair(line, col);
1897  }
1898 
1899  /// Return the referenced dialect resources within the printer.
1902  return dialectResources;
1903  }
1904 
1905  LogicalResult pushCyclicPrinting(const void *opaquePointer) {
1906  return success(cyclicPrintingStack.insert(opaquePointer));
1907  }
1908 
1909  void popCyclicPrinting() { cyclicPrintingStack.pop_back(); }
1910 
1911 private:
1912  /// Collection of OpAsm interfaces implemented in the context.
1914 
1915  /// A collection of non-dialect resource printers.
1916  SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
1917 
1918  /// A set of dialect resources that were referenced during printing.
1920 
1921  /// The state used for attribute and type aliases.
1922  AliasState aliasState;
1923 
1924  /// The state used for SSA value names.
1925  SSANameState nameState;
1926 
1927  /// The state used for distinct attribute identifiers.
1928  DistinctState distinctState;
1929 
1930  /// Flags that control op output.
1931  OpPrintingFlags printerFlags;
1932 
1933  /// An optional location map to be populated.
1934  AsmState::LocationMap *locationMap;
1935 
1936  /// Stack of potentially cyclic mutable attributes or type currently being
1937  /// printed.
1938  SetVector<const void *> cyclicPrintingStack;
1939 
1940  // Allow direct access to the impl fields.
1941  friend AsmState;
1942 };
1943 
1944 template <typename Range>
1945 void printDimensionList(raw_ostream &stream, Range &&shape) {
1946  llvm::interleave(
1947  shape, stream,
1948  [&stream](const auto &dimSize) {
1949  if (ShapedType::isDynamic(dimSize))
1950  stream << "?";
1951  else
1952  stream << dimSize;
1953  },
1954  "x");
1955 }
1956 
1957 } // namespace detail
1958 } // namespace mlir
1959 
1960 /// Verifies the operation and switches to generic op printing if verification
1961 /// fails. We need to do this because custom print functions may fail for
1962 /// invalid ops.
1964  OpPrintingFlags printerFlags) {
1965  if (printerFlags.shouldPrintGenericOpForm() ||
1966  printerFlags.shouldAssumeVerified())
1967  return printerFlags;
1968 
1969  // Ignore errors emitted by the verifier. We check the thread id to avoid
1970  // consuming other threads' errors.
1971  auto parentThreadId = llvm::get_threadid();
1972  ScopedDiagnosticHandler diagHandler(op->getContext(), [&](Diagnostic &diag) {
1973  if (parentThreadId == llvm::get_threadid()) {
1974  LLVM_DEBUG({
1975  diag.print(llvm::dbgs());
1976  llvm::dbgs() << "\n";
1977  });
1978  return success();
1979  }
1980  return failure();
1981  });
1982  if (failed(verify(op))) {
1983  LLVM_DEBUG(llvm::dbgs()
1984  << DEBUG_TYPE << ": '" << op->getName()
1985  << "' failed to verify and will be printed in generic form\n");
1986  printerFlags.printGenericOpForm();
1987  }
1988 
1989  return printerFlags;
1990 }
1991 
1992 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
1993  LocationMap *locationMap, FallbackAsmResourceMap *map)
1994  : impl(std::make_unique<AsmStateImpl>(
1995  op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
1996  if (map)
1998 }
1999 AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
2000  LocationMap *locationMap, FallbackAsmResourceMap *map)
2001  : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
2002  if (map)
2004 }
2005 AsmState::~AsmState() = default;
2006 
2008  return impl->getPrinterFlags();
2009 }
2010 
2012  std::unique_ptr<AsmResourcePrinter> printer) {
2013  impl->externalResourcePrinters.emplace_back(std::move(printer));
2014 }
2015 
2018  return impl->getDialectResources();
2019 }
2020 
2021 //===----------------------------------------------------------------------===//
2022 // AsmPrinter::Impl
2023 //===----------------------------------------------------------------------===//
2024 
2025 AsmPrinter::Impl::Impl(raw_ostream &os, AsmStateImpl &state)
2026  : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
2027 
2029  // Check to see if we are printing debug information.
2030  if (!printerFlags.shouldPrintDebugInfo())
2031  return;
2032 
2033  os << " ";
2034  printLocation(loc, /*allowAlias=*/allowAlias);
2035 }
2036 
2038  bool isTopLevel) {
2039  // If this isn't a top-level location, check for an alias.
2040  if (!isTopLevel && succeeded(state.getAliasState().getAlias(loc, os)))
2041  return;
2042 
2044  .Case<OpaqueLoc>([&](OpaqueLoc loc) {
2045  printLocationInternal(loc.getFallbackLocation(), pretty);
2046  })
2047  .Case<UnknownLoc>([&](UnknownLoc loc) {
2048  if (pretty)
2049  os << "[unknown]";
2050  else
2051  os << "unknown";
2052  })
2053  .Case<FileLineColRange>([&](FileLineColRange loc) {
2054  if (pretty)
2055  os << loc.getFilename().getValue();
2056  else
2057  printEscapedString(loc.getFilename());
2058  if (loc.getEndColumn() == loc.getStartColumn() &&
2059  loc.getStartLine() == loc.getEndLine()) {
2060  os << ':' << loc.getStartLine() << ':' << loc.getStartColumn();
2061  return;
2062  }
2063  if (loc.getStartLine() == loc.getEndLine()) {
2064  os << ':' << loc.getStartLine() << ':' << loc.getStartColumn()
2065  << " to :" << loc.getEndColumn();
2066  return;
2067  }
2068  os << ':' << loc.getStartLine() << ':' << loc.getStartColumn() << " to "
2069  << loc.getEndLine() << ':' << loc.getEndColumn();
2070  })
2071  .Case<NameLoc>([&](NameLoc loc) {
2072  printEscapedString(loc.getName());
2073 
2074  // Print the child if it isn't unknown.
2075  auto childLoc = loc.getChildLoc();
2076  if (!llvm::isa<UnknownLoc>(childLoc)) {
2077  os << '(';
2078  printLocationInternal(childLoc, pretty);
2079  os << ')';
2080  }
2081  })
2082  .Case<CallSiteLoc>([&](CallSiteLoc loc) {
2083  Location caller = loc.getCaller();
2084  Location callee = loc.getCallee();
2085  if (!pretty)
2086  os << "callsite(";
2087  printLocationInternal(callee, pretty);
2088  if (pretty) {
2089  if (llvm::isa<NameLoc>(callee)) {
2090  if (llvm::isa<FileLineColLoc>(caller)) {
2091  os << " at ";
2092  } else {
2093  os << newLine << " at ";
2094  }
2095  } else {
2096  os << newLine << " at ";
2097  }
2098  } else {
2099  os << " at ";
2100  }
2101  printLocationInternal(caller, pretty);
2102  if (!pretty)
2103  os << ")";
2104  })
2105  .Case<FusedLoc>([&](FusedLoc loc) {
2106  if (!pretty)
2107  os << "fused";
2108  if (Attribute metadata = loc.getMetadata()) {
2109  os << '<';
2110  printAttribute(metadata);
2111  os << '>';
2112  }
2113  os << '[';
2114  interleave(
2115  loc.getLocations(),
2116  [&](Location loc) { printLocationInternal(loc, pretty); },
2117  [&]() { os << ", "; });
2118  os << ']';
2119  })
2120  .Default([&](LocationAttr loc) {
2121  // Assumes that this is a dialect-specific attribute and prints it
2122  // directly.
2123  printAttribute(loc);
2124  });
2125 }
2126 
2127 /// Print a floating point value in a way that the parser will be able to
2128 /// round-trip losslessly.
2129 static void printFloatValue(const APFloat &apValue, raw_ostream &os,
2130  bool *printedHex = nullptr) {
2131  // We would like to output the FP constant value in exponential notation,
2132  // but we cannot do this if doing so will lose precision. Check here to
2133  // make sure that we only output it in exponential format if we can parse
2134  // the value back and get the same value.
2135  bool isInf = apValue.isInfinity();
2136  bool isNaN = apValue.isNaN();
2137  if (!isInf && !isNaN) {
2138  SmallString<128> strValue;
2139  apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
2140  /*TruncateZero=*/false);
2141 
2142  // Check to make sure that the stringized number is not some string like
2143  // "Inf" or NaN, that atof will accept, but the lexer will not. Check
2144  // that the string matches the "[-+]?[0-9]" regex.
2145  assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
2146  ((strValue[0] == '-' || strValue[0] == '+') &&
2147  (strValue[1] >= '0' && strValue[1] <= '9'))) &&
2148  "[-+]?[0-9] regex does not match!");
2149 
2150  // Parse back the stringized version and check that the value is equal
2151  // (i.e., there is no precision loss).
2152  if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
2153  os << strValue;
2154  return;
2155  }
2156 
2157  // If it is not, use the default format of APFloat instead of the
2158  // exponential notation.
2159  strValue.clear();
2160  apValue.toString(strValue);
2161 
2162  // Make sure that we can parse the default form as a float.
2163  if (strValue.str().contains('.')) {
2164  os << strValue;
2165  return;
2166  }
2167  }
2168 
2169  // Print special values in hexadecimal format. The sign bit should be included
2170  // in the literal.
2171  if (printedHex)
2172  *printedHex = true;
2174  APInt apInt = apValue.bitcastToAPInt();
2175  apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
2176  /*formatAsCLiteral=*/true);
2177  os << str;
2178 }
2179 
2180 void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
2181  if (printerFlags.shouldPrintDebugInfoPrettyForm())
2182  return printLocationInternal(loc, /*pretty=*/true, /*isTopLevel=*/true);
2183 
2184  os << "loc(";
2185  if (!allowAlias || failed(printAlias(loc)))
2186  printLocationInternal(loc, /*pretty=*/false, /*isTopLevel=*/true);
2187  os << ')';
2188 }
2189 
2191  const AsmDialectResourceHandle &resource) {
2192  auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
2193  os << interface->getResourceKey(resource);
2194  state.getDialectResources()[resource.getDialect()].insert(resource);
2195 }
2196 
2197 /// Returns true if the given dialect symbol data is simple enough to print in
2198 /// the pretty form. This is essentially when the symbol takes the form:
2199 /// identifier (`<` body `>`)?
2200 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
2201  // The name must start with an identifier.
2202  if (symName.empty() || !isalpha(symName.front()))
2203  return false;
2204 
2205  // Ignore all the characters that are valid in an identifier in the symbol
2206  // name.
2207  symName = symName.drop_while(
2208  [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
2209  if (symName.empty())
2210  return true;
2211 
2212  // If we got to an unexpected character, then it must be a <>. Check that the
2213  // rest of the symbol is wrapped within <>.
2214  return symName.front() == '<' && symName.back() == '>';
2215 }
2216 
2217 /// Print the given dialect symbol to the stream.
2218 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
2219  StringRef dialectName, StringRef symString) {
2220  os << symPrefix << dialectName;
2221 
2222  // If this symbol name is simple enough, print it directly in pretty form,
2223  // otherwise, we print it as an escaped string.
2225  os << '.' << symString;
2226  return;
2227  }
2228 
2229  os << '<' << symString << '>';
2230 }
2231 
2232 /// Returns true if the given string can be represented as a bare identifier.
2233 static bool isBareIdentifier(StringRef name) {
2234  // By making this unsigned, the value passed in to isalnum will always be
2235  // in the range 0-255. This is important when building with MSVC because
2236  // its implementation will assert. This situation can arise when dealing
2237  // with UTF-8 multibyte characters.
2238  if (name.empty() || (!isalpha(name[0]) && name[0] != '_'))
2239  return false;
2240  return llvm::all_of(name.drop_front(), [](unsigned char c) {
2241  return isalnum(c) || c == '_' || c == '$' || c == '.';
2242  });
2243 }
2244 
2245 /// Print the given string as a keyword, or a quoted and escaped string if it
2246 /// has any special or non-printable characters in it.
2247 static void printKeywordOrString(StringRef keyword, raw_ostream &os) {
2248  // If it can be represented as a bare identifier, write it directly.
2249  if (isBareIdentifier(keyword)) {
2250  os << keyword;
2251  return;
2252  }
2253 
2254  // Otherwise, output the keyword wrapped in quotes with proper escaping.
2255  os << "\"";
2256  printEscapedString(keyword, os);
2257  os << '"';
2258 }
2259 
2260 /// Print the given string as a symbol reference. A symbol reference is
2261 /// represented as a string prefixed with '@'. The reference is surrounded with
2262 /// ""'s and escaped if it has any special or non-printable characters in it.
2263 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
2264  if (symbolRef.empty()) {
2265  os << "@<<INVALID EMPTY SYMBOL>>";
2266  return;
2267  }
2268  os << '@';
2269  printKeywordOrString(symbolRef, os);
2270 }
2271 
2272 // Print out a valid ElementsAttr that is succinct and can represent any
2273 // potential shape/type, for use when eliding a large ElementsAttr.
2274 //
2275 // We choose to use a dense resource ElementsAttr literal with conspicuous
2276 // content to hopefully alert readers to the fact that this has been elided.
2277 static void printElidedElementsAttr(raw_ostream &os) {
2278  os << R"(dense_resource<__elided__>)";
2279 }
2280 
2281 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
2282  return state.getAliasState().getAlias(attr, os);
2283 }
2284 
2285 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
2286  return state.getAliasState().getAlias(type, os);
2287 }
2288 
2289 void AsmPrinter::Impl::printAttribute(Attribute attr,
2290  AttrTypeElision typeElision) {
2291  if (!attr) {
2292  os << "<<NULL ATTRIBUTE>>";
2293  return;
2294  }
2295 
2296  // Try to print an alias for this attribute.
2297  if (succeeded(printAlias(attr)))
2298  return;
2299  return printAttributeImpl(attr, typeElision);
2300 }
2301 
2302 void AsmPrinter::Impl::printAttributeImpl(Attribute attr,
2303  AttrTypeElision typeElision) {
2304  if (!isa<BuiltinDialect>(attr.getDialect())) {
2305  printDialectAttribute(attr);
2306  } else if (auto opaqueAttr = llvm::dyn_cast<OpaqueAttr>(attr)) {
2307  printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
2308  opaqueAttr.getAttrData());
2309  } else if (llvm::isa<UnitAttr>(attr)) {
2310  os << "unit";
2311  return;
2312  } else if (auto distinctAttr = llvm::dyn_cast<DistinctAttr>(attr)) {
2313  os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<";
2314  if (!llvm::isa<UnitAttr>(distinctAttr.getReferencedAttr())) {
2315  printAttribute(distinctAttr.getReferencedAttr());
2316  }
2317  os << '>';
2318  return;
2319  } else if (auto dictAttr = llvm::dyn_cast<DictionaryAttr>(attr)) {
2320  os << '{';
2321  interleaveComma(dictAttr.getValue(),
2322  [&](NamedAttribute attr) { printNamedAttribute(attr); });
2323  os << '}';
2324 
2325  } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) {
2326  Type intType = intAttr.getType();
2327  if (intType.isSignlessInteger(1)) {
2328  os << (intAttr.getValue().getBoolValue() ? "true" : "false");
2329 
2330  // Boolean integer attributes always elides the type.
2331  return;
2332  }
2333 
2334  // Only print attributes as unsigned if they are explicitly unsigned or are
2335  // signless 1-bit values. Indexes, signed values, and multi-bit signless
2336  // values print as signed.
2337  bool isUnsigned =
2338  intType.isUnsignedInteger() || intType.isSignlessInteger(1);
2339  intAttr.getValue().print(os, !isUnsigned);
2340 
2341  // IntegerAttr elides the type if I64.
2342  if (typeElision == AttrTypeElision::May && intType.isSignlessInteger(64))
2343  return;
2344 
2345  } else if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attr)) {
2346  bool printedHex = false;
2347  printFloatValue(floatAttr.getValue(), os, &printedHex);
2348 
2349  // FloatAttr elides the type if F64.
2350  if (typeElision == AttrTypeElision::May && floatAttr.getType().isF64() &&
2351  !printedHex)
2352  return;
2353 
2354  } else if (auto strAttr = llvm::dyn_cast<StringAttr>(attr)) {
2355  printEscapedString(strAttr.getValue());
2356 
2357  } else if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(attr)) {
2358  os << '[';
2359  interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
2360  printAttribute(attr, AttrTypeElision::May);
2361  });
2362  os << ']';
2363 
2364  } else if (auto affineMapAttr = llvm::dyn_cast<AffineMapAttr>(attr)) {
2365  os << "affine_map<";
2366  affineMapAttr.getValue().print(os);
2367  os << '>';
2368 
2369  // AffineMap always elides the type.
2370  return;
2371 
2372  } else if (auto integerSetAttr = llvm::dyn_cast<IntegerSetAttr>(attr)) {
2373  os << "affine_set<";
2374  integerSetAttr.getValue().print(os);
2375  os << '>';
2376 
2377  // IntegerSet always elides the type.
2378  return;
2379 
2380  } else if (auto typeAttr = llvm::dyn_cast<TypeAttr>(attr)) {
2381  printType(typeAttr.getValue());
2382 
2383  } else if (auto refAttr = llvm::dyn_cast<SymbolRefAttr>(attr)) {
2384  printSymbolReference(refAttr.getRootReference().getValue(), os);
2385  for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
2386  os << "::";
2387  printSymbolReference(nestedRef.getValue(), os);
2388  }
2389 
2390  } else if (auto intOrFpEltAttr =
2391  llvm::dyn_cast<DenseIntOrFPElementsAttr>(attr)) {
2392  if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
2393  printElidedElementsAttr(os);
2394  } else {
2395  os << "dense<";
2396  printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
2397  os << '>';
2398  }
2399 
2400  } else if (auto strEltAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr)) {
2401  if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
2402  printElidedElementsAttr(os);
2403  } else {
2404  os << "dense<";
2405  printDenseStringElementsAttr(strEltAttr);
2406  os << '>';
2407  }
2408 
2409  } else if (auto sparseEltAttr = llvm::dyn_cast<SparseElementsAttr>(attr)) {
2410  if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
2411  printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
2412  printElidedElementsAttr(os);
2413  } else {
2414  os << "sparse<";
2415  DenseIntElementsAttr indices = sparseEltAttr.getIndices();
2416  if (indices.getNumElements() != 0) {
2417  printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
2418  os << ", ";
2419  printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
2420  }
2421  os << '>';
2422  }
2423  } else if (auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(attr)) {
2424  stridedLayoutAttr.print(os);
2425  } else if (auto denseArrayAttr = llvm::dyn_cast<DenseArrayAttr>(attr)) {
2426  os << "array<";
2427  printType(denseArrayAttr.getElementType());
2428  if (!denseArrayAttr.empty()) {
2429  os << ": ";
2430  printDenseArrayAttr(denseArrayAttr);
2431  }
2432  os << ">";
2433  return;
2434  } else if (auto resourceAttr =
2435  llvm::dyn_cast<DenseResourceElementsAttr>(attr)) {
2436  os << "dense_resource<";
2437  printResourceHandle(resourceAttr.getRawHandle());
2438  os << ">";
2439  } else if (auto locAttr = llvm::dyn_cast<LocationAttr>(attr)) {
2440  printLocation(locAttr);
2441  } else {
2442  llvm::report_fatal_error("Unknown builtin attribute");
2443  }
2444  // Don't print the type if we must elide it, or if it is a None type.
2445  if (typeElision != AttrTypeElision::Must) {
2446  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(attr)) {
2447  Type attrType = typedAttr.getType();
2448  if (!llvm::isa<NoneType>(attrType)) {
2449  os << " : ";
2450  printType(attrType);
2451  }
2452  }
2453  }
2454 }
2455 
2456 /// Print the integer element of a DenseElementsAttr.
2457 static void printDenseIntElement(const APInt &value, raw_ostream &os,
2458  Type type) {
2459  if (type.isInteger(1))
2460  os << (value.getBoolValue() ? "true" : "false");
2461  else
2462  value.print(os, !type.isUnsignedInteger());
2463 }
2464 
2465 static void
2466 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
2467  function_ref<void(unsigned)> printEltFn) {
2468  // Special case for 0-d and splat tensors.
2469  if (isSplat)
2470  return printEltFn(0);
2471 
2472  // Special case for degenerate tensors.
2473  auto numElements = type.getNumElements();
2474  if (numElements == 0)
2475  return;
2476 
2477  // We use a mixed-radix counter to iterate through the shape. When we bump a
2478  // non-least-significant digit, we emit a close bracket. When we next emit an
2479  // element we re-open all closed brackets.
2480 
2481  // The mixed-radix counter, with radices in 'shape'.
2482  int64_t rank = type.getRank();
2483  SmallVector<unsigned, 4> counter(rank, 0);
2484  // The number of brackets that have been opened and not closed.
2485  unsigned openBrackets = 0;
2486 
2487  auto shape = type.getShape();
2488  auto bumpCounter = [&] {
2489  // Bump the least significant digit.
2490  ++counter[rank - 1];
2491  // Iterate backwards bubbling back the increment.
2492  for (unsigned i = rank - 1; i > 0; --i)
2493  if (counter[i] >= shape[i]) {
2494  // Index 'i' is rolled over. Bump (i-1) and close a bracket.
2495  counter[i] = 0;
2496  ++counter[i - 1];
2497  --openBrackets;
2498  os << ']';
2499  }
2500  };
2501 
2502  for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
2503  if (idx != 0)
2504  os << ", ";
2505  while (openBrackets++ < rank)
2506  os << '[';
2507  openBrackets = rank;
2508  printEltFn(idx);
2509  bumpCounter();
2510  }
2511  while (openBrackets-- > 0)
2512  os << ']';
2513 }
2514 
2515 void AsmPrinter::Impl::printDenseElementsAttr(DenseElementsAttr attr,
2516  bool allowHex) {
2517  if (auto stringAttr = llvm::dyn_cast<DenseStringElementsAttr>(attr))
2518  return printDenseStringElementsAttr(stringAttr);
2519 
2520  printDenseIntOrFPElementsAttr(llvm::cast<DenseIntOrFPElementsAttr>(attr),
2521  allowHex);
2522 }
2523 
2524 void AsmPrinter::Impl::printDenseIntOrFPElementsAttr(
2525  DenseIntOrFPElementsAttr attr, bool allowHex) {
2526  auto type = attr.getType();
2527  auto elementType = type.getElementType();
2528 
2529  // Check to see if we should format this attribute as a hex string.
2530  if (allowHex && printerFlags.shouldPrintElementsAttrWithHex(attr)) {
2531  ArrayRef<char> rawData = attr.getRawData();
2532  if (llvm::endianness::native == llvm::endianness::big) {
2533  // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
2534  // machines. It is converted here to print in LE format.
2535  SmallVector<char, 64> outDataVec(rawData.size());
2536  MutableArrayRef<char> convRawData(outDataVec);
2537  DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
2538  rawData, convRawData, type);
2539  printHexString(convRawData);
2540  } else {
2541  printHexString(rawData);
2542  }
2543 
2544  return;
2545  }
2546 
2547  if (ComplexType complexTy = llvm::dyn_cast<ComplexType>(elementType)) {
2548  Type complexElementType = complexTy.getElementType();
2549  // Note: The if and else below had a common lambda function which invoked
2550  // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
2551  // and hence was replaced.
2552  if (llvm::isa<IntegerType>(complexElementType)) {
2553  auto valueIt = attr.value_begin<std::complex<APInt>>();
2554  printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2555  auto complexValue = *(valueIt + index);
2556  os << "(";
2557  printDenseIntElement(complexValue.real(), os, complexElementType);
2558  os << ",";
2559  printDenseIntElement(complexValue.imag(), os, complexElementType);
2560  os << ")";
2561  });
2562  } else {
2563  auto valueIt = attr.value_begin<std::complex<APFloat>>();
2564  printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2565  auto complexValue = *(valueIt + index);
2566  os << "(";
2567  printFloatValue(complexValue.real(), os);
2568  os << ",";
2569  printFloatValue(complexValue.imag(), os);
2570  os << ")";
2571  });
2572  }
2573  } else if (elementType.isIntOrIndex()) {
2574  auto valueIt = attr.value_begin<APInt>();
2575  printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2576  printDenseIntElement(*(valueIt + index), os, elementType);
2577  });
2578  } else {
2579  assert(llvm::isa<FloatType>(elementType) && "unexpected element type");
2580  auto valueIt = attr.value_begin<APFloat>();
2581  printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
2582  printFloatValue(*(valueIt + index), os);
2583  });
2584  }
2585 }
2586 
2587 void AsmPrinter::Impl::printDenseStringElementsAttr(
2588  DenseStringElementsAttr attr) {
2589  ArrayRef<StringRef> data = attr.getRawStringData();
2590  auto printFn = [&](unsigned index) { printEscapedString(data[index]); };
2591  printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
2592 }
2593 
2594 void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) {
2595  Type type = attr.getElementType();
2596  unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth();
2597  unsigned byteSize = bitwidth / 8;
2598  ArrayRef<char> data = attr.getRawData();
2599 
2600  auto printElementAt = [&](unsigned i) {
2601  APInt value(bitwidth, 0);
2602  if (bitwidth) {
2603  llvm::LoadIntFromMemory(
2604  value, reinterpret_cast<const uint8_t *>(data.begin() + byteSize * i),
2605  byteSize);
2606  }
2607  // Print the data as-is or as a float.
2608  if (type.isIntOrIndex()) {
2609  printDenseIntElement(value, getStream(), type);
2610  } else {
2611  APFloat fltVal(llvm::cast<FloatType>(type).getFloatSemantics(), value);
2612  printFloatValue(fltVal, getStream());
2613  }
2614  };
2615  llvm::interleaveComma(llvm::seq<unsigned>(0, attr.size()), getStream(),
2616  printElementAt);
2617 }
2618 
2619 void AsmPrinter::Impl::printType(Type type) {
2620  if (!type) {
2621  os << "<<NULL TYPE>>";
2622  return;
2623  }
2624 
2625  // Try to print an alias for this type.
2626  if (succeeded(printAlias(type)))
2627  return;
2628  return printTypeImpl(type);
2629 }
2630 
2631 void AsmPrinter::Impl::printTypeImpl(Type type) {
2632  TypeSwitch<Type>(type)
2633  .Case<OpaqueType>([&](OpaqueType opaqueTy) {
2634  printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
2635  opaqueTy.getTypeData());
2636  })
2637  .Case<IndexType>([&](Type) { os << "index"; })
2638  .Case<Float4E2M1FNType>([&](Type) { os << "f4E2M1FN"; })
2639  .Case<Float6E2M3FNType>([&](Type) { os << "f6E2M3FN"; })
2640  .Case<Float6E3M2FNType>([&](Type) { os << "f6E3M2FN"; })
2641  .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
2642  .Case<Float8E4M3Type>([&](Type) { os << "f8E4M3"; })
2643  .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
2644  .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
2645  .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
2646  .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
2647  .Case<Float8E3M4Type>([&](Type) { os << "f8E3M4"; })
2648  .Case<Float8E8M0FNUType>([&](Type) { os << "f8E8M0FNU"; })
2649  .Case<BFloat16Type>([&](Type) { os << "bf16"; })
2650  .Case<Float16Type>([&](Type) { os << "f16"; })
2651  .Case<FloatTF32Type>([&](Type) { os << "tf32"; })
2652  .Case<Float32Type>([&](Type) { os << "f32"; })
2653  .Case<Float64Type>([&](Type) { os << "f64"; })
2654  .Case<Float80Type>([&](Type) { os << "f80"; })
2655  .Case<Float128Type>([&](Type) { os << "f128"; })
2656  .Case<IntegerType>([&](IntegerType integerTy) {
2657  if (integerTy.isSigned())
2658  os << 's';
2659  else if (integerTy.isUnsigned())
2660  os << 'u';
2661  os << 'i' << integerTy.getWidth();
2662  })
2663  .Case<FunctionType>([&](FunctionType funcTy) {
2664  os << '(';
2665  interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
2666  os << ") -> ";
2667  ArrayRef<Type> results = funcTy.getResults();
2668  if (results.size() == 1 && !llvm::isa<FunctionType>(results[0])) {
2669  printType(results[0]);
2670  } else {
2671  os << '(';
2672  interleaveComma(results, [&](Type ty) { printType(ty); });
2673  os << ')';
2674  }
2675  })
2676  .Case<VectorType>([&](VectorType vectorTy) {
2677  auto scalableDims = vectorTy.getScalableDims();
2678  os << "vector<";
2679  auto vShape = vectorTy.getShape();
2680  unsigned lastDim = vShape.size();
2681  unsigned dimIdx = 0;
2682  for (dimIdx = 0; dimIdx < lastDim; dimIdx++) {
2683  if (!scalableDims.empty() && scalableDims[dimIdx])
2684  os << '[';
2685  os << vShape[dimIdx];
2686  if (!scalableDims.empty() && scalableDims[dimIdx])
2687  os << ']';
2688  os << 'x';
2689  }
2690  printType(vectorTy.getElementType());
2691  os << '>';
2692  })
2693  .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
2694  os << "tensor<";
2695  printDimensionList(tensorTy.getShape());
2696  if (!tensorTy.getShape().empty())
2697  os << 'x';
2698  printType(tensorTy.getElementType());
2699  // Only print the encoding attribute value if set.
2700  if (tensorTy.getEncoding()) {
2701  os << ", ";
2702  printAttribute(tensorTy.getEncoding());
2703  }
2704  os << '>';
2705  })
2706  .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
2707  os << "tensor<*x";
2708  printType(tensorTy.getElementType());
2709  os << '>';
2710  })
2711  .Case<MemRefType>([&](MemRefType memrefTy) {
2712  os << "memref<";
2713  printDimensionList(memrefTy.getShape());
2714  if (!memrefTy.getShape().empty())
2715  os << 'x';
2716  printType(memrefTy.getElementType());
2717  MemRefLayoutAttrInterface layout = memrefTy.getLayout();
2718  if (!llvm::isa<AffineMapAttr>(layout) || !layout.isIdentity()) {
2719  os << ", ";
2720  printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
2721  }
2722  // Only print the memory space if it is the non-default one.
2723  if (memrefTy.getMemorySpace()) {
2724  os << ", ";
2725  printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2726  }
2727  os << '>';
2728  })
2729  .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
2730  os << "memref<*x";
2731  printType(memrefTy.getElementType());
2732  // Only print the memory space if it is the non-default one.
2733  if (memrefTy.getMemorySpace()) {
2734  os << ", ";
2735  printAttribute(memrefTy.getMemorySpace(), AttrTypeElision::May);
2736  }
2737  os << '>';
2738  })
2739  .Case<ComplexType>([&](ComplexType complexTy) {
2740  os << "complex<";
2741  printType(complexTy.getElementType());
2742  os << '>';
2743  })
2744  .Case<TupleType>([&](TupleType tupleTy) {
2745  os << "tuple<";
2746  interleaveComma(tupleTy.getTypes(),
2747  [&](Type type) { printType(type); });
2748  os << '>';
2749  })
2750  .Case<NoneType>([&](Type) { os << "none"; })
2751  .Default([&](Type type) { return printDialectType(type); });
2752 }
2753 
2754 void AsmPrinter::Impl::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2755  ArrayRef<StringRef> elidedAttrs,
2756  bool withKeyword) {
2757  // If there are no attributes, then there is nothing to be done.
2758  if (attrs.empty())
2759  return;
2760 
2761  // Functor used to print a filtered attribute list.
2762  auto printFilteredAttributesFn = [&](auto filteredAttrs) {
2763  // Print the 'attributes' keyword if necessary.
2764  if (withKeyword)
2765  os << " attributes";
2766 
2767  // Otherwise, print them all out in braces.
2768  os << " {";
2769  interleaveComma(filteredAttrs,
2770  [&](NamedAttribute attr) { printNamedAttribute(attr); });
2771  os << '}';
2772  };
2773 
2774  // If no attributes are elided, we can directly print with no filtering.
2775  if (elidedAttrs.empty())
2776  return printFilteredAttributesFn(attrs);
2777 
2778  // Otherwise, filter out any attributes that shouldn't be included.
2779  llvm::SmallDenseSet<StringRef> elidedAttrsSet(elidedAttrs.begin(),
2780  elidedAttrs.end());
2781  auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
2782  return !elidedAttrsSet.contains(attr.getName().strref());
2783  });
2784  if (!filteredAttrs.empty())
2785  printFilteredAttributesFn(filteredAttrs);
2786 }
2787 void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) {
2788  // Print the name without quotes if possible.
2789  ::printKeywordOrString(attr.getName().strref(), os);
2790 
2791  // Pretty printing elides the attribute value for unit attributes.
2792  if (llvm::isa<UnitAttr>(attr.getValue()))
2793  return;
2794 
2795  os << " = ";
2796  printAttribute(attr.getValue());
2797 }
2798 
2799 void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
2800  auto &dialect = attr.getDialect();
2801 
2802  // Ask the dialect to serialize the attribute to a string.
2803  std::string attrName;
2804  {
2805  llvm::raw_string_ostream attrNameStr(attrName);
2806  Impl subPrinter(attrNameStr, state);
2807  DialectAsmPrinter printer(subPrinter);
2808  dialect.printAttribute(attr, printer);
2809  }
2810  printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
2811 }
2812 
2813 void AsmPrinter::Impl::printDialectType(Type type) {
2814  auto &dialect = type.getDialect();
2815 
2816  // Ask the dialect to serialize the type to a string.
2817  std::string typeName;
2818  {
2819  llvm::raw_string_ostream typeNameStr(typeName);
2820  Impl subPrinter(typeNameStr, state);
2821  DialectAsmPrinter printer(subPrinter);
2822  dialect.printType(type, printer);
2823  }
2824  printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
2825 }
2826 
2827 void AsmPrinter::Impl::printEscapedString(StringRef str) {
2828  os << "\"";
2829  llvm::printEscapedString(str, os);
2830  os << "\"";
2831 }
2832 
2834  os << "\"0x" << llvm::toHex(str) << "\"";
2835 }
2837  printHexString(StringRef(data.data(), data.size()));
2838 }
2839 
2840 LogicalResult AsmPrinter::Impl::pushCyclicPrinting(const void *opaquePointer) {
2841  return state.pushCyclicPrinting(opaquePointer);
2842 }
2843 
2844 void AsmPrinter::Impl::popCyclicPrinting() { state.popCyclicPrinting(); }
2845 
2847  detail::printDimensionList(os, shape);
2848 }
2849 
2850 //===--------------------------------------------------------------------===//
2851 // AsmPrinter
2852 //===--------------------------------------------------------------------===//
2853 
2854 AsmPrinter::~AsmPrinter() = default;
2855 
2856 raw_ostream &AsmPrinter::getStream() const {
2857  assert(impl && "expected AsmPrinter::getStream to be overriden");
2858  return impl->getStream();
2859 }
2860 
2861 /// Print the given floating point value in a stablized form.
2862 void AsmPrinter::printFloat(const APFloat &value) {
2863  assert(impl && "expected AsmPrinter::printFloat to be overriden");
2864  printFloatValue(value, impl->getStream());
2865 }
2866 
2868  assert(impl && "expected AsmPrinter::printType to be overriden");
2869  impl->printType(type);
2870 }
2871 
2873  assert(impl && "expected AsmPrinter::printAttribute to be overriden");
2874  impl->printAttribute(attr);
2875 }
2876 
2877 LogicalResult AsmPrinter::printAlias(Attribute attr) {
2878  assert(impl && "expected AsmPrinter::printAlias to be overriden");
2879  return impl->printAlias(attr);
2880 }
2881 
2882 LogicalResult AsmPrinter::printAlias(Type type) {
2883  assert(impl && "expected AsmPrinter::printAlias to be overriden");
2884  return impl->printAlias(type);
2885 }
2886 
2888  assert(impl &&
2889  "expected AsmPrinter::printAttributeWithoutType to be overriden");
2890  impl->printAttribute(attr, Impl::AttrTypeElision::Must);
2891 }
2892 
2893 void AsmPrinter::printKeywordOrString(StringRef keyword) {
2894  assert(impl && "expected AsmPrinter::printKeywordOrString to be overriden");
2895  ::printKeywordOrString(keyword, impl->getStream());
2896 }
2897 
2898 void AsmPrinter::printString(StringRef keyword) {
2899  assert(impl && "expected AsmPrinter::printString to be overriden");
2900  *this << '"';
2901  printEscapedString(keyword, getStream());
2902  *this << '"';
2903 }
2904 
2905 void AsmPrinter::printSymbolName(StringRef symbolRef) {
2906  assert(impl && "expected AsmPrinter::printSymbolName to be overriden");
2907  ::printSymbolReference(symbolRef, impl->getStream());
2908 }
2909 
2911  assert(impl && "expected AsmPrinter::printResourceHandle to be overriden");
2912  impl->printResourceHandle(resource);
2913 }
2914 
2917 }
2918 
2919 LogicalResult AsmPrinter::pushCyclicPrinting(const void *opaquePointer) {
2920  return impl->pushCyclicPrinting(opaquePointer);
2921 }
2922 
2923 void AsmPrinter::popCyclicPrinting() { impl->popCyclicPrinting(); }
2924 
2925 //===----------------------------------------------------------------------===//
2926 // Affine expressions and maps
2927 //===----------------------------------------------------------------------===//
2928 
2930  AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
2931  printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
2932 }
2933 
2935  AffineExpr expr, BindingStrength enclosingTightness,
2936  function_ref<void(unsigned, bool)> printValueName) {
2937  const char *binopSpelling = nullptr;
2938  switch (expr.getKind()) {
2939  case AffineExprKind::SymbolId: {
2940  unsigned pos = cast<AffineSymbolExpr>(expr).getPosition();
2941  if (printValueName)
2942  printValueName(pos, /*isSymbol=*/true);
2943  else
2944  os << 's' << pos;
2945  return;
2946  }
2947  case AffineExprKind::DimId: {
2948  unsigned pos = cast<AffineDimExpr>(expr).getPosition();
2949  if (printValueName)
2950  printValueName(pos, /*isSymbol=*/false);
2951  else
2952  os << 'd' << pos;
2953  return;
2954  }
2956  os << cast<AffineConstantExpr>(expr).getValue();
2957  return;
2958  case AffineExprKind::Add:
2959  binopSpelling = " + ";
2960  break;
2961  case AffineExprKind::Mul:
2962  binopSpelling = " * ";
2963  break;
2965  binopSpelling = " floordiv ";
2966  break;
2968  binopSpelling = " ceildiv ";
2969  break;
2970  case AffineExprKind::Mod:
2971  binopSpelling = " mod ";
2972  break;
2973  }
2974 
2975  auto binOp = cast<AffineBinaryOpExpr>(expr);
2976  AffineExpr lhsExpr = binOp.getLHS();
2977  AffineExpr rhsExpr = binOp.getRHS();
2978 
2979  // Handle tightly binding binary operators.
2980  if (binOp.getKind() != AffineExprKind::Add) {
2981  if (enclosingTightness == BindingStrength::Strong)
2982  os << '(';
2983 
2984  // Pretty print multiplication with -1.
2985  auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr);
2986  if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2987  rhsConst.getValue() == -1) {
2988  os << "-";
2989  printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2990  if (enclosingTightness == BindingStrength::Strong)
2991  os << ')';
2992  return;
2993  }
2994 
2995  printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2996 
2997  os << binopSpelling;
2998  printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2999 
3000  if (enclosingTightness == BindingStrength::Strong)
3001  os << ')';
3002  return;
3003  }
3004 
3005  // Print out special "pretty" forms for add.
3006  if (enclosingTightness == BindingStrength::Strong)
3007  os << '(';
3008 
3009  // Pretty print addition to a product that has a negative operand as a
3010  // subtraction.
3011  if (auto rhs = dyn_cast<AffineBinaryOpExpr>(rhsExpr)) {
3012  if (rhs.getKind() == AffineExprKind::Mul) {
3013  AffineExpr rrhsExpr = rhs.getRHS();
3014  if (auto rrhs = dyn_cast<AffineConstantExpr>(rrhsExpr)) {
3015  if (rrhs.getValue() == -1) {
3016  printAffineExprInternal(lhsExpr, BindingStrength::Weak,
3017  printValueName);
3018  os << " - ";
3019  if (rhs.getLHS().getKind() == AffineExprKind::Add) {
3020  printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
3021  printValueName);
3022  } else {
3023  printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
3024  printValueName);
3025  }
3026 
3027  if (enclosingTightness == BindingStrength::Strong)
3028  os << ')';
3029  return;
3030  }
3031 
3032  if (rrhs.getValue() < -1) {
3033  printAffineExprInternal(lhsExpr, BindingStrength::Weak,
3034  printValueName);
3035  os << " - ";
3036  printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
3037  printValueName);
3038  os << " * " << -rrhs.getValue();
3039  if (enclosingTightness == BindingStrength::Strong)
3040  os << ')';
3041  return;
3042  }
3043  }
3044  }
3045  }
3046 
3047  // Pretty print addition to a negative number as a subtraction.
3048  if (auto rhsConst = dyn_cast<AffineConstantExpr>(rhsExpr)) {
3049  if (rhsConst.getValue() < 0) {
3050  printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
3051  os << " - " << -rhsConst.getValue();
3052  if (enclosingTightness == BindingStrength::Strong)
3053  os << ')';
3054  return;
3055  }
3056  }
3057 
3058  printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
3059 
3060  os << " + ";
3061  printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
3062 
3063  if (enclosingTightness == BindingStrength::Strong)
3064  os << ')';
3065 }
3066 
3068  printAffineExprInternal(expr, BindingStrength::Weak);
3069  isEq ? os << " == 0" : os << " >= 0";
3070 }
3071 
3073  // Dimension identifiers.
3074  os << '(';
3075  for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
3076  os << 'd' << i << ", ";
3077  if (map.getNumDims() >= 1)
3078  os << 'd' << map.getNumDims() - 1;
3079  os << ')';
3080 
3081  // Symbolic identifiers.
3082  if (map.getNumSymbols() != 0) {
3083  os << '[';
3084  for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
3085  os << 's' << i << ", ";
3086  if (map.getNumSymbols() >= 1)
3087  os << 's' << map.getNumSymbols() - 1;
3088  os << ']';
3089  }
3090 
3091  // Result affine expressions.
3092  os << " -> (";
3093  interleaveComma(map.getResults(),
3094  [&](AffineExpr expr) { printAffineExpr(expr); });
3095  os << ')';
3096 }
3097 
3099  // Dimension identifiers.
3100  os << '(';
3101  for (unsigned i = 1; i < set.getNumDims(); ++i)
3102  os << 'd' << i - 1 << ", ";
3103  if (set.getNumDims() >= 1)
3104  os << 'd' << set.getNumDims() - 1;
3105  os << ')';
3106 
3107  // Symbolic identifiers.
3108  if (set.getNumSymbols() != 0) {
3109  os << '[';
3110  for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
3111  os << 's' << i << ", ";
3112  if (set.getNumSymbols() >= 1)
3113  os << 's' << set.getNumSymbols() - 1;
3114  os << ']';
3115  }
3116 
3117  // Print constraints.
3118  os << " : (";
3119  int numConstraints = set.getNumConstraints();
3120  for (int i = 1; i < numConstraints; ++i) {
3121  printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
3122  os << ", ";
3123  }
3124  if (numConstraints >= 1)
3125  printAffineConstraint(set.getConstraint(numConstraints - 1),
3126  set.isEq(numConstraints - 1));
3127  os << ')';
3128 }
3129 
3130 //===----------------------------------------------------------------------===//
3131 // OperationPrinter
3132 //===----------------------------------------------------------------------===//
3133 
3134 namespace {
3135 /// This class contains the logic for printing operations, regions, and blocks.
3136 class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
3137 public:
3138  using Impl = AsmPrinter::Impl;
3139  using Impl::printType;
3140 
3141  explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
3142  : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
3143 
3144  /// Print the given top-level operation.
3145  void printTopLevelOperation(Operation *op);
3146 
3147  /// Print the given operation, including its left-hand side and its right-hand
3148  /// side, with its indent and location.
3149  void printFullOpWithIndentAndLoc(Operation *op);
3150  /// Print the given operation, including its left-hand side and its right-hand
3151  /// side, but not including indentation and location.
3152  void printFullOp(Operation *op);
3153  /// Print the right-hand size of the given operation in the custom or generic
3154  /// form.
3155  void printCustomOrGenericOp(Operation *op) override;
3156  /// Print the right-hand side of the given operation in the generic form.
3157  void printGenericOp(Operation *op, bool printOpName) override;
3158 
3159  /// Print the name of the given block.
3160  void printBlockName(Block *block);
3161 
3162  /// Print the given block. If 'printBlockArgs' is false, the arguments of the
3163  /// block are not printed. If 'printBlockTerminator' is false, the terminator
3164  /// operation of the block is not printed.
3165  void print(Block *block, bool printBlockArgs = true,
3166  bool printBlockTerminator = true);
3167 
3168  /// Print the ID of the given value, optionally with its result number.
3169  void printValueID(Value value, bool printResultNo = true,
3170  raw_ostream *streamOverride = nullptr) const;
3171 
3172  /// Print the ID of the given operation.
3173  void printOperationID(Operation *op,
3174  raw_ostream *streamOverride = nullptr) const;
3175 
3176  //===--------------------------------------------------------------------===//
3177  // OpAsmPrinter methods
3178  //===--------------------------------------------------------------------===//
3179 
3180  /// Print a loc(...) specifier if printing debug info is enabled. Locations
3181  /// may be deferred with an alias.
3182  void printOptionalLocationSpecifier(Location loc) override {
3183  printTrailingLocation(loc);
3184  }
3185 
3186  /// Print a newline and indent the printer to the start of the current
3187  /// operation.
3188  void printNewline() override {
3189  os << newLine;
3190  os.indent(currentIndent);
3191  }
3192 
3193  /// Increase indentation.
3194  void increaseIndent() override { currentIndent += indentWidth; }
3195 
3196  /// Decrease indentation.
3197  void decreaseIndent() override { currentIndent -= indentWidth; }
3198 
3199  /// Print a block argument in the usual format of:
3200  /// %ssaName : type {attr1=42} loc("here")
3201  /// where location printing is controlled by the standard internal option.
3202  /// You may pass omitType=true to not print a type, and pass an empty
3203  /// attribute list if you don't care for attributes.
3204  void printRegionArgument(BlockArgument arg,
3205  ArrayRef<NamedAttribute> argAttrs = {},
3206  bool omitType = false) override;
3207 
3208  /// Print the ID for the given value.
3209  void printOperand(Value value) override { printValueID(value); }
3210  void printOperand(Value value, raw_ostream &os) override {
3211  printValueID(value, /*printResultNo=*/true, &os);
3212  }
3213 
3214  /// Print an optional attribute dictionary with a given set of elided values.
3215  void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
3216  ArrayRef<StringRef> elidedAttrs = {}) override {
3217  Impl::printOptionalAttrDict(attrs, elidedAttrs);
3218  }
3219  void printOptionalAttrDictWithKeyword(
3221  ArrayRef<StringRef> elidedAttrs = {}) override {
3222  Impl::printOptionalAttrDict(attrs, elidedAttrs,
3223  /*withKeyword=*/true);
3224  }
3225 
3226  /// Print the given successor.
3227  void printSuccessor(Block *successor) override;
3228 
3229  /// Print an operation successor with the operands used for the block
3230  /// arguments.
3231  void printSuccessorAndUseList(Block *successor,
3232  ValueRange succOperands) override;
3233 
3234  /// Print the given region.
3235  void printRegion(Region &region, bool printEntryBlockArgs,
3236  bool printBlockTerminators, bool printEmptyBlock) override;
3237 
3238  /// Renumber the arguments for the specified region to the same names as the
3239  /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
3240  /// operations. If any entry in namesToUse is null, the corresponding
3241  /// argument name is left alone.
3242  void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
3243  state.getSSANameState().shadowRegionArgs(region, namesToUse);
3244  }
3245 
3246  /// Print the given affine map with the symbol and dimension operands printed
3247  /// inline with the map.
3248  void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3249  ValueRange operands) override;
3250 
3251  /// Print the given affine expression with the symbol and dimension operands
3252  /// printed inline with the expression.
3253  void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
3254  ValueRange symOperands) override;
3255 
3256  /// Print users of this operation or id of this operation if it has no result.
3257  void printUsersComment(Operation *op);
3258 
3259  /// Print users of this block arg.
3260  void printUsersComment(BlockArgument arg);
3261 
3262  /// Print the users of a value.
3263  void printValueUsers(Value value);
3264 
3265  /// Print either the ids of the result values or the id of the operation if
3266  /// the operation has no results.
3267  void printUserIDs(Operation *user, bool prefixComma = false);
3268 
3269 private:
3270  /// This class represents a resource builder implementation for the MLIR
3271  /// textual assembly format.
3272  class ResourceBuilder : public AsmResourceBuilder {
3273  public:
3274  using ValueFn = function_ref<void(raw_ostream &)>;
3275  using PrintFn = function_ref<void(StringRef, ValueFn)>;
3276 
3277  ResourceBuilder(PrintFn printFn) : printFn(printFn) {}
3278  ~ResourceBuilder() override = default;
3279 
3280  void buildBool(StringRef key, bool data) final {
3281  printFn(key, [&](raw_ostream &os) { os << (data ? "true" : "false"); });
3282  }
3283 
3284  void buildString(StringRef key, StringRef data) final {
3285  printFn(key, [&](raw_ostream &os) {
3286  os << "\"";
3287  llvm::printEscapedString(data, os);
3288  os << "\"";
3289  });
3290  }
3291 
3292  void buildBlob(StringRef key, ArrayRef<char> data,
3293  uint32_t dataAlignment) final {
3294  printFn(key, [&](raw_ostream &os) {
3295  // Store the blob in a hex string containing the alignment and the data.
3296  llvm::support::ulittle32_t dataAlignmentLE(dataAlignment);
3297  os << "\"0x"
3298  << llvm::toHex(StringRef(reinterpret_cast<char *>(&dataAlignmentLE),
3299  sizeof(dataAlignment)))
3300  << llvm::toHex(StringRef(data.data(), data.size())) << "\"";
3301  });
3302  }
3303 
3304  private:
3305  PrintFn printFn;
3306  };
3307 
3308  /// Print the metadata dictionary for the file, eliding it if it is empty.
3309  void printFileMetadataDictionary(Operation *op);
3310 
3311  /// Print the resource sections for the file metadata dictionary.
3312  /// `checkAddMetadataDict` is used to indicate that metadata is going to be
3313  /// added, and the file metadata dictionary should be started if it hasn't
3314  /// yet.
3315  void printResourceFileMetadata(function_ref<void()> checkAddMetadataDict,
3316  Operation *op);
3317 
3318  // Contains the stack of default dialects to use when printing regions.
3319  // A new dialect is pushed to the stack before parsing regions nested under an
3320  // operation implementing `OpAsmOpInterface`, and popped when done. At the
3321  // top-level we start with "builtin" as the default, so that the top-level
3322  // `module` operation prints as-is.
3323  SmallVector<StringRef> defaultDialectStack{"builtin"};
3324 
3325  /// The number of spaces used for indenting nested operations.
3326  const static unsigned indentWidth = 2;
3327 
3328  // This is the current indentation level for nested structures.
3329  unsigned currentIndent = 0;
3330 };
3331 } // namespace
3332 
3333 void OperationPrinter::printTopLevelOperation(Operation *op) {
3334  // Output the aliases at the top level that can't be deferred.
3335  state.getAliasState().printNonDeferredAliases(*this, newLine);
3336 
3337  // Print the module.
3338  printFullOpWithIndentAndLoc(op);
3339  os << newLine;
3340 
3341  // Output the aliases at the top level that can be deferred.
3342  state.getAliasState().printDeferredAliases(*this, newLine);
3343 
3344  // Output any file level metadata.
3345  printFileMetadataDictionary(op);
3346 }
3347 
3348 void OperationPrinter::printFileMetadataDictionary(Operation *op) {
3349  bool sawMetadataEntry = false;
3350  auto checkAddMetadataDict = [&] {
3351  if (!std::exchange(sawMetadataEntry, true))
3352  os << newLine << "{-#" << newLine;
3353  };
3354 
3355  // Add the various types of metadata.
3356  printResourceFileMetadata(checkAddMetadataDict, op);
3357 
3358  // If the file dictionary exists, close it.
3359  if (sawMetadataEntry)
3360  os << newLine << "#-}" << newLine;
3361 }
3362 
3363 void OperationPrinter::printResourceFileMetadata(
3364  function_ref<void()> checkAddMetadataDict, Operation *op) {
3365  // Functor used to add data entries to the file metadata dictionary.
3366  bool hadResource = false;
3367  bool needResourceComma = false;
3368  bool needEntryComma = false;
3369  auto processProvider = [&](StringRef dictName, StringRef name, auto &provider,
3370  auto &&...providerArgs) {
3371  bool hadEntry = false;
3372  auto printFn = [&](StringRef key, ResourceBuilder::ValueFn valueFn) {
3373  checkAddMetadataDict();
3374 
3375  auto printFormatting = [&]() {
3376  // Emit the top-level resource entry if we haven't yet.
3377  if (!std::exchange(hadResource, true)) {
3378  if (needResourceComma)
3379  os << "," << newLine;
3380  os << " " << dictName << "_resources: {" << newLine;
3381  }
3382  // Emit the parent resource entry if we haven't yet.
3383  if (!std::exchange(hadEntry, true)) {
3384  if (needEntryComma)
3385  os << "," << newLine;
3386  os << " " << name << ": {" << newLine;
3387  } else {
3388  os << "," << newLine;
3389  }
3390  };
3391 
3392  std::optional<uint64_t> charLimit =
3393  printerFlags.getLargeResourceStringLimit();
3394  if (charLimit.has_value()) {
3395  std::string resourceStr;
3396  llvm::raw_string_ostream ss(resourceStr);
3397  valueFn(ss);
3398 
3399  // Only print entry if it's string is small enough
3400  if (resourceStr.size() > charLimit.value())
3401  return;
3402 
3403  printFormatting();
3404  os << " " << key << ": " << resourceStr;
3405  } else {
3406  printFormatting();
3407  os << " " << key << ": ";
3408  valueFn(os);
3409  }
3410  };
3411  ResourceBuilder entryBuilder(printFn);
3412  provider.buildResources(op, providerArgs..., entryBuilder);
3413 
3414  needEntryComma |= hadEntry;
3415  if (hadEntry)
3416  os << newLine << " }";
3417  };
3418 
3419  // Print the `dialect_resources` section if we have any dialects with
3420  // resources.
3421  for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
3422  auto &dialectResources = state.getDialectResources();
3423  StringRef name = interface.getDialect()->getNamespace();
3424  auto it = dialectResources.find(interface.getDialect());
3425  if (it != dialectResources.end())
3426  processProvider("dialect", name, interface, it->second);
3427  else
3428  processProvider("dialect", name, interface,
3430  }
3431  if (hadResource)
3432  os << newLine << " }";
3433 
3434  // Print the `external_resources` section if we have any external clients with
3435  // resources.
3436  needEntryComma = false;
3437  needResourceComma = hadResource;
3438  hadResource = false;
3439  for (const auto &printer : state.getResourcePrinters())
3440  processProvider("external", printer.getName(), printer);
3441  if (hadResource)
3442  os << newLine << " }";
3443 }
3444 
3445 /// Print a block argument in the usual format of:
3446 /// %ssaName : type {attr1=42} loc("here")
3447 /// where location printing is controlled by the standard internal option.
3448 /// You may pass omitType=true to not print a type, and pass an empty
3449 /// attribute list if you don't care for attributes.
3450 void OperationPrinter::printRegionArgument(BlockArgument arg,
3451  ArrayRef<NamedAttribute> argAttrs,
3452  bool omitType) {
3453  printOperand(arg);
3454  if (!omitType) {
3455  os << ": ";
3456  printType(arg.getType());
3457  }
3458  printOptionalAttrDict(argAttrs);
3459  // TODO: We should allow location aliases on block arguments.
3460  printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3461 }
3462 
3463 void OperationPrinter::printFullOpWithIndentAndLoc(Operation *op) {
3464  // Track the location of this operation.
3465  state.registerOperationLocation(op, newLine.curLine, currentIndent);
3466 
3467  os.indent(currentIndent);
3468  printFullOp(op);
3469  printTrailingLocation(op->getLoc());
3470  if (printerFlags.shouldPrintValueUsers())
3471  printUsersComment(op);
3472 }
3473 
3474 void OperationPrinter::printFullOp(Operation *op) {
3475  if (size_t numResults = op->getNumResults()) {
3476  auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
3477  printValueID(op->getResult(resultNo), /*printResultNo=*/false);
3478  if (resultCount > 1)
3479  os << ':' << resultCount;
3480  };
3481 
3482  // Check to see if this operation has multiple result groups.
3483  ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
3484  if (!resultGroups.empty()) {
3485  // Interleave the groups excluding the last one, this one will be handled
3486  // separately.
3487  interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
3488  printResultGroup(resultGroups[i],
3489  resultGroups[i + 1] - resultGroups[i]);
3490  });
3491  os << ", ";
3492  printResultGroup(resultGroups.back(), numResults - resultGroups.back());
3493 
3494  } else {
3495  printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
3496  }
3497 
3498  os << " = ";
3499  }
3500 
3501  printCustomOrGenericOp(op);
3502 }
3503 
3504 void OperationPrinter::printUsersComment(Operation *op) {
3505  unsigned numResults = op->getNumResults();
3506  if (!numResults && op->getNumOperands()) {
3507  os << " // id: ";
3508  printOperationID(op);
3509  } else if (numResults && op->use_empty()) {
3510  os << " // unused";
3511  } else if (numResults && !op->use_empty()) {
3512  // Print "user" if the operation has one result used to compute one other
3513  // result, or is used in one operation with no result.
3514  unsigned usedInNResults = 0;
3515  unsigned usedInNOperations = 0;
3517  for (Operation *user : op->getUsers()) {
3518  if (userSet.insert(user).second) {
3519  ++usedInNOperations;
3520  usedInNResults += user->getNumResults();
3521  }
3522  }
3523 
3524  // We already know that users is not empty.
3525  bool exactlyOneUniqueUse =
3526  usedInNResults <= 1 && usedInNOperations <= 1 && numResults == 1;
3527  os << " // " << (exactlyOneUniqueUse ? "user" : "users") << ": ";
3528  bool shouldPrintBrackets = numResults > 1;
3529  auto printOpResult = [&](OpResult opResult) {
3530  if (shouldPrintBrackets)
3531  os << "(";
3532  printValueUsers(opResult);
3533  if (shouldPrintBrackets)
3534  os << ")";
3535  };
3536 
3537  interleaveComma(op->getResults(), printOpResult);
3538  }
3539 }
3540 
3541 void OperationPrinter::printUsersComment(BlockArgument arg) {
3542  os << "// ";
3543  printValueID(arg);
3544  if (arg.use_empty()) {
3545  os << " is unused";
3546  } else {
3547  os << " is used by ";
3548  printValueUsers(arg);
3549  }
3550  os << newLine;
3551 }
3552 
3553 void OperationPrinter::printValueUsers(Value value) {
3554  if (value.use_empty())
3555  os << "unused";
3556 
3557  // One value might be used as the operand of an operation more than once.
3558  // Only print the operations results once in that case.
3560  for (auto [index, user] : enumerate(value.getUsers())) {
3561  if (userSet.insert(user).second)
3562  printUserIDs(user, index);
3563  }
3564 }
3565 
3566 void OperationPrinter::printUserIDs(Operation *user, bool prefixComma) {
3567  if (prefixComma)
3568  os << ", ";
3569 
3570  if (!user->getNumResults()) {
3571  printOperationID(user);
3572  } else {
3573  interleaveComma(user->getResults(),
3574  [this](Value result) { printValueID(result); });
3575  }
3576 }
3577 
3578 void OperationPrinter::printCustomOrGenericOp(Operation *op) {
3579  // If requested, always print the generic form.
3580  if (!printerFlags.shouldPrintGenericOpForm()) {
3581  // Check to see if this is a known operation. If so, use the registered
3582  // custom printer hook.
3583  if (auto opInfo = op->getRegisteredInfo()) {
3584  opInfo->printAssembly(op, *this, defaultDialectStack.back());
3585  return;
3586  }
3587  // Otherwise try to dispatch to the dialect, if available.
3588  if (Dialect *dialect = op->getDialect()) {
3589  if (auto opPrinter = dialect->getOperationPrinter(op)) {
3590  // Print the op name first.
3591  StringRef name = op->getName().getStringRef();
3592  // Only drop the default dialect prefix when it cannot lead to
3593  // ambiguities.
3594  if (name.count('.') == 1)
3595  name.consume_front((defaultDialectStack.back() + ".").str());
3596  os << name;
3597 
3598  // Print the rest of the op now.
3599  opPrinter(op, *this);
3600  return;
3601  }
3602  }
3603  }
3604 
3605  // Otherwise print with the generic assembly form.
3606  printGenericOp(op, /*printOpName=*/true);
3607 }
3608 
3609 void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
3610  if (printOpName)
3611  printEscapedString(op->getName().getStringRef());
3612  os << '(';
3613  interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
3614  os << ')';
3615 
3616  // For terminators, print the list of successors and their operands.
3617  if (op->getNumSuccessors() != 0) {
3618  os << '[';
3619  interleaveComma(op->getSuccessors(),
3620  [&](Block *successor) { printBlockName(successor); });
3621  os << ']';
3622  }
3623 
3624  // Print the properties.
3625  if (Attribute prop = op->getPropertiesAsAttribute()) {
3626  os << " <";
3627  Impl::printAttribute(prop);
3628  os << '>';
3629  }
3630 
3631  // Print regions.
3632  if (op->getNumRegions() != 0) {
3633  os << " (";
3634  interleaveComma(op->getRegions(), [&](Region &region) {
3635  printRegion(region, /*printEntryBlockArgs=*/true,
3636  /*printBlockTerminators=*/true, /*printEmptyBlock=*/true);
3637  });
3638  os << ')';
3639  }
3640 
3641  printOptionalAttrDict(op->getPropertiesStorage()
3642  ? llvm::to_vector(op->getDiscardableAttrs())
3643  : op->getAttrs());
3644 
3645  // Print the type signature of the operation.
3646  os << " : ";
3647  printFunctionalType(op);
3648 }
3649 
3650 void OperationPrinter::printBlockName(Block *block) {
3651  os << state.getSSANameState().getBlockInfo(block).name;
3652 }
3653 
3654 void OperationPrinter::print(Block *block, bool printBlockArgs,
3655  bool printBlockTerminator) {
3656  // Print the block label and argument list if requested.
3657  if (printBlockArgs) {
3658  os.indent(currentIndent);
3659  printBlockName(block);
3660 
3661  // Print the argument list if non-empty.
3662  if (!block->args_empty()) {
3663  os << '(';
3664  interleaveComma(block->getArguments(), [&](BlockArgument arg) {
3665  printValueID(arg);
3666  os << ": ";
3667  printType(arg.getType());
3668  // TODO: We should allow location aliases on block arguments.
3669  printTrailingLocation(arg.getLoc(), /*allowAlias*/ false);
3670  });
3671  os << ')';
3672  }
3673  os << ':';
3674 
3675  // Print out some context information about the predecessors of this block.
3676  if (!block->getParent()) {
3677  os << " // block is not in a region!";
3678  } else if (block->hasNoPredecessors()) {
3679  if (!block->isEntryBlock())
3680  os << " // no predecessors";
3681  } else if (auto *pred = block->getSinglePredecessor()) {
3682  os << " // pred: ";
3683  printBlockName(pred);
3684  } else {
3685  // We want to print the predecessors in a stable order, not in
3686  // whatever order the use-list is in, so gather and sort them.
3687  SmallVector<BlockInfo, 4> predIDs;
3688  for (auto *pred : block->getPredecessors())
3689  predIDs.push_back(state.getSSANameState().getBlockInfo(pred));
3690  llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
3691  return lhs.ordering < rhs.ordering;
3692  });
3693 
3694  os << " // " << predIDs.size() << " preds: ";
3695 
3696  interleaveComma(predIDs, [&](BlockInfo pred) { os << pred.name; });
3697  }
3698  os << newLine;
3699  }
3700 
3701  currentIndent += indentWidth;
3702 
3703  if (printerFlags.shouldPrintValueUsers()) {
3704  for (BlockArgument arg : block->getArguments()) {
3705  os.indent(currentIndent);
3706  printUsersComment(arg);
3707  }
3708  }
3709 
3710  bool hasTerminator =
3711  !block->empty() && block->back().hasTrait<OpTrait::IsTerminator>();
3712  auto range = llvm::make_range(
3713  block->begin(),
3714  std::prev(block->end(),
3715  (!hasTerminator || printBlockTerminator) ? 0 : 1));
3716  for (auto &op : range) {
3717  printFullOpWithIndentAndLoc(&op);
3718  os << newLine;
3719  }
3720  currentIndent -= indentWidth;
3721 }
3722 
3723 void OperationPrinter::printValueID(Value value, bool printResultNo,
3724  raw_ostream *streamOverride) const {
3725  state.getSSANameState().printValueID(value, printResultNo,
3726  streamOverride ? *streamOverride : os);
3727 }
3728 
3729 void OperationPrinter::printOperationID(Operation *op,
3730  raw_ostream *streamOverride) const {
3731  state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride
3732  : os);
3733 }
3734 
3735 void OperationPrinter::printSuccessor(Block *successor) {
3736  printBlockName(successor);
3737 }
3738 
3739 void OperationPrinter::printSuccessorAndUseList(Block *successor,
3740  ValueRange succOperands) {
3741  printBlockName(successor);
3742  if (succOperands.empty())
3743  return;
3744 
3745  os << '(';
3746  interleaveComma(succOperands,
3747  [this](Value operand) { printValueID(operand); });
3748  os << " : ";
3749  interleaveComma(succOperands,
3750  [this](Value operand) { printType(operand.getType()); });
3751  os << ')';
3752 }
3753 
3754 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
3755  bool printBlockTerminators,
3756  bool printEmptyBlock) {
3757  if (printerFlags.shouldSkipRegions()) {
3758  os << "{...}";
3759  return;
3760  }
3761  os << "{" << newLine;
3762  if (!region.empty()) {
3763  auto restoreDefaultDialect =
3764  llvm::make_scope_exit([&]() { defaultDialectStack.pop_back(); });
3765  if (auto iface = dyn_cast<OpAsmOpInterface>(region.getParentOp()))
3766  defaultDialectStack.push_back(iface.getDefaultDialect());
3767  else
3768  defaultDialectStack.push_back("");
3769 
3770  auto *entryBlock = &region.front();
3771  // Force printing the block header if printEmptyBlock is set and the block
3772  // is empty or if printEntryBlockArgs is set and there are arguments to
3773  // print.
3774  bool shouldAlwaysPrintBlockHeader =
3775  (printEmptyBlock && entryBlock->empty()) ||
3776  (printEntryBlockArgs && entryBlock->getNumArguments() != 0);
3777  print(entryBlock, shouldAlwaysPrintBlockHeader, printBlockTerminators);
3778  for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
3779  print(&b);
3780  }
3781  os.indent(currentIndent) << "}";
3782 }
3783 
3784 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
3785  ValueRange operands) {
3786  if (!mapAttr) {
3787  os << "<<NULL AFFINE MAP>>";
3788  return;
3789  }
3790  AffineMap map = mapAttr.getValue();
3791  unsigned numDims = map.getNumDims();
3792  auto printValueName = [&](unsigned pos, bool isSymbol) {
3793  unsigned index = isSymbol ? numDims + pos : pos;
3794  assert(index < operands.size());
3795  if (isSymbol)
3796  os << "symbol(";
3797  printValueID(operands[index]);
3798  if (isSymbol)
3799  os << ')';
3800  };
3801 
3802  interleaveComma(map.getResults(), [&](AffineExpr expr) {
3803  printAffineExpr(expr, printValueName);
3804  });
3805 }
3806 
3807 void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
3808  ValueRange dimOperands,
3809  ValueRange symOperands) {
3810  auto printValueName = [&](unsigned pos, bool isSymbol) {
3811  if (!isSymbol)
3812  return printValueID(dimOperands[pos]);
3813  os << "symbol(";
3814  printValueID(symOperands[pos]);
3815  os << ')';
3816  };
3817  printAffineExpr(expr, printValueName);
3818 }
3819 
3820 //===----------------------------------------------------------------------===//
3821 // print and dump methods
3822 //===----------------------------------------------------------------------===//
3823 
3824 void Attribute::print(raw_ostream &os, bool elideType) const {
3825  if (!*this) {
3826  os << "<<NULL ATTRIBUTE>>";
3827  return;
3828  }
3829 
3830  AsmState state(getContext());
3831  print(os, state, elideType);
3832 }
3833 void Attribute::print(raw_ostream &os, AsmState &state, bool elideType) const {
3834  using AttrTypeElision = AsmPrinter::Impl::AttrTypeElision;
3835  AsmPrinter::Impl(os, state.getImpl())
3836  .printAttribute(*this, elideType ? AttrTypeElision::Must
3837  : AttrTypeElision::Never);
3838 }
3839 
3840 void Attribute::dump() const {
3841  print(llvm::errs());
3842  llvm::errs() << "\n";
3843 }
3844 
3845 void Attribute::printStripped(raw_ostream &os, AsmState &state) const {
3846  if (!*this) {
3847  os << "<<NULL ATTRIBUTE>>";
3848  return;
3849  }
3850 
3851  AsmPrinter::Impl subPrinter(os, state.getImpl());
3852  if (succeeded(subPrinter.printAlias(*this)))
3853  return;
3854 
3855  auto &dialect = this->getDialect();
3856  uint64_t posPrior = os.tell();
3857  DialectAsmPrinter printer(subPrinter);
3858  dialect.printAttribute(*this, printer);
3859  if (posPrior != os.tell())
3860  return;
3861 
3862  // Fallback to printing with prefix if the above failed to write anything
3863  // to the output stream.
3864  print(os, state);
3865 }
3866 void Attribute::printStripped(raw_ostream &os) const {
3867  if (!*this) {
3868  os << "<<NULL ATTRIBUTE>>";
3869  return;
3870  }
3871 
3872  AsmState state(getContext());
3873  printStripped(os, state);
3874 }
3875 
3876 void Type::print(raw_ostream &os) const {
3877  if (!*this) {
3878  os << "<<NULL TYPE>>";
3879  return;
3880  }
3881 
3882  AsmState state(getContext());
3883  print(os, state);
3884 }
3885 void Type::print(raw_ostream &os, AsmState &state) const {
3886  AsmPrinter::Impl(os, state.getImpl()).printType(*this);
3887 }
3888 
3889 void Type::dump() const {
3890  print(llvm::errs());
3891  llvm::errs() << "\n";
3892 }
3893 
3894 void AffineMap::dump() const {
3895  print(llvm::errs());
3896  llvm::errs() << "\n";
3897 }
3898 
3899 void IntegerSet::dump() const {
3900  print(llvm::errs());
3901  llvm::errs() << "\n";
3902 }
3903 
3904 void AffineExpr::print(raw_ostream &os) const {
3905  if (!expr) {
3906  os << "<<NULL AFFINE EXPR>>";
3907  return;
3908  }
3909  AsmState state(getContext());
3910  AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this);
3911 }
3912 
3913 void AffineExpr::dump() const {
3914  print(llvm::errs());
3915  llvm::errs() << "\n";
3916 }
3917 
3918 void AffineMap::print(raw_ostream &os) const {
3919  if (!map) {
3920  os << "<<NULL AFFINE MAP>>";
3921  return;
3922  }
3923  AsmState state(getContext());
3924  AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this);
3925 }
3926 
3927 void IntegerSet::print(raw_ostream &os) const {
3928  AsmState state(getContext());
3929  AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this);
3930 }
3931 
3932 void Value::print(raw_ostream &os) const { print(os, OpPrintingFlags()); }
3933 void Value::print(raw_ostream &os, const OpPrintingFlags &flags) const {
3934  if (!impl) {
3935  os << "<<NULL VALUE>>";
3936  return;
3937  }
3938 
3939  if (auto *op = getDefiningOp())
3940  return op->print(os, flags);
3941  // TODO: Improve BlockArgument print'ing.
3942  BlockArgument arg = llvm::cast<BlockArgument>(*this);
3943  os << "<block argument> of type '" << arg.getType()
3944  << "' at index: " << arg.getArgNumber();
3945 }
3946 void Value::print(raw_ostream &os, AsmState &state) const {
3947  if (!impl) {
3948  os << "<<NULL VALUE>>";
3949  return;
3950  }
3951 
3952  if (auto *op = getDefiningOp())
3953  return op->print(os, state);
3954 
3955  // TODO: Improve BlockArgument print'ing.
3956  BlockArgument arg = llvm::cast<BlockArgument>(*this);
3957  os << "<block argument> of type '" << arg.getType()
3958  << "' at index: " << arg.getArgNumber();
3959 }
3960 
3961 void Value::dump() const {
3962  print(llvm::errs());
3963  llvm::errs() << "\n";
3964 }
3965 
3966 void Value::printAsOperand(raw_ostream &os, AsmState &state) const {
3967  // TODO: This doesn't necessarily capture all potential cases.
3968  // Currently, region arguments can be shadowed when printing the main
3969  // operation. If the IR hasn't been printed, this will produce the old SSA
3970  // name and not the shadowed name.
3971  state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
3972  os);
3973 }
3974 
3975 static Operation *findParent(Operation *op, bool shouldUseLocalScope) {
3976  do {
3977  // If we are printing local scope, stop at the first operation that is
3978  // isolated from above.
3979  if (shouldUseLocalScope && op->hasTrait<OpTrait::IsIsolatedFromAbove>())
3980  break;
3981 
3982  // Otherwise, traverse up to the next parent.
3983  Operation *parentOp = op->getParentOp();
3984  if (!parentOp)
3985  break;
3986  op = parentOp;
3987  } while (true);
3988  return op;
3989 }
3990 
3991 void Value::printAsOperand(raw_ostream &os,
3992  const OpPrintingFlags &flags) const {
3993  Operation *op;
3994  if (auto result = llvm::dyn_cast<OpResult>(*this)) {
3995  op = result.getOwner();
3996  } else {
3997  op = llvm::cast<BlockArgument>(*this).getOwner()->getParentOp();
3998  if (!op) {
3999  os << "<<UNKNOWN SSA VALUE>>";
4000  return;
4001  }
4002  }
4003  op = findParent(op, flags.shouldUseLocalScope());
4004  AsmState state(op, flags);
4005  printAsOperand(os, state);
4006 }
4007 
4008 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
4009  // Find the operation to number from based upon the provided flags.
4010  Operation *op = findParent(this, printerFlags.shouldUseLocalScope());
4011  AsmState state(op, printerFlags);
4012  print(os, state);
4013 }
4014 void Operation::print(raw_ostream &os, AsmState &state) {
4015  OperationPrinter printer(os, state.getImpl());
4016  if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
4017  state.getImpl().initializeAliases(this);
4018  printer.printTopLevelOperation(this);
4019  } else {
4020  printer.printFullOpWithIndentAndLoc(this);
4021  }
4022 }
4023 
4025  print(llvm::errs(), OpPrintingFlags().useLocalScope());
4026  llvm::errs() << "\n";
4027 }
4028 
4030  print(llvm::errs(), OpPrintingFlags().useLocalScope().assumeVerified());
4031  llvm::errs() << "\n";
4032 }
4033 
4034 void Block::print(raw_ostream &os) {
4035  Operation *parentOp = getParentOp();
4036  if (!parentOp) {
4037  os << "<<UNLINKED BLOCK>>\n";
4038  return;
4039  }
4040  // Get the top-level op.
4041  while (auto *nextOp = parentOp->getParentOp())
4042  parentOp = nextOp;
4043 
4044  AsmState state(parentOp);
4045  print(os, state);
4046 }
4047 void Block::print(raw_ostream &os, AsmState &state) {
4048  OperationPrinter(os, state.getImpl()).print(this);
4049 }
4050 
4051 void Block::dump() { print(llvm::errs()); }
4052 
4053 /// Print out the name of the block without printing its body.
4054 void Block::printAsOperand(raw_ostream &os, bool printType) {
4055  Operation *parentOp = getParentOp();
4056  if (!parentOp) {
4057  os << "<<UNLINKED BLOCK>>\n";
4058  return;
4059  }
4060  AsmState state(parentOp);
4061  printAsOperand(os, state);
4062 }
4063 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
4064  OperationPrinter printer(os, state.getImpl());
4065  printer.printBlockName(this);
4066 }
4067 
4068 raw_ostream &mlir::operator<<(raw_ostream &os, Block &block) {
4069  block.print(os);
4070  return os;
4071 }
4072 
4073 //===--------------------------------------------------------------------===//
4074 // Custom printers
4075 //===--------------------------------------------------------------------===//
4076 namespace mlir {
4077 
4079  ArrayRef<int64_t> dimensions) {
4080  if (dimensions.empty())
4081  printer << "[";
4082  printer.printDimensionList(dimensions);
4083  if (dimensions.empty())
4084  printer << "]";
4085 }
4086 
4087 ParseResult parseDimensionList(OpAsmParser &parser,
4088  DenseI64ArrayAttr &dimensions) {
4089  // Empty list case denoted by "[]".
4090  if (succeeded(parser.parseOptionalLSquare())) {
4091  if (failed(parser.parseRSquare())) {
4092  return parser.emitError(parser.getCurrentLocation())
4093  << "Failed parsing dimension list.";
4094  }
4095  dimensions =
4097  return success();
4098  }
4099 
4100  // Non-empty list case.
4101  SmallVector<int64_t> shapeArr;
4102  if (failed(parser.parseDimensionList(shapeArr, true, false))) {
4103  return parser.emitError(parser.getCurrentLocation())
4104  << "Failed parsing dimension list.";
4105  }
4106  if (shapeArr.empty()) {
4107  return parser.emitError(parser.getCurrentLocation())
4108  << "Failed parsing dimension list. Did you mean an empty list? It "
4109  "must be denoted by \"[]\".";
4110  }
4111  dimensions = DenseI64ArrayAttr::get(parser.getContext(), shapeArr);
4112  return success();
4113 }
4114 
4115 } // namespace mlir
static StringRef sanitizeIdentifier(StringRef name, SmallString< 16 > &buffer, StringRef allowedPunctChars="$._-", bool allowTrailingDigit=true)
Sanitize the given name such that it can be used as a valid identifier.
static void printSymbolReference(StringRef symbolRef, raw_ostream &os)
Print the given string as a symbol reference.
static void printFloatValue(const APFloat &apValue, raw_ostream &os, bool *printedHex=nullptr)
Print a floating point value in a way that the parser will be able to round-trip losslessly.
static llvm::ManagedStatic< AsmPrinterOptions > clOptions
Definition: AsmPrinter.cpp:206
static Operation * findParent(Operation *op, bool shouldUseLocalScope)
static void printKeywordOrString(StringRef keyword, raw_ostream &os)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)
Returns true if the given dialect symbol data is simple enough to print in the pretty form.
static void printDialectSymbol(raw_ostream &os, StringRef symPrefix, StringRef dialectName, StringRef symString)
Print the given dialect symbol to the stream.
#define DEBUG_TYPE
Definition: AsmPrinter.cpp:59
static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, OpPrintingFlags printerFlags)
Verifies the operation and switches to generic op printing if verification fails.
MLIR_CRUNNERUTILS_EXPORT void printString(char const *s)
MLIR_CRUNNERUTILS_EXPORT void printNewline()
static void visit(Operation *op, DenseSet< Operation * > &visited)
Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s) connected to the given operation.
Definition: PDL.cpp:63
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static void printRegion(llvm::raw_ostream &os, Region *region, OpPrintingFlags &flags)
Definition: Unit.cpp:28
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExprKind getKind() const
Return the classification for this type.
Definition: AffineExpr.cpp:35
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getNumSymbols() const
Definition: AffineMap.cpp:398
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
This class represents an opaque handle to a dialect resource entry.
Dialect * getDialect() const
Return the dialect that owns the resource.
This class represents a single parsed resource entry.
Definition: AsmState.h:290
virtual FailureOr< AsmResourceBlob > parseAsBlob(BlobAllocatorFn allocator) const =0
Parse the resource entry represented by a binary blob.
virtual InFlightDiagnostic emitError() const =0
Emit an error at the location of this entry.
virtual AsmResourceEntryKind getKind() const =0
Return the kind of this value.
virtual FailureOr< std::string > parseAsString() const =0
Parse the resource entry represented by a human-readable string.
virtual FailureOr< bool > parseAsBool() const =0
Parse the resource entry represented by a boolean.
virtual StringRef getKey() const =0
Return the key of the resource entry.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ~AsmParser()
ParseResult parseTypeList(SmallVectorImpl< Type > &result)
Parse a type list.
Definition: AsmPrinter.cpp:78
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
Impl(raw_ostream &os, AsmStateImpl &state)
BindingStrength
This enum is used to represent the binding strength of the enclosing context that an AffineExprStorag...
Definition: AsmPrinter.cpp:510
void printHexString(StringRef str)
Print a hex string, wrapped with "".
void printDenseArrayAttr(DenseArrayAttr attr)
Print a dense array attribute.
void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex)
Print a dense elements attribute.
void printAttribute(Attribute attr, AttrTypeElision typeElision=AttrTypeElision::Never)
Print the given attribute or an alias.
void printDimensionList(ArrayRef< int64_t > shape)
OpPrintingFlags printerFlags
A set of flags to control the printer's behavior.
Definition: AsmPrinter.cpp:525
raw_ostream & os
The output stream for the printer.
Definition: AsmPrinter.cpp:519
void printResourceHandle(const AsmDialectResourceHandle &resource)
Print a reference to the given resource that is owned by the given dialect.
raw_ostream & getStream()
Returns the output stream of the printer.
Definition: AsmPrinter.cpp:413
LogicalResult printAlias(Attribute attr)
Print the alias for the given attribute, return failure if no alias could be printed.
void printDialectAttribute(Attribute attr)
void interleaveComma(const Container &c, UnaryFunctor eachFn) const
Definition: AsmPrinter.cpp:416
void printDialectType(Type type)
void printLocation(LocationAttr loc, bool allowAlias=false)
Print the given location to the stream.
AsmStateImpl & state
An underlying assembly printer state.
Definition: AsmPrinter.cpp:522
void printAffineMap(AffineMap map)
void printTrailingLocation(Location loc, bool allowAlias=true)
void printAffineExprInternal(AffineExpr expr, BindingStrength enclosingTightness, function_ref< void(unsigned, bool)> printValueName=nullptr)
void printEscapedString(StringRef str)
Print an escaped string, wrapped with "".
void printAffineExpr(AffineExpr expr, function_ref< void(unsigned, bool)> printValueName=nullptr)
void printDenseStringElementsAttr(DenseStringElementsAttr attr)
Print a dense string elements attribute.
void printAttributeImpl(Attribute attr, AttrTypeElision typeElision=AttrTypeElision::Never)
Print the given attribute without considering an alias.
void printAffineConstraint(AffineExpr expr, bool isEq)
void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, bool allowHex)
Print a dense elements attribute.
AttrTypeElision
This enum describes the different kinds of elision for the type of an attribute when printing it.
Definition: AsmPrinter.cpp:422
@ May
The type may be elided when it matches the default used in the parser (for example i64 is the default...
@ Never
The type must not be elided,.
@ Must
The type must be elided.
LogicalResult pushCyclicPrinting(const void *opaquePointer)
void printIntegerSet(IntegerSet set)
NewLineCounter newLine
A tracker for the number of new lines emitted during printing.
Definition: AsmPrinter.cpp:528
void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={}, bool withKeyword=false)
void printType(Type type)
Print the given type or an alias.
void printLocationInternal(LocationAttr loc, bool pretty=false, bool isTopLevel=false)
void printTypeImpl(Type type)
Print the given type.
void printNamedAttribute(NamedAttribute attr)
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual LogicalResult printAlias(Attribute attr)
Print the alias for the given attribute, return failure if no alias could be printed.
virtual void popCyclicPrinting()
Removes the element that was last inserted with a successful call to pushCyclicPrinting.
virtual LogicalResult pushCyclicPrinting(const void *opaquePointer)
Pushes a new attribute or type in the form of a type erased pointer into an internal set.
virtual void printType(Type type)
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
virtual void printString(StringRef string)
Print the given string as a quoted string, escaping any special or non-printable characters in it.
virtual void printAttribute(Attribute attr)
void printDimensionList(ArrayRef< int64_t > shape)
virtual ~AsmPrinter()
virtual raw_ostream & getStream() const
Return the raw output stream used by this printer.
virtual void printResourceHandle(const AsmDialectResourceHandle &resource)
Print a handle to the given dialect resource.
virtual void printFloat(const APFloat &value)
Print the given floating point value in a stabilized form that can be roundtripped through the IR.
This class is used to build resource entries for use by the printer.
Definition: AsmState.h:246
virtual void buildString(StringRef key, StringRef data)=0
Build a resource entry represented by the given human-readable string value.
virtual void buildBool(StringRef key, bool data)=0
Build a resource entry represented by the given bool.
virtual void buildBlob(StringRef key, ArrayRef< char > data, uint32_t dataAlignment)=0
Build an resource entry represented by the given binary blob data.
This class represents an instance of a resource parser.
Definition: AsmState.h:337
static std::unique_ptr< AsmResourcePrinter > fromCallable(StringRef name, CallableT &&printFn)
Return a resource printer implemented via the given callable, whose form should match that of buildRe...
Definition: AsmState.h:398
This class provides management for the lifetime of the state used when printing the IR.
Definition: AsmState.h:540
void attachResourcePrinter(std::unique_ptr< AsmResourcePrinter > printer)
Attach the given resource printer to the AsmState.
DenseMap< Dialect *, SetVector< AsmDialectResourceHandle > > & getDialectResources() const
Returns a map of dialect resources that were referenced when using this state to print IR.
void attachFallbackResourcePrinter(FallbackAsmResourceMap &map)
Attach resource printers to the AsmState for the fallback resources in the given map.
Definition: AsmState.h:586
const OpPrintingFlags & getPrinterFlags() const
Get the printer flags.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Dialect & getDialect() const
Get the dialect this attribute is registered to.
Definition: Attributes.h:76
bool hasTrait()
Returns true if the type was registered with a particular trait.
Definition: Attributes.h:110
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:91
static Attribute getFromOpaquePointer(const void *ptr)
Construct an attribute from the opaque pointer representation.
Definition: Attributes.h:93
This class represents an argument of a Block.
Definition: Value.h:319
Location getLoc() const
Return the location for this argument.
Definition: Value.h:334
unsigned getArgNumber() const
Returns the number of this argument.
Definition: Value.h:331
Block represents an ordered list of Operations.
Definition: Block.h:33
bool empty()
Definition: Block.h:148
Operation & back()
Definition: Block.h:152
void printAsOperand(raw_ostream &os, bool printType=true)
Print out the name of the block without printing its body.
void print(raw_ostream &os)
BlockArgListType getArguments()
Definition: Block.h:87
iterator end()
Definition: Block.h:144
iterator begin()
Definition: Block.h:143
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
Operation * getParentOp()
Returns the closest surrounding operation that contains this block.
Definition: Block.cpp:33
An attribute that represents a reference to a dense vector or tensor object.
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine.
Definition: Diagnostics.h:155
~DialectAsmParser() override
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
~DialectAsmPrinter() override
A collection of dialect interfaces within a context, for a given concrete interface type.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition: Dialect.h:38
virtual void printAttribute(Attribute, DialectAsmPrinter &) const
Print an attribute registered to this dialect.
Definition: Dialect.h:99
virtual void printType(Type, DialectAsmPrinter &) const
Print a type registered to this dialect.
Definition: Dialect.h:107
An attribute that associates a referenced attribute with a unique identifier.
A fallback map containing external resources not explicitly handled by another parser/printer.
Definition: AsmState.h:419
std::vector< std::unique_ptr< AsmResourcePrinter > > getPrinters()
Build a set of resource printers to print the resources within this map.
An integer set representing a conjunction of one or more affine equalities and inequalities.
Definition: IntegerSet.h:44
unsigned getNumDims() const
Definition: IntegerSet.cpp:15
unsigned getNumConstraints() const
Definition: IntegerSet.cpp:21
AffineExpr getConstraint(unsigned idx) const
Definition: IntegerSet.cpp:45
bool isEq(unsigned idx) const
Returns true if the idx^th constraint is an equality, false if it is an inequality.
Definition: IntegerSet.cpp:55
unsigned getNumSymbols() const
Definition: IntegerSet.cpp:16
Location objects represent source locations information in MLIR.
Definition: Location.h:31
T findInstanceOf()
Return an instance of the given location type if one is nested under the current location.
Definition: Location.h:44
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:221
virtual std::string getResourceKey(const AsmDialectResourceHandle &handle) const
Return a key to use for the given resource.
virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const
Hook for parsing resource entries.
Definition: AsmPrinter.cpp:131
AliasResult
Holds the result of getAlias hook call.
@ FinalAlias
An alias was provided and it should be used (no other hooks will be checked).
@ NoAlias
The object (type or attribute) is not supported by the hook and an alias was not provided.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
~OpAsmParser() override
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
Definition: AsmPrinter.cpp:95
~OpAsmPrinter() override
Set of flags used to control the behavior of the various IR print methods (e.g.
bool shouldElideElementsAttr(ElementsAttr attr) const
Return if the given ElementsAttr should be elided.
Definition: AsmPrinter.cpp:314
std::optional< int64_t > getLargeElementsAttrLimit() const
Return the size limit for printing large ElementsAttr.
Definition: AsmPrinter.cpp:329
bool shouldUseNameLocAsPrefix() const
Return if the printer should use NameLocs as prefixes when printing SSA IDs.
Definition: AsmPrinter.cpp:380
bool shouldAssumeVerified() const
Return if operation verification should be skipped.
Definition: AsmPrinter.cpp:362
OpPrintingFlags & printLargeElementsAttrWithHex(int64_t largeElementLimit=100)
Enables the printing of large element attributes with a hex string.
Definition: AsmPrinter.cpp:254
bool shouldUseLocalScope() const
Return if the printer should use local scope when dumping the IR.
Definition: AsmPrinter.cpp:367
bool shouldPrintDebugInfoPrettyForm() const
Return if debug information should be printed in the pretty form.
Definition: AsmPrinter.cpp:349
bool shouldPrintElementsAttrWithHex(ElementsAttr attr) const
Return if the given ElementsAttr should be printed as hex string.
Definition: AsmPrinter.cpp:321
bool shouldPrintUniqueSSAIDs() const
Return if printer should use unique SSA IDs.
Definition: AsmPrinter.cpp:375
bool shouldPrintValueUsers() const
Return if the printer should print users of values.
Definition: AsmPrinter.cpp:370
int64_t getLargeElementsAttrHexLimit() const
Return the size limit for printing large ElementsAttr as hex string.
Definition: AsmPrinter.cpp:334
bool shouldPrintGenericOpForm() const
Return if operations should be printed in the generic form.
Definition: AsmPrinter.cpp:354
OpPrintingFlags & elideLargeResourceString(int64_t largeResourceLimit=64)
Enables the elision of large resources strings by omitting them from the dialect_resources section.
Definition: AsmPrinter.cpp:260
bool shouldPrintDebugInfo() const
Return if debug information should be printed.
Definition: AsmPrinter.cpp:344
OpPrintingFlags & elideLargeElementsAttrs(int64_t largeElementLimit=16)
Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningl...
Definition: AsmPrinter.cpp:248
OpPrintingFlags & printValueUsers(bool enable=true)
Print users of values as comments.
Definition: AsmPrinter.cpp:301
OpPrintingFlags & enableDebugInfo(bool enable=true, bool prettyForm=false)
Enable or disable printing of debug information (based on enable).
Definition: AsmPrinter.cpp:267
OpPrintingFlags()
Initialize the printing flags with default supplied by the cl::opts above.
Definition: AsmPrinter.cpp:216
bool shouldSkipRegions() const
Return if regions should be skipped.
Definition: AsmPrinter.cpp:359
OpPrintingFlags & printGenericOpForm(bool enable=true)
Always print operations in the generic form.
Definition: AsmPrinter.cpp:275
OpPrintingFlags & useLocalScope(bool enable=true)
Use local scope when printing the operation.
Definition: AsmPrinter.cpp:295
std::optional< uint64_t > getLargeResourceStringLimit() const
Return the size limit in chars for printing large resources.
Definition: AsmPrinter.cpp:339
OpPrintingFlags & assumeVerified(bool enable=true)
Do not verify the operation when using custom operation printers.
Definition: AsmPrinter.cpp:287
OpPrintingFlags & skipRegions(bool skip=true)
Skip printing regions.
Definition: AsmPrinter.cpp:281
OpPrintingFlags & printUniqueSSAIDs(bool enable=true)
Print unique SSA ID numbers for values, block arguments and naming conflicts across all regions.
Definition: AsmPrinter.cpp:308
This is a value defined by a result of an operation.
Definition: Value.h:457
Operation * getOwner() const
Returns the operation that owns this result.
Definition: Value.h:466
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
Definition: OpDefinition.h:765
void dump() const
Definition: AsmPrinter.cpp:63
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
void printAssembly(Operation *op, OpAsmPrinter &p, StringRef defaultDialect) const
This hook implements the AsmPrinter for this operation.
void print(raw_ostream &os) const
Definition: AsmPrinter.cpp:61
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool use_empty()
Returns true if this operation has no uses.
Definition: Operation.h:853
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
unsigned getNumSuccessors()
Definition: Operation.h:707
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
void print(raw_ostream &os, const OpPrintingFlags &flags=std::nullopt)
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
std::optional< RegisteredOperationName > getRegisteredInfo()
If this operation has a registered operation description, return it.
Definition: Operation.h:123
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
Attribute getPropertiesAsAttribute()
Return the properties converted to an attribute.
Definition: Operation.cpp:349
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
LLVM_DUMP_METHOD void dumpPretty()
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:874
SuccessorRange getSuccessors()
Definition: Operation.h:704
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator_range< OpIterator > getOps()
Definition: Region.h:172
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition: Region.h:200
bool empty()
Definition: Region.h:60
unsigned getNumArguments()
Definition: Region.h:123
BlockArgument getArgument(unsigned i)
Definition: Region.h:124
Block & front()
Definition: Region.h:65
This diagnostic handler is a simple RAII class that registers and erases a diagnostic handler on a gi...
Definition: Diagnostics.h:522
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:188
Dialect & getDialect() const
Get the dialect this type is registered to.
Definition: Types.h:123
static Type getFromOpaquePointer(const void *pointer)
Definition: Types.h:191
void walkImmediateSubElements(function_ref< void(Attribute)> walkAttrsFn, function_ref< void(Type)> walkTypesFn) const
Walk all of the immediately nested sub-attributes and sub-types.
Definition: Types.h:218
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool use_empty() const
Returns true if this value has no uses.
Definition: Value.h:218
void dump() const
void print(raw_ostream &os) const
Type getType() const
Return the type of this value.
Definition: Value.h:129
void printAsOperand(raw_ostream &os, AsmState &state) const
Print this value as if it were an operand.
user_range getUsers() const
Definition: Value.h:228
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void registerOperationLocation(Operation *op, unsigned line, unsigned col)
Register the location, line and column, within the buffer that the given operation was printed at.
const OpPrintingFlags & getPrinterFlags() const
Get the printer flags.
auto getResourcePrinters()
Return the non-dialect resource printers.
LogicalResult pushCyclicPrinting(const void *opaquePointer)
SSANameState & getSSANameState()
Get the state used for SSA names.
DialectInterfaceCollection< OpAsmDialectInterface > & getDialectInterfaces()
Return the dialects within the context that implement OpAsmDialectInterface.
AliasState & getAliasState()
Get the state used for aliases.
void initializeAliases(Operation *op)
Initialize the alias state to enable the printing of aliases.
AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap)
DenseMap< Dialect *, SetVector< AsmDialectResourceHandle > > & getDialectResources()
Return the referenced dialect resources within the printer.
AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, AsmState::LocationMap *locationMap)
DistinctState & getDistinctState()
Get the state used for distinct attribute identifiers.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
void printType(Type type, AsmPrinter &printer)
Prints an LLVM Dialect type.
LogicalResult parseCommaSeparatedList(llvm::cl::Option &opt, StringRef argName, StringRef optionStr, function_ref< LogicalResult(StringRef)> elementParseFn)
Parse a string containing a list of comma-delimited elements, invoking the given parser for each sub-...
AttrTypeReplacer.
void printDimensionList(raw_ostream &stream, Range &&shape)
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
bool operator<(const Fraction &x, const Fraction &y)
Definition: Fraction.h:83
Include the generated interface declarations.
ParseResult parseDimensionList(OpAsmParser &parser, DenseI64ArrayAttr &dimensions)
StringRef toString(AsmResourceEntryKind kind)
@ CeilDiv
RHS of ceildiv is always a constant or a symbolic expression.
@ Mul
RHS of mul is always a constant or a symbolic expression.
@ Mod
RHS of mod is always a constant or a symbolic expression with a positive value.
@ DimId
Dimensional identifier.
@ FloorDiv
RHS of floordiv is always a constant or a symbolic expression.
@ Constant
Constant integer.
@ SymbolId
Symbolic identifier.
void registerAsmPrinterCLOptions()
Register a set of useful command-line options that can be used to configure various flags within the ...
Definition: AsmPrinter.cpp:210
Type parseType(llvm::StringRef typeStr, MLIRContext *context, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR type to an MLIR context if it was valid.
AsmResourceEntryKind
This enum represents the different kinds of resource values.
Definition: AsmState.h:279
@ String
A string value.
@ Bool
A boolean value.
@ Blob
A blob of data with an accompanying alignment.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
Definition: AliasAnalysis.h:78
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
This trait is used to determine if a storage user, like Type, is mutable or not.