MLIR  16.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37  PDLPatternConfigSet *configSet,
38  ByteCodeAddr rewriterAddr) {
39  PatternBenefit benefit = matchOp.getBenefit();
40  MLIRContext *ctx = matchOp.getContext();
41 
42  // Collect the set of generated operations.
43  SmallVector<StringRef, 8> generatedOps;
44  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
45  generatedOps =
46  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
47 
48  // Check to see if this is pattern matches a specific operation type.
49  if (Optional<StringRef> rootKind = matchOp.getRootKind())
50  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
51  generatedOps);
52  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
53  benefit, ctx, generatedOps);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // PDLByteCodeMutableState
58 //===----------------------------------------------------------------------===//
59 
60 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
61 /// to the position of the pattern within the range returned by
62 /// `PDLByteCode::getPatterns`.
64  PatternBenefit benefit) {
65  currentPatternBenefits[patternIndex] = benefit;
66 }
67 
68 /// Cleanup any allocated state after a full match/rewrite has been completed.
69 /// This method should be called irregardless of whether the match+rewrite was a
70 /// success or not.
72  allocatedTypeRangeMemory.clear();
73  allocatedValueRangeMemory.clear();
74 }
75 
76 //===----------------------------------------------------------------------===//
77 // Bytecode OpCodes
78 //===----------------------------------------------------------------------===//
79 
80 namespace {
81 enum OpCode : ByteCodeField {
82  /// Apply an externally registered constraint.
83  ApplyConstraint,
84  /// Apply an externally registered rewrite.
85  ApplyRewrite,
86  /// Check if two generic values are equal.
87  AreEqual,
88  /// Check if two ranges are equal.
89  AreRangesEqual,
90  /// Unconditional branch.
91  Branch,
92  /// Compare the operand count of an operation with a constant.
93  CheckOperandCount,
94  /// Compare the name of an operation with a constant.
95  CheckOperationName,
96  /// Compare the result count of an operation with a constant.
97  CheckResultCount,
98  /// Compare a range of types to a constant range of types.
99  CheckTypes,
100  /// Continue to the next iteration of a loop.
101  Continue,
102  /// Create a type range from a list of constant types.
103  CreateConstantTypeRange,
104  /// Create an operation.
105  CreateOperation,
106  /// Create a type range from a list of dynamic types.
107  CreateDynamicTypeRange,
108  /// Create a value range.
109  CreateDynamicValueRange,
110  /// Erase an operation.
111  EraseOp,
112  /// Extract the op from a range at the specified index.
113  ExtractOp,
114  /// Extract the type from a range at the specified index.
115  ExtractType,
116  /// Extract the value from a range at the specified index.
117  ExtractValue,
118  /// Terminate a matcher or rewrite sequence.
119  Finalize,
120  /// Iterate over a range of values.
121  ForEach,
122  /// Get a specific attribute of an operation.
123  GetAttribute,
124  /// Get the type of an attribute.
125  GetAttributeType,
126  /// Get the defining operation of a value.
127  GetDefiningOp,
128  /// Get a specific operand of an operation.
129  GetOperand0,
130  GetOperand1,
131  GetOperand2,
132  GetOperand3,
133  GetOperandN,
134  /// Get a specific operand group of an operation.
135  GetOperands,
136  /// Get a specific result of an operation.
137  GetResult0,
138  GetResult1,
139  GetResult2,
140  GetResult3,
141  GetResultN,
142  /// Get a specific result group of an operation.
143  GetResults,
144  /// Get the users of a value or a range of values.
145  GetUsers,
146  /// Get the type of a value.
147  GetValueType,
148  /// Get the types of a value range.
149  GetValueRangeTypes,
150  /// Check if a generic value is not null.
151  IsNotNull,
152  /// Record a successful pattern match.
153  RecordMatch,
154  /// Replace an operation.
155  ReplaceOp,
156  /// Compare an attribute with a set of constants.
157  SwitchAttribute,
158  /// Compare the operand count of an operation with a set of constants.
159  SwitchOperandCount,
160  /// Compare the name of an operation with a set of constants.
161  SwitchOperationName,
162  /// Compare the result count of an operation with a set of constants.
163  SwitchResultCount,
164  /// Compare a type with a set of constants.
165  SwitchType,
166  /// Compare a range of types with a set of constants.
167  SwitchTypes,
168 };
169 } // namespace
170 
171 /// A marker used to indicate if an operation should infer types.
174 
175 //===----------------------------------------------------------------------===//
176 // ByteCode Generation
177 //===----------------------------------------------------------------------===//
178 
179 //===----------------------------------------------------------------------===//
180 // Generator
181 
182 namespace {
183 struct ByteCodeLiveRange;
184 struct ByteCodeWriter;
185 
186 /// Check if the given class `T` can be converted to an opaque pointer.
187 template <typename T, typename... Args>
188 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
189 
190 /// This class represents the main generator for the pattern bytecode.
191 class Generator {
192 public:
193  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
194  SmallVectorImpl<ByteCodeField> &matcherByteCode,
195  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
197  ByteCodeField &maxValueMemoryIndex,
198  ByteCodeField &maxOpRangeMemoryIndex,
199  ByteCodeField &maxTypeRangeMemoryIndex,
200  ByteCodeField &maxValueRangeMemoryIndex,
201  ByteCodeField &maxLoopLevel,
202  llvm::StringMap<PDLConstraintFunction> &constraintFns,
203  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
205  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
206  rewriterByteCode(rewriterByteCode), patterns(patterns),
207  maxValueMemoryIndex(maxValueMemoryIndex),
208  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
209  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
210  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
211  maxLoopLevel(maxLoopLevel), configMap(configMap) {
212  for (const auto &it : llvm::enumerate(constraintFns))
213  constraintToMemIndex.try_emplace(it.value().first(), it.index());
214  for (const auto &it : llvm::enumerate(rewriteFns))
215  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
216  }
217 
218  /// Generate the bytecode for the given PDL interpreter module.
219  void generate(ModuleOp module);
220 
221  /// Return the memory index to use for the given value.
222  ByteCodeField &getMemIndex(Value value) {
223  assert(valueToMemIndex.count(value) &&
224  "expected memory index to be assigned");
225  return valueToMemIndex[value];
226  }
227 
228  /// Return the range memory index used to store the given range value.
229  ByteCodeField &getRangeStorageIndex(Value value) {
230  assert(valueToRangeIndex.count(value) &&
231  "expected range index to be assigned");
232  return valueToRangeIndex[value];
233  }
234 
235  /// Return an index to use when referring to the given data that is uniqued in
236  /// the MLIR context.
237  template <typename T>
239  getMemIndex(T val) {
240  const void *opaqueVal = val.getAsOpaquePointer();
241 
242  // Get or insert a reference to this value.
243  auto it = uniquedDataToMemIndex.try_emplace(
244  opaqueVal, maxValueMemoryIndex + uniquedData.size());
245  if (it.second)
246  uniquedData.push_back(opaqueVal);
247  return it.first->second;
248  }
249 
250 private:
251  /// Allocate memory indices for the results of operations within the matcher
252  /// and rewriters.
253  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
254  ModuleOp rewriterModule);
255 
256  /// Generate the bytecode for the given operation.
257  void generate(Region *region, ByteCodeWriter &writer);
258  void generate(Operation *op, ByteCodeWriter &writer);
259  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
260  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
261  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
297 
298  /// Mapping from value to its corresponding memory index.
299  DenseMap<Value, ByteCodeField> valueToMemIndex;
300 
301  /// Mapping from a range value to its corresponding range storage index.
302  DenseMap<Value, ByteCodeField> valueToRangeIndex;
303 
304  /// Mapping from the name of an externally registered rewrite to its index in
305  /// the bytecode registry.
306  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
307 
308  /// Mapping from the name of an externally registered constraint to its index
309  /// in the bytecode registry.
310  llvm::StringMap<ByteCodeField> constraintToMemIndex;
311 
312  /// Mapping from rewriter function name to the bytecode address of the
313  /// rewriter function in byte.
314  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
315 
316  /// Mapping from a uniqued storage object to its memory index within
317  /// `uniquedData`.
318  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
319 
320  /// The current level of the foreach loop.
321  ByteCodeField curLoopLevel = 0;
322 
323  /// The current MLIR context.
324  MLIRContext *ctx;
325 
326  /// Mapping from block to its address.
328 
329  /// Data of the ByteCode class to be populated.
330  std::vector<const void *> &uniquedData;
331  SmallVectorImpl<ByteCodeField> &matcherByteCode;
332  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
334  ByteCodeField &maxValueMemoryIndex;
335  ByteCodeField &maxOpRangeMemoryIndex;
336  ByteCodeField &maxTypeRangeMemoryIndex;
337  ByteCodeField &maxValueRangeMemoryIndex;
338  ByteCodeField &maxLoopLevel;
339 
340  /// A map of pattern configurations.
342 };
343 
344 /// This class provides utilities for writing a bytecode stream.
345 struct ByteCodeWriter {
346  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
347  : bytecode(bytecode), generator(generator) {}
348 
349  /// Append a field to the bytecode.
350  void append(ByteCodeField field) { bytecode.push_back(field); }
351  void append(OpCode opCode) { bytecode.push_back(opCode); }
352 
353  /// Append an address to the bytecode.
354  void append(ByteCodeAddr field) {
355  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
356  "unexpected ByteCode address size");
357 
358  ByteCodeField fieldParts[2];
359  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
360  bytecode.append({fieldParts[0], fieldParts[1]});
361  }
362 
363  /// Append a single successor to the bytecode, the exact address will need to
364  /// be resolved later.
365  void append(Block *successor) {
366  // Add back a reference to the successor so that the address can be resolved
367  // later.
368  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
369  append(ByteCodeAddr(0));
370  }
371 
372  /// Append a successor range to the bytecode, the exact address will need to
373  /// be resolved later.
374  void append(SuccessorRange successors) {
375  for (Block *successor : successors)
376  append(successor);
377  }
378 
379  /// Append a range of values that will be read as generic PDLValues.
380  void appendPDLValueList(OperandRange values) {
381  bytecode.push_back(values.size());
382  for (Value value : values)
383  appendPDLValue(value);
384  }
385 
386  /// Append a value as a PDLValue.
387  void appendPDLValue(Value value) {
388  appendPDLValueKind(value);
389  append(value);
390  }
391 
392  /// Append the PDLValue::Kind of the given value.
393  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
394 
395  /// Append the PDLValue::Kind of the given type.
396  void appendPDLValueKind(Type type) {
397  PDLValue::Kind kind =
399  .Case<pdl::AttributeType>(
400  [](Type) { return PDLValue::Kind::Attribute; })
401  .Case<pdl::OperationType>(
402  [](Type) { return PDLValue::Kind::Operation; })
403  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
404  if (rangeTy.getElementType().isa<pdl::TypeType>())
407  })
408  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
409  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
410  bytecode.push_back(static_cast<ByteCodeField>(kind));
411  }
412 
413  /// Append a value that will be stored in a memory slot and not inline within
414  /// the bytecode.
415  template <typename T>
418  append(T value) {
419  bytecode.push_back(generator.getMemIndex(value));
420  }
421 
422  /// Append a range of values.
423  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425  append(T range) {
426  bytecode.push_back(llvm::size(range));
427  for (auto it : range)
428  append(it);
429  }
430 
431  /// Append a variadic number of fields to the bytecode.
432  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
433  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
434  append(field);
435  append(field2, fields...);
436  }
437 
438  /// Appends a value as a pointer, stored inline within the bytecode.
439  template <typename T>
441  appendInline(T value) {
442  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
443  const void *pointer = value.getAsOpaquePointer();
444  ByteCodeField fieldParts[numParts];
445  std::memcpy(fieldParts, &pointer, sizeof(const void *));
446  bytecode.append(fieldParts, fieldParts + numParts);
447  }
448 
449  /// Successor references in the bytecode that have yet to be resolved.
450  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
451 
452  /// The underlying bytecode buffer.
454 
455  /// The main generator producing PDL.
456  Generator &generator;
457 };
458 
459 /// This class represents a live range of PDL Interpreter values, containing
460 /// information about when values are live within a match/rewrite.
461 struct ByteCodeLiveRange {
462  using Set = llvm::IntervalMap<uint64_t, char, 16>;
463  using Allocator = Set::Allocator;
464 
465  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
466 
467  /// Union this live range with the one provided.
468  void unionWith(const ByteCodeLiveRange &rhs) {
469  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
470  ++it)
471  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
472  }
473 
474  /// Returns true if this range overlaps with the one provided.
475  bool overlaps(const ByteCodeLiveRange &rhs) const {
476  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
477  .valid();
478  }
479 
480  /// A map representing the ranges of the match/rewrite that a value is live in
481  /// the interpreter.
482  ///
483  /// We use std::unique_ptr here, because IntervalMap does not provide a
484  /// correct copy or move constructor. We can eliminate the pointer once
485  /// https://reviews.llvm.org/D113240 lands.
486  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
487 
488  /// The operation range storage index for this range.
489  Optional<unsigned> opRangeIndex;
490 
491  /// The type range storage index for this range.
492  Optional<unsigned> typeRangeIndex;
493 
494  /// The value range storage index for this range.
495  Optional<unsigned> valueRangeIndex;
496 };
497 } // namespace
498 
499 void Generator::generate(ModuleOp module) {
500  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
501  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
502  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
503  pdl_interp::PDLInterpDialect::getRewriterModuleName());
504  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
505 
506  // Allocate memory indices for the results of operations within the matcher
507  // and rewriters.
508  allocateMemoryIndices(matcherFunc, rewriterModule);
509 
510  // Generate code for the rewriter functions.
511  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
512  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
513  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
514  for (Operation &op : rewriterFunc.getOps())
515  generate(&op, rewriterByteCodeWriter);
516  }
517  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
518  "unexpected branches in rewriter function");
519 
520  // Generate code for the matcher function.
521  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
522  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
523 
524  // Resolve successor references in the matcher.
525  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
526  ByteCodeAddr addr = blockToAddr[it.first];
527  for (unsigned offsetToFix : it.second)
528  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
529  }
530 }
531 
532 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
533  ModuleOp rewriterModule) {
534  // Rewriters use simplistic allocation scheme that simply assigns an index to
535  // each result.
536  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
537  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
538  auto processRewriterValue = [&](Value val) {
539  valueToMemIndex.try_emplace(val, index++);
540  if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
541  Type elementTy = rangeType.getElementType();
542  if (elementTy.isa<pdl::TypeType>())
543  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
544  else if (elementTy.isa<pdl::ValueType>())
545  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
546  }
547  };
548 
549  for (BlockArgument arg : rewriterFunc.getArguments())
550  processRewriterValue(arg);
551  rewriterFunc.getBody().walk([&](Operation *op) {
552  for (Value result : op->getResults())
553  processRewriterValue(result);
554  });
555  if (index > maxValueMemoryIndex)
556  maxValueMemoryIndex = index;
557  if (typeRangeIndex > maxTypeRangeMemoryIndex)
558  maxTypeRangeMemoryIndex = typeRangeIndex;
559  if (valueRangeIndex > maxValueRangeMemoryIndex)
560  maxValueRangeMemoryIndex = valueRangeIndex;
561  }
562 
563  // The matcher function uses a more sophisticated numbering that tries to
564  // minimize the number of memory indices assigned. This is done by determining
565  // a live range of the values within the matcher, then the allocation is just
566  // finding the minimal number of overlapping live ranges. This is essentially
567  // a simplified form of register allocation where we don't necessarily have a
568  // limited number of registers, but we still want to minimize the number used.
569  DenseMap<Operation *, unsigned> opToFirstIndex;
570  DenseMap<Operation *, unsigned> opToLastIndex;
571 
572  // A custom walk that marks the first and the last index of each operation.
573  // The entry marks the beginning of the liveness range for this operation,
574  // followed by nested operations, followed by the end of the liveness range.
575  unsigned index = 0;
576  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
577  opToFirstIndex.try_emplace(op, index++);
578  for (Region &region : op->getRegions())
579  for (Block &block : region.getBlocks())
580  for (Operation &nested : block)
581  walk(&nested);
582  opToLastIndex.try_emplace(op, index++);
583  };
584  walk(matcherFunc);
585 
586  // Liveness info for each of the defs within the matcher.
587  ByteCodeLiveRange::Allocator allocator;
588  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
589 
590  // Assign the root operation being matched to slot 0.
591  BlockArgument rootOpArg = matcherFunc.getArgument(0);
592  valueToMemIndex[rootOpArg] = 0;
593 
594  // Walk each of the blocks, computing the def interval that the value is used.
595  Liveness matcherLiveness(matcherFunc);
596  matcherFunc->walk([&](Block *block) {
597  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
598  assert(info && "expected liveness info for block");
599  auto processValue = [&](Value value, Operation *firstUseOrDef) {
600  // We don't need to process the root op argument, this value is always
601  // assigned to the first memory slot.
602  if (value == rootOpArg)
603  return;
604 
605  // Set indices for the range of this block that the value is used.
606  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
607  defRangeIt->second.liveness->insert(
608  opToFirstIndex[firstUseOrDef],
609  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
610  /*dummyValue*/ 0);
611 
612  // Check to see if this value is a range type.
613  if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
614  Type eleType = rangeTy.getElementType();
615  if (eleType.isa<pdl::OperationType>())
616  defRangeIt->second.opRangeIndex = 0;
617  else if (eleType.isa<pdl::TypeType>())
618  defRangeIt->second.typeRangeIndex = 0;
619  else if (eleType.isa<pdl::ValueType>())
620  defRangeIt->second.valueRangeIndex = 0;
621  }
622  };
623 
624  // Process the live-ins of this block.
625  for (Value liveIn : info->in()) {
626  // Only process the value if it has been defined in the current region.
627  // Other values that span across pdl_interp.foreach will be added higher
628  // up. This ensures that the we keep them alive for the entire duration
629  // of the loop.
630  if (liveIn.getParentRegion() == block->getParent())
631  processValue(liveIn, &block->front());
632  }
633 
634  // Process the block arguments for the entry block (those are not live-in).
635  if (block->isEntryBlock()) {
636  for (Value argument : block->getArguments())
637  processValue(argument, &block->front());
638  }
639 
640  // Process any new defs within this block.
641  for (Operation &op : *block)
642  for (Value result : op.getResults())
643  processValue(result, &op);
644  });
645 
646  // Greedily allocate memory slots using the computed def live ranges.
647  std::vector<ByteCodeLiveRange> allocatedIndices;
648 
649  // The number of memory indices currently allocated (and its next value).
650  // Recall that the root gets allocated memory index 0.
651  ByteCodeField numIndices = 1;
652 
653  // The number of memory ranges of various types (and their next values).
654  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
655 
656  for (auto &defIt : valueDefRanges) {
657  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
658  ByteCodeLiveRange &defRange = defIt.second;
659 
660  // Try to allocate to an existing index.
661  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
662  ByteCodeLiveRange &existingRange = existingIndexIt.value();
663  if (!defRange.overlaps(existingRange)) {
664  existingRange.unionWith(defRange);
665  memIndex = existingIndexIt.index() + 1;
666 
667  if (defRange.opRangeIndex) {
668  if (!existingRange.opRangeIndex)
669  existingRange.opRangeIndex = numOpRanges++;
670  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
671  } else if (defRange.typeRangeIndex) {
672  if (!existingRange.typeRangeIndex)
673  existingRange.typeRangeIndex = numTypeRanges++;
674  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
675  } else if (defRange.valueRangeIndex) {
676  if (!existingRange.valueRangeIndex)
677  existingRange.valueRangeIndex = numValueRanges++;
678  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
679  }
680  break;
681  }
682  }
683 
684  // If no existing index could be used, add a new one.
685  if (memIndex == 0) {
686  allocatedIndices.emplace_back(allocator);
687  ByteCodeLiveRange &newRange = allocatedIndices.back();
688  newRange.unionWith(defRange);
689 
690  // Allocate an index for op/type/value ranges.
691  if (defRange.opRangeIndex) {
692  newRange.opRangeIndex = numOpRanges;
693  valueToRangeIndex[defIt.first] = numOpRanges++;
694  } else if (defRange.typeRangeIndex) {
695  newRange.typeRangeIndex = numTypeRanges;
696  valueToRangeIndex[defIt.first] = numTypeRanges++;
697  } else if (defRange.valueRangeIndex) {
698  newRange.valueRangeIndex = numValueRanges;
699  valueToRangeIndex[defIt.first] = numValueRanges++;
700  }
701 
702  memIndex = allocatedIndices.size();
703  ++numIndices;
704  }
705  }
706 
707  // Print the index usage and ensure that we did not run out of index space.
708  LLVM_DEBUG({
709  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
710  << "(down from initial " << valueDefRanges.size() << ").\n";
711  });
712  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
713  "Ran out of memory for allocated indices");
714 
715  // Update the max number of indices.
716  if (numIndices > maxValueMemoryIndex)
717  maxValueMemoryIndex = numIndices;
718  if (numOpRanges > maxOpRangeMemoryIndex)
719  maxOpRangeMemoryIndex = numOpRanges;
720  if (numTypeRanges > maxTypeRangeMemoryIndex)
721  maxTypeRangeMemoryIndex = numTypeRanges;
722  if (numValueRanges > maxValueRangeMemoryIndex)
723  maxValueRangeMemoryIndex = numValueRanges;
724 }
725 
726 void Generator::generate(Region *region, ByteCodeWriter &writer) {
727  llvm::ReversePostOrderTraversal<Region *> rpot(region);
728  for (Block *block : rpot) {
729  // Keep track of where this block begins within the matcher function.
730  blockToAddr.try_emplace(block, matcherByteCode.size());
731  for (Operation &op : *block)
732  generate(&op, writer);
733  }
734 }
735 
736 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
737  LLVM_DEBUG({
738  // The following list must contain all the operations that do not
739  // produce any bytecode.
740  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
741  writer.appendInline(op->getLoc());
742  });
744  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
745  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
746  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
747  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
748  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
749  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
750  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
751  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
752  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
753  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
754  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
755  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
756  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
757  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
758  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
759  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
760  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
761  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
762  pdl_interp::SwitchResultCountOp>(
763  [&](auto interpOp) { this->generate(interpOp, writer); })
764  .Default([](Operation *) {
765  llvm_unreachable("unknown `pdl_interp` operation");
766  });
767 }
768 
769 void Generator::generate(pdl_interp::ApplyConstraintOp op,
770  ByteCodeWriter &writer) {
771  assert(constraintToMemIndex.count(op.getName()) &&
772  "expected index for constraint function");
773  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
774  writer.appendPDLValueList(op.getArgs());
775  writer.append(op.getSuccessors());
776 }
777 void Generator::generate(pdl_interp::ApplyRewriteOp op,
778  ByteCodeWriter &writer) {
779  assert(externalRewriterToMemIndex.count(op.getName()) &&
780  "expected index for rewrite function");
781  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
782  writer.appendPDLValueList(op.getArgs());
783 
784  ResultRange results = op.getResults();
785  writer.append(ByteCodeField(results.size()));
786  for (Value result : results) {
787  // In debug mode we also record the expected kind of the result, so that we
788  // can provide extra verification of the native rewrite function.
789 #ifndef NDEBUG
790  writer.appendPDLValueKind(result);
791 #endif
792 
793  // Range results also need to append the range storage index.
794  if (result.getType().isa<pdl::RangeType>())
795  writer.append(getRangeStorageIndex(result));
796  writer.append(result);
797  }
798 }
799 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
800  Value lhs = op.getLhs();
801  if (lhs.getType().isa<pdl::RangeType>()) {
802  writer.append(OpCode::AreRangesEqual);
803  writer.appendPDLValueKind(lhs);
804  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
805  return;
806  }
807 
808  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
809 }
810 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
811  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
812 }
813 void Generator::generate(pdl_interp::CheckAttributeOp op,
814  ByteCodeWriter &writer) {
815  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
816  op.getSuccessors());
817 }
818 void Generator::generate(pdl_interp::CheckOperandCountOp op,
819  ByteCodeWriter &writer) {
820  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
821  static_cast<ByteCodeField>(op.getCompareAtLeast()),
822  op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckOperationNameOp op,
825  ByteCodeWriter &writer) {
826  writer.append(OpCode::CheckOperationName, op.getInputOp(),
827  OperationName(op.getName(), ctx), op.getSuccessors());
828 }
829 void Generator::generate(pdl_interp::CheckResultCountOp op,
830  ByteCodeWriter &writer) {
831  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
832  static_cast<ByteCodeField>(op.getCompareAtLeast()),
833  op.getSuccessors());
834 }
835 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
836  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
837  op.getSuccessors());
838 }
839 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
840  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
841  op.getSuccessors());
842 }
843 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
844  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
845  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
846 }
847 void Generator::generate(pdl_interp::CreateAttributeOp op,
848  ByteCodeWriter &writer) {
849  // Simply repoint the memory index of the result to the constant.
850  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
851 }
852 void Generator::generate(pdl_interp::CreateOperationOp op,
853  ByteCodeWriter &writer) {
854  writer.append(OpCode::CreateOperation, op.getResultOp(),
855  OperationName(op.getName(), ctx));
856  writer.appendPDLValueList(op.getInputOperands());
857 
858  // Add the attributes.
859  OperandRange attributes = op.getInputAttributes();
860  writer.append(static_cast<ByteCodeField>(attributes.size()));
861  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
862  writer.append(std::get<0>(it), std::get<1>(it));
863 
864  // Add the result types. If the operation has inferred results, we use a
865  // marker "size" value. Otherwise, we add the list of explicit result types.
866  if (op.getInferredResultTypes())
867  writer.append(kInferTypesMarker);
868  else
869  writer.appendPDLValueList(op.getInputResultTypes());
870 }
871 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
872  // Append the correct opcode for the range type.
873  TypeSwitch<Type>(op.getType().getElementType())
874  .Case(
875  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
876  .Case([&](pdl::ValueType) {
877  writer.append(OpCode::CreateDynamicValueRange);
878  });
879 
880  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
881  writer.appendPDLValueList(op->getOperands());
882 }
883 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
884  // Simply repoint the memory index of the result to the constant.
885  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
886 }
887 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
888  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
889  getRangeStorageIndex(op.getResult()), op.getValue());
890 }
891 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
892  writer.append(OpCode::EraseOp, op.getInputOp());
893 }
894 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
895  OpCode opCode =
896  TypeSwitch<Type, OpCode>(op.getResult().getType())
897  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
898  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
899  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
900  .Default([](Type) -> OpCode {
901  llvm_unreachable("unsupported element type");
902  });
903  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
904 }
905 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
906  writer.append(OpCode::Finalize);
907 }
908 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
909  BlockArgument arg = op.getLoopVariable();
910  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
911  writer.appendPDLValueKind(arg.getType());
912  writer.append(curLoopLevel, op.getSuccessor());
913  ++curLoopLevel;
914  if (curLoopLevel > maxLoopLevel)
915  maxLoopLevel = curLoopLevel;
916  generate(&op.getRegion(), writer);
917  --curLoopLevel;
918 }
919 void Generator::generate(pdl_interp::GetAttributeOp op,
920  ByteCodeWriter &writer) {
921  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
922  op.getNameAttr());
923 }
924 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
925  ByteCodeWriter &writer) {
926  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
927 }
928 void Generator::generate(pdl_interp::GetDefiningOpOp op,
929  ByteCodeWriter &writer) {
930  writer.append(OpCode::GetDefiningOp, op.getInputOp());
931  writer.appendPDLValue(op.getValue());
932 }
933 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
934  uint32_t index = op.getIndex();
935  if (index < 4)
936  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
937  else
938  writer.append(OpCode::GetOperandN, index);
939  writer.append(op.getInputOp(), op.getValue());
940 }
941 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
942  Value result = op.getValue();
943  Optional<uint32_t> index = op.getIndex();
944  writer.append(OpCode::GetOperands,
945  index.value_or(std::numeric_limits<uint32_t>::max()),
946  op.getInputOp());
947  if (result.getType().isa<pdl::RangeType>())
948  writer.append(getRangeStorageIndex(result));
949  else
951  writer.append(result);
952 }
953 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
954  uint32_t index = op.getIndex();
955  if (index < 4)
956  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
957  else
958  writer.append(OpCode::GetResultN, index);
959  writer.append(op.getInputOp(), op.getValue());
960 }
961 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
962  Value result = op.getValue();
963  Optional<uint32_t> index = op.getIndex();
964  writer.append(OpCode::GetResults,
965  index.value_or(std::numeric_limits<uint32_t>::max()),
966  op.getInputOp());
967  if (result.getType().isa<pdl::RangeType>())
968  writer.append(getRangeStorageIndex(result));
969  else
971  writer.append(result);
972 }
973 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
974  Value operations = op.getOperations();
975  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
976  writer.append(OpCode::GetUsers, operations, rangeIndex);
977  writer.appendPDLValue(op.getValue());
978 }
979 void Generator::generate(pdl_interp::GetValueTypeOp op,
980  ByteCodeWriter &writer) {
981  if (op.getType().isa<pdl::RangeType>()) {
982  Value result = op.getResult();
983  writer.append(OpCode::GetValueRangeTypes, result,
984  getRangeStorageIndex(result), op.getValue());
985  } else {
986  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
987  }
988 }
989 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
990  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
991 }
992 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
993  ByteCodeField patternIndex = patterns.size();
994  patterns.emplace_back(PDLByteCodePattern::create(
995  op, configMap.lookup(op),
996  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
997  writer.append(OpCode::RecordMatch, patternIndex,
998  SuccessorRange(op.getOperation()), op.getMatchedOps());
999  writer.appendPDLValueList(op.getInputs());
1000 }
1001 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1002  writer.append(OpCode::ReplaceOp, op.getInputOp());
1003  writer.appendPDLValueList(op.getReplValues());
1004 }
1005 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1006  ByteCodeWriter &writer) {
1007  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1008  op.getCaseValuesAttr(), op.getSuccessors());
1009 }
1010 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1011  ByteCodeWriter &writer) {
1012  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1013  op.getCaseValuesAttr(), op.getSuccessors());
1014 }
1015 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1016  ByteCodeWriter &writer) {
1017  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1018  return OperationName(attr.cast<StringAttr>().getValue(), ctx);
1019  });
1020  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1021  op.getSuccessors());
1022 }
1023 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1024  ByteCodeWriter &writer) {
1025  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1026  op.getCaseValuesAttr(), op.getSuccessors());
1027 }
1028 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1029  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1030  op.getSuccessors());
1031 }
1032 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1033  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1034  op.getSuccessors());
1035 }
1036 
1037 //===----------------------------------------------------------------------===//
1038 // PDLByteCode
1039 //===----------------------------------------------------------------------===//
1040 
1042  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1044  llvm::StringMap<PDLConstraintFunction> constraintFns,
1045  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1046  : configs(std::move(configs)) {
1047  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1048  rewriterByteCode, patterns, maxValueMemoryIndex,
1049  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1050  maxLoopLevel, constraintFns, rewriteFns, configMap);
1051  generator.generate(module);
1052 
1053  // Initialize the external functions.
1054  for (auto &it : constraintFns)
1055  constraintFunctions.push_back(std::move(it.second));
1056  for (auto &it : rewriteFns)
1057  rewriteFunctions.push_back(std::move(it.second));
1058 }
1059 
1060 /// Initialize the given state such that it can be used to execute the current
1061 /// bytecode.
1063  state.memory.resize(maxValueMemoryIndex, nullptr);
1064  state.opRangeMemory.resize(maxOpRangeCount);
1065  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1066  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1067  state.loopIndex.resize(maxLoopLevel, 0);
1068  state.currentPatternBenefits.reserve(patterns.size());
1069  for (const PDLByteCodePattern &pattern : patterns)
1070  state.currentPatternBenefits.push_back(pattern.getBenefit());
1071 }
1072 
1073 //===----------------------------------------------------------------------===//
1074 // ByteCode Execution
1075 
1076 namespace {
1077 /// This class provides support for executing a bytecode stream.
1078 class ByteCodeExecutor {
1079 public:
1080  ByteCodeExecutor(
1081  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1082  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1083  MutableArrayRef<TypeRange> typeRangeMemory,
1084  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1085  MutableArrayRef<ValueRange> valueRangeMemory,
1086  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1087  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1089  ArrayRef<PatternBenefit> currentPatternBenefits,
1091  ArrayRef<PDLConstraintFunction> constraintFunctions,
1092  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1093  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1094  typeRangeMemory(typeRangeMemory),
1095  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1096  valueRangeMemory(valueRangeMemory),
1097  allocatedValueRangeMemory(allocatedValueRangeMemory),
1098  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1099  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1100  constraintFunctions(constraintFunctions),
1101  rewriteFunctions(rewriteFunctions) {}
1102 
1103  /// Start executing the code at the current bytecode index. `matches` is an
1104  /// optional field provided when this function is executed in a matching
1105  /// context.
1107  execute(PatternRewriter &rewriter,
1108  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1109  Optional<Location> mainRewriteLoc = {});
1110 
1111 private:
1112  /// Internal implementation of executing each of the bytecode commands.
1113  void executeApplyConstraint(PatternRewriter &rewriter);
1114  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1115  void executeAreEqual();
1116  void executeAreRangesEqual();
1117  void executeBranch();
1118  void executeCheckOperandCount();
1119  void executeCheckOperationName();
1120  void executeCheckResultCount();
1121  void executeCheckTypes();
1122  void executeContinue();
1123  void executeCreateConstantTypeRange();
1124  void executeCreateOperation(PatternRewriter &rewriter,
1125  Location mainRewriteLoc);
1126  template <typename T>
1127  void executeDynamicCreateRange(StringRef type);
1128  void executeEraseOp(PatternRewriter &rewriter);
1129  template <typename T, typename Range, PDLValue::Kind kind>
1130  void executeExtract();
1131  void executeFinalize();
1132  void executeForEach();
1133  void executeGetAttribute();
1134  void executeGetAttributeType();
1135  void executeGetDefiningOp();
1136  void executeGetOperand(unsigned index);
1137  void executeGetOperands();
1138  void executeGetResult(unsigned index);
1139  void executeGetResults();
1140  void executeGetUsers();
1141  void executeGetValueType();
1142  void executeGetValueRangeTypes();
1143  void executeIsNotNull();
1144  void executeRecordMatch(PatternRewriter &rewriter,
1146  void executeReplaceOp(PatternRewriter &rewriter);
1147  void executeSwitchAttribute();
1148  void executeSwitchOperandCount();
1149  void executeSwitchOperationName();
1150  void executeSwitchResultCount();
1151  void executeSwitchType();
1152  void executeSwitchTypes();
1153 
1154  /// Pushes a code iterator to the stack.
1155  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1156 
1157  /// Pops a code iterator from the stack, returning true on success.
1158  void popCodeIt() {
1159  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1160  curCodeIt = resumeCodeIt.back();
1161  resumeCodeIt.pop_back();
1162  }
1163 
1164  /// Return the bytecode iterator at the start of the current op code.
1165  const ByteCodeField *getPrevCodeIt() const {
1166  LLVM_DEBUG({
1167  // Account for the op code and the Location stored inline.
1168  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1169  });
1170 
1171  // Account for the op code only.
1172  return curCodeIt - 1;
1173  }
1174 
1175  /// Read a value from the bytecode buffer, optionally skipping a certain
1176  /// number of prefix values. These methods always update the buffer to point
1177  /// to the next field after the read data.
1178  template <typename T = ByteCodeField>
1179  T read(size_t skipN = 0) {
1180  curCodeIt += skipN;
1181  return readImpl<T>();
1182  }
1183  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1184 
1185  /// Read a list of values from the bytecode buffer.
1186  template <typename ValueT, typename T>
1187  void readList(SmallVectorImpl<T> &list) {
1188  list.clear();
1189  for (unsigned i = 0, e = read(); i != e; ++i)
1190  list.push_back(read<ValueT>());
1191  }
1192 
1193  /// Read a list of values from the bytecode buffer. The values may be encoded
1194  /// either as a single element or a range of elements.
1195  void readList(SmallVectorImpl<Type> &list) {
1196  for (unsigned i = 0, e = read(); i != e; ++i) {
1197  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1198  list.push_back(read<Type>());
1199  } else {
1200  TypeRange *values = read<TypeRange *>();
1201  list.append(values->begin(), values->end());
1202  }
1203  }
1204  }
1205  void readList(SmallVectorImpl<Value> &list) {
1206  for (unsigned i = 0, e = read(); i != e; ++i) {
1207  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1208  list.push_back(read<Value>());
1209  } else {
1210  ValueRange *values = read<ValueRange *>();
1211  list.append(values->begin(), values->end());
1212  }
1213  }
1214  }
1215 
1216  /// Read a value stored inline as a pointer.
1217  template <typename T>
1219  readInline() {
1220  const void *pointer;
1221  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1222  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1223  return T::getFromOpaquePointer(pointer);
1224  }
1225 
1226  /// Jump to a specific successor based on a predicate value.
1227  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1228  /// Jump to a specific successor based on a destination index.
1229  void selectJump(size_t destIndex) {
1230  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1231  }
1232 
1233  /// Handle a switch operation with the provided value and cases.
1234  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1235  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1236  LLVM_DEBUG({
1237  llvm::dbgs() << " * Value: " << value << "\n"
1238  << " * Cases: ";
1239  llvm::interleaveComma(cases, llvm::dbgs());
1240  llvm::dbgs() << "\n";
1241  });
1242 
1243  // Check to see if the attribute value is within the case list. Jump to
1244  // the correct successor index based on the result.
1245  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1246  if (cmp(*it, value))
1247  return selectJump(size_t((it - cases.begin()) + 1));
1248  selectJump(size_t(0));
1249  }
1250 
1251  /// Store a pointer to memory.
1252  void storeToMemory(unsigned index, const void *value) {
1253  memory[index] = value;
1254  }
1255 
1256  /// Store a value to memory as an opaque pointer.
1257  template <typename T>
1259  storeToMemory(unsigned index, T value) {
1260  memory[index] = value.getAsOpaquePointer();
1261  }
1262 
1263  /// Internal implementation of reading various data types from the bytecode
1264  /// stream.
1265  template <typename T>
1266  const void *readFromMemory() {
1267  size_t index = *curCodeIt++;
1268 
1269  // If this type is an SSA value, it can only be stored in non-const memory.
1270  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1271  Value>::value ||
1272  index < memory.size())
1273  return memory[index];
1274 
1275  // Otherwise, if this index is not inbounds it is uniqued.
1276  return uniquedMemory[index - memory.size()];
1277  }
1278  template <typename T>
1280  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1281  }
1282  template <typename T>
1284  T>
1285  readImpl() {
1286  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1287  }
1288  template <typename T>
1290  switch (read<PDLValue::Kind>()) {
1292  return read<Attribute>();
1294  return read<Operation *>();
1295  case PDLValue::Kind::Type:
1296  return read<Type>();
1297  case PDLValue::Kind::Value:
1298  return read<Value>();
1300  return read<TypeRange *>();
1302  return read<ValueRange *>();
1303  }
1304  llvm_unreachable("unhandled PDLValue::Kind");
1305  }
1306  template <typename T>
1308  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1309  "unexpected ByteCode address size");
1310  ByteCodeAddr result;
1311  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1312  curCodeIt += 2;
1313  return result;
1314  }
1315  template <typename T>
1317  return *curCodeIt++;
1318  }
1319  template <typename T>
1321  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1322  }
1323 
1324  /// Assign the given range to the given memory index. This allocates a new
1325  /// range object if necessary.
1326  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1327  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1328  unsigned rangeIndex) {
1329  // Utility functor used to type-erase the assignment.
1330  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1331  // If the input range is empty, we don't need to allocate anything.
1332  if (range.empty()) {
1333  rangeMemory[rangeIndex] = {};
1334  } else {
1335  // Allocate a buffer for this type range.
1336  llvm::OwningArrayRef<T> storage(llvm::size(range));
1337  llvm::copy(range, storage.begin());
1338 
1339  // Assign this to the range slot and use the range as the value for the
1340  // memory index.
1341  allocatedRangeMemory.emplace_back(std::move(storage));
1342  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1343  }
1344  memory[memIndex] = &rangeMemory[rangeIndex];
1345  };
1346 
1347  // Dispatch based on the concrete range type.
1348  if constexpr (std::is_same_v<T, Type>) {
1349  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1350  } else if constexpr (std::is_same_v<T, Value>) {
1351  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1352  } else {
1353  llvm_unreachable("unhandled range type");
1354  }
1355  }
1356 
1357  /// The underlying bytecode buffer.
1358  const ByteCodeField *curCodeIt;
1359 
1360  /// The stack of bytecode positions at which to resume operation.
1362 
1363  /// The current execution memory.
1365  MutableArrayRef<OwningOpRange> opRangeMemory;
1366  MutableArrayRef<TypeRange> typeRangeMemory;
1367  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1368  MutableArrayRef<ValueRange> valueRangeMemory;
1369  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1370 
1371  /// The current loop indices.
1372  MutableArrayRef<unsigned> loopIndex;
1373 
1374  /// References to ByteCode data necessary for execution.
1375  ArrayRef<const void *> uniquedMemory;
1377  ArrayRef<PatternBenefit> currentPatternBenefits;
1379  ArrayRef<PDLConstraintFunction> constraintFunctions;
1380  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1381 };
1382 
1383 /// This class is an instantiation of the PDLResultList that provides access to
1384 /// the returned results. This API is not on `PDLResultList` to avoid
1385 /// overexposing access to information specific solely to the ByteCode.
1386 class ByteCodeRewriteResultList : public PDLResultList {
1387 public:
1388  ByteCodeRewriteResultList(unsigned maxNumResults)
1389  : PDLResultList(maxNumResults) {}
1390 
1391  /// Return the list of PDL results.
1392  MutableArrayRef<PDLValue> getResults() { return results; }
1393 
1394  /// Return the type ranges allocated by this list.
1395  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1396  return allocatedTypeRanges;
1397  }
1398 
1399  /// Return the value ranges allocated by this list.
1400  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1401  return allocatedValueRanges;
1402  }
1403 };
1404 } // namespace
1405 
1406 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1407  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1408  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1410  readList<PDLValue>(args);
1411 
1412  LLVM_DEBUG({
1413  llvm::dbgs() << " * Arguments: ";
1414  llvm::interleaveComma(args, llvm::dbgs());
1415  });
1416 
1417  // Invoke the constraint and jump to the proper destination.
1418  selectJump(succeeded(constraintFn(rewriter, args)));
1419 }
1420 
1421 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1422  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1423  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1425  readList<PDLValue>(args);
1426 
1427  LLVM_DEBUG({
1428  llvm::dbgs() << " * Arguments: ";
1429  llvm::interleaveComma(args, llvm::dbgs());
1430  });
1431 
1432  // Execute the rewrite function.
1433  ByteCodeField numResults = read();
1434  ByteCodeRewriteResultList results(numResults);
1435  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1436 
1437  assert(results.getResults().size() == numResults &&
1438  "native PDL rewrite function returned unexpected number of results");
1439 
1440  // Store the results in the bytecode memory.
1441  for (PDLValue &result : results.getResults()) {
1442  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1443 
1444 // In debug mode we also verify the expected kind of the result.
1445 #ifndef NDEBUG
1446  assert(result.getKind() == read<PDLValue::Kind>() &&
1447  "native PDL rewrite function returned an unexpected type of result");
1448 #endif
1449 
1450  // If the result is a range, we need to copy it over to the bytecodes
1451  // range memory.
1452  if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1453  unsigned rangeIndex = read();
1454  typeRangeMemory[rangeIndex] = *typeRange;
1455  memory[read()] = &typeRangeMemory[rangeIndex];
1456  } else if (Optional<ValueRange> valueRange =
1457  result.dyn_cast<ValueRange>()) {
1458  unsigned rangeIndex = read();
1459  valueRangeMemory[rangeIndex] = *valueRange;
1460  memory[read()] = &valueRangeMemory[rangeIndex];
1461  } else {
1462  memory[read()] = result.getAsOpaquePointer();
1463  }
1464  }
1465 
1466  // Copy over any underlying storage allocated for result ranges.
1467  for (auto &it : results.getAllocatedTypeRanges())
1468  allocatedTypeRangeMemory.push_back(std::move(it));
1469  for (auto &it : results.getAllocatedValueRanges())
1470  allocatedValueRangeMemory.push_back(std::move(it));
1471 
1472  // Process the result of the rewrite.
1473  if (failed(rewriteResult)) {
1474  LLVM_DEBUG(llvm::dbgs() << " - Failed");
1475  return failure();
1476  }
1477  return success();
1478 }
1479 
1480 void ByteCodeExecutor::executeAreEqual() {
1481  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1482  const void *lhs = read<const void *>();
1483  const void *rhs = read<const void *>();
1484 
1485  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1486  selectJump(lhs == rhs);
1487 }
1488 
1489 void ByteCodeExecutor::executeAreRangesEqual() {
1490  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1491  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1492  const void *lhs = read<const void *>();
1493  const void *rhs = read<const void *>();
1494 
1495  switch (valueKind) {
1497  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1498  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1499  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1500  selectJump(*lhsRange == *rhsRange);
1501  break;
1502  }
1504  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1505  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1506  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1507  selectJump(*lhsRange == *rhsRange);
1508  break;
1509  }
1510  default:
1511  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1512  }
1513 }
1514 
1515 void ByteCodeExecutor::executeBranch() {
1516  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1517  curCodeIt = &code[read<ByteCodeAddr>()];
1518 }
1519 
1520 void ByteCodeExecutor::executeCheckOperandCount() {
1521  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1522  Operation *op = read<Operation *>();
1523  uint32_t expectedCount = read<uint32_t>();
1524  bool compareAtLeast = read();
1525 
1526  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1527  << " * Expected: " << expectedCount << "\n"
1528  << " * Comparator: "
1529  << (compareAtLeast ? ">=" : "==") << "\n");
1530  if (compareAtLeast)
1531  selectJump(op->getNumOperands() >= expectedCount);
1532  else
1533  selectJump(op->getNumOperands() == expectedCount);
1534 }
1535 
1536 void ByteCodeExecutor::executeCheckOperationName() {
1537  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1538  Operation *op = read<Operation *>();
1539  OperationName expectedName = read<OperationName>();
1540 
1541  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1542  << " * Expected: \"" << expectedName << "\"\n");
1543  selectJump(op->getName() == expectedName);
1544 }
1545 
1546 void ByteCodeExecutor::executeCheckResultCount() {
1547  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1548  Operation *op = read<Operation *>();
1549  uint32_t expectedCount = read<uint32_t>();
1550  bool compareAtLeast = read();
1551 
1552  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1553  << " * Expected: " << expectedCount << "\n"
1554  << " * Comparator: "
1555  << (compareAtLeast ? ">=" : "==") << "\n");
1556  if (compareAtLeast)
1557  selectJump(op->getNumResults() >= expectedCount);
1558  else
1559  selectJump(op->getNumResults() == expectedCount);
1560 }
1561 
1562 void ByteCodeExecutor::executeCheckTypes() {
1563  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1564  TypeRange *lhs = read<TypeRange *>();
1565  Attribute rhs = read<Attribute>();
1566  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1567 
1568  selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1569 }
1570 
1571 void ByteCodeExecutor::executeContinue() {
1572  ByteCodeField level = read();
1573  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1574  << " * Level: " << level << "\n");
1575  ++loopIndex[level];
1576  popCodeIt();
1577 }
1578 
1579 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1580  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1581  unsigned memIndex = read();
1582  unsigned rangeIndex = read();
1583  ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1584 
1585  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1586  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1587  rangeIndex);
1588 }
1589 
1590 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1591  Location mainRewriteLoc) {
1592  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1593 
1594  unsigned memIndex = read();
1595  OperationState state(mainRewriteLoc, read<OperationName>());
1596  readList(state.operands);
1597  for (unsigned i = 0, e = read(); i != e; ++i) {
1598  StringAttr name = read<StringAttr>();
1599  if (Attribute attr = read<Attribute>())
1600  state.addAttribute(name, attr);
1601  }
1602 
1603  // Read in the result types. If the "size" is the sentinel value, this
1604  // indicates that the result types should be inferred.
1605  unsigned numResults = read();
1606  if (numResults == kInferTypesMarker) {
1607  InferTypeOpInterface::Concept *inferInterface =
1608  state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1609  assert(inferInterface &&
1610  "expected operation to provide InferTypeOpInterface");
1611 
1612  // TODO: Handle failure.
1613  if (failed(inferInterface->inferReturnTypes(
1614  state.getContext(), state.location, state.operands,
1615  state.attributes.getDictionary(state.getContext()), state.regions,
1616  state.types)))
1617  return;
1618  } else {
1619  // Otherwise, this is a fixed number of results.
1620  for (unsigned i = 0; i != numResults; ++i) {
1621  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1622  state.types.push_back(read<Type>());
1623  } else {
1624  TypeRange *resultTypes = read<TypeRange *>();
1625  state.types.append(resultTypes->begin(), resultTypes->end());
1626  }
1627  }
1628  }
1629 
1630  Operation *resultOp = rewriter.create(state);
1631  memory[memIndex] = resultOp;
1632 
1633  LLVM_DEBUG({
1634  llvm::dbgs() << " * Attributes: "
1635  << state.attributes.getDictionary(state.getContext())
1636  << "\n * Operands: ";
1637  llvm::interleaveComma(state.operands, llvm::dbgs());
1638  llvm::dbgs() << "\n * Result Types: ";
1639  llvm::interleaveComma(state.types, llvm::dbgs());
1640  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1641  });
1642 }
1643 
1644 template <typename T>
1645 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1646  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1647  unsigned memIndex = read();
1648  unsigned rangeIndex = read();
1649  SmallVector<T> values;
1650  readList(values);
1651 
1652  LLVM_DEBUG({
1653  llvm::dbgs() << "\n * " << type << "s: ";
1654  llvm::interleaveComma(values, llvm::dbgs());
1655  llvm::dbgs() << "\n";
1656  });
1657 
1658  assignRangeToMemory(values, memIndex, rangeIndex);
1659 }
1660 
1661 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1662  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1663  Operation *op = read<Operation *>();
1664 
1665  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1666  rewriter.eraseOp(op);
1667 }
1668 
1669 template <typename T, typename Range, PDLValue::Kind kind>
1670 void ByteCodeExecutor::executeExtract() {
1671  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1672  Range *range = read<Range *>();
1673  unsigned index = read<uint32_t>();
1674  unsigned memIndex = read();
1675 
1676  if (!range) {
1677  memory[memIndex] = nullptr;
1678  return;
1679  }
1680 
1681  T result = index < range->size() ? (*range)[index] : T();
1682  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1683  << " * Index: " << index << "\n"
1684  << " * Result: " << result << "\n");
1685  storeToMemory(memIndex, result);
1686 }
1687 
1688 void ByteCodeExecutor::executeFinalize() {
1689  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1690 }
1691 
1692 void ByteCodeExecutor::executeForEach() {
1693  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1694  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1695  unsigned rangeIndex = read();
1696  unsigned memIndex = read();
1697  const void *value = nullptr;
1698 
1699  switch (read<PDLValue::Kind>()) {
1701  unsigned &index = loopIndex[read()];
1702  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1703  assert(index <= array.size() && "iterated past the end");
1704  if (index < array.size()) {
1705  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1706  value = array[index];
1707  break;
1708  }
1709 
1710  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1711  index = 0;
1712  selectJump(size_t(0));
1713  return;
1714  }
1715  default:
1716  llvm_unreachable("unexpected `ForEach` value kind");
1717  }
1718 
1719  // Store the iterate value and the stack address.
1720  memory[memIndex] = value;
1721  pushCodeIt(prevCodeIt);
1722 
1723  // Skip over the successor (we will enter the body of the loop).
1724  read<ByteCodeAddr>();
1725 }
1726 
1727 void ByteCodeExecutor::executeGetAttribute() {
1728  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1729  unsigned memIndex = read();
1730  Operation *op = read<Operation *>();
1731  StringAttr attrName = read<StringAttr>();
1732  Attribute attr = op->getAttr(attrName);
1733 
1734  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1735  << " * Attribute: " << attrName << "\n"
1736  << " * Result: " << attr << "\n");
1737  memory[memIndex] = attr.getAsOpaquePointer();
1738 }
1739 
1740 void ByteCodeExecutor::executeGetAttributeType() {
1741  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1742  unsigned memIndex = read();
1743  Attribute attr = read<Attribute>();
1744  Type type;
1745  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1746  type = typedAttr.getType();
1747 
1748  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1749  << " * Result: " << type << "\n");
1750  memory[memIndex] = type.getAsOpaquePointer();
1751 }
1752 
1753 void ByteCodeExecutor::executeGetDefiningOp() {
1754  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1755  unsigned memIndex = read();
1756  Operation *op = nullptr;
1757  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1758  Value value = read<Value>();
1759  if (value)
1760  op = value.getDefiningOp();
1761  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1762  } else {
1763  ValueRange *values = read<ValueRange *>();
1764  if (values && !values->empty()) {
1765  op = values->front().getDefiningOp();
1766  }
1767  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1768  }
1769 
1770  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1771  memory[memIndex] = op;
1772 }
1773 
1774 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1775  Operation *op = read<Operation *>();
1776  unsigned memIndex = read();
1777  Value operand =
1778  index < op->getNumOperands() ? op->getOperand(index) : Value();
1779 
1780  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1781  << " * Index: " << index << "\n"
1782  << " * Result: " << operand << "\n");
1783  memory[memIndex] = operand.getAsOpaquePointer();
1784 }
1785 
1786 /// This function is the internal implementation of `GetResults` and
1787 /// `GetOperands` that provides support for extracting a value range from the
1788 /// given operation.
1789 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1790 static void *
1791 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1792  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1793  MutableArrayRef<ValueRange> valueRangeMemory) {
1794  // Check for the sentinel index that signals that all values should be
1795  // returned.
1796  if (index == std::numeric_limits<uint32_t>::max()) {
1797  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1798  // `values` is already the full value range.
1799 
1800  // Otherwise, check to see if this operation uses AttrSizedSegments.
1801  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1802  LLVM_DEBUG(llvm::dbgs()
1803  << " * Extracting values from `" << attrSizedSegments << "`\n");
1804 
1805  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1806  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1807  return nullptr;
1808 
1809  ArrayRef<int32_t> segments = segmentAttr;
1810  unsigned startIndex =
1811  std::accumulate(segments.begin(), segments.begin() + index, 0);
1812  values = values.slice(startIndex, *std::next(segments.begin(), index));
1813 
1814  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1815  << *std::next(segments.begin(), index) << "]\n");
1816 
1817  // Otherwise, assume this is the last operand group of the operation.
1818  // FIXME: We currently don't support operations with
1819  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1820  // have a way to detect it's presence.
1821  } else if (values.size() >= index) {
1822  LLVM_DEBUG(llvm::dbgs()
1823  << " * Treating values as trailing variadic range\n");
1824  values = values.drop_front(index);
1825 
1826  // If we couldn't detect a way to compute the values, bail out.
1827  } else {
1828  return nullptr;
1829  }
1830 
1831  // If the range index is valid, we are returning a range.
1832  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1833  valueRangeMemory[rangeIndex] = values;
1834  return &valueRangeMemory[rangeIndex];
1835  }
1836 
1837  // If a range index wasn't provided, the range is required to be non-variadic.
1838  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1839 }
1840 
1841 void ByteCodeExecutor::executeGetOperands() {
1842  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1843  unsigned index = read<uint32_t>();
1844  Operation *op = read<Operation *>();
1845  ByteCodeField rangeIndex = read();
1846 
1847  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1848  op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1849  valueRangeMemory);
1850  if (!result)
1851  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1852  memory[read()] = result;
1853 }
1854 
1855 void ByteCodeExecutor::executeGetResult(unsigned index) {
1856  Operation *op = read<Operation *>();
1857  unsigned memIndex = read();
1858  OpResult result =
1859  index < op->getNumResults() ? op->getResult(index) : OpResult();
1860 
1861  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1862  << " * Index: " << index << "\n"
1863  << " * Result: " << result << "\n");
1864  memory[memIndex] = result.getAsOpaquePointer();
1865 }
1866 
1867 void ByteCodeExecutor::executeGetResults() {
1868  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1869  unsigned index = read<uint32_t>();
1870  Operation *op = read<Operation *>();
1871  ByteCodeField rangeIndex = read();
1872 
1873  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1874  op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1875  valueRangeMemory);
1876  if (!result)
1877  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1878  memory[read()] = result;
1879 }
1880 
1881 void ByteCodeExecutor::executeGetUsers() {
1882  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1883  unsigned memIndex = read();
1884  unsigned rangeIndex = read();
1885  OwningOpRange &range = opRangeMemory[rangeIndex];
1886  memory[memIndex] = &range;
1887 
1888  range = OwningOpRange();
1889  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1890  // Read the value.
1891  Value value = read<Value>();
1892  if (!value)
1893  return;
1894  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1895 
1896  // Extract the users of a single value.
1897  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1898  llvm::copy(value.getUsers(), range.begin());
1899  } else {
1900  // Read a range of values.
1901  ValueRange *values = read<ValueRange *>();
1902  if (!values)
1903  return;
1904  LLVM_DEBUG({
1905  llvm::dbgs() << " * Values (" << values->size() << "): ";
1906  llvm::interleaveComma(*values, llvm::dbgs());
1907  llvm::dbgs() << "\n";
1908  });
1909 
1910  // Extract all the users of a range of values.
1912  for (Value value : *values)
1913  users.append(value.user_begin(), value.user_end());
1914  range = OwningOpRange(users.size());
1915  llvm::copy(users, range.begin());
1916  }
1917 
1918  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1919 }
1920 
1921 void ByteCodeExecutor::executeGetValueType() {
1922  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1923  unsigned memIndex = read();
1924  Value value = read<Value>();
1925  Type type = value ? value.getType() : Type();
1926 
1927  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1928  << " * Result: " << type << "\n");
1929  memory[memIndex] = type.getAsOpaquePointer();
1930 }
1931 
1932 void ByteCodeExecutor::executeGetValueRangeTypes() {
1933  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1934  unsigned memIndex = read();
1935  unsigned rangeIndex = read();
1936  ValueRange *values = read<ValueRange *>();
1937  if (!values) {
1938  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1939  memory[memIndex] = nullptr;
1940  return;
1941  }
1942 
1943  LLVM_DEBUG({
1944  llvm::dbgs() << " * Values (" << values->size() << "): ";
1945  llvm::interleaveComma(*values, llvm::dbgs());
1946  llvm::dbgs() << "\n * Result: ";
1947  llvm::interleaveComma(values->getType(), llvm::dbgs());
1948  llvm::dbgs() << "\n";
1949  });
1950  typeRangeMemory[rangeIndex] = values->getType();
1951  memory[memIndex] = &typeRangeMemory[rangeIndex];
1952 }
1953 
1954 void ByteCodeExecutor::executeIsNotNull() {
1955  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1956  const void *value = read<const void *>();
1957 
1958  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1959  selectJump(value != nullptr);
1960 }
1961 
1962 void ByteCodeExecutor::executeRecordMatch(
1963  PatternRewriter &rewriter,
1965  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1966  unsigned patternIndex = read();
1967  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1968  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1969 
1970  // If the benefit of the pattern is impossible, skip the processing of the
1971  // rest of the pattern.
1972  if (benefit.isImpossibleToMatch()) {
1973  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1974  curCodeIt = dest;
1975  return;
1976  }
1977 
1978  // Create a fused location containing the locations of each of the
1979  // operations used in the match. This will be used as the location for
1980  // created operations during the rewrite that don't already have an
1981  // explicit location set.
1982  unsigned numMatchLocs = read();
1983  SmallVector<Location, 4> matchLocs;
1984  matchLocs.reserve(numMatchLocs);
1985  for (unsigned i = 0; i != numMatchLocs; ++i)
1986  matchLocs.push_back(read<Operation *>()->getLoc());
1987  Location matchLoc = rewriter.getFusedLoc(matchLocs);
1988 
1989  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1990  << " * Location: " << matchLoc << "\n");
1991  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1992  PDLByteCode::MatchResult &match = matches.back();
1993 
1994  // Record all of the inputs to the match. If any of the inputs are ranges, we
1995  // will also need to remap the range pointer to memory stored in the match
1996  // state.
1997  unsigned numInputs = read();
1998  match.values.reserve(numInputs);
1999  match.typeRangeValues.reserve(numInputs);
2000  match.valueRangeValues.reserve(numInputs);
2001  for (unsigned i = 0; i < numInputs; ++i) {
2002  switch (read<PDLValue::Kind>()) {
2004  match.typeRangeValues.push_back(*read<TypeRange *>());
2005  match.values.push_back(&match.typeRangeValues.back());
2006  break;
2008  match.valueRangeValues.push_back(*read<ValueRange *>());
2009  match.values.push_back(&match.valueRangeValues.back());
2010  break;
2011  default:
2012  match.values.push_back(read<const void *>());
2013  break;
2014  }
2015  }
2016  curCodeIt = dest;
2017 }
2018 
2019 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2020  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2021  Operation *op = read<Operation *>();
2023  readList(args);
2024 
2025  LLVM_DEBUG({
2026  llvm::dbgs() << " * Operation: " << *op << "\n"
2027  << " * Values: ";
2028  llvm::interleaveComma(args, llvm::dbgs());
2029  llvm::dbgs() << "\n";
2030  });
2031  rewriter.replaceOp(op, args);
2032 }
2033 
2034 void ByteCodeExecutor::executeSwitchAttribute() {
2035  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2036  Attribute value = read<Attribute>();
2037  ArrayAttr cases = read<ArrayAttr>();
2038  handleSwitch(value, cases);
2039 }
2040 
2041 void ByteCodeExecutor::executeSwitchOperandCount() {
2042  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2043  Operation *op = read<Operation *>();
2044  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2045 
2046  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2047  handleSwitch(op->getNumOperands(), cases);
2048 }
2049 
2050 void ByteCodeExecutor::executeSwitchOperationName() {
2051  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2052  OperationName value = read<Operation *>()->getName();
2053  size_t caseCount = read();
2054 
2055  // The operation names are stored in-line, so to print them out for
2056  // debugging purposes we need to read the array before executing the
2057  // switch so that we can display all of the possible values.
2058  LLVM_DEBUG({
2059  const ByteCodeField *prevCodeIt = curCodeIt;
2060  llvm::dbgs() << " * Value: " << value << "\n"
2061  << " * Cases: ";
2062  llvm::interleaveComma(
2063  llvm::map_range(llvm::seq<size_t>(0, caseCount),
2064  [&](size_t) { return read<OperationName>(); }),
2065  llvm::dbgs());
2066  llvm::dbgs() << "\n";
2067  curCodeIt = prevCodeIt;
2068  });
2069 
2070  // Try to find the switch value within any of the cases.
2071  for (size_t i = 0; i != caseCount; ++i) {
2072  if (read<OperationName>() == value) {
2073  curCodeIt += (caseCount - i - 1);
2074  return selectJump(i + 1);
2075  }
2076  }
2077  selectJump(size_t(0));
2078 }
2079 
2080 void ByteCodeExecutor::executeSwitchResultCount() {
2081  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2082  Operation *op = read<Operation *>();
2083  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2084 
2085  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2086  handleSwitch(op->getNumResults(), cases);
2087 }
2088 
2089 void ByteCodeExecutor::executeSwitchType() {
2090  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2091  Type value = read<Type>();
2092  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2093  handleSwitch(value, cases);
2094 }
2095 
2096 void ByteCodeExecutor::executeSwitchTypes() {
2097  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2098  TypeRange *value = read<TypeRange *>();
2099  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2100  if (!value) {
2101  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2102  return selectJump(size_t(0));
2103  }
2104  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2105  return value == caseValue.getAsValueRange<TypeAttr>();
2106  });
2107 }
2108 
2110 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2112  Optional<Location> mainRewriteLoc) {
2113  while (true) {
2114  // Print the location of the operation being executed.
2115  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2116 
2117  OpCode opCode = static_cast<OpCode>(read());
2118  switch (opCode) {
2119  case ApplyConstraint:
2120  executeApplyConstraint(rewriter);
2121  break;
2122  case ApplyRewrite:
2123  if (failed(executeApplyRewrite(rewriter)))
2124  return failure();
2125  break;
2126  case AreEqual:
2127  executeAreEqual();
2128  break;
2129  case AreRangesEqual:
2130  executeAreRangesEqual();
2131  break;
2132  case Branch:
2133  executeBranch();
2134  break;
2135  case CheckOperandCount:
2136  executeCheckOperandCount();
2137  break;
2138  case CheckOperationName:
2139  executeCheckOperationName();
2140  break;
2141  case CheckResultCount:
2142  executeCheckResultCount();
2143  break;
2144  case CheckTypes:
2145  executeCheckTypes();
2146  break;
2147  case Continue:
2148  executeContinue();
2149  break;
2150  case CreateConstantTypeRange:
2151  executeCreateConstantTypeRange();
2152  break;
2153  case CreateOperation:
2154  executeCreateOperation(rewriter, *mainRewriteLoc);
2155  break;
2156  case CreateDynamicTypeRange:
2157  executeDynamicCreateRange<Type>("Type");
2158  break;
2159  case CreateDynamicValueRange:
2160  executeDynamicCreateRange<Value>("Value");
2161  break;
2162  case EraseOp:
2163  executeEraseOp(rewriter);
2164  break;
2165  case ExtractOp:
2166  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2167  break;
2168  case ExtractType:
2169  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2170  break;
2171  case ExtractValue:
2172  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2173  break;
2174  case Finalize:
2175  executeFinalize();
2176  LLVM_DEBUG(llvm::dbgs() << "\n");
2177  return success();
2178  case ForEach:
2179  executeForEach();
2180  break;
2181  case GetAttribute:
2182  executeGetAttribute();
2183  break;
2184  case GetAttributeType:
2185  executeGetAttributeType();
2186  break;
2187  case GetDefiningOp:
2188  executeGetDefiningOp();
2189  break;
2190  case GetOperand0:
2191  case GetOperand1:
2192  case GetOperand2:
2193  case GetOperand3: {
2194  unsigned index = opCode - GetOperand0;
2195  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2196  executeGetOperand(index);
2197  break;
2198  }
2199  case GetOperandN:
2200  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2201  executeGetOperand(read<uint32_t>());
2202  break;
2203  case GetOperands:
2204  executeGetOperands();
2205  break;
2206  case GetResult0:
2207  case GetResult1:
2208  case GetResult2:
2209  case GetResult3: {
2210  unsigned index = opCode - GetResult0;
2211  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2212  executeGetResult(index);
2213  break;
2214  }
2215  case GetResultN:
2216  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2217  executeGetResult(read<uint32_t>());
2218  break;
2219  case GetResults:
2220  executeGetResults();
2221  break;
2222  case GetUsers:
2223  executeGetUsers();
2224  break;
2225  case GetValueType:
2226  executeGetValueType();
2227  break;
2228  case GetValueRangeTypes:
2229  executeGetValueRangeTypes();
2230  break;
2231  case IsNotNull:
2232  executeIsNotNull();
2233  break;
2234  case RecordMatch:
2235  assert(matches &&
2236  "expected matches to be provided when executing the matcher");
2237  executeRecordMatch(rewriter, *matches);
2238  break;
2239  case ReplaceOp:
2240  executeReplaceOp(rewriter);
2241  break;
2242  case SwitchAttribute:
2243  executeSwitchAttribute();
2244  break;
2245  case SwitchOperandCount:
2246  executeSwitchOperandCount();
2247  break;
2248  case SwitchOperationName:
2249  executeSwitchOperationName();
2250  break;
2251  case SwitchResultCount:
2252  executeSwitchResultCount();
2253  break;
2254  case SwitchType:
2255  executeSwitchType();
2256  break;
2257  case SwitchTypes:
2258  executeSwitchTypes();
2259  break;
2260  }
2261  LLVM_DEBUG(llvm::dbgs() << "\n");
2262  }
2263 }
2264 
2267  PDLByteCodeMutableState &state) const {
2268  // The first memory slot is always the root operation.
2269  state.memory[0] = op;
2270 
2271  // The matcher function always starts at code address 0.
2272  ByteCodeExecutor executor(
2273  matcherByteCode.data(), state.memory, state.opRangeMemory,
2274  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2275  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2276  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2277  constraintFunctions, rewriteFunctions);
2278  LogicalResult executeResult = executor.execute(rewriter, &matches);
2279  (void)executeResult;
2280  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2281 
2282  // Order the found matches by benefit.
2283  std::stable_sort(matches.begin(), matches.end(),
2284  [](const MatchResult &lhs, const MatchResult &rhs) {
2285  return lhs.benefit > rhs.benefit;
2286  });
2287 }
2288 
2290  const MatchResult &match,
2291  PDLByteCodeMutableState &state) const {
2292  auto *configSet = match.pattern->getConfigSet();
2293  if (configSet)
2294  configSet->notifyRewriteBegin(rewriter);
2295 
2296  // The arguments of the rewrite function are stored at the start of the
2297  // memory buffer.
2298  llvm::copy(match.values, state.memory.begin());
2299 
2300  ByteCodeExecutor executor(
2301  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2302  state.opRangeMemory, state.typeRangeMemory,
2303  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2304  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2305  rewriterByteCode, state.currentPatternBenefits, patterns,
2306  constraintFunctions, rewriteFunctions);
2307  LogicalResult result =
2308  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2309 
2310  if (configSet)
2311  configSet->notifyRewriteEnd(rewriter);
2312 
2313  // If the rewrite failed, check if the pattern rewriter can recover. If it
2314  // can, we can signal to the pattern applicator to keep trying patterns. If it
2315  // doesn't, we need to bail. Bailing here should be fine, given that we have
2316  // no means to propagate such a failure to the user, and it also indicates a
2317  // bug in the user code (i.e. failable rewrites should not be used with
2318  // pattern rewriters that don't support it).
2319  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2320  LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2321  llvm::report_fatal_error(
2322  "Native PDL Rewrite failed, but the pattern "
2323  "rewriter doesn't support recovery. Failable pattern rewrites should "
2324  "not be used with pattern rewriters that do not support them.");
2325  }
2326  return result;
2327 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1791
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:172
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static constexpr const bool value
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
U dyn_cast() const
Definition: Attributes.h:127
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:85
This class represents an argument of a Block.
Definition: Value.h:296
Block represents an ordered list of Operations.
Definition: Block.h:30
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:26
BlockArgListType getArguments()
Definition: Block.h:76
Operation & front()
Definition: Block.h:142
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:64
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:422
This is a value defined by a result of an operation.
Definition: Value.h:442
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:41
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:31
Value getOperand(unsigned idx)
Definition: Operation.h:267
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:532
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:371
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:324
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
unsigned getNumOperands()
Definition: Operation.h:263
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:295
result_range getResults()
Definition: Operation.h:332
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
This class contains a set of configurations for a specific pattern.
Definition: PatternMatch.h:862
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:742
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:630
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:633
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:610
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:618
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:230
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:175
bool isa() const
Definition: Types.h:260
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:349
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
Type getType() const
Return the type of this value.
Definition: Value.h:114
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:223
U cast() const
Definition: Value.h:105
This class contains the mutable state of a bytecode instance.
Definition: ByteCode.h:71
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.cpp:63
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
Definition: ByteCode.cpp:71
All of the data pertaining to a specific pattern within the bytecode.
Definition: ByteCode.h:38
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr)
Definition: ByteCode.cpp:36
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches.
Definition: ByteCode.cpp:2265
PDLByteCode(ModuleOp module, SmallVector< std::unique_ptr< PDLPatternConfigSet >> configs, const DenseMap< Operation *, PDLPatternConfigSet * > &configMap, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
Definition: ByteCode.cpp:1041
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition: ByteCode.cpp:1062
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Definition: ByteCode.cpp:2289
Detect if any of the given parameter types has a sub-element handler.
uint32_t ByteCodeAddr
Definition: ByteCode.h:30
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.cpp:24
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:230
uint16_t ByteCodeField
Use generic bytecode types.
Definition: ByteCode.h:29
llvm::OwningArrayRef< Operation * > OwningOpRange
Definition: ByteCode.h:31
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:924
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:932
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Definition: PatternMatch.h:157
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Definition: ByteCode.h:132
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
Definition: ByteCode.h:146
SmallVector< ValueRange, 0 > valueRangeValues
Definition: ByteCode.h:147
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.
Definition: ByteCode.h:144