MLIR  18.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38  PDLPatternConfigSet *configSet,
39  ByteCodeAddr rewriterAddr) {
40  PatternBenefit benefit = matchOp.getBenefit();
41  MLIRContext *ctx = matchOp.getContext();
42 
43  // Collect the set of generated operations.
44  SmallVector<StringRef, 8> generatedOps;
45  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46  generatedOps =
47  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49  // Check to see if this is pattern matches a specific operation type.
50  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52  generatedOps);
53  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54  benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
65  PatternBenefit benefit) {
66  currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
73  allocatedTypeRangeMemory.clear();
74  allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83  /// Apply an externally registered constraint.
84  ApplyConstraint,
85  /// Apply an externally registered rewrite.
86  ApplyRewrite,
87  /// Check if two generic values are equal.
88  AreEqual,
89  /// Check if two ranges are equal.
90  AreRangesEqual,
91  /// Unconditional branch.
92  Branch,
93  /// Compare the operand count of an operation with a constant.
94  CheckOperandCount,
95  /// Compare the name of an operation with a constant.
96  CheckOperationName,
97  /// Compare the result count of an operation with a constant.
98  CheckResultCount,
99  /// Compare a range of types to a constant range of types.
100  CheckTypes,
101  /// Continue to the next iteration of a loop.
102  Continue,
103  /// Create a type range from a list of constant types.
104  CreateConstantTypeRange,
105  /// Create an operation.
106  CreateOperation,
107  /// Create a type range from a list of dynamic types.
108  CreateDynamicTypeRange,
109  /// Create a value range.
110  CreateDynamicValueRange,
111  /// Erase an operation.
112  EraseOp,
113  /// Extract the op from a range at the specified index.
114  ExtractOp,
115  /// Extract the type from a range at the specified index.
116  ExtractType,
117  /// Extract the value from a range at the specified index.
118  ExtractValue,
119  /// Terminate a matcher or rewrite sequence.
120  Finalize,
121  /// Iterate over a range of values.
122  ForEach,
123  /// Get a specific attribute of an operation.
124  GetAttribute,
125  /// Get the type of an attribute.
126  GetAttributeType,
127  /// Get the defining operation of a value.
128  GetDefiningOp,
129  /// Get a specific operand of an operation.
130  GetOperand0,
131  GetOperand1,
132  GetOperand2,
133  GetOperand3,
134  GetOperandN,
135  /// Get a specific operand group of an operation.
136  GetOperands,
137  /// Get a specific result of an operation.
138  GetResult0,
139  GetResult1,
140  GetResult2,
141  GetResult3,
142  GetResultN,
143  /// Get a specific result group of an operation.
144  GetResults,
145  /// Get the users of a value or a range of values.
146  GetUsers,
147  /// Get the type of a value.
148  GetValueType,
149  /// Get the types of a value range.
150  GetValueRangeTypes,
151  /// Check if a generic value is not null.
152  IsNotNull,
153  /// Record a successful pattern match.
154  RecordMatch,
155  /// Replace an operation.
156  ReplaceOp,
157  /// Compare an attribute with a set of constants.
158  SwitchAttribute,
159  /// Compare the operand count of an operation with a set of constants.
160  SwitchOperandCount,
161  /// Compare the name of an operation with a set of constants.
162  SwitchOperationName,
163  /// Compare the result count of an operation with a set of constants.
164  SwitchResultCount,
165  /// Compare a type with a set of constants.
166  SwitchType,
167  /// Compare a range of types with a set of constants.
168  SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 
183 namespace {
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
186 
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T, typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190 
191 /// This class represents the main generator for the pattern bytecode.
192 class Generator {
193 public:
194  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195  SmallVectorImpl<ByteCodeField> &matcherByteCode,
196  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
198  ByteCodeField &maxValueMemoryIndex,
199  ByteCodeField &maxOpRangeMemoryIndex,
200  ByteCodeField &maxTypeRangeMemoryIndex,
201  ByteCodeField &maxValueRangeMemoryIndex,
202  ByteCodeField &maxLoopLevel,
203  llvm::StringMap<PDLConstraintFunction> &constraintFns,
204  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207  rewriterByteCode(rewriterByteCode), patterns(patterns),
208  maxValueMemoryIndex(maxValueMemoryIndex),
209  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212  maxLoopLevel(maxLoopLevel), configMap(configMap) {
213  for (const auto &it : llvm::enumerate(constraintFns))
214  constraintToMemIndex.try_emplace(it.value().first(), it.index());
215  for (const auto &it : llvm::enumerate(rewriteFns))
216  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217  }
218 
219  /// Generate the bytecode for the given PDL interpreter module.
220  void generate(ModuleOp module);
221 
222  /// Return the memory index to use for the given value.
223  ByteCodeField &getMemIndex(Value value) {
224  assert(valueToMemIndex.count(value) &&
225  "expected memory index to be assigned");
226  return valueToMemIndex[value];
227  }
228 
229  /// Return the range memory index used to store the given range value.
230  ByteCodeField &getRangeStorageIndex(Value value) {
231  assert(valueToRangeIndex.count(value) &&
232  "expected range index to be assigned");
233  return valueToRangeIndex[value];
234  }
235 
236  /// Return an index to use when referring to the given data that is uniqued in
237  /// the MLIR context.
238  template <typename T>
239  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240  getMemIndex(T val) {
241  const void *opaqueVal = val.getAsOpaquePointer();
242 
243  // Get or insert a reference to this value.
244  auto it = uniquedDataToMemIndex.try_emplace(
245  opaqueVal, maxValueMemoryIndex + uniquedData.size());
246  if (it.second)
247  uniquedData.push_back(opaqueVal);
248  return it.first->second;
249  }
250 
251 private:
252  /// Allocate memory indices for the results of operations within the matcher
253  /// and rewriters.
254  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255  ModuleOp rewriterModule);
256 
257  /// Generate the bytecode for the given operation.
258  void generate(Region *region, ByteCodeWriter &writer);
259  void generate(Operation *op, ByteCodeWriter &writer);
260  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 
299  /// Mapping from value to its corresponding memory index.
300  DenseMap<Value, ByteCodeField> valueToMemIndex;
301 
302  /// Mapping from a range value to its corresponding range storage index.
303  DenseMap<Value, ByteCodeField> valueToRangeIndex;
304 
305  /// Mapping from the name of an externally registered rewrite to its index in
306  /// the bytecode registry.
307  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308 
309  /// Mapping from the name of an externally registered constraint to its index
310  /// in the bytecode registry.
311  llvm::StringMap<ByteCodeField> constraintToMemIndex;
312 
313  /// Mapping from rewriter function name to the bytecode address of the
314  /// rewriter function in byte.
315  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316 
317  /// Mapping from a uniqued storage object to its memory index within
318  /// `uniquedData`.
319  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320 
321  /// The current level of the foreach loop.
322  ByteCodeField curLoopLevel = 0;
323 
324  /// The current MLIR context.
325  MLIRContext *ctx;
326 
327  /// Mapping from block to its address.
329 
330  /// Data of the ByteCode class to be populated.
331  std::vector<const void *> &uniquedData;
332  SmallVectorImpl<ByteCodeField> &matcherByteCode;
333  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
335  ByteCodeField &maxValueMemoryIndex;
336  ByteCodeField &maxOpRangeMemoryIndex;
337  ByteCodeField &maxTypeRangeMemoryIndex;
338  ByteCodeField &maxValueRangeMemoryIndex;
339  ByteCodeField &maxLoopLevel;
340 
341  /// A map of pattern configurations.
343 };
344 
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter {
347  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348  : bytecode(bytecode), generator(generator) {}
349 
350  /// Append a field to the bytecode.
351  void append(ByteCodeField field) { bytecode.push_back(field); }
352  void append(OpCode opCode) { bytecode.push_back(opCode); }
353 
354  /// Append an address to the bytecode.
355  void append(ByteCodeAddr field) {
356  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357  "unexpected ByteCode address size");
358 
359  ByteCodeField fieldParts[2];
360  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361  bytecode.append({fieldParts[0], fieldParts[1]});
362  }
363 
364  /// Append a single successor to the bytecode, the exact address will need to
365  /// be resolved later.
366  void append(Block *successor) {
367  // Add back a reference to the successor so that the address can be resolved
368  // later.
369  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370  append(ByteCodeAddr(0));
371  }
372 
373  /// Append a successor range to the bytecode, the exact address will need to
374  /// be resolved later.
375  void append(SuccessorRange successors) {
376  for (Block *successor : successors)
377  append(successor);
378  }
379 
380  /// Append a range of values that will be read as generic PDLValues.
381  void appendPDLValueList(OperandRange values) {
382  bytecode.push_back(values.size());
383  for (Value value : values)
384  appendPDLValue(value);
385  }
386 
387  /// Append a value as a PDLValue.
388  void appendPDLValue(Value value) {
389  appendPDLValueKind(value);
390  append(value);
391  }
392 
393  /// Append the PDLValue::Kind of the given value.
394  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
395 
396  /// Append the PDLValue::Kind of the given type.
397  void appendPDLValueKind(Type type) {
398  PDLValue::Kind kind =
400  .Case<pdl::AttributeType>(
401  [](Type) { return PDLValue::Kind::Attribute; })
402  .Case<pdl::OperationType>(
403  [](Type) { return PDLValue::Kind::Operation; })
404  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405  if (isa<pdl::TypeType>(rangeTy.getElementType()))
408  })
409  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411  bytecode.push_back(static_cast<ByteCodeField>(kind));
412  }
413 
414  /// Append a value that will be stored in a memory slot and not inline within
415  /// the bytecode.
416  template <typename T>
417  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418  std::is_pointer<T>::value>
419  append(T value) {
420  bytecode.push_back(generator.getMemIndex(value));
421  }
422 
423  /// Append a range of values.
424  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426  append(T range) {
427  bytecode.push_back(llvm::size(range));
428  for (auto it : range)
429  append(it);
430  }
431 
432  /// Append a variadic number of fields to the bytecode.
433  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435  append(field);
436  append(field2, fields...);
437  }
438 
439  /// Appends a value as a pointer, stored inline within the bytecode.
440  template <typename T>
441  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442  appendInline(T value) {
443  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444  const void *pointer = value.getAsOpaquePointer();
445  ByteCodeField fieldParts[numParts];
446  std::memcpy(fieldParts, &pointer, sizeof(const void *));
447  bytecode.append(fieldParts, fieldParts + numParts);
448  }
449 
450  /// Successor references in the bytecode that have yet to be resolved.
451  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452 
453  /// The underlying bytecode buffer.
455 
456  /// The main generator producing PDL.
457  Generator &generator;
458 };
459 
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange {
463  using Set = llvm::IntervalMap<uint64_t, char, 16>;
464  using Allocator = Set::Allocator;
465 
466  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467 
468  /// Union this live range with the one provided.
469  void unionWith(const ByteCodeLiveRange &rhs) {
470  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471  ++it)
472  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
473  }
474 
475  /// Returns true if this range overlaps with the one provided.
476  bool overlaps(const ByteCodeLiveRange &rhs) const {
477  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478  .valid();
479  }
480 
481  /// A map representing the ranges of the match/rewrite that a value is live in
482  /// the interpreter.
483  ///
484  /// We use std::unique_ptr here, because IntervalMap does not provide a
485  /// correct copy or move constructor. We can eliminate the pointer once
486  /// https://reviews.llvm.org/D113240 lands.
487  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 
489  /// The operation range storage index for this range.
490  std::optional<unsigned> opRangeIndex;
491 
492  /// The type range storage index for this range.
493  std::optional<unsigned> typeRangeIndex;
494 
495  /// The value range storage index for this range.
496  std::optional<unsigned> valueRangeIndex;
497 };
498 } // namespace
499 
500 void Generator::generate(ModuleOp module) {
501  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504  pdl_interp::PDLInterpDialect::getRewriterModuleName());
505  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506 
507  // Allocate memory indices for the results of operations within the matcher
508  // and rewriters.
509  allocateMemoryIndices(matcherFunc, rewriterModule);
510 
511  // Generate code for the rewriter functions.
512  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515  for (Operation &op : rewriterFunc.getOps())
516  generate(&op, rewriterByteCodeWriter);
517  }
518  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519  "unexpected branches in rewriter function");
520 
521  // Generate code for the matcher function.
522  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524 
525  // Resolve successor references in the matcher.
526  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527  ByteCodeAddr addr = blockToAddr[it.first];
528  for (unsigned offsetToFix : it.second)
529  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530  }
531 }
532 
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534  ModuleOp rewriterModule) {
535  // Rewriters use simplistic allocation scheme that simply assigns an index to
536  // each result.
537  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539  auto processRewriterValue = [&](Value val) {
540  valueToMemIndex.try_emplace(val, index++);
541  if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542  Type elementTy = rangeType.getElementType();
543  if (isa<pdl::TypeType>(elementTy))
544  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545  else if (isa<pdl::ValueType>(elementTy))
546  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547  }
548  };
549 
550  for (BlockArgument arg : rewriterFunc.getArguments())
551  processRewriterValue(arg);
552  rewriterFunc.getBody().walk([&](Operation *op) {
553  for (Value result : op->getResults())
554  processRewriterValue(result);
555  });
556  if (index > maxValueMemoryIndex)
557  maxValueMemoryIndex = index;
558  if (typeRangeIndex > maxTypeRangeMemoryIndex)
559  maxTypeRangeMemoryIndex = typeRangeIndex;
560  if (valueRangeIndex > maxValueRangeMemoryIndex)
561  maxValueRangeMemoryIndex = valueRangeIndex;
562  }
563 
564  // The matcher function uses a more sophisticated numbering that tries to
565  // minimize the number of memory indices assigned. This is done by determining
566  // a live range of the values within the matcher, then the allocation is just
567  // finding the minimal number of overlapping live ranges. This is essentially
568  // a simplified form of register allocation where we don't necessarily have a
569  // limited number of registers, but we still want to minimize the number used.
570  DenseMap<Operation *, unsigned> opToFirstIndex;
571  DenseMap<Operation *, unsigned> opToLastIndex;
572 
573  // A custom walk that marks the first and the last index of each operation.
574  // The entry marks the beginning of the liveness range for this operation,
575  // followed by nested operations, followed by the end of the liveness range.
576  unsigned index = 0;
577  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578  opToFirstIndex.try_emplace(op, index++);
579  for (Region &region : op->getRegions())
580  for (Block &block : region.getBlocks())
581  for (Operation &nested : block)
582  walk(&nested);
583  opToLastIndex.try_emplace(op, index++);
584  };
585  walk(matcherFunc);
586 
587  // Liveness info for each of the defs within the matcher.
588  ByteCodeLiveRange::Allocator allocator;
589  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590 
591  // Assign the root operation being matched to slot 0.
592  BlockArgument rootOpArg = matcherFunc.getArgument(0);
593  valueToMemIndex[rootOpArg] = 0;
594 
595  // Walk each of the blocks, computing the def interval that the value is used.
596  Liveness matcherLiveness(matcherFunc);
597  matcherFunc->walk([&](Block *block) {
598  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599  assert(info && "expected liveness info for block");
600  auto processValue = [&](Value value, Operation *firstUseOrDef) {
601  // We don't need to process the root op argument, this value is always
602  // assigned to the first memory slot.
603  if (value == rootOpArg)
604  return;
605 
606  // Set indices for the range of this block that the value is used.
607  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608  defRangeIt->second.liveness->insert(
609  opToFirstIndex[firstUseOrDef],
610  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611  /*dummyValue*/ 0);
612 
613  // Check to see if this value is a range type.
614  if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
615  Type eleType = rangeTy.getElementType();
616  if (isa<pdl::OperationType>(eleType))
617  defRangeIt->second.opRangeIndex = 0;
618  else if (isa<pdl::TypeType>(eleType))
619  defRangeIt->second.typeRangeIndex = 0;
620  else if (isa<pdl::ValueType>(eleType))
621  defRangeIt->second.valueRangeIndex = 0;
622  }
623  };
624 
625  // Process the live-ins of this block.
626  for (Value liveIn : info->in()) {
627  // Only process the value if it has been defined in the current region.
628  // Other values that span across pdl_interp.foreach will be added higher
629  // up. This ensures that the we keep them alive for the entire duration
630  // of the loop.
631  if (liveIn.getParentRegion() == block->getParent())
632  processValue(liveIn, &block->front());
633  }
634 
635  // Process the block arguments for the entry block (those are not live-in).
636  if (block->isEntryBlock()) {
637  for (Value argument : block->getArguments())
638  processValue(argument, &block->front());
639  }
640 
641  // Process any new defs within this block.
642  for (Operation &op : *block)
643  for (Value result : op.getResults())
644  processValue(result, &op);
645  });
646 
647  // Greedily allocate memory slots using the computed def live ranges.
648  std::vector<ByteCodeLiveRange> allocatedIndices;
649 
650  // The number of memory indices currently allocated (and its next value).
651  // Recall that the root gets allocated memory index 0.
652  ByteCodeField numIndices = 1;
653 
654  // The number of memory ranges of various types (and their next values).
655  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656 
657  for (auto &defIt : valueDefRanges) {
658  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659  ByteCodeLiveRange &defRange = defIt.second;
660 
661  // Try to allocate to an existing index.
662  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
663  ByteCodeLiveRange &existingRange = existingIndexIt.value();
664  if (!defRange.overlaps(existingRange)) {
665  existingRange.unionWith(defRange);
666  memIndex = existingIndexIt.index() + 1;
667 
668  if (defRange.opRangeIndex) {
669  if (!existingRange.opRangeIndex)
670  existingRange.opRangeIndex = numOpRanges++;
671  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672  } else if (defRange.typeRangeIndex) {
673  if (!existingRange.typeRangeIndex)
674  existingRange.typeRangeIndex = numTypeRanges++;
675  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676  } else if (defRange.valueRangeIndex) {
677  if (!existingRange.valueRangeIndex)
678  existingRange.valueRangeIndex = numValueRanges++;
679  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680  }
681  break;
682  }
683  }
684 
685  // If no existing index could be used, add a new one.
686  if (memIndex == 0) {
687  allocatedIndices.emplace_back(allocator);
688  ByteCodeLiveRange &newRange = allocatedIndices.back();
689  newRange.unionWith(defRange);
690 
691  // Allocate an index for op/type/value ranges.
692  if (defRange.opRangeIndex) {
693  newRange.opRangeIndex = numOpRanges;
694  valueToRangeIndex[defIt.first] = numOpRanges++;
695  } else if (defRange.typeRangeIndex) {
696  newRange.typeRangeIndex = numTypeRanges;
697  valueToRangeIndex[defIt.first] = numTypeRanges++;
698  } else if (defRange.valueRangeIndex) {
699  newRange.valueRangeIndex = numValueRanges;
700  valueToRangeIndex[defIt.first] = numValueRanges++;
701  }
702 
703  memIndex = allocatedIndices.size();
704  ++numIndices;
705  }
706  }
707 
708  // Print the index usage and ensure that we did not run out of index space.
709  LLVM_DEBUG({
710  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711  << "(down from initial " << valueDefRanges.size() << ").\n";
712  });
713  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714  "Ran out of memory for allocated indices");
715 
716  // Update the max number of indices.
717  if (numIndices > maxValueMemoryIndex)
718  maxValueMemoryIndex = numIndices;
719  if (numOpRanges > maxOpRangeMemoryIndex)
720  maxOpRangeMemoryIndex = numOpRanges;
721  if (numTypeRanges > maxTypeRangeMemoryIndex)
722  maxTypeRangeMemoryIndex = numTypeRanges;
723  if (numValueRanges > maxValueRangeMemoryIndex)
724  maxValueRangeMemoryIndex = numValueRanges;
725 }
726 
727 void Generator::generate(Region *region, ByteCodeWriter &writer) {
728  llvm::ReversePostOrderTraversal<Region *> rpot(region);
729  for (Block *block : rpot) {
730  // Keep track of where this block begins within the matcher function.
731  blockToAddr.try_emplace(block, matcherByteCode.size());
732  for (Operation &op : *block)
733  generate(&op, writer);
734  }
735 }
736 
737 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738  LLVM_DEBUG({
739  // The following list must contain all the operations that do not
740  // produce any bytecode.
741  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742  writer.appendInline(op->getLoc());
743  });
745  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763  pdl_interp::SwitchResultCountOp>(
764  [&](auto interpOp) { this->generate(interpOp, writer); })
765  .Default([](Operation *) {
766  llvm_unreachable("unknown `pdl_interp` operation");
767  });
768 }
769 
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771  ByteCodeWriter &writer) {
772  assert(constraintToMemIndex.count(op.getName()) &&
773  "expected index for constraint function");
774  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
775  writer.appendPDLValueList(op.getArgs());
776  writer.append(ByteCodeField(op.getIsNegated()));
777  writer.append(op.getSuccessors());
778 }
779 void Generator::generate(pdl_interp::ApplyRewriteOp op,
780  ByteCodeWriter &writer) {
781  assert(externalRewriterToMemIndex.count(op.getName()) &&
782  "expected index for rewrite function");
783  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
784  writer.appendPDLValueList(op.getArgs());
785 
786  ResultRange results = op.getResults();
787  writer.append(ByteCodeField(results.size()));
788  for (Value result : results) {
789  // In debug mode we also record the expected kind of the result, so that we
790  // can provide extra verification of the native rewrite function.
791 #ifndef NDEBUG
792  writer.appendPDLValueKind(result);
793 #endif
794 
795  // Range results also need to append the range storage index.
796  if (isa<pdl::RangeType>(result.getType()))
797  writer.append(getRangeStorageIndex(result));
798  writer.append(result);
799  }
800 }
801 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
802  Value lhs = op.getLhs();
803  if (isa<pdl::RangeType>(lhs.getType())) {
804  writer.append(OpCode::AreRangesEqual);
805  writer.appendPDLValueKind(lhs);
806  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
807  return;
808  }
809 
810  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
811 }
812 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
813  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
814 }
815 void Generator::generate(pdl_interp::CheckAttributeOp op,
816  ByteCodeWriter &writer) {
817  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
818  op.getSuccessors());
819 }
820 void Generator::generate(pdl_interp::CheckOperandCountOp op,
821  ByteCodeWriter &writer) {
822  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
823  static_cast<ByteCodeField>(op.getCompareAtLeast()),
824  op.getSuccessors());
825 }
826 void Generator::generate(pdl_interp::CheckOperationNameOp op,
827  ByteCodeWriter &writer) {
828  writer.append(OpCode::CheckOperationName, op.getInputOp(),
829  OperationName(op.getName(), ctx), op.getSuccessors());
830 }
831 void Generator::generate(pdl_interp::CheckResultCountOp op,
832  ByteCodeWriter &writer) {
833  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
834  static_cast<ByteCodeField>(op.getCompareAtLeast()),
835  op.getSuccessors());
836 }
837 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
838  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
839  op.getSuccessors());
840 }
841 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
842  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
843  op.getSuccessors());
844 }
845 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
846  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
847  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
848 }
849 void Generator::generate(pdl_interp::CreateAttributeOp op,
850  ByteCodeWriter &writer) {
851  // Simply repoint the memory index of the result to the constant.
852  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
853 }
854 void Generator::generate(pdl_interp::CreateOperationOp op,
855  ByteCodeWriter &writer) {
856  writer.append(OpCode::CreateOperation, op.getResultOp(),
857  OperationName(op.getName(), ctx));
858  writer.appendPDLValueList(op.getInputOperands());
859 
860  // Add the attributes.
861  OperandRange attributes = op.getInputAttributes();
862  writer.append(static_cast<ByteCodeField>(attributes.size()));
863  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
864  writer.append(std::get<0>(it), std::get<1>(it));
865 
866  // Add the result types. If the operation has inferred results, we use a
867  // marker "size" value. Otherwise, we add the list of explicit result types.
868  if (op.getInferredResultTypes())
869  writer.append(kInferTypesMarker);
870  else
871  writer.appendPDLValueList(op.getInputResultTypes());
872 }
873 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
874  // Append the correct opcode for the range type.
875  TypeSwitch<Type>(op.getType().getElementType())
876  .Case(
877  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
878  .Case([&](pdl::ValueType) {
879  writer.append(OpCode::CreateDynamicValueRange);
880  });
881 
882  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
883  writer.appendPDLValueList(op->getOperands());
884 }
885 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
886  // Simply repoint the memory index of the result to the constant.
887  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
888 }
889 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
890  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
891  getRangeStorageIndex(op.getResult()), op.getValue());
892 }
893 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
894  writer.append(OpCode::EraseOp, op.getInputOp());
895 }
896 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
897  OpCode opCode =
899  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
900  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
901  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
902  .Default([](Type) -> OpCode {
903  llvm_unreachable("unsupported element type");
904  });
905  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
906 }
907 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
908  writer.append(OpCode::Finalize);
909 }
910 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
911  BlockArgument arg = op.getLoopVariable();
912  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
913  writer.appendPDLValueKind(arg.getType());
914  writer.append(curLoopLevel, op.getSuccessor());
915  ++curLoopLevel;
916  if (curLoopLevel > maxLoopLevel)
917  maxLoopLevel = curLoopLevel;
918  generate(&op.getRegion(), writer);
919  --curLoopLevel;
920 }
921 void Generator::generate(pdl_interp::GetAttributeOp op,
922  ByteCodeWriter &writer) {
923  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
924  op.getNameAttr());
925 }
926 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
927  ByteCodeWriter &writer) {
928  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
929 }
930 void Generator::generate(pdl_interp::GetDefiningOpOp op,
931  ByteCodeWriter &writer) {
932  writer.append(OpCode::GetDefiningOp, op.getInputOp());
933  writer.appendPDLValue(op.getValue());
934 }
935 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
936  uint32_t index = op.getIndex();
937  if (index < 4)
938  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
939  else
940  writer.append(OpCode::GetOperandN, index);
941  writer.append(op.getInputOp(), op.getValue());
942 }
943 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
944  Value result = op.getValue();
945  std::optional<uint32_t> index = op.getIndex();
946  writer.append(OpCode::GetOperands,
947  index.value_or(std::numeric_limits<uint32_t>::max()),
948  op.getInputOp());
949  if (isa<pdl::RangeType>(result.getType()))
950  writer.append(getRangeStorageIndex(result));
951  else
953  writer.append(result);
954 }
955 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
956  uint32_t index = op.getIndex();
957  if (index < 4)
958  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
959  else
960  writer.append(OpCode::GetResultN, index);
961  writer.append(op.getInputOp(), op.getValue());
962 }
963 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
964  Value result = op.getValue();
965  std::optional<uint32_t> index = op.getIndex();
966  writer.append(OpCode::GetResults,
967  index.value_or(std::numeric_limits<uint32_t>::max()),
968  op.getInputOp());
969  if (isa<pdl::RangeType>(result.getType()))
970  writer.append(getRangeStorageIndex(result));
971  else
973  writer.append(result);
974 }
975 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
976  Value operations = op.getOperations();
977  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
978  writer.append(OpCode::GetUsers, operations, rangeIndex);
979  writer.appendPDLValue(op.getValue());
980 }
981 void Generator::generate(pdl_interp::GetValueTypeOp op,
982  ByteCodeWriter &writer) {
983  if (isa<pdl::RangeType>(op.getType())) {
984  Value result = op.getResult();
985  writer.append(OpCode::GetValueRangeTypes, result,
986  getRangeStorageIndex(result), op.getValue());
987  } else {
988  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
989  }
990 }
991 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
992  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
993 }
994 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
995  ByteCodeField patternIndex = patterns.size();
996  patterns.emplace_back(PDLByteCodePattern::create(
997  op, configMap.lookup(op),
998  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
999  writer.append(OpCode::RecordMatch, patternIndex,
1000  SuccessorRange(op.getOperation()), op.getMatchedOps());
1001  writer.appendPDLValueList(op.getInputs());
1002 }
1003 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1004  writer.append(OpCode::ReplaceOp, op.getInputOp());
1005  writer.appendPDLValueList(op.getReplValues());
1006 }
1007 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1008  ByteCodeWriter &writer) {
1009  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1010  op.getCaseValuesAttr(), op.getSuccessors());
1011 }
1012 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1013  ByteCodeWriter &writer) {
1014  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1015  op.getCaseValuesAttr(), op.getSuccessors());
1016 }
1017 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1018  ByteCodeWriter &writer) {
1019  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1020  return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1021  });
1022  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1023  op.getSuccessors());
1024 }
1025 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1026  ByteCodeWriter &writer) {
1027  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1028  op.getCaseValuesAttr(), op.getSuccessors());
1029 }
1030 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1031  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1032  op.getSuccessors());
1033 }
1034 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1035  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1036  op.getSuccessors());
1037 }
1038 
1039 //===----------------------------------------------------------------------===//
1040 // PDLByteCode
1041 //===----------------------------------------------------------------------===//
1042 
1044  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1046  llvm::StringMap<PDLConstraintFunction> constraintFns,
1047  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1048  : configs(std::move(configs)) {
1049  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1050  rewriterByteCode, patterns, maxValueMemoryIndex,
1051  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1052  maxLoopLevel, constraintFns, rewriteFns, configMap);
1053  generator.generate(module);
1054 
1055  // Initialize the external functions.
1056  for (auto &it : constraintFns)
1057  constraintFunctions.push_back(std::move(it.second));
1058  for (auto &it : rewriteFns)
1059  rewriteFunctions.push_back(std::move(it.second));
1060 }
1061 
1062 /// Initialize the given state such that it can be used to execute the current
1063 /// bytecode.
1065  state.memory.resize(maxValueMemoryIndex, nullptr);
1066  state.opRangeMemory.resize(maxOpRangeCount);
1067  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1068  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1069  state.loopIndex.resize(maxLoopLevel, 0);
1070  state.currentPatternBenefits.reserve(patterns.size());
1071  for (const PDLByteCodePattern &pattern : patterns)
1072  state.currentPatternBenefits.push_back(pattern.getBenefit());
1073 }
1074 
1075 //===----------------------------------------------------------------------===//
1076 // ByteCode Execution
1077 
1078 namespace {
1079 /// This class provides support for executing a bytecode stream.
1080 class ByteCodeExecutor {
1081 public:
1082  ByteCodeExecutor(
1083  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1084  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1085  MutableArrayRef<TypeRange> typeRangeMemory,
1086  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1087  MutableArrayRef<ValueRange> valueRangeMemory,
1088  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1089  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1091  ArrayRef<PatternBenefit> currentPatternBenefits,
1093  ArrayRef<PDLConstraintFunction> constraintFunctions,
1094  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1095  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1096  typeRangeMemory(typeRangeMemory),
1097  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1098  valueRangeMemory(valueRangeMemory),
1099  allocatedValueRangeMemory(allocatedValueRangeMemory),
1100  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1101  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1102  constraintFunctions(constraintFunctions),
1103  rewriteFunctions(rewriteFunctions) {}
1104 
1105  /// Start executing the code at the current bytecode index. `matches` is an
1106  /// optional field provided when this function is executed in a matching
1107  /// context.
1109  execute(PatternRewriter &rewriter,
1110  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1111  std::optional<Location> mainRewriteLoc = {});
1112 
1113 private:
1114  /// Internal implementation of executing each of the bytecode commands.
1115  void executeApplyConstraint(PatternRewriter &rewriter);
1116  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1117  void executeAreEqual();
1118  void executeAreRangesEqual();
1119  void executeBranch();
1120  void executeCheckOperandCount();
1121  void executeCheckOperationName();
1122  void executeCheckResultCount();
1123  void executeCheckTypes();
1124  void executeContinue();
1125  void executeCreateConstantTypeRange();
1126  void executeCreateOperation(PatternRewriter &rewriter,
1127  Location mainRewriteLoc);
1128  template <typename T>
1129  void executeDynamicCreateRange(StringRef type);
1130  void executeEraseOp(PatternRewriter &rewriter);
1131  template <typename T, typename Range, PDLValue::Kind kind>
1132  void executeExtract();
1133  void executeFinalize();
1134  void executeForEach();
1135  void executeGetAttribute();
1136  void executeGetAttributeType();
1137  void executeGetDefiningOp();
1138  void executeGetOperand(unsigned index);
1139  void executeGetOperands();
1140  void executeGetResult(unsigned index);
1141  void executeGetResults();
1142  void executeGetUsers();
1143  void executeGetValueType();
1144  void executeGetValueRangeTypes();
1145  void executeIsNotNull();
1146  void executeRecordMatch(PatternRewriter &rewriter,
1148  void executeReplaceOp(PatternRewriter &rewriter);
1149  void executeSwitchAttribute();
1150  void executeSwitchOperandCount();
1151  void executeSwitchOperationName();
1152  void executeSwitchResultCount();
1153  void executeSwitchType();
1154  void executeSwitchTypes();
1155 
1156  /// Pushes a code iterator to the stack.
1157  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1158 
1159  /// Pops a code iterator from the stack, returning true on success.
1160  void popCodeIt() {
1161  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1162  curCodeIt = resumeCodeIt.back();
1163  resumeCodeIt.pop_back();
1164  }
1165 
1166  /// Return the bytecode iterator at the start of the current op code.
1167  const ByteCodeField *getPrevCodeIt() const {
1168  LLVM_DEBUG({
1169  // Account for the op code and the Location stored inline.
1170  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1171  });
1172 
1173  // Account for the op code only.
1174  return curCodeIt - 1;
1175  }
1176 
1177  /// Read a value from the bytecode buffer, optionally skipping a certain
1178  /// number of prefix values. These methods always update the buffer to point
1179  /// to the next field after the read data.
1180  template <typename T = ByteCodeField>
1181  T read(size_t skipN = 0) {
1182  curCodeIt += skipN;
1183  return readImpl<T>();
1184  }
1185  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1186 
1187  /// Read a list of values from the bytecode buffer.
1188  template <typename ValueT, typename T>
1189  void readList(SmallVectorImpl<T> &list) {
1190  list.clear();
1191  for (unsigned i = 0, e = read(); i != e; ++i)
1192  list.push_back(read<ValueT>());
1193  }
1194 
1195  /// Read a list of values from the bytecode buffer. The values may be encoded
1196  /// either as a single element or a range of elements.
1197  void readList(SmallVectorImpl<Type> &list) {
1198  for (unsigned i = 0, e = read(); i != e; ++i) {
1199  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1200  list.push_back(read<Type>());
1201  } else {
1202  TypeRange *values = read<TypeRange *>();
1203  list.append(values->begin(), values->end());
1204  }
1205  }
1206  }
1207  void readList(SmallVectorImpl<Value> &list) {
1208  for (unsigned i = 0, e = read(); i != e; ++i) {
1209  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1210  list.push_back(read<Value>());
1211  } else {
1212  ValueRange *values = read<ValueRange *>();
1213  list.append(values->begin(), values->end());
1214  }
1215  }
1216  }
1217 
1218  /// Read a value stored inline as a pointer.
1219  template <typename T>
1220  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1221  readInline() {
1222  const void *pointer;
1223  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1224  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1225  return T::getFromOpaquePointer(pointer);
1226  }
1227 
1228  /// Jump to a specific successor based on a predicate value.
1229  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1230  /// Jump to a specific successor based on a destination index.
1231  void selectJump(size_t destIndex) {
1232  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1233  }
1234 
1235  /// Handle a switch operation with the provided value and cases.
1236  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1237  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1238  LLVM_DEBUG({
1239  llvm::dbgs() << " * Value: " << value << "\n"
1240  << " * Cases: ";
1241  llvm::interleaveComma(cases, llvm::dbgs());
1242  llvm::dbgs() << "\n";
1243  });
1244 
1245  // Check to see if the attribute value is within the case list. Jump to
1246  // the correct successor index based on the result.
1247  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1248  if (cmp(*it, value))
1249  return selectJump(size_t((it - cases.begin()) + 1));
1250  selectJump(size_t(0));
1251  }
1252 
1253  /// Store a pointer to memory.
1254  void storeToMemory(unsigned index, const void *value) {
1255  memory[index] = value;
1256  }
1257 
1258  /// Store a value to memory as an opaque pointer.
1259  template <typename T>
1260  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1261  storeToMemory(unsigned index, T value) {
1262  memory[index] = value.getAsOpaquePointer();
1263  }
1264 
1265  /// Internal implementation of reading various data types from the bytecode
1266  /// stream.
1267  template <typename T>
1268  const void *readFromMemory() {
1269  size_t index = *curCodeIt++;
1270 
1271  // If this type is an SSA value, it can only be stored in non-const memory.
1272  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1273  Value>::value ||
1274  index < memory.size())
1275  return memory[index];
1276 
1277  // Otherwise, if this index is not inbounds it is uniqued.
1278  return uniquedMemory[index - memory.size()];
1279  }
1280  template <typename T>
1281  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1282  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1283  }
1284  template <typename T>
1285  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1286  T>
1287  readImpl() {
1288  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1289  }
1290  template <typename T>
1291  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1292  switch (read<PDLValue::Kind>()) {
1294  return read<Attribute>();
1296  return read<Operation *>();
1297  case PDLValue::Kind::Type:
1298  return read<Type>();
1299  case PDLValue::Kind::Value:
1300  return read<Value>();
1302  return read<TypeRange *>();
1304  return read<ValueRange *>();
1305  }
1306  llvm_unreachable("unhandled PDLValue::Kind");
1307  }
1308  template <typename T>
1309  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1310  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1311  "unexpected ByteCode address size");
1312  ByteCodeAddr result;
1313  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1314  curCodeIt += 2;
1315  return result;
1316  }
1317  template <typename T>
1318  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1319  return *curCodeIt++;
1320  }
1321  template <typename T>
1322  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1323  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1324  }
1325 
1326  /// Assign the given range to the given memory index. This allocates a new
1327  /// range object if necessary.
1328  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1329  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1330  unsigned rangeIndex) {
1331  // Utility functor used to type-erase the assignment.
1332  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1333  // If the input range is empty, we don't need to allocate anything.
1334  if (range.empty()) {
1335  rangeMemory[rangeIndex] = {};
1336  } else {
1337  // Allocate a buffer for this type range.
1338  llvm::OwningArrayRef<T> storage(llvm::size(range));
1339  llvm::copy(range, storage.begin());
1340 
1341  // Assign this to the range slot and use the range as the value for the
1342  // memory index.
1343  allocatedRangeMemory.emplace_back(std::move(storage));
1344  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1345  }
1346  memory[memIndex] = &rangeMemory[rangeIndex];
1347  };
1348 
1349  // Dispatch based on the concrete range type.
1350  if constexpr (std::is_same_v<T, Type>) {
1351  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1352  } else if constexpr (std::is_same_v<T, Value>) {
1353  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1354  } else {
1355  llvm_unreachable("unhandled range type");
1356  }
1357  }
1358 
1359  /// The underlying bytecode buffer.
1360  const ByteCodeField *curCodeIt;
1361 
1362  /// The stack of bytecode positions at which to resume operation.
1364 
1365  /// The current execution memory.
1367  MutableArrayRef<OwningOpRange> opRangeMemory;
1368  MutableArrayRef<TypeRange> typeRangeMemory;
1369  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1370  MutableArrayRef<ValueRange> valueRangeMemory;
1371  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1372 
1373  /// The current loop indices.
1374  MutableArrayRef<unsigned> loopIndex;
1375 
1376  /// References to ByteCode data necessary for execution.
1377  ArrayRef<const void *> uniquedMemory;
1379  ArrayRef<PatternBenefit> currentPatternBenefits;
1381  ArrayRef<PDLConstraintFunction> constraintFunctions;
1382  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1383 };
1384 
1385 /// This class is an instantiation of the PDLResultList that provides access to
1386 /// the returned results. This API is not on `PDLResultList` to avoid
1387 /// overexposing access to information specific solely to the ByteCode.
1388 class ByteCodeRewriteResultList : public PDLResultList {
1389 public:
1390  ByteCodeRewriteResultList(unsigned maxNumResults)
1391  : PDLResultList(maxNumResults) {}
1392 
1393  /// Return the list of PDL results.
1394  MutableArrayRef<PDLValue> getResults() { return results; }
1395 
1396  /// Return the type ranges allocated by this list.
1397  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1398  return allocatedTypeRanges;
1399  }
1400 
1401  /// Return the value ranges allocated by this list.
1402  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1403  return allocatedValueRanges;
1404  }
1405 };
1406 } // namespace
1407 
1408 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1409  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1410  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1412  readList<PDLValue>(args);
1413 
1414  LLVM_DEBUG({
1415  llvm::dbgs() << " * Arguments: ";
1416  llvm::interleaveComma(args, llvm::dbgs());
1417  llvm::dbgs() << "\n";
1418  });
1419 
1420  ByteCodeField isNegated = read();
1421  LLVM_DEBUG({
1422  llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1423  llvm::interleaveComma(args, llvm::dbgs());
1424  });
1425  // Invoke the constraint and jump to the proper destination.
1426  selectJump(isNegated != succeeded(constraintFn(rewriter, args)));
1427 }
1428 
1429 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1430  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1431  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1433  readList<PDLValue>(args);
1434 
1435  LLVM_DEBUG({
1436  llvm::dbgs() << " * Arguments: ";
1437  llvm::interleaveComma(args, llvm::dbgs());
1438  });
1439 
1440  // Execute the rewrite function.
1441  ByteCodeField numResults = read();
1442  ByteCodeRewriteResultList results(numResults);
1443  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1444 
1445  assert(results.getResults().size() == numResults &&
1446  "native PDL rewrite function returned unexpected number of results");
1447 
1448  // Store the results in the bytecode memory.
1449  for (PDLValue &result : results.getResults()) {
1450  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1451 
1452 // In debug mode we also verify the expected kind of the result.
1453 #ifndef NDEBUG
1454  assert(result.getKind() == read<PDLValue::Kind>() &&
1455  "native PDL rewrite function returned an unexpected type of result");
1456 #endif
1457 
1458  // If the result is a range, we need to copy it over to the bytecodes
1459  // range memory.
1460  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1461  unsigned rangeIndex = read();
1462  typeRangeMemory[rangeIndex] = *typeRange;
1463  memory[read()] = &typeRangeMemory[rangeIndex];
1464  } else if (std::optional<ValueRange> valueRange =
1465  result.dyn_cast<ValueRange>()) {
1466  unsigned rangeIndex = read();
1467  valueRangeMemory[rangeIndex] = *valueRange;
1468  memory[read()] = &valueRangeMemory[rangeIndex];
1469  } else {
1470  memory[read()] = result.getAsOpaquePointer();
1471  }
1472  }
1473 
1474  // Copy over any underlying storage allocated for result ranges.
1475  for (auto &it : results.getAllocatedTypeRanges())
1476  allocatedTypeRangeMemory.push_back(std::move(it));
1477  for (auto &it : results.getAllocatedValueRanges())
1478  allocatedValueRangeMemory.push_back(std::move(it));
1479 
1480  // Process the result of the rewrite.
1481  if (failed(rewriteResult)) {
1482  LLVM_DEBUG(llvm::dbgs() << " - Failed");
1483  return failure();
1484  }
1485  return success();
1486 }
1487 
1488 void ByteCodeExecutor::executeAreEqual() {
1489  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1490  const void *lhs = read<const void *>();
1491  const void *rhs = read<const void *>();
1492 
1493  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1494  selectJump(lhs == rhs);
1495 }
1496 
1497 void ByteCodeExecutor::executeAreRangesEqual() {
1498  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1499  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1500  const void *lhs = read<const void *>();
1501  const void *rhs = read<const void *>();
1502 
1503  switch (valueKind) {
1505  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1506  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1507  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1508  selectJump(*lhsRange == *rhsRange);
1509  break;
1510  }
1512  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1513  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1514  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1515  selectJump(*lhsRange == *rhsRange);
1516  break;
1517  }
1518  default:
1519  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1520  }
1521 }
1522 
1523 void ByteCodeExecutor::executeBranch() {
1524  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1525  curCodeIt = &code[read<ByteCodeAddr>()];
1526 }
1527 
1528 void ByteCodeExecutor::executeCheckOperandCount() {
1529  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1530  Operation *op = read<Operation *>();
1531  uint32_t expectedCount = read<uint32_t>();
1532  bool compareAtLeast = read();
1533 
1534  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1535  << " * Expected: " << expectedCount << "\n"
1536  << " * Comparator: "
1537  << (compareAtLeast ? ">=" : "==") << "\n");
1538  if (compareAtLeast)
1539  selectJump(op->getNumOperands() >= expectedCount);
1540  else
1541  selectJump(op->getNumOperands() == expectedCount);
1542 }
1543 
1544 void ByteCodeExecutor::executeCheckOperationName() {
1545  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1546  Operation *op = read<Operation *>();
1547  OperationName expectedName = read<OperationName>();
1548 
1549  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1550  << " * Expected: \"" << expectedName << "\"\n");
1551  selectJump(op->getName() == expectedName);
1552 }
1553 
1554 void ByteCodeExecutor::executeCheckResultCount() {
1555  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1556  Operation *op = read<Operation *>();
1557  uint32_t expectedCount = read<uint32_t>();
1558  bool compareAtLeast = read();
1559 
1560  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1561  << " * Expected: " << expectedCount << "\n"
1562  << " * Comparator: "
1563  << (compareAtLeast ? ">=" : "==") << "\n");
1564  if (compareAtLeast)
1565  selectJump(op->getNumResults() >= expectedCount);
1566  else
1567  selectJump(op->getNumResults() == expectedCount);
1568 }
1569 
1570 void ByteCodeExecutor::executeCheckTypes() {
1571  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1572  TypeRange *lhs = read<TypeRange *>();
1573  Attribute rhs = read<Attribute>();
1574  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1575 
1576  selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1577 }
1578 
1579 void ByteCodeExecutor::executeContinue() {
1580  ByteCodeField level = read();
1581  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1582  << " * Level: " << level << "\n");
1583  ++loopIndex[level];
1584  popCodeIt();
1585 }
1586 
1587 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1588  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1589  unsigned memIndex = read();
1590  unsigned rangeIndex = read();
1591  ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1592 
1593  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1594  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1595  rangeIndex);
1596 }
1597 
1598 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1599  Location mainRewriteLoc) {
1600  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1601 
1602  unsigned memIndex = read();
1603  OperationState state(mainRewriteLoc, read<OperationName>());
1604  readList(state.operands);
1605  for (unsigned i = 0, e = read(); i != e; ++i) {
1606  StringAttr name = read<StringAttr>();
1607  if (Attribute attr = read<Attribute>())
1608  state.addAttribute(name, attr);
1609  }
1610 
1611  // Read in the result types. If the "size" is the sentinel value, this
1612  // indicates that the result types should be inferred.
1613  unsigned numResults = read();
1614  if (numResults == kInferTypesMarker) {
1615  InferTypeOpInterface::Concept *inferInterface =
1616  state.name.getInterface<InferTypeOpInterface>();
1617  assert(inferInterface &&
1618  "expected operation to provide InferTypeOpInterface");
1619 
1620  // TODO: Handle failure.
1621  if (failed(inferInterface->inferReturnTypes(
1622  state.getContext(), state.location, state.operands,
1623  state.attributes.getDictionary(state.getContext()),
1624  state.getRawProperties(), state.regions, state.types)))
1625  return;
1626  } else {
1627  // Otherwise, this is a fixed number of results.
1628  for (unsigned i = 0; i != numResults; ++i) {
1629  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1630  state.types.push_back(read<Type>());
1631  } else {
1632  TypeRange *resultTypes = read<TypeRange *>();
1633  state.types.append(resultTypes->begin(), resultTypes->end());
1634  }
1635  }
1636  }
1637 
1638  Operation *resultOp = rewriter.create(state);
1639  memory[memIndex] = resultOp;
1640 
1641  LLVM_DEBUG({
1642  llvm::dbgs() << " * Attributes: "
1643  << state.attributes.getDictionary(state.getContext())
1644  << "\n * Operands: ";
1645  llvm::interleaveComma(state.operands, llvm::dbgs());
1646  llvm::dbgs() << "\n * Result Types: ";
1647  llvm::interleaveComma(state.types, llvm::dbgs());
1648  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1649  });
1650 }
1651 
1652 template <typename T>
1653 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1654  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1655  unsigned memIndex = read();
1656  unsigned rangeIndex = read();
1657  SmallVector<T> values;
1658  readList(values);
1659 
1660  LLVM_DEBUG({
1661  llvm::dbgs() << "\n * " << type << "s: ";
1662  llvm::interleaveComma(values, llvm::dbgs());
1663  llvm::dbgs() << "\n";
1664  });
1665 
1666  assignRangeToMemory(values, memIndex, rangeIndex);
1667 }
1668 
1669 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1670  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1671  Operation *op = read<Operation *>();
1672 
1673  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1674  rewriter.eraseOp(op);
1675 }
1676 
1677 template <typename T, typename Range, PDLValue::Kind kind>
1678 void ByteCodeExecutor::executeExtract() {
1679  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1680  Range *range = read<Range *>();
1681  unsigned index = read<uint32_t>();
1682  unsigned memIndex = read();
1683 
1684  if (!range) {
1685  memory[memIndex] = nullptr;
1686  return;
1687  }
1688 
1689  T result = index < range->size() ? (*range)[index] : T();
1690  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1691  << " * Index: " << index << "\n"
1692  << " * Result: " << result << "\n");
1693  storeToMemory(memIndex, result);
1694 }
1695 
1696 void ByteCodeExecutor::executeFinalize() {
1697  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1698 }
1699 
1700 void ByteCodeExecutor::executeForEach() {
1701  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1702  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1703  unsigned rangeIndex = read();
1704  unsigned memIndex = read();
1705  const void *value = nullptr;
1706 
1707  switch (read<PDLValue::Kind>()) {
1709  unsigned &index = loopIndex[read()];
1710  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1711  assert(index <= array.size() && "iterated past the end");
1712  if (index < array.size()) {
1713  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1714  value = array[index];
1715  break;
1716  }
1717 
1718  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1719  index = 0;
1720  selectJump(size_t(0));
1721  return;
1722  }
1723  default:
1724  llvm_unreachable("unexpected `ForEach` value kind");
1725  }
1726 
1727  // Store the iterate value and the stack address.
1728  memory[memIndex] = value;
1729  pushCodeIt(prevCodeIt);
1730 
1731  // Skip over the successor (we will enter the body of the loop).
1732  read<ByteCodeAddr>();
1733 }
1734 
1735 void ByteCodeExecutor::executeGetAttribute() {
1736  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1737  unsigned memIndex = read();
1738  Operation *op = read<Operation *>();
1739  StringAttr attrName = read<StringAttr>();
1740  Attribute attr = op->getAttr(attrName);
1741 
1742  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1743  << " * Attribute: " << attrName << "\n"
1744  << " * Result: " << attr << "\n");
1745  memory[memIndex] = attr.getAsOpaquePointer();
1746 }
1747 
1748 void ByteCodeExecutor::executeGetAttributeType() {
1749  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1750  unsigned memIndex = read();
1751  Attribute attr = read<Attribute>();
1752  Type type;
1753  if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1754  type = typedAttr.getType();
1755 
1756  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1757  << " * Result: " << type << "\n");
1758  memory[memIndex] = type.getAsOpaquePointer();
1759 }
1760 
1761 void ByteCodeExecutor::executeGetDefiningOp() {
1762  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1763  unsigned memIndex = read();
1764  Operation *op = nullptr;
1765  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1766  Value value = read<Value>();
1767  if (value)
1768  op = value.getDefiningOp();
1769  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1770  } else {
1771  ValueRange *values = read<ValueRange *>();
1772  if (values && !values->empty()) {
1773  op = values->front().getDefiningOp();
1774  }
1775  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1776  }
1777 
1778  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1779  memory[memIndex] = op;
1780 }
1781 
1782 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1783  Operation *op = read<Operation *>();
1784  unsigned memIndex = read();
1785  Value operand =
1786  index < op->getNumOperands() ? op->getOperand(index) : Value();
1787 
1788  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1789  << " * Index: " << index << "\n"
1790  << " * Result: " << operand << "\n");
1791  memory[memIndex] = operand.getAsOpaquePointer();
1792 }
1793 
1794 /// This function is the internal implementation of `GetResults` and
1795 /// `GetOperands` that provides support for extracting a value range from the
1796 /// given operation.
1797 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1798 static void *
1799 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1800  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1801  MutableArrayRef<ValueRange> valueRangeMemory) {
1802  // Check for the sentinel index that signals that all values should be
1803  // returned.
1804  if (index == std::numeric_limits<uint32_t>::max()) {
1805  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1806  // `values` is already the full value range.
1807 
1808  // Otherwise, check to see if this operation uses AttrSizedSegments.
1809  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1810  LLVM_DEBUG(llvm::dbgs()
1811  << " * Extracting values from `" << attrSizedSegments << "`\n");
1812 
1813  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1814  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1815  return nullptr;
1816 
1817  ArrayRef<int32_t> segments = segmentAttr;
1818  unsigned startIndex =
1819  std::accumulate(segments.begin(), segments.begin() + index, 0);
1820  values = values.slice(startIndex, *std::next(segments.begin(), index));
1821 
1822  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1823  << *std::next(segments.begin(), index) << "]\n");
1824 
1825  // Otherwise, assume this is the last operand group of the operation.
1826  // FIXME: We currently don't support operations with
1827  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1828  // have a way to detect it's presence.
1829  } else if (values.size() >= index) {
1830  LLVM_DEBUG(llvm::dbgs()
1831  << " * Treating values as trailing variadic range\n");
1832  values = values.drop_front(index);
1833 
1834  // If we couldn't detect a way to compute the values, bail out.
1835  } else {
1836  return nullptr;
1837  }
1838 
1839  // If the range index is valid, we are returning a range.
1840  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1841  valueRangeMemory[rangeIndex] = values;
1842  return &valueRangeMemory[rangeIndex];
1843  }
1844 
1845  // If a range index wasn't provided, the range is required to be non-variadic.
1846  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1847 }
1848 
1849 void ByteCodeExecutor::executeGetOperands() {
1850  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1851  unsigned index = read<uint32_t>();
1852  Operation *op = read<Operation *>();
1853  ByteCodeField rangeIndex = read();
1854 
1855  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1856  op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1857  valueRangeMemory);
1858  if (!result)
1859  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1860  memory[read()] = result;
1861 }
1862 
1863 void ByteCodeExecutor::executeGetResult(unsigned index) {
1864  Operation *op = read<Operation *>();
1865  unsigned memIndex = read();
1866  OpResult result =
1867  index < op->getNumResults() ? op->getResult(index) : OpResult();
1868 
1869  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1870  << " * Index: " << index << "\n"
1871  << " * Result: " << result << "\n");
1872  memory[memIndex] = result.getAsOpaquePointer();
1873 }
1874 
1875 void ByteCodeExecutor::executeGetResults() {
1876  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1877  unsigned index = read<uint32_t>();
1878  Operation *op = read<Operation *>();
1879  ByteCodeField rangeIndex = read();
1880 
1881  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1882  op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1883  valueRangeMemory);
1884  if (!result)
1885  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1886  memory[read()] = result;
1887 }
1888 
1889 void ByteCodeExecutor::executeGetUsers() {
1890  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1891  unsigned memIndex = read();
1892  unsigned rangeIndex = read();
1893  OwningOpRange &range = opRangeMemory[rangeIndex];
1894  memory[memIndex] = &range;
1895 
1896  range = OwningOpRange();
1897  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1898  // Read the value.
1899  Value value = read<Value>();
1900  if (!value)
1901  return;
1902  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1903 
1904  // Extract the users of a single value.
1905  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1906  llvm::copy(value.getUsers(), range.begin());
1907  } else {
1908  // Read a range of values.
1909  ValueRange *values = read<ValueRange *>();
1910  if (!values)
1911  return;
1912  LLVM_DEBUG({
1913  llvm::dbgs() << " * Values (" << values->size() << "): ";
1914  llvm::interleaveComma(*values, llvm::dbgs());
1915  llvm::dbgs() << "\n";
1916  });
1917 
1918  // Extract all the users of a range of values.
1920  for (Value value : *values)
1921  users.append(value.user_begin(), value.user_end());
1922  range = OwningOpRange(users.size());
1923  llvm::copy(users, range.begin());
1924  }
1925 
1926  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1927 }
1928 
1929 void ByteCodeExecutor::executeGetValueType() {
1930  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1931  unsigned memIndex = read();
1932  Value value = read<Value>();
1933  Type type = value ? value.getType() : Type();
1934 
1935  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1936  << " * Result: " << type << "\n");
1937  memory[memIndex] = type.getAsOpaquePointer();
1938 }
1939 
1940 void ByteCodeExecutor::executeGetValueRangeTypes() {
1941  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1942  unsigned memIndex = read();
1943  unsigned rangeIndex = read();
1944  ValueRange *values = read<ValueRange *>();
1945  if (!values) {
1946  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1947  memory[memIndex] = nullptr;
1948  return;
1949  }
1950 
1951  LLVM_DEBUG({
1952  llvm::dbgs() << " * Values (" << values->size() << "): ";
1953  llvm::interleaveComma(*values, llvm::dbgs());
1954  llvm::dbgs() << "\n * Result: ";
1955  llvm::interleaveComma(values->getType(), llvm::dbgs());
1956  llvm::dbgs() << "\n";
1957  });
1958  typeRangeMemory[rangeIndex] = values->getType();
1959  memory[memIndex] = &typeRangeMemory[rangeIndex];
1960 }
1961 
1962 void ByteCodeExecutor::executeIsNotNull() {
1963  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1964  const void *value = read<const void *>();
1965 
1966  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1967  selectJump(value != nullptr);
1968 }
1969 
1970 void ByteCodeExecutor::executeRecordMatch(
1971  PatternRewriter &rewriter,
1973  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1974  unsigned patternIndex = read();
1975  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1976  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1977 
1978  // If the benefit of the pattern is impossible, skip the processing of the
1979  // rest of the pattern.
1980  if (benefit.isImpossibleToMatch()) {
1981  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1982  curCodeIt = dest;
1983  return;
1984  }
1985 
1986  // Create a fused location containing the locations of each of the
1987  // operations used in the match. This will be used as the location for
1988  // created operations during the rewrite that don't already have an
1989  // explicit location set.
1990  unsigned numMatchLocs = read();
1991  SmallVector<Location, 4> matchLocs;
1992  matchLocs.reserve(numMatchLocs);
1993  for (unsigned i = 0; i != numMatchLocs; ++i)
1994  matchLocs.push_back(read<Operation *>()->getLoc());
1995  Location matchLoc = rewriter.getFusedLoc(matchLocs);
1996 
1997  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1998  << " * Location: " << matchLoc << "\n");
1999  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2000  PDLByteCode::MatchResult &match = matches.back();
2001 
2002  // Record all of the inputs to the match. If any of the inputs are ranges, we
2003  // will also need to remap the range pointer to memory stored in the match
2004  // state.
2005  unsigned numInputs = read();
2006  match.values.reserve(numInputs);
2007  match.typeRangeValues.reserve(numInputs);
2008  match.valueRangeValues.reserve(numInputs);
2009  for (unsigned i = 0; i < numInputs; ++i) {
2010  switch (read<PDLValue::Kind>()) {
2012  match.typeRangeValues.push_back(*read<TypeRange *>());
2013  match.values.push_back(&match.typeRangeValues.back());
2014  break;
2016  match.valueRangeValues.push_back(*read<ValueRange *>());
2017  match.values.push_back(&match.valueRangeValues.back());
2018  break;
2019  default:
2020  match.values.push_back(read<const void *>());
2021  break;
2022  }
2023  }
2024  curCodeIt = dest;
2025 }
2026 
2027 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2028  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2029  Operation *op = read<Operation *>();
2031  readList(args);
2032 
2033  LLVM_DEBUG({
2034  llvm::dbgs() << " * Operation: " << *op << "\n"
2035  << " * Values: ";
2036  llvm::interleaveComma(args, llvm::dbgs());
2037  llvm::dbgs() << "\n";
2038  });
2039  rewriter.replaceOp(op, args);
2040 }
2041 
2042 void ByteCodeExecutor::executeSwitchAttribute() {
2043  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2044  Attribute value = read<Attribute>();
2045  ArrayAttr cases = read<ArrayAttr>();
2046  handleSwitch(value, cases);
2047 }
2048 
2049 void ByteCodeExecutor::executeSwitchOperandCount() {
2050  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2051  Operation *op = read<Operation *>();
2052  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2053 
2054  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2055  handleSwitch(op->getNumOperands(), cases);
2056 }
2057 
2058 void ByteCodeExecutor::executeSwitchOperationName() {
2059  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2060  OperationName value = read<Operation *>()->getName();
2061  size_t caseCount = read();
2062 
2063  // The operation names are stored in-line, so to print them out for
2064  // debugging purposes we need to read the array before executing the
2065  // switch so that we can display all of the possible values.
2066  LLVM_DEBUG({
2067  const ByteCodeField *prevCodeIt = curCodeIt;
2068  llvm::dbgs() << " * Value: " << value << "\n"
2069  << " * Cases: ";
2070  llvm::interleaveComma(
2071  llvm::map_range(llvm::seq<size_t>(0, caseCount),
2072  [&](size_t) { return read<OperationName>(); }),
2073  llvm::dbgs());
2074  llvm::dbgs() << "\n";
2075  curCodeIt = prevCodeIt;
2076  });
2077 
2078  // Try to find the switch value within any of the cases.
2079  for (size_t i = 0; i != caseCount; ++i) {
2080  if (read<OperationName>() == value) {
2081  curCodeIt += (caseCount - i - 1);
2082  return selectJump(i + 1);
2083  }
2084  }
2085  selectJump(size_t(0));
2086 }
2087 
2088 void ByteCodeExecutor::executeSwitchResultCount() {
2089  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2090  Operation *op = read<Operation *>();
2091  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2092 
2093  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2094  handleSwitch(op->getNumResults(), cases);
2095 }
2096 
2097 void ByteCodeExecutor::executeSwitchType() {
2098  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2099  Type value = read<Type>();
2100  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2101  handleSwitch(value, cases);
2102 }
2103 
2104 void ByteCodeExecutor::executeSwitchTypes() {
2105  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2106  TypeRange *value = read<TypeRange *>();
2107  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2108  if (!value) {
2109  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2110  return selectJump(size_t(0));
2111  }
2112  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2113  return value == caseValue.getAsValueRange<TypeAttr>();
2114  });
2115 }
2116 
2118 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2120  std::optional<Location> mainRewriteLoc) {
2121  while (true) {
2122  // Print the location of the operation being executed.
2123  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2124 
2125  OpCode opCode = static_cast<OpCode>(read());
2126  switch (opCode) {
2127  case ApplyConstraint:
2128  executeApplyConstraint(rewriter);
2129  break;
2130  case ApplyRewrite:
2131  if (failed(executeApplyRewrite(rewriter)))
2132  return failure();
2133  break;
2134  case AreEqual:
2135  executeAreEqual();
2136  break;
2137  case AreRangesEqual:
2138  executeAreRangesEqual();
2139  break;
2140  case Branch:
2141  executeBranch();
2142  break;
2143  case CheckOperandCount:
2144  executeCheckOperandCount();
2145  break;
2146  case CheckOperationName:
2147  executeCheckOperationName();
2148  break;
2149  case CheckResultCount:
2150  executeCheckResultCount();
2151  break;
2152  case CheckTypes:
2153  executeCheckTypes();
2154  break;
2155  case Continue:
2156  executeContinue();
2157  break;
2158  case CreateConstantTypeRange:
2159  executeCreateConstantTypeRange();
2160  break;
2161  case CreateOperation:
2162  executeCreateOperation(rewriter, *mainRewriteLoc);
2163  break;
2164  case CreateDynamicTypeRange:
2165  executeDynamicCreateRange<Type>("Type");
2166  break;
2167  case CreateDynamicValueRange:
2168  executeDynamicCreateRange<Value>("Value");
2169  break;
2170  case EraseOp:
2171  executeEraseOp(rewriter);
2172  break;
2173  case ExtractOp:
2174  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2175  break;
2176  case ExtractType:
2177  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2178  break;
2179  case ExtractValue:
2180  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2181  break;
2182  case Finalize:
2183  executeFinalize();
2184  LLVM_DEBUG(llvm::dbgs() << "\n");
2185  return success();
2186  case ForEach:
2187  executeForEach();
2188  break;
2189  case GetAttribute:
2190  executeGetAttribute();
2191  break;
2192  case GetAttributeType:
2193  executeGetAttributeType();
2194  break;
2195  case GetDefiningOp:
2196  executeGetDefiningOp();
2197  break;
2198  case GetOperand0:
2199  case GetOperand1:
2200  case GetOperand2:
2201  case GetOperand3: {
2202  unsigned index = opCode - GetOperand0;
2203  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2204  executeGetOperand(index);
2205  break;
2206  }
2207  case GetOperandN:
2208  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2209  executeGetOperand(read<uint32_t>());
2210  break;
2211  case GetOperands:
2212  executeGetOperands();
2213  break;
2214  case GetResult0:
2215  case GetResult1:
2216  case GetResult2:
2217  case GetResult3: {
2218  unsigned index = opCode - GetResult0;
2219  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2220  executeGetResult(index);
2221  break;
2222  }
2223  case GetResultN:
2224  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2225  executeGetResult(read<uint32_t>());
2226  break;
2227  case GetResults:
2228  executeGetResults();
2229  break;
2230  case GetUsers:
2231  executeGetUsers();
2232  break;
2233  case GetValueType:
2234  executeGetValueType();
2235  break;
2236  case GetValueRangeTypes:
2237  executeGetValueRangeTypes();
2238  break;
2239  case IsNotNull:
2240  executeIsNotNull();
2241  break;
2242  case RecordMatch:
2243  assert(matches &&
2244  "expected matches to be provided when executing the matcher");
2245  executeRecordMatch(rewriter, *matches);
2246  break;
2247  case ReplaceOp:
2248  executeReplaceOp(rewriter);
2249  break;
2250  case SwitchAttribute:
2251  executeSwitchAttribute();
2252  break;
2253  case SwitchOperandCount:
2254  executeSwitchOperandCount();
2255  break;
2256  case SwitchOperationName:
2257  executeSwitchOperationName();
2258  break;
2259  case SwitchResultCount:
2260  executeSwitchResultCount();
2261  break;
2262  case SwitchType:
2263  executeSwitchType();
2264  break;
2265  case SwitchTypes:
2266  executeSwitchTypes();
2267  break;
2268  }
2269  LLVM_DEBUG(llvm::dbgs() << "\n");
2270  }
2271 }
2272 
2275  PDLByteCodeMutableState &state) const {
2276  // The first memory slot is always the root operation.
2277  state.memory[0] = op;
2278 
2279  // The matcher function always starts at code address 0.
2280  ByteCodeExecutor executor(
2281  matcherByteCode.data(), state.memory, state.opRangeMemory,
2282  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2283  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2284  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2285  constraintFunctions, rewriteFunctions);
2286  LogicalResult executeResult = executor.execute(rewriter, &matches);
2287  (void)executeResult;
2288  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2289 
2290  // Order the found matches by benefit.
2291  std::stable_sort(matches.begin(), matches.end(),
2292  [](const MatchResult &lhs, const MatchResult &rhs) {
2293  return lhs.benefit > rhs.benefit;
2294  });
2295 }
2296 
2298  const MatchResult &match,
2299  PDLByteCodeMutableState &state) const {
2300  auto *configSet = match.pattern->getConfigSet();
2301  if (configSet)
2302  configSet->notifyRewriteBegin(rewriter);
2303 
2304  // The arguments of the rewrite function are stored at the start of the
2305  // memory buffer.
2306  llvm::copy(match.values, state.memory.begin());
2307 
2308  ByteCodeExecutor executor(
2309  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2310  state.opRangeMemory, state.typeRangeMemory,
2311  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2312  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2313  rewriterByteCode, state.currentPatternBenefits, patterns,
2314  constraintFunctions, rewriteFunctions);
2315  LogicalResult result =
2316  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2317 
2318  if (configSet)
2319  configSet->notifyRewriteEnd(rewriter);
2320 
2321  // If the rewrite failed, check if the pattern rewriter can recover. If it
2322  // can, we can signal to the pattern applicator to keep trying patterns. If it
2323  // doesn't, we need to bail. Bailing here should be fine, given that we have
2324  // no means to propagate such a failure to the user, and it also indicates a
2325  // bug in the user code (i.e. failable rewrites should not be used with
2326  // pattern rewriters that don't support it).
2327  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2328  LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2329  llvm::report_fatal_error(
2330  "Native PDL Rewrite failed, but the pattern "
2331  "rewriter doesn't support recovery. Failable pattern rewrites should "
2332  "not be used with pattern rewriters that do not support them.");
2333  }
2334  return result;
2335 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1799
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:173
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:82
This class represents an argument of a Block.
Definition: Value.h:315
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
BlockArgListType getArguments()
Definition: Block.h:80
Operation & front()
Definition: Block.h:146
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This is a value defined by a result of an operation.
Definition: Value.h:453
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:728
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:528
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:512
Block * getSuccessor(unsigned index)
Definition: Operation.h:687
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:341
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:665
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:655
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
SuccessorRange getSuccessors()
Definition: Operation.h:682
result_range getResults()
Definition: Operation.h:410
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:979
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:859
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:747
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:750
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:33
bool isImpossibleToMatch() const
Definition: PatternMatch.h:43
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:727
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:735
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:239
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:176
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:227
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_iterator user_begin() const
Definition: Value.h:222
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:239
user_iterator user_end() const
Definition: Value.h:223
user_range getUsers() const
Definition: Value.h:224
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
This class contains the mutable state of a bytecode instance.
Definition: ByteCode.h:71
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.cpp:64
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
Definition: ByteCode.cpp:72
All of the data pertaining to a specific pattern within the bytecode.
Definition: ByteCode.h:38
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr)
Definition: ByteCode.cpp:37
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches.
Definition: ByteCode.cpp:2273
PDLByteCode(ModuleOp module, SmallVector< std::unique_ptr< PDLPatternConfigSet >> configs, const DenseMap< Operation *, PDLPatternConfigSet * > &configMap, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
Definition: ByteCode.cpp:1043
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition: ByteCode.cpp:1064
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Definition: ByteCode.cpp:2297
Detect if any of the given parameter types has a sub-element handler.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:137
uint32_t ByteCodeAddr
Definition: ByteCode.h:30
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
uint16_t ByteCodeField
Use generic bytecode types.
Definition: ByteCode.h:29
llvm::OwningArrayRef< Operation * > OwningOpRange
Definition: ByteCode.h:31
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:158
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Definition: ByteCode.h:132
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
Definition: ByteCode.h:146
SmallVector< ValueRange, 0 > valueRangeValues
Definition: ByteCode.h:147
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.
Definition: ByteCode.h:144