MLIR  17.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38  PDLPatternConfigSet *configSet,
39  ByteCodeAddr rewriterAddr) {
40  PatternBenefit benefit = matchOp.getBenefit();
41  MLIRContext *ctx = matchOp.getContext();
42 
43  // Collect the set of generated operations.
44  SmallVector<StringRef, 8> generatedOps;
45  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46  generatedOps =
47  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49  // Check to see if this is pattern matches a specific operation type.
50  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52  generatedOps);
53  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54  benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
65  PatternBenefit benefit) {
66  currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
73  allocatedTypeRangeMemory.clear();
74  allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83  /// Apply an externally registered constraint.
84  ApplyConstraint,
85  /// Apply an externally registered rewrite.
86  ApplyRewrite,
87  /// Check if two generic values are equal.
88  AreEqual,
89  /// Check if two ranges are equal.
90  AreRangesEqual,
91  /// Unconditional branch.
92  Branch,
93  /// Compare the operand count of an operation with a constant.
94  CheckOperandCount,
95  /// Compare the name of an operation with a constant.
96  CheckOperationName,
97  /// Compare the result count of an operation with a constant.
98  CheckResultCount,
99  /// Compare a range of types to a constant range of types.
100  CheckTypes,
101  /// Continue to the next iteration of a loop.
102  Continue,
103  /// Create a type range from a list of constant types.
104  CreateConstantTypeRange,
105  /// Create an operation.
106  CreateOperation,
107  /// Create a type range from a list of dynamic types.
108  CreateDynamicTypeRange,
109  /// Create a value range.
110  CreateDynamicValueRange,
111  /// Erase an operation.
112  EraseOp,
113  /// Extract the op from a range at the specified index.
114  ExtractOp,
115  /// Extract the type from a range at the specified index.
116  ExtractType,
117  /// Extract the value from a range at the specified index.
118  ExtractValue,
119  /// Terminate a matcher or rewrite sequence.
120  Finalize,
121  /// Iterate over a range of values.
122  ForEach,
123  /// Get a specific attribute of an operation.
124  GetAttribute,
125  /// Get the type of an attribute.
126  GetAttributeType,
127  /// Get the defining operation of a value.
128  GetDefiningOp,
129  /// Get a specific operand of an operation.
130  GetOperand0,
131  GetOperand1,
132  GetOperand2,
133  GetOperand3,
134  GetOperandN,
135  /// Get a specific operand group of an operation.
136  GetOperands,
137  /// Get a specific result of an operation.
138  GetResult0,
139  GetResult1,
140  GetResult2,
141  GetResult3,
142  GetResultN,
143  /// Get a specific result group of an operation.
144  GetResults,
145  /// Get the users of a value or a range of values.
146  GetUsers,
147  /// Get the type of a value.
148  GetValueType,
149  /// Get the types of a value range.
150  GetValueRangeTypes,
151  /// Check if a generic value is not null.
152  IsNotNull,
153  /// Record a successful pattern match.
154  RecordMatch,
155  /// Replace an operation.
156  ReplaceOp,
157  /// Compare an attribute with a set of constants.
158  SwitchAttribute,
159  /// Compare the operand count of an operation with a set of constants.
160  SwitchOperandCount,
161  /// Compare the name of an operation with a set of constants.
162  SwitchOperationName,
163  /// Compare the result count of an operation with a set of constants.
164  SwitchResultCount,
165  /// Compare a type with a set of constants.
166  SwitchType,
167  /// Compare a range of types with a set of constants.
168  SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 
183 namespace {
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
186 
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T, typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190 
191 /// This class represents the main generator for the pattern bytecode.
192 class Generator {
193 public:
194  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195  SmallVectorImpl<ByteCodeField> &matcherByteCode,
196  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
198  ByteCodeField &maxValueMemoryIndex,
199  ByteCodeField &maxOpRangeMemoryIndex,
200  ByteCodeField &maxTypeRangeMemoryIndex,
201  ByteCodeField &maxValueRangeMemoryIndex,
202  ByteCodeField &maxLoopLevel,
203  llvm::StringMap<PDLConstraintFunction> &constraintFns,
204  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207  rewriterByteCode(rewriterByteCode), patterns(patterns),
208  maxValueMemoryIndex(maxValueMemoryIndex),
209  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212  maxLoopLevel(maxLoopLevel), configMap(configMap) {
213  for (const auto &it : llvm::enumerate(constraintFns))
214  constraintToMemIndex.try_emplace(it.value().first(), it.index());
215  for (const auto &it : llvm::enumerate(rewriteFns))
216  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217  }
218 
219  /// Generate the bytecode for the given PDL interpreter module.
220  void generate(ModuleOp module);
221 
222  /// Return the memory index to use for the given value.
223  ByteCodeField &getMemIndex(Value value) {
224  assert(valueToMemIndex.count(value) &&
225  "expected memory index to be assigned");
226  return valueToMemIndex[value];
227  }
228 
229  /// Return the range memory index used to store the given range value.
230  ByteCodeField &getRangeStorageIndex(Value value) {
231  assert(valueToRangeIndex.count(value) &&
232  "expected range index to be assigned");
233  return valueToRangeIndex[value];
234  }
235 
236  /// Return an index to use when referring to the given data that is uniqued in
237  /// the MLIR context.
238  template <typename T>
239  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240  getMemIndex(T val) {
241  const void *opaqueVal = val.getAsOpaquePointer();
242 
243  // Get or insert a reference to this value.
244  auto it = uniquedDataToMemIndex.try_emplace(
245  opaqueVal, maxValueMemoryIndex + uniquedData.size());
246  if (it.second)
247  uniquedData.push_back(opaqueVal);
248  return it.first->second;
249  }
250 
251 private:
252  /// Allocate memory indices for the results of operations within the matcher
253  /// and rewriters.
254  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255  ModuleOp rewriterModule);
256 
257  /// Generate the bytecode for the given operation.
258  void generate(Region *region, ByteCodeWriter &writer);
259  void generate(Operation *op, ByteCodeWriter &writer);
260  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 
299  /// Mapping from value to its corresponding memory index.
300  DenseMap<Value, ByteCodeField> valueToMemIndex;
301 
302  /// Mapping from a range value to its corresponding range storage index.
303  DenseMap<Value, ByteCodeField> valueToRangeIndex;
304 
305  /// Mapping from the name of an externally registered rewrite to its index in
306  /// the bytecode registry.
307  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308 
309  /// Mapping from the name of an externally registered constraint to its index
310  /// in the bytecode registry.
311  llvm::StringMap<ByteCodeField> constraintToMemIndex;
312 
313  /// Mapping from rewriter function name to the bytecode address of the
314  /// rewriter function in byte.
315  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316 
317  /// Mapping from a uniqued storage object to its memory index within
318  /// `uniquedData`.
319  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320 
321  /// The current level of the foreach loop.
322  ByteCodeField curLoopLevel = 0;
323 
324  /// The current MLIR context.
325  MLIRContext *ctx;
326 
327  /// Mapping from block to its address.
329 
330  /// Data of the ByteCode class to be populated.
331  std::vector<const void *> &uniquedData;
332  SmallVectorImpl<ByteCodeField> &matcherByteCode;
333  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
335  ByteCodeField &maxValueMemoryIndex;
336  ByteCodeField &maxOpRangeMemoryIndex;
337  ByteCodeField &maxTypeRangeMemoryIndex;
338  ByteCodeField &maxValueRangeMemoryIndex;
339  ByteCodeField &maxLoopLevel;
340 
341  /// A map of pattern configurations.
343 };
344 
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter {
347  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348  : bytecode(bytecode), generator(generator) {}
349 
350  /// Append a field to the bytecode.
351  void append(ByteCodeField field) { bytecode.push_back(field); }
352  void append(OpCode opCode) { bytecode.push_back(opCode); }
353 
354  /// Append an address to the bytecode.
355  void append(ByteCodeAddr field) {
356  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357  "unexpected ByteCode address size");
358 
359  ByteCodeField fieldParts[2];
360  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361  bytecode.append({fieldParts[0], fieldParts[1]});
362  }
363 
364  /// Append a single successor to the bytecode, the exact address will need to
365  /// be resolved later.
366  void append(Block *successor) {
367  // Add back a reference to the successor so that the address can be resolved
368  // later.
369  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370  append(ByteCodeAddr(0));
371  }
372 
373  /// Append a successor range to the bytecode, the exact address will need to
374  /// be resolved later.
375  void append(SuccessorRange successors) {
376  for (Block *successor : successors)
377  append(successor);
378  }
379 
380  /// Append a range of values that will be read as generic PDLValues.
381  void appendPDLValueList(OperandRange values) {
382  bytecode.push_back(values.size());
383  for (Value value : values)
384  appendPDLValue(value);
385  }
386 
387  /// Append a value as a PDLValue.
388  void appendPDLValue(Value value) {
389  appendPDLValueKind(value);
390  append(value);
391  }
392 
393  /// Append the PDLValue::Kind of the given value.
394  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
395 
396  /// Append the PDLValue::Kind of the given type.
397  void appendPDLValueKind(Type type) {
398  PDLValue::Kind kind =
400  .Case<pdl::AttributeType>(
401  [](Type) { return PDLValue::Kind::Attribute; })
402  .Case<pdl::OperationType>(
403  [](Type) { return PDLValue::Kind::Operation; })
404  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405  if (rangeTy.getElementType().isa<pdl::TypeType>())
408  })
409  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411  bytecode.push_back(static_cast<ByteCodeField>(kind));
412  }
413 
414  /// Append a value that will be stored in a memory slot and not inline within
415  /// the bytecode.
416  template <typename T>
417  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418  std::is_pointer<T>::value>
419  append(T value) {
420  bytecode.push_back(generator.getMemIndex(value));
421  }
422 
423  /// Append a range of values.
424  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426  append(T range) {
427  bytecode.push_back(llvm::size(range));
428  for (auto it : range)
429  append(it);
430  }
431 
432  /// Append a variadic number of fields to the bytecode.
433  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435  append(field);
436  append(field2, fields...);
437  }
438 
439  /// Appends a value as a pointer, stored inline within the bytecode.
440  template <typename T>
441  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442  appendInline(T value) {
443  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444  const void *pointer = value.getAsOpaquePointer();
445  ByteCodeField fieldParts[numParts];
446  std::memcpy(fieldParts, &pointer, sizeof(const void *));
447  bytecode.append(fieldParts, fieldParts + numParts);
448  }
449 
450  /// Successor references in the bytecode that have yet to be resolved.
451  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452 
453  /// The underlying bytecode buffer.
455 
456  /// The main generator producing PDL.
457  Generator &generator;
458 };
459 
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange {
463  using Set = llvm::IntervalMap<uint64_t, char, 16>;
464  using Allocator = Set::Allocator;
465 
466  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467 
468  /// Union this live range with the one provided.
469  void unionWith(const ByteCodeLiveRange &rhs) {
470  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471  ++it)
472  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
473  }
474 
475  /// Returns true if this range overlaps with the one provided.
476  bool overlaps(const ByteCodeLiveRange &rhs) const {
477  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478  .valid();
479  }
480 
481  /// A map representing the ranges of the match/rewrite that a value is live in
482  /// the interpreter.
483  ///
484  /// We use std::unique_ptr here, because IntervalMap does not provide a
485  /// correct copy or move constructor. We can eliminate the pointer once
486  /// https://reviews.llvm.org/D113240 lands.
487  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 
489  /// The operation range storage index for this range.
490  std::optional<unsigned> opRangeIndex;
491 
492  /// The type range storage index for this range.
493  std::optional<unsigned> typeRangeIndex;
494 
495  /// The value range storage index for this range.
496  std::optional<unsigned> valueRangeIndex;
497 };
498 } // namespace
499 
500 void Generator::generate(ModuleOp module) {
501  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504  pdl_interp::PDLInterpDialect::getRewriterModuleName());
505  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506 
507  // Allocate memory indices for the results of operations within the matcher
508  // and rewriters.
509  allocateMemoryIndices(matcherFunc, rewriterModule);
510 
511  // Generate code for the rewriter functions.
512  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515  for (Operation &op : rewriterFunc.getOps())
516  generate(&op, rewriterByteCodeWriter);
517  }
518  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519  "unexpected branches in rewriter function");
520 
521  // Generate code for the matcher function.
522  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524 
525  // Resolve successor references in the matcher.
526  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527  ByteCodeAddr addr = blockToAddr[it.first];
528  for (unsigned offsetToFix : it.second)
529  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530  }
531 }
532 
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534  ModuleOp rewriterModule) {
535  // Rewriters use simplistic allocation scheme that simply assigns an index to
536  // each result.
537  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539  auto processRewriterValue = [&](Value val) {
540  valueToMemIndex.try_emplace(val, index++);
541  if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
542  Type elementTy = rangeType.getElementType();
543  if (elementTy.isa<pdl::TypeType>())
544  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545  else if (elementTy.isa<pdl::ValueType>())
546  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547  }
548  };
549 
550  for (BlockArgument arg : rewriterFunc.getArguments())
551  processRewriterValue(arg);
552  rewriterFunc.getBody().walk([&](Operation *op) {
553  for (Value result : op->getResults())
554  processRewriterValue(result);
555  });
556  if (index > maxValueMemoryIndex)
557  maxValueMemoryIndex = index;
558  if (typeRangeIndex > maxTypeRangeMemoryIndex)
559  maxTypeRangeMemoryIndex = typeRangeIndex;
560  if (valueRangeIndex > maxValueRangeMemoryIndex)
561  maxValueRangeMemoryIndex = valueRangeIndex;
562  }
563 
564  // The matcher function uses a more sophisticated numbering that tries to
565  // minimize the number of memory indices assigned. This is done by determining
566  // a live range of the values within the matcher, then the allocation is just
567  // finding the minimal number of overlapping live ranges. This is essentially
568  // a simplified form of register allocation where we don't necessarily have a
569  // limited number of registers, but we still want to minimize the number used.
570  DenseMap<Operation *, unsigned> opToFirstIndex;
571  DenseMap<Operation *, unsigned> opToLastIndex;
572 
573  // A custom walk that marks the first and the last index of each operation.
574  // The entry marks the beginning of the liveness range for this operation,
575  // followed by nested operations, followed by the end of the liveness range.
576  unsigned index = 0;
577  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578  opToFirstIndex.try_emplace(op, index++);
579  for (Region &region : op->getRegions())
580  for (Block &block : region.getBlocks())
581  for (Operation &nested : block)
582  walk(&nested);
583  opToLastIndex.try_emplace(op, index++);
584  };
585  walk(matcherFunc);
586 
587  // Liveness info for each of the defs within the matcher.
588  ByteCodeLiveRange::Allocator allocator;
589  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590 
591  // Assign the root operation being matched to slot 0.
592  BlockArgument rootOpArg = matcherFunc.getArgument(0);
593  valueToMemIndex[rootOpArg] = 0;
594 
595  // Walk each of the blocks, computing the def interval that the value is used.
596  Liveness matcherLiveness(matcherFunc);
597  matcherFunc->walk([&](Block *block) {
598  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599  assert(info && "expected liveness info for block");
600  auto processValue = [&](Value value, Operation *firstUseOrDef) {
601  // We don't need to process the root op argument, this value is always
602  // assigned to the first memory slot.
603  if (value == rootOpArg)
604  return;
605 
606  // Set indices for the range of this block that the value is used.
607  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608  defRangeIt->second.liveness->insert(
609  opToFirstIndex[firstUseOrDef],
610  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611  /*dummyValue*/ 0);
612 
613  // Check to see if this value is a range type.
614  if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
615  Type eleType = rangeTy.getElementType();
616  if (eleType.isa<pdl::OperationType>())
617  defRangeIt->second.opRangeIndex = 0;
618  else if (eleType.isa<pdl::TypeType>())
619  defRangeIt->second.typeRangeIndex = 0;
620  else if (eleType.isa<pdl::ValueType>())
621  defRangeIt->second.valueRangeIndex = 0;
622  }
623  };
624 
625  // Process the live-ins of this block.
626  for (Value liveIn : info->in()) {
627  // Only process the value if it has been defined in the current region.
628  // Other values that span across pdl_interp.foreach will be added higher
629  // up. This ensures that the we keep them alive for the entire duration
630  // of the loop.
631  if (liveIn.getParentRegion() == block->getParent())
632  processValue(liveIn, &block->front());
633  }
634 
635  // Process the block arguments for the entry block (those are not live-in).
636  if (block->isEntryBlock()) {
637  for (Value argument : block->getArguments())
638  processValue(argument, &block->front());
639  }
640 
641  // Process any new defs within this block.
642  for (Operation &op : *block)
643  for (Value result : op.getResults())
644  processValue(result, &op);
645  });
646 
647  // Greedily allocate memory slots using the computed def live ranges.
648  std::vector<ByteCodeLiveRange> allocatedIndices;
649 
650  // The number of memory indices currently allocated (and its next value).
651  // Recall that the root gets allocated memory index 0.
652  ByteCodeField numIndices = 1;
653 
654  // The number of memory ranges of various types (and their next values).
655  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656 
657  for (auto &defIt : valueDefRanges) {
658  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659  ByteCodeLiveRange &defRange = defIt.second;
660 
661  // Try to allocate to an existing index.
662  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
663  ByteCodeLiveRange &existingRange = existingIndexIt.value();
664  if (!defRange.overlaps(existingRange)) {
665  existingRange.unionWith(defRange);
666  memIndex = existingIndexIt.index() + 1;
667 
668  if (defRange.opRangeIndex) {
669  if (!existingRange.opRangeIndex)
670  existingRange.opRangeIndex = numOpRanges++;
671  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672  } else if (defRange.typeRangeIndex) {
673  if (!existingRange.typeRangeIndex)
674  existingRange.typeRangeIndex = numTypeRanges++;
675  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676  } else if (defRange.valueRangeIndex) {
677  if (!existingRange.valueRangeIndex)
678  existingRange.valueRangeIndex = numValueRanges++;
679  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680  }
681  break;
682  }
683  }
684 
685  // If no existing index could be used, add a new one.
686  if (memIndex == 0) {
687  allocatedIndices.emplace_back(allocator);
688  ByteCodeLiveRange &newRange = allocatedIndices.back();
689  newRange.unionWith(defRange);
690 
691  // Allocate an index for op/type/value ranges.
692  if (defRange.opRangeIndex) {
693  newRange.opRangeIndex = numOpRanges;
694  valueToRangeIndex[defIt.first] = numOpRanges++;
695  } else if (defRange.typeRangeIndex) {
696  newRange.typeRangeIndex = numTypeRanges;
697  valueToRangeIndex[defIt.first] = numTypeRanges++;
698  } else if (defRange.valueRangeIndex) {
699  newRange.valueRangeIndex = numValueRanges;
700  valueToRangeIndex[defIt.first] = numValueRanges++;
701  }
702 
703  memIndex = allocatedIndices.size();
704  ++numIndices;
705  }
706  }
707 
708  // Print the index usage and ensure that we did not run out of index space.
709  LLVM_DEBUG({
710  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711  << "(down from initial " << valueDefRanges.size() << ").\n";
712  });
713  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714  "Ran out of memory for allocated indices");
715 
716  // Update the max number of indices.
717  if (numIndices > maxValueMemoryIndex)
718  maxValueMemoryIndex = numIndices;
719  if (numOpRanges > maxOpRangeMemoryIndex)
720  maxOpRangeMemoryIndex = numOpRanges;
721  if (numTypeRanges > maxTypeRangeMemoryIndex)
722  maxTypeRangeMemoryIndex = numTypeRanges;
723  if (numValueRanges > maxValueRangeMemoryIndex)
724  maxValueRangeMemoryIndex = numValueRanges;
725 }
726 
727 void Generator::generate(Region *region, ByteCodeWriter &writer) {
728  llvm::ReversePostOrderTraversal<Region *> rpot(region);
729  for (Block *block : rpot) {
730  // Keep track of where this block begins within the matcher function.
731  blockToAddr.try_emplace(block, matcherByteCode.size());
732  for (Operation &op : *block)
733  generate(&op, writer);
734  }
735 }
736 
737 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738  LLVM_DEBUG({
739  // The following list must contain all the operations that do not
740  // produce any bytecode.
741  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742  writer.appendInline(op->getLoc());
743  });
745  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763  pdl_interp::SwitchResultCountOp>(
764  [&](auto interpOp) { this->generate(interpOp, writer); })
765  .Default([](Operation *) {
766  llvm_unreachable("unknown `pdl_interp` operation");
767  });
768 }
769 
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771  ByteCodeWriter &writer) {
772  assert(constraintToMemIndex.count(op.getName()) &&
773  "expected index for constraint function");
774  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
775  writer.appendPDLValueList(op.getArgs());
776  writer.append(op.getSuccessors());
777 }
778 void Generator::generate(pdl_interp::ApplyRewriteOp op,
779  ByteCodeWriter &writer) {
780  assert(externalRewriterToMemIndex.count(op.getName()) &&
781  "expected index for rewrite function");
782  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
783  writer.appendPDLValueList(op.getArgs());
784 
785  ResultRange results = op.getResults();
786  writer.append(ByteCodeField(results.size()));
787  for (Value result : results) {
788  // In debug mode we also record the expected kind of the result, so that we
789  // can provide extra verification of the native rewrite function.
790 #ifndef NDEBUG
791  writer.appendPDLValueKind(result);
792 #endif
793 
794  // Range results also need to append the range storage index.
795  if (result.getType().isa<pdl::RangeType>())
796  writer.append(getRangeStorageIndex(result));
797  writer.append(result);
798  }
799 }
800 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
801  Value lhs = op.getLhs();
802  if (lhs.getType().isa<pdl::RangeType>()) {
803  writer.append(OpCode::AreRangesEqual);
804  writer.appendPDLValueKind(lhs);
805  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
806  return;
807  }
808 
809  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
810 }
811 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
812  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
813 }
814 void Generator::generate(pdl_interp::CheckAttributeOp op,
815  ByteCodeWriter &writer) {
816  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
817  op.getSuccessors());
818 }
819 void Generator::generate(pdl_interp::CheckOperandCountOp op,
820  ByteCodeWriter &writer) {
821  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
822  static_cast<ByteCodeField>(op.getCompareAtLeast()),
823  op.getSuccessors());
824 }
825 void Generator::generate(pdl_interp::CheckOperationNameOp op,
826  ByteCodeWriter &writer) {
827  writer.append(OpCode::CheckOperationName, op.getInputOp(),
828  OperationName(op.getName(), ctx), op.getSuccessors());
829 }
830 void Generator::generate(pdl_interp::CheckResultCountOp op,
831  ByteCodeWriter &writer) {
832  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
833  static_cast<ByteCodeField>(op.getCompareAtLeast()),
834  op.getSuccessors());
835 }
836 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
837  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
838  op.getSuccessors());
839 }
840 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
841  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
842  op.getSuccessors());
843 }
844 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
845  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
846  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
847 }
848 void Generator::generate(pdl_interp::CreateAttributeOp op,
849  ByteCodeWriter &writer) {
850  // Simply repoint the memory index of the result to the constant.
851  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
852 }
853 void Generator::generate(pdl_interp::CreateOperationOp op,
854  ByteCodeWriter &writer) {
855  writer.append(OpCode::CreateOperation, op.getResultOp(),
856  OperationName(op.getName(), ctx));
857  writer.appendPDLValueList(op.getInputOperands());
858 
859  // Add the attributes.
860  OperandRange attributes = op.getInputAttributes();
861  writer.append(static_cast<ByteCodeField>(attributes.size()));
862  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
863  writer.append(std::get<0>(it), std::get<1>(it));
864 
865  // Add the result types. If the operation has inferred results, we use a
866  // marker "size" value. Otherwise, we add the list of explicit result types.
867  if (op.getInferredResultTypes())
868  writer.append(kInferTypesMarker);
869  else
870  writer.appendPDLValueList(op.getInputResultTypes());
871 }
872 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
873  // Append the correct opcode for the range type.
874  TypeSwitch<Type>(op.getType().getElementType())
875  .Case(
876  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
877  .Case([&](pdl::ValueType) {
878  writer.append(OpCode::CreateDynamicValueRange);
879  });
880 
881  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
882  writer.appendPDLValueList(op->getOperands());
883 }
884 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
885  // Simply repoint the memory index of the result to the constant.
886  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
887 }
888 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
889  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
890  getRangeStorageIndex(op.getResult()), op.getValue());
891 }
892 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
893  writer.append(OpCode::EraseOp, op.getInputOp());
894 }
895 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
896  OpCode opCode =
897  TypeSwitch<Type, OpCode>(op.getResult().getType())
898  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
899  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
900  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
901  .Default([](Type) -> OpCode {
902  llvm_unreachable("unsupported element type");
903  });
904  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
905 }
906 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
907  writer.append(OpCode::Finalize);
908 }
909 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
910  BlockArgument arg = op.getLoopVariable();
911  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
912  writer.appendPDLValueKind(arg.getType());
913  writer.append(curLoopLevel, op.getSuccessor());
914  ++curLoopLevel;
915  if (curLoopLevel > maxLoopLevel)
916  maxLoopLevel = curLoopLevel;
917  generate(&op.getRegion(), writer);
918  --curLoopLevel;
919 }
920 void Generator::generate(pdl_interp::GetAttributeOp op,
921  ByteCodeWriter &writer) {
922  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
923  op.getNameAttr());
924 }
925 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
926  ByteCodeWriter &writer) {
927  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
928 }
929 void Generator::generate(pdl_interp::GetDefiningOpOp op,
930  ByteCodeWriter &writer) {
931  writer.append(OpCode::GetDefiningOp, op.getInputOp());
932  writer.appendPDLValue(op.getValue());
933 }
934 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
935  uint32_t index = op.getIndex();
936  if (index < 4)
937  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
938  else
939  writer.append(OpCode::GetOperandN, index);
940  writer.append(op.getInputOp(), op.getValue());
941 }
942 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
943  Value result = op.getValue();
944  std::optional<uint32_t> index = op.getIndex();
945  writer.append(OpCode::GetOperands,
946  index.value_or(std::numeric_limits<uint32_t>::max()),
947  op.getInputOp());
948  if (result.getType().isa<pdl::RangeType>())
949  writer.append(getRangeStorageIndex(result));
950  else
952  writer.append(result);
953 }
954 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
955  uint32_t index = op.getIndex();
956  if (index < 4)
957  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
958  else
959  writer.append(OpCode::GetResultN, index);
960  writer.append(op.getInputOp(), op.getValue());
961 }
962 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
963  Value result = op.getValue();
964  std::optional<uint32_t> index = op.getIndex();
965  writer.append(OpCode::GetResults,
966  index.value_or(std::numeric_limits<uint32_t>::max()),
967  op.getInputOp());
968  if (result.getType().isa<pdl::RangeType>())
969  writer.append(getRangeStorageIndex(result));
970  else
972  writer.append(result);
973 }
974 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
975  Value operations = op.getOperations();
976  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
977  writer.append(OpCode::GetUsers, operations, rangeIndex);
978  writer.appendPDLValue(op.getValue());
979 }
980 void Generator::generate(pdl_interp::GetValueTypeOp op,
981  ByteCodeWriter &writer) {
982  if (op.getType().isa<pdl::RangeType>()) {
983  Value result = op.getResult();
984  writer.append(OpCode::GetValueRangeTypes, result,
985  getRangeStorageIndex(result), op.getValue());
986  } else {
987  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
988  }
989 }
990 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
991  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
992 }
993 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
994  ByteCodeField patternIndex = patterns.size();
995  patterns.emplace_back(PDLByteCodePattern::create(
996  op, configMap.lookup(op),
997  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
998  writer.append(OpCode::RecordMatch, patternIndex,
999  SuccessorRange(op.getOperation()), op.getMatchedOps());
1000  writer.appendPDLValueList(op.getInputs());
1001 }
1002 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1003  writer.append(OpCode::ReplaceOp, op.getInputOp());
1004  writer.appendPDLValueList(op.getReplValues());
1005 }
1006 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1007  ByteCodeWriter &writer) {
1008  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1009  op.getCaseValuesAttr(), op.getSuccessors());
1010 }
1011 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1012  ByteCodeWriter &writer) {
1013  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1014  op.getCaseValuesAttr(), op.getSuccessors());
1015 }
1016 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1017  ByteCodeWriter &writer) {
1018  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1019  return OperationName(attr.cast<StringAttr>().getValue(), ctx);
1020  });
1021  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1022  op.getSuccessors());
1023 }
1024 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1025  ByteCodeWriter &writer) {
1026  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1027  op.getCaseValuesAttr(), op.getSuccessors());
1028 }
1029 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1030  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1031  op.getSuccessors());
1032 }
1033 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1034  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1035  op.getSuccessors());
1036 }
1037 
1038 //===----------------------------------------------------------------------===//
1039 // PDLByteCode
1040 //===----------------------------------------------------------------------===//
1041 
1043  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1045  llvm::StringMap<PDLConstraintFunction> constraintFns,
1046  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1047  : configs(std::move(configs)) {
1048  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1049  rewriterByteCode, patterns, maxValueMemoryIndex,
1050  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1051  maxLoopLevel, constraintFns, rewriteFns, configMap);
1052  generator.generate(module);
1053 
1054  // Initialize the external functions.
1055  for (auto &it : constraintFns)
1056  constraintFunctions.push_back(std::move(it.second));
1057  for (auto &it : rewriteFns)
1058  rewriteFunctions.push_back(std::move(it.second));
1059 }
1060 
1061 /// Initialize the given state such that it can be used to execute the current
1062 /// bytecode.
1064  state.memory.resize(maxValueMemoryIndex, nullptr);
1065  state.opRangeMemory.resize(maxOpRangeCount);
1066  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1067  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1068  state.loopIndex.resize(maxLoopLevel, 0);
1069  state.currentPatternBenefits.reserve(patterns.size());
1070  for (const PDLByteCodePattern &pattern : patterns)
1071  state.currentPatternBenefits.push_back(pattern.getBenefit());
1072 }
1073 
1074 //===----------------------------------------------------------------------===//
1075 // ByteCode Execution
1076 
1077 namespace {
1078 /// This class provides support for executing a bytecode stream.
1079 class ByteCodeExecutor {
1080 public:
1081  ByteCodeExecutor(
1082  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1083  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1084  MutableArrayRef<TypeRange> typeRangeMemory,
1085  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1086  MutableArrayRef<ValueRange> valueRangeMemory,
1087  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1088  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1090  ArrayRef<PatternBenefit> currentPatternBenefits,
1092  ArrayRef<PDLConstraintFunction> constraintFunctions,
1093  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1094  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1095  typeRangeMemory(typeRangeMemory),
1096  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1097  valueRangeMemory(valueRangeMemory),
1098  allocatedValueRangeMemory(allocatedValueRangeMemory),
1099  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1100  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1101  constraintFunctions(constraintFunctions),
1102  rewriteFunctions(rewriteFunctions) {}
1103 
1104  /// Start executing the code at the current bytecode index. `matches` is an
1105  /// optional field provided when this function is executed in a matching
1106  /// context.
1108  execute(PatternRewriter &rewriter,
1109  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1110  std::optional<Location> mainRewriteLoc = {});
1111 
1112 private:
1113  /// Internal implementation of executing each of the bytecode commands.
1114  void executeApplyConstraint(PatternRewriter &rewriter);
1115  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1116  void executeAreEqual();
1117  void executeAreRangesEqual();
1118  void executeBranch();
1119  void executeCheckOperandCount();
1120  void executeCheckOperationName();
1121  void executeCheckResultCount();
1122  void executeCheckTypes();
1123  void executeContinue();
1124  void executeCreateConstantTypeRange();
1125  void executeCreateOperation(PatternRewriter &rewriter,
1126  Location mainRewriteLoc);
1127  template <typename T>
1128  void executeDynamicCreateRange(StringRef type);
1129  void executeEraseOp(PatternRewriter &rewriter);
1130  template <typename T, typename Range, PDLValue::Kind kind>
1131  void executeExtract();
1132  void executeFinalize();
1133  void executeForEach();
1134  void executeGetAttribute();
1135  void executeGetAttributeType();
1136  void executeGetDefiningOp();
1137  void executeGetOperand(unsigned index);
1138  void executeGetOperands();
1139  void executeGetResult(unsigned index);
1140  void executeGetResults();
1141  void executeGetUsers();
1142  void executeGetValueType();
1143  void executeGetValueRangeTypes();
1144  void executeIsNotNull();
1145  void executeRecordMatch(PatternRewriter &rewriter,
1147  void executeReplaceOp(PatternRewriter &rewriter);
1148  void executeSwitchAttribute();
1149  void executeSwitchOperandCount();
1150  void executeSwitchOperationName();
1151  void executeSwitchResultCount();
1152  void executeSwitchType();
1153  void executeSwitchTypes();
1154 
1155  /// Pushes a code iterator to the stack.
1156  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1157 
1158  /// Pops a code iterator from the stack, returning true on success.
1159  void popCodeIt() {
1160  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1161  curCodeIt = resumeCodeIt.back();
1162  resumeCodeIt.pop_back();
1163  }
1164 
1165  /// Return the bytecode iterator at the start of the current op code.
1166  const ByteCodeField *getPrevCodeIt() const {
1167  LLVM_DEBUG({
1168  // Account for the op code and the Location stored inline.
1169  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1170  });
1171 
1172  // Account for the op code only.
1173  return curCodeIt - 1;
1174  }
1175 
1176  /// Read a value from the bytecode buffer, optionally skipping a certain
1177  /// number of prefix values. These methods always update the buffer to point
1178  /// to the next field after the read data.
1179  template <typename T = ByteCodeField>
1180  T read(size_t skipN = 0) {
1181  curCodeIt += skipN;
1182  return readImpl<T>();
1183  }
1184  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1185 
1186  /// Read a list of values from the bytecode buffer.
1187  template <typename ValueT, typename T>
1188  void readList(SmallVectorImpl<T> &list) {
1189  list.clear();
1190  for (unsigned i = 0, e = read(); i != e; ++i)
1191  list.push_back(read<ValueT>());
1192  }
1193 
1194  /// Read a list of values from the bytecode buffer. The values may be encoded
1195  /// either as a single element or a range of elements.
1196  void readList(SmallVectorImpl<Type> &list) {
1197  for (unsigned i = 0, e = read(); i != e; ++i) {
1198  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1199  list.push_back(read<Type>());
1200  } else {
1201  TypeRange *values = read<TypeRange *>();
1202  list.append(values->begin(), values->end());
1203  }
1204  }
1205  }
1206  void readList(SmallVectorImpl<Value> &list) {
1207  for (unsigned i = 0, e = read(); i != e; ++i) {
1208  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1209  list.push_back(read<Value>());
1210  } else {
1211  ValueRange *values = read<ValueRange *>();
1212  list.append(values->begin(), values->end());
1213  }
1214  }
1215  }
1216 
1217  /// Read a value stored inline as a pointer.
1218  template <typename T>
1219  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1220  readInline() {
1221  const void *pointer;
1222  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1223  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1224  return T::getFromOpaquePointer(pointer);
1225  }
1226 
1227  /// Jump to a specific successor based on a predicate value.
1228  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1229  /// Jump to a specific successor based on a destination index.
1230  void selectJump(size_t destIndex) {
1231  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1232  }
1233 
1234  /// Handle a switch operation with the provided value and cases.
1235  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1236  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1237  LLVM_DEBUG({
1238  llvm::dbgs() << " * Value: " << value << "\n"
1239  << " * Cases: ";
1240  llvm::interleaveComma(cases, llvm::dbgs());
1241  llvm::dbgs() << "\n";
1242  });
1243 
1244  // Check to see if the attribute value is within the case list. Jump to
1245  // the correct successor index based on the result.
1246  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1247  if (cmp(*it, value))
1248  return selectJump(size_t((it - cases.begin()) + 1));
1249  selectJump(size_t(0));
1250  }
1251 
1252  /// Store a pointer to memory.
1253  void storeToMemory(unsigned index, const void *value) {
1254  memory[index] = value;
1255  }
1256 
1257  /// Store a value to memory as an opaque pointer.
1258  template <typename T>
1259  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1260  storeToMemory(unsigned index, T value) {
1261  memory[index] = value.getAsOpaquePointer();
1262  }
1263 
1264  /// Internal implementation of reading various data types from the bytecode
1265  /// stream.
1266  template <typename T>
1267  const void *readFromMemory() {
1268  size_t index = *curCodeIt++;
1269 
1270  // If this type is an SSA value, it can only be stored in non-const memory.
1271  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1272  Value>::value ||
1273  index < memory.size())
1274  return memory[index];
1275 
1276  // Otherwise, if this index is not inbounds it is uniqued.
1277  return uniquedMemory[index - memory.size()];
1278  }
1279  template <typename T>
1280  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1281  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1282  }
1283  template <typename T>
1284  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1285  T>
1286  readImpl() {
1287  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1288  }
1289  template <typename T>
1290  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1291  switch (read<PDLValue::Kind>()) {
1293  return read<Attribute>();
1295  return read<Operation *>();
1296  case PDLValue::Kind::Type:
1297  return read<Type>();
1298  case PDLValue::Kind::Value:
1299  return read<Value>();
1301  return read<TypeRange *>();
1303  return read<ValueRange *>();
1304  }
1305  llvm_unreachable("unhandled PDLValue::Kind");
1306  }
1307  template <typename T>
1308  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1309  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1310  "unexpected ByteCode address size");
1311  ByteCodeAddr result;
1312  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1313  curCodeIt += 2;
1314  return result;
1315  }
1316  template <typename T>
1317  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1318  return *curCodeIt++;
1319  }
1320  template <typename T>
1321  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1322  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1323  }
1324 
1325  /// Assign the given range to the given memory index. This allocates a new
1326  /// range object if necessary.
1327  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1328  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1329  unsigned rangeIndex) {
1330  // Utility functor used to type-erase the assignment.
1331  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1332  // If the input range is empty, we don't need to allocate anything.
1333  if (range.empty()) {
1334  rangeMemory[rangeIndex] = {};
1335  } else {
1336  // Allocate a buffer for this type range.
1337  llvm::OwningArrayRef<T> storage(llvm::size(range));
1338  llvm::copy(range, storage.begin());
1339 
1340  // Assign this to the range slot and use the range as the value for the
1341  // memory index.
1342  allocatedRangeMemory.emplace_back(std::move(storage));
1343  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1344  }
1345  memory[memIndex] = &rangeMemory[rangeIndex];
1346  };
1347 
1348  // Dispatch based on the concrete range type.
1349  if constexpr (std::is_same_v<T, Type>) {
1350  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1351  } else if constexpr (std::is_same_v<T, Value>) {
1352  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1353  } else {
1354  llvm_unreachable("unhandled range type");
1355  }
1356  }
1357 
1358  /// The underlying bytecode buffer.
1359  const ByteCodeField *curCodeIt;
1360 
1361  /// The stack of bytecode positions at which to resume operation.
1363 
1364  /// The current execution memory.
1366  MutableArrayRef<OwningOpRange> opRangeMemory;
1367  MutableArrayRef<TypeRange> typeRangeMemory;
1368  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1369  MutableArrayRef<ValueRange> valueRangeMemory;
1370  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1371 
1372  /// The current loop indices.
1373  MutableArrayRef<unsigned> loopIndex;
1374 
1375  /// References to ByteCode data necessary for execution.
1376  ArrayRef<const void *> uniquedMemory;
1378  ArrayRef<PatternBenefit> currentPatternBenefits;
1380  ArrayRef<PDLConstraintFunction> constraintFunctions;
1381  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1382 };
1383 
1384 /// This class is an instantiation of the PDLResultList that provides access to
1385 /// the returned results. This API is not on `PDLResultList` to avoid
1386 /// overexposing access to information specific solely to the ByteCode.
1387 class ByteCodeRewriteResultList : public PDLResultList {
1388 public:
1389  ByteCodeRewriteResultList(unsigned maxNumResults)
1390  : PDLResultList(maxNumResults) {}
1391 
1392  /// Return the list of PDL results.
1393  MutableArrayRef<PDLValue> getResults() { return results; }
1394 
1395  /// Return the type ranges allocated by this list.
1396  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1397  return allocatedTypeRanges;
1398  }
1399 
1400  /// Return the value ranges allocated by this list.
1401  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1402  return allocatedValueRanges;
1403  }
1404 };
1405 } // namespace
1406 
1407 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1408  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1409  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1411  readList<PDLValue>(args);
1412 
1413  LLVM_DEBUG({
1414  llvm::dbgs() << " * Arguments: ";
1415  llvm::interleaveComma(args, llvm::dbgs());
1416  });
1417 
1418  // Invoke the constraint and jump to the proper destination.
1419  selectJump(succeeded(constraintFn(rewriter, args)));
1420 }
1421 
1422 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1423  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1424  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1426  readList<PDLValue>(args);
1427 
1428  LLVM_DEBUG({
1429  llvm::dbgs() << " * Arguments: ";
1430  llvm::interleaveComma(args, llvm::dbgs());
1431  });
1432 
1433  // Execute the rewrite function.
1434  ByteCodeField numResults = read();
1435  ByteCodeRewriteResultList results(numResults);
1436  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1437 
1438  assert(results.getResults().size() == numResults &&
1439  "native PDL rewrite function returned unexpected number of results");
1440 
1441  // Store the results in the bytecode memory.
1442  for (PDLValue &result : results.getResults()) {
1443  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1444 
1445 // In debug mode we also verify the expected kind of the result.
1446 #ifndef NDEBUG
1447  assert(result.getKind() == read<PDLValue::Kind>() &&
1448  "native PDL rewrite function returned an unexpected type of result");
1449 #endif
1450 
1451  // If the result is a range, we need to copy it over to the bytecodes
1452  // range memory.
1453  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1454  unsigned rangeIndex = read();
1455  typeRangeMemory[rangeIndex] = *typeRange;
1456  memory[read()] = &typeRangeMemory[rangeIndex];
1457  } else if (std::optional<ValueRange> valueRange =
1458  result.dyn_cast<ValueRange>()) {
1459  unsigned rangeIndex = read();
1460  valueRangeMemory[rangeIndex] = *valueRange;
1461  memory[read()] = &valueRangeMemory[rangeIndex];
1462  } else {
1463  memory[read()] = result.getAsOpaquePointer();
1464  }
1465  }
1466 
1467  // Copy over any underlying storage allocated for result ranges.
1468  for (auto &it : results.getAllocatedTypeRanges())
1469  allocatedTypeRangeMemory.push_back(std::move(it));
1470  for (auto &it : results.getAllocatedValueRanges())
1471  allocatedValueRangeMemory.push_back(std::move(it));
1472 
1473  // Process the result of the rewrite.
1474  if (failed(rewriteResult)) {
1475  LLVM_DEBUG(llvm::dbgs() << " - Failed");
1476  return failure();
1477  }
1478  return success();
1479 }
1480 
1481 void ByteCodeExecutor::executeAreEqual() {
1482  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1483  const void *lhs = read<const void *>();
1484  const void *rhs = read<const void *>();
1485 
1486  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1487  selectJump(lhs == rhs);
1488 }
1489 
1490 void ByteCodeExecutor::executeAreRangesEqual() {
1491  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1492  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1493  const void *lhs = read<const void *>();
1494  const void *rhs = read<const void *>();
1495 
1496  switch (valueKind) {
1498  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1499  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1500  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1501  selectJump(*lhsRange == *rhsRange);
1502  break;
1503  }
1505  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1506  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1507  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1508  selectJump(*lhsRange == *rhsRange);
1509  break;
1510  }
1511  default:
1512  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1513  }
1514 }
1515 
1516 void ByteCodeExecutor::executeBranch() {
1517  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1518  curCodeIt = &code[read<ByteCodeAddr>()];
1519 }
1520 
1521 void ByteCodeExecutor::executeCheckOperandCount() {
1522  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1523  Operation *op = read<Operation *>();
1524  uint32_t expectedCount = read<uint32_t>();
1525  bool compareAtLeast = read();
1526 
1527  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1528  << " * Expected: " << expectedCount << "\n"
1529  << " * Comparator: "
1530  << (compareAtLeast ? ">=" : "==") << "\n");
1531  if (compareAtLeast)
1532  selectJump(op->getNumOperands() >= expectedCount);
1533  else
1534  selectJump(op->getNumOperands() == expectedCount);
1535 }
1536 
1537 void ByteCodeExecutor::executeCheckOperationName() {
1538  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1539  Operation *op = read<Operation *>();
1540  OperationName expectedName = read<OperationName>();
1541 
1542  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1543  << " * Expected: \"" << expectedName << "\"\n");
1544  selectJump(op->getName() == expectedName);
1545 }
1546 
1547 void ByteCodeExecutor::executeCheckResultCount() {
1548  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1549  Operation *op = read<Operation *>();
1550  uint32_t expectedCount = read<uint32_t>();
1551  bool compareAtLeast = read();
1552 
1553  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1554  << " * Expected: " << expectedCount << "\n"
1555  << " * Comparator: "
1556  << (compareAtLeast ? ">=" : "==") << "\n");
1557  if (compareAtLeast)
1558  selectJump(op->getNumResults() >= expectedCount);
1559  else
1560  selectJump(op->getNumResults() == expectedCount);
1561 }
1562 
1563 void ByteCodeExecutor::executeCheckTypes() {
1564  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1565  TypeRange *lhs = read<TypeRange *>();
1566  Attribute rhs = read<Attribute>();
1567  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1568 
1569  selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1570 }
1571 
1572 void ByteCodeExecutor::executeContinue() {
1573  ByteCodeField level = read();
1574  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1575  << " * Level: " << level << "\n");
1576  ++loopIndex[level];
1577  popCodeIt();
1578 }
1579 
1580 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1581  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1582  unsigned memIndex = read();
1583  unsigned rangeIndex = read();
1584  ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1585 
1586  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1587  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1588  rangeIndex);
1589 }
1590 
1591 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1592  Location mainRewriteLoc) {
1593  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1594 
1595  unsigned memIndex = read();
1596  OperationState state(mainRewriteLoc, read<OperationName>());
1597  readList(state.operands);
1598  for (unsigned i = 0, e = read(); i != e; ++i) {
1599  StringAttr name = read<StringAttr>();
1600  if (Attribute attr = read<Attribute>())
1601  state.addAttribute(name, attr);
1602  }
1603 
1604  // Read in the result types. If the "size" is the sentinel value, this
1605  // indicates that the result types should be inferred.
1606  unsigned numResults = read();
1607  if (numResults == kInferTypesMarker) {
1608  InferTypeOpInterface::Concept *inferInterface =
1609  state.name.getInterface<InferTypeOpInterface>();
1610  assert(inferInterface &&
1611  "expected operation to provide InferTypeOpInterface");
1612 
1613  // TODO: Handle failure.
1614  if (failed(inferInterface->inferReturnTypes(
1615  state.getContext(), state.location, state.operands,
1616  state.attributes.getDictionary(state.getContext()), state.regions,
1617  state.types)))
1618  return;
1619  } else {
1620  // Otherwise, this is a fixed number of results.
1621  for (unsigned i = 0; i != numResults; ++i) {
1622  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1623  state.types.push_back(read<Type>());
1624  } else {
1625  TypeRange *resultTypes = read<TypeRange *>();
1626  state.types.append(resultTypes->begin(), resultTypes->end());
1627  }
1628  }
1629  }
1630 
1631  Operation *resultOp = rewriter.create(state);
1632  memory[memIndex] = resultOp;
1633 
1634  LLVM_DEBUG({
1635  llvm::dbgs() << " * Attributes: "
1636  << state.attributes.getDictionary(state.getContext())
1637  << "\n * Operands: ";
1638  llvm::interleaveComma(state.operands, llvm::dbgs());
1639  llvm::dbgs() << "\n * Result Types: ";
1640  llvm::interleaveComma(state.types, llvm::dbgs());
1641  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1642  });
1643 }
1644 
1645 template <typename T>
1646 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1647  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1648  unsigned memIndex = read();
1649  unsigned rangeIndex = read();
1650  SmallVector<T> values;
1651  readList(values);
1652 
1653  LLVM_DEBUG({
1654  llvm::dbgs() << "\n * " << type << "s: ";
1655  llvm::interleaveComma(values, llvm::dbgs());
1656  llvm::dbgs() << "\n";
1657  });
1658 
1659  assignRangeToMemory(values, memIndex, rangeIndex);
1660 }
1661 
1662 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1663  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1664  Operation *op = read<Operation *>();
1665 
1666  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1667  rewriter.eraseOp(op);
1668 }
1669 
1670 template <typename T, typename Range, PDLValue::Kind kind>
1671 void ByteCodeExecutor::executeExtract() {
1672  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1673  Range *range = read<Range *>();
1674  unsigned index = read<uint32_t>();
1675  unsigned memIndex = read();
1676 
1677  if (!range) {
1678  memory[memIndex] = nullptr;
1679  return;
1680  }
1681 
1682  T result = index < range->size() ? (*range)[index] : T();
1683  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1684  << " * Index: " << index << "\n"
1685  << " * Result: " << result << "\n");
1686  storeToMemory(memIndex, result);
1687 }
1688 
1689 void ByteCodeExecutor::executeFinalize() {
1690  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1691 }
1692 
1693 void ByteCodeExecutor::executeForEach() {
1694  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1695  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1696  unsigned rangeIndex = read();
1697  unsigned memIndex = read();
1698  const void *value = nullptr;
1699 
1700  switch (read<PDLValue::Kind>()) {
1702  unsigned &index = loopIndex[read()];
1703  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1704  assert(index <= array.size() && "iterated past the end");
1705  if (index < array.size()) {
1706  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1707  value = array[index];
1708  break;
1709  }
1710 
1711  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1712  index = 0;
1713  selectJump(size_t(0));
1714  return;
1715  }
1716  default:
1717  llvm_unreachable("unexpected `ForEach` value kind");
1718  }
1719 
1720  // Store the iterate value and the stack address.
1721  memory[memIndex] = value;
1722  pushCodeIt(prevCodeIt);
1723 
1724  // Skip over the successor (we will enter the body of the loop).
1725  read<ByteCodeAddr>();
1726 }
1727 
1728 void ByteCodeExecutor::executeGetAttribute() {
1729  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1730  unsigned memIndex = read();
1731  Operation *op = read<Operation *>();
1732  StringAttr attrName = read<StringAttr>();
1733  Attribute attr = op->getAttr(attrName);
1734 
1735  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1736  << " * Attribute: " << attrName << "\n"
1737  << " * Result: " << attr << "\n");
1738  memory[memIndex] = attr.getAsOpaquePointer();
1739 }
1740 
1741 void ByteCodeExecutor::executeGetAttributeType() {
1742  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1743  unsigned memIndex = read();
1744  Attribute attr = read<Attribute>();
1745  Type type;
1746  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1747  type = typedAttr.getType();
1748 
1749  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1750  << " * Result: " << type << "\n");
1751  memory[memIndex] = type.getAsOpaquePointer();
1752 }
1753 
1754 void ByteCodeExecutor::executeGetDefiningOp() {
1755  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1756  unsigned memIndex = read();
1757  Operation *op = nullptr;
1758  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1759  Value value = read<Value>();
1760  if (value)
1761  op = value.getDefiningOp();
1762  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1763  } else {
1764  ValueRange *values = read<ValueRange *>();
1765  if (values && !values->empty()) {
1766  op = values->front().getDefiningOp();
1767  }
1768  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1769  }
1770 
1771  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1772  memory[memIndex] = op;
1773 }
1774 
1775 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1776  Operation *op = read<Operation *>();
1777  unsigned memIndex = read();
1778  Value operand =
1779  index < op->getNumOperands() ? op->getOperand(index) : Value();
1780 
1781  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1782  << " * Index: " << index << "\n"
1783  << " * Result: " << operand << "\n");
1784  memory[memIndex] = operand.getAsOpaquePointer();
1785 }
1786 
1787 /// This function is the internal implementation of `GetResults` and
1788 /// `GetOperands` that provides support for extracting a value range from the
1789 /// given operation.
1790 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1791 static void *
1792 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1793  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1794  MutableArrayRef<ValueRange> valueRangeMemory) {
1795  // Check for the sentinel index that signals that all values should be
1796  // returned.
1797  if (index == std::numeric_limits<uint32_t>::max()) {
1798  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1799  // `values` is already the full value range.
1800 
1801  // Otherwise, check to see if this operation uses AttrSizedSegments.
1802  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1803  LLVM_DEBUG(llvm::dbgs()
1804  << " * Extracting values from `" << attrSizedSegments << "`\n");
1805 
1806  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1807  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1808  return nullptr;
1809 
1810  ArrayRef<int32_t> segments = segmentAttr;
1811  unsigned startIndex =
1812  std::accumulate(segments.begin(), segments.begin() + index, 0);
1813  values = values.slice(startIndex, *std::next(segments.begin(), index));
1814 
1815  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1816  << *std::next(segments.begin(), index) << "]\n");
1817 
1818  // Otherwise, assume this is the last operand group of the operation.
1819  // FIXME: We currently don't support operations with
1820  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1821  // have a way to detect it's presence.
1822  } else if (values.size() >= index) {
1823  LLVM_DEBUG(llvm::dbgs()
1824  << " * Treating values as trailing variadic range\n");
1825  values = values.drop_front(index);
1826 
1827  // If we couldn't detect a way to compute the values, bail out.
1828  } else {
1829  return nullptr;
1830  }
1831 
1832  // If the range index is valid, we are returning a range.
1833  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1834  valueRangeMemory[rangeIndex] = values;
1835  return &valueRangeMemory[rangeIndex];
1836  }
1837 
1838  // If a range index wasn't provided, the range is required to be non-variadic.
1839  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1840 }
1841 
1842 void ByteCodeExecutor::executeGetOperands() {
1843  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1844  unsigned index = read<uint32_t>();
1845  Operation *op = read<Operation *>();
1846  ByteCodeField rangeIndex = read();
1847 
1848  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1849  op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1850  valueRangeMemory);
1851  if (!result)
1852  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1853  memory[read()] = result;
1854 }
1855 
1856 void ByteCodeExecutor::executeGetResult(unsigned index) {
1857  Operation *op = read<Operation *>();
1858  unsigned memIndex = read();
1859  OpResult result =
1860  index < op->getNumResults() ? op->getResult(index) : OpResult();
1861 
1862  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1863  << " * Index: " << index << "\n"
1864  << " * Result: " << result << "\n");
1865  memory[memIndex] = result.getAsOpaquePointer();
1866 }
1867 
1868 void ByteCodeExecutor::executeGetResults() {
1869  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1870  unsigned index = read<uint32_t>();
1871  Operation *op = read<Operation *>();
1872  ByteCodeField rangeIndex = read();
1873 
1874  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1875  op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1876  valueRangeMemory);
1877  if (!result)
1878  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1879  memory[read()] = result;
1880 }
1881 
1882 void ByteCodeExecutor::executeGetUsers() {
1883  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1884  unsigned memIndex = read();
1885  unsigned rangeIndex = read();
1886  OwningOpRange &range = opRangeMemory[rangeIndex];
1887  memory[memIndex] = &range;
1888 
1889  range = OwningOpRange();
1890  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1891  // Read the value.
1892  Value value = read<Value>();
1893  if (!value)
1894  return;
1895  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1896 
1897  // Extract the users of a single value.
1898  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1899  llvm::copy(value.getUsers(), range.begin());
1900  } else {
1901  // Read a range of values.
1902  ValueRange *values = read<ValueRange *>();
1903  if (!values)
1904  return;
1905  LLVM_DEBUG({
1906  llvm::dbgs() << " * Values (" << values->size() << "): ";
1907  llvm::interleaveComma(*values, llvm::dbgs());
1908  llvm::dbgs() << "\n";
1909  });
1910 
1911  // Extract all the users of a range of values.
1913  for (Value value : *values)
1914  users.append(value.user_begin(), value.user_end());
1915  range = OwningOpRange(users.size());
1916  llvm::copy(users, range.begin());
1917  }
1918 
1919  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1920 }
1921 
1922 void ByteCodeExecutor::executeGetValueType() {
1923  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1924  unsigned memIndex = read();
1925  Value value = read<Value>();
1926  Type type = value ? value.getType() : Type();
1927 
1928  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1929  << " * Result: " << type << "\n");
1930  memory[memIndex] = type.getAsOpaquePointer();
1931 }
1932 
1933 void ByteCodeExecutor::executeGetValueRangeTypes() {
1934  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1935  unsigned memIndex = read();
1936  unsigned rangeIndex = read();
1937  ValueRange *values = read<ValueRange *>();
1938  if (!values) {
1939  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1940  memory[memIndex] = nullptr;
1941  return;
1942  }
1943 
1944  LLVM_DEBUG({
1945  llvm::dbgs() << " * Values (" << values->size() << "): ";
1946  llvm::interleaveComma(*values, llvm::dbgs());
1947  llvm::dbgs() << "\n * Result: ";
1948  llvm::interleaveComma(values->getType(), llvm::dbgs());
1949  llvm::dbgs() << "\n";
1950  });
1951  typeRangeMemory[rangeIndex] = values->getType();
1952  memory[memIndex] = &typeRangeMemory[rangeIndex];
1953 }
1954 
1955 void ByteCodeExecutor::executeIsNotNull() {
1956  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1957  const void *value = read<const void *>();
1958 
1959  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1960  selectJump(value != nullptr);
1961 }
1962 
1963 void ByteCodeExecutor::executeRecordMatch(
1964  PatternRewriter &rewriter,
1966  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1967  unsigned patternIndex = read();
1968  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1969  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1970 
1971  // If the benefit of the pattern is impossible, skip the processing of the
1972  // rest of the pattern.
1973  if (benefit.isImpossibleToMatch()) {
1974  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1975  curCodeIt = dest;
1976  return;
1977  }
1978 
1979  // Create a fused location containing the locations of each of the
1980  // operations used in the match. This will be used as the location for
1981  // created operations during the rewrite that don't already have an
1982  // explicit location set.
1983  unsigned numMatchLocs = read();
1984  SmallVector<Location, 4> matchLocs;
1985  matchLocs.reserve(numMatchLocs);
1986  for (unsigned i = 0; i != numMatchLocs; ++i)
1987  matchLocs.push_back(read<Operation *>()->getLoc());
1988  Location matchLoc = rewriter.getFusedLoc(matchLocs);
1989 
1990  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1991  << " * Location: " << matchLoc << "\n");
1992  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1993  PDLByteCode::MatchResult &match = matches.back();
1994 
1995  // Record all of the inputs to the match. If any of the inputs are ranges, we
1996  // will also need to remap the range pointer to memory stored in the match
1997  // state.
1998  unsigned numInputs = read();
1999  match.values.reserve(numInputs);
2000  match.typeRangeValues.reserve(numInputs);
2001  match.valueRangeValues.reserve(numInputs);
2002  for (unsigned i = 0; i < numInputs; ++i) {
2003  switch (read<PDLValue::Kind>()) {
2005  match.typeRangeValues.push_back(*read<TypeRange *>());
2006  match.values.push_back(&match.typeRangeValues.back());
2007  break;
2009  match.valueRangeValues.push_back(*read<ValueRange *>());
2010  match.values.push_back(&match.valueRangeValues.back());
2011  break;
2012  default:
2013  match.values.push_back(read<const void *>());
2014  break;
2015  }
2016  }
2017  curCodeIt = dest;
2018 }
2019 
2020 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2021  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2022  Operation *op = read<Operation *>();
2024  readList(args);
2025 
2026  LLVM_DEBUG({
2027  llvm::dbgs() << " * Operation: " << *op << "\n"
2028  << " * Values: ";
2029  llvm::interleaveComma(args, llvm::dbgs());
2030  llvm::dbgs() << "\n";
2031  });
2032  rewriter.replaceOp(op, args);
2033 }
2034 
2035 void ByteCodeExecutor::executeSwitchAttribute() {
2036  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2037  Attribute value = read<Attribute>();
2038  ArrayAttr cases = read<ArrayAttr>();
2039  handleSwitch(value, cases);
2040 }
2041 
2042 void ByteCodeExecutor::executeSwitchOperandCount() {
2043  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2044  Operation *op = read<Operation *>();
2045  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2046 
2047  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2048  handleSwitch(op->getNumOperands(), cases);
2049 }
2050 
2051 void ByteCodeExecutor::executeSwitchOperationName() {
2052  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2053  OperationName value = read<Operation *>()->getName();
2054  size_t caseCount = read();
2055 
2056  // The operation names are stored in-line, so to print them out for
2057  // debugging purposes we need to read the array before executing the
2058  // switch so that we can display all of the possible values.
2059  LLVM_DEBUG({
2060  const ByteCodeField *prevCodeIt = curCodeIt;
2061  llvm::dbgs() << " * Value: " << value << "\n"
2062  << " * Cases: ";
2063  llvm::interleaveComma(
2064  llvm::map_range(llvm::seq<size_t>(0, caseCount),
2065  [&](size_t) { return read<OperationName>(); }),
2066  llvm::dbgs());
2067  llvm::dbgs() << "\n";
2068  curCodeIt = prevCodeIt;
2069  });
2070 
2071  // Try to find the switch value within any of the cases.
2072  for (size_t i = 0; i != caseCount; ++i) {
2073  if (read<OperationName>() == value) {
2074  curCodeIt += (caseCount - i - 1);
2075  return selectJump(i + 1);
2076  }
2077  }
2078  selectJump(size_t(0));
2079 }
2080 
2081 void ByteCodeExecutor::executeSwitchResultCount() {
2082  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2083  Operation *op = read<Operation *>();
2084  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2085 
2086  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2087  handleSwitch(op->getNumResults(), cases);
2088 }
2089 
2090 void ByteCodeExecutor::executeSwitchType() {
2091  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2092  Type value = read<Type>();
2093  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2094  handleSwitch(value, cases);
2095 }
2096 
2097 void ByteCodeExecutor::executeSwitchTypes() {
2098  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2099  TypeRange *value = read<TypeRange *>();
2100  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2101  if (!value) {
2102  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2103  return selectJump(size_t(0));
2104  }
2105  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2106  return value == caseValue.getAsValueRange<TypeAttr>();
2107  });
2108 }
2109 
2111 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2113  std::optional<Location> mainRewriteLoc) {
2114  while (true) {
2115  // Print the location of the operation being executed.
2116  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2117 
2118  OpCode opCode = static_cast<OpCode>(read());
2119  switch (opCode) {
2120  case ApplyConstraint:
2121  executeApplyConstraint(rewriter);
2122  break;
2123  case ApplyRewrite:
2124  if (failed(executeApplyRewrite(rewriter)))
2125  return failure();
2126  break;
2127  case AreEqual:
2128  executeAreEqual();
2129  break;
2130  case AreRangesEqual:
2131  executeAreRangesEqual();
2132  break;
2133  case Branch:
2134  executeBranch();
2135  break;
2136  case CheckOperandCount:
2137  executeCheckOperandCount();
2138  break;
2139  case CheckOperationName:
2140  executeCheckOperationName();
2141  break;
2142  case CheckResultCount:
2143  executeCheckResultCount();
2144  break;
2145  case CheckTypes:
2146  executeCheckTypes();
2147  break;
2148  case Continue:
2149  executeContinue();
2150  break;
2151  case CreateConstantTypeRange:
2152  executeCreateConstantTypeRange();
2153  break;
2154  case CreateOperation:
2155  executeCreateOperation(rewriter, *mainRewriteLoc);
2156  break;
2157  case CreateDynamicTypeRange:
2158  executeDynamicCreateRange<Type>("Type");
2159  break;
2160  case CreateDynamicValueRange:
2161  executeDynamicCreateRange<Value>("Value");
2162  break;
2163  case EraseOp:
2164  executeEraseOp(rewriter);
2165  break;
2166  case ExtractOp:
2167  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2168  break;
2169  case ExtractType:
2170  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2171  break;
2172  case ExtractValue:
2173  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2174  break;
2175  case Finalize:
2176  executeFinalize();
2177  LLVM_DEBUG(llvm::dbgs() << "\n");
2178  return success();
2179  case ForEach:
2180  executeForEach();
2181  break;
2182  case GetAttribute:
2183  executeGetAttribute();
2184  break;
2185  case GetAttributeType:
2186  executeGetAttributeType();
2187  break;
2188  case GetDefiningOp:
2189  executeGetDefiningOp();
2190  break;
2191  case GetOperand0:
2192  case GetOperand1:
2193  case GetOperand2:
2194  case GetOperand3: {
2195  unsigned index = opCode - GetOperand0;
2196  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2197  executeGetOperand(index);
2198  break;
2199  }
2200  case GetOperandN:
2201  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2202  executeGetOperand(read<uint32_t>());
2203  break;
2204  case GetOperands:
2205  executeGetOperands();
2206  break;
2207  case GetResult0:
2208  case GetResult1:
2209  case GetResult2:
2210  case GetResult3: {
2211  unsigned index = opCode - GetResult0;
2212  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2213  executeGetResult(index);
2214  break;
2215  }
2216  case GetResultN:
2217  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2218  executeGetResult(read<uint32_t>());
2219  break;
2220  case GetResults:
2221  executeGetResults();
2222  break;
2223  case GetUsers:
2224  executeGetUsers();
2225  break;
2226  case GetValueType:
2227  executeGetValueType();
2228  break;
2229  case GetValueRangeTypes:
2230  executeGetValueRangeTypes();
2231  break;
2232  case IsNotNull:
2233  executeIsNotNull();
2234  break;
2235  case RecordMatch:
2236  assert(matches &&
2237  "expected matches to be provided when executing the matcher");
2238  executeRecordMatch(rewriter, *matches);
2239  break;
2240  case ReplaceOp:
2241  executeReplaceOp(rewriter);
2242  break;
2243  case SwitchAttribute:
2244  executeSwitchAttribute();
2245  break;
2246  case SwitchOperandCount:
2247  executeSwitchOperandCount();
2248  break;
2249  case SwitchOperationName:
2250  executeSwitchOperationName();
2251  break;
2252  case SwitchResultCount:
2253  executeSwitchResultCount();
2254  break;
2255  case SwitchType:
2256  executeSwitchType();
2257  break;
2258  case SwitchTypes:
2259  executeSwitchTypes();
2260  break;
2261  }
2262  LLVM_DEBUG(llvm::dbgs() << "\n");
2263  }
2264 }
2265 
2268  PDLByteCodeMutableState &state) const {
2269  // The first memory slot is always the root operation.
2270  state.memory[0] = op;
2271 
2272  // The matcher function always starts at code address 0.
2273  ByteCodeExecutor executor(
2274  matcherByteCode.data(), state.memory, state.opRangeMemory,
2275  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2276  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2277  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2278  constraintFunctions, rewriteFunctions);
2279  LogicalResult executeResult = executor.execute(rewriter, &matches);
2280  (void)executeResult;
2281  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2282 
2283  // Order the found matches by benefit.
2284  std::stable_sort(matches.begin(), matches.end(),
2285  [](const MatchResult &lhs, const MatchResult &rhs) {
2286  return lhs.benefit > rhs.benefit;
2287  });
2288 }
2289 
2291  const MatchResult &match,
2292  PDLByteCodeMutableState &state) const {
2293  auto *configSet = match.pattern->getConfigSet();
2294  if (configSet)
2295  configSet->notifyRewriteBegin(rewriter);
2296 
2297  // The arguments of the rewrite function are stored at the start of the
2298  // memory buffer.
2299  llvm::copy(match.values, state.memory.begin());
2300 
2301  ByteCodeExecutor executor(
2302  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2303  state.opRangeMemory, state.typeRangeMemory,
2304  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2305  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2306  rewriterByteCode, state.currentPatternBenefits, patterns,
2307  constraintFunctions, rewriteFunctions);
2308  LogicalResult result =
2309  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2310 
2311  if (configSet)
2312  configSet->notifyRewriteEnd(rewriter);
2313 
2314  // If the rewrite failed, check if the pattern rewriter can recover. If it
2315  // can, we can signal to the pattern applicator to keep trying patterns. If it
2316  // doesn't, we need to bail. Bailing here should be fine, given that we have
2317  // no means to propagate such a failure to the user, and it also indicates a
2318  // bug in the user code (i.e. failable rewrites should not be used with
2319  // pattern rewriters that don't support it).
2320  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2321  LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2322  llvm::report_fatal_error(
2323  "Native PDL Rewrite failed, but the pattern "
2324  "rewriter doesn't support recovery. Failable pattern rewrites should "
2325  "not be used with pattern rewriters that do not support them.");
2326  }
2327  return result;
2328 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1792
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:173
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:166
U cast() const
Definition: Attributes.h:176
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:82
This class represents an argument of a Block.
Definition: Value.h:304
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:432
This is a value defined by a result of an operation.
Definition: Value.h:442
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:75
Value getOperand(unsigned idx)
Definition: Operation.h:329
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:592
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:437
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:433
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:386
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:207
unsigned getNumOperands()
Definition: Operation.h:325
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:540
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:103
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:357
result_range getResults()
Definition: Operation.h:394
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:383
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:920
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:800
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:688
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:691
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:668
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:676
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:231
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:175
U dyn_cast() const
Definition: Types.h:311
bool isa() const
Definition: Types.h:301
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:217
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:370
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
user_iterator user_begin() const
Definition: Value.h:215
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:231
user_iterator user_end() const
Definition: Value.h:216
user_range getUsers() const
Definition: Value.h:217
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class contains the mutable state of a bytecode instance.
Definition: ByteCode.h:71
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.cpp:64
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
Definition: ByteCode.cpp:72
All of the data pertaining to a specific pattern within the bytecode.
Definition: ByteCode.h:38
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr)
Definition: ByteCode.cpp:37
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches.
Definition: ByteCode.cpp:2266
PDLByteCode(ModuleOp module, SmallVector< std::unique_ptr< PDLPatternConfigSet >> configs, const DenseMap< Operation *, PDLPatternConfigSet * > &configMap, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
Definition: ByteCode.cpp:1042
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition: ByteCode.cpp:1063
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Definition: ByteCode.cpp:2290
Detect if any of the given parameter types has a sub-element handler.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
uint32_t ByteCodeAddr
Definition: ByteCode.h:30
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:223
uint16_t ByteCodeField
Use generic bytecode types.
Definition: ByteCode.h:29
llvm::OwningArrayRef< Operation * > OwningOpRange
Definition: ByteCode.h:31
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:982
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:990
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Definition: ByteCode.h:132
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
Definition: ByteCode.h:146
SmallVector< ValueRange, 0 > valueRangeValues
Definition: ByteCode.h:147
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.
Definition: ByteCode.h:144