MLIR  21.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38  PDLPatternConfigSet *configSet,
39  ByteCodeAddr rewriterAddr) {
40  PatternBenefit benefit = matchOp.getBenefit();
41  MLIRContext *ctx = matchOp.getContext();
42 
43  // Collect the set of generated operations.
44  SmallVector<StringRef, 8> generatedOps;
45  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46  generatedOps =
47  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49  // Check to see if this is pattern matches a specific operation type.
50  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52  generatedOps);
53  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54  benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65  PatternBenefit benefit) {
66  currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
73  allocatedTypeRangeMemory.clear();
74  allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83  /// Apply an externally registered constraint.
84  ApplyConstraint,
85  /// Apply an externally registered rewrite.
86  ApplyRewrite,
87  /// Check if two generic values are equal.
88  AreEqual,
89  /// Check if two ranges are equal.
90  AreRangesEqual,
91  /// Unconditional branch.
92  Branch,
93  /// Compare the operand count of an operation with a constant.
94  CheckOperandCount,
95  /// Compare the name of an operation with a constant.
96  CheckOperationName,
97  /// Compare the result count of an operation with a constant.
98  CheckResultCount,
99  /// Compare a range of types to a constant range of types.
100  CheckTypes,
101  /// Continue to the next iteration of a loop.
102  Continue,
103  /// Create a type range from a list of constant types.
104  CreateConstantTypeRange,
105  /// Create an operation.
106  CreateOperation,
107  /// Create a type range from a list of dynamic types.
108  CreateDynamicTypeRange,
109  /// Create a value range.
110  CreateDynamicValueRange,
111  /// Erase an operation.
112  EraseOp,
113  /// Extract the op from a range at the specified index.
114  ExtractOp,
115  /// Extract the type from a range at the specified index.
116  ExtractType,
117  /// Extract the value from a range at the specified index.
118  ExtractValue,
119  /// Terminate a matcher or rewrite sequence.
120  Finalize,
121  /// Iterate over a range of values.
122  ForEach,
123  /// Get a specific attribute of an operation.
124  GetAttribute,
125  /// Get the type of an attribute.
126  GetAttributeType,
127  /// Get the defining operation of a value.
128  GetDefiningOp,
129  /// Get a specific operand of an operation.
130  GetOperand0,
131  GetOperand1,
132  GetOperand2,
133  GetOperand3,
134  GetOperandN,
135  /// Get a specific operand group of an operation.
136  GetOperands,
137  /// Get a specific result of an operation.
138  GetResult0,
139  GetResult1,
140  GetResult2,
141  GetResult3,
142  GetResultN,
143  /// Get a specific result group of an operation.
144  GetResults,
145  /// Get the users of a value or a range of values.
146  GetUsers,
147  /// Get the type of a value.
148  GetValueType,
149  /// Get the types of a value range.
150  GetValueRangeTypes,
151  /// Check if a generic value is not null.
152  IsNotNull,
153  /// Record a successful pattern match.
154  RecordMatch,
155  /// Replace an operation.
156  ReplaceOp,
157  /// Compare an attribute with a set of constants.
158  SwitchAttribute,
159  /// Compare the operand count of an operation with a set of constants.
160  SwitchOperandCount,
161  /// Compare the name of an operation with a set of constants.
162  SwitchOperationName,
163  /// Compare the result count of an operation with a set of constants.
164  SwitchResultCount,
165  /// Compare a type with a set of constants.
166  SwitchType,
167  /// Compare a range of types with a set of constants.
168  SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker =
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 //===----------------------------------------------------------------------===//
183 
184 namespace {
185 struct ByteCodeLiveRange;
186 struct ByteCodeWriter;
187 
188 /// Check if the given class `T` can be converted to an opaque pointer.
189 template <typename T, typename... Args>
190 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
191 
192 /// This class represents the main generator for the pattern bytecode.
193 class Generator {
194 public:
195  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
196  SmallVectorImpl<ByteCodeField> &matcherByteCode,
197  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
199  ByteCodeField &maxValueMemoryIndex,
200  ByteCodeField &maxOpRangeMemoryIndex,
201  ByteCodeField &maxTypeRangeMemoryIndex,
202  ByteCodeField &maxValueRangeMemoryIndex,
203  ByteCodeField &maxLoopLevel,
204  llvm::StringMap<PDLConstraintFunction> &constraintFns,
205  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
207  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
208  rewriterByteCode(rewriterByteCode), patterns(patterns),
209  maxValueMemoryIndex(maxValueMemoryIndex),
210  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
211  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
212  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
213  maxLoopLevel(maxLoopLevel), configMap(configMap) {
214  for (const auto &it : llvm::enumerate(constraintFns))
215  constraintToMemIndex.try_emplace(it.value().first(), it.index());
216  for (const auto &it : llvm::enumerate(rewriteFns))
217  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
218  }
219 
220  /// Generate the bytecode for the given PDL interpreter module.
221  void generate(ModuleOp module);
222 
223  /// Return the memory index to use for the given value.
224  ByteCodeField &getMemIndex(Value value) {
225  assert(valueToMemIndex.count(value) &&
226  "expected memory index to be assigned");
227  return valueToMemIndex[value];
228  }
229 
230  /// Return the range memory index used to store the given range value.
231  ByteCodeField &getRangeStorageIndex(Value value) {
232  assert(valueToRangeIndex.count(value) &&
233  "expected range index to be assigned");
234  return valueToRangeIndex[value];
235  }
236 
237  /// Return an index to use when referring to the given data that is uniqued in
238  /// the MLIR context.
239  template <typename T>
240  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
241  getMemIndex(T val) {
242  const void *opaqueVal = val.getAsOpaquePointer();
243 
244  // Get or insert a reference to this value.
245  auto it = uniquedDataToMemIndex.try_emplace(
246  opaqueVal, maxValueMemoryIndex + uniquedData.size());
247  if (it.second)
248  uniquedData.push_back(opaqueVal);
249  return it.first->second;
250  }
251 
252 private:
253  /// Allocate memory indices for the results of operations within the matcher
254  /// and rewriters.
255  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
256  ModuleOp rewriterModule);
257 
258  /// Generate the bytecode for the given operation.
259  void generate(Region *region, ByteCodeWriter &writer);
260  void generate(Operation *op, ByteCodeWriter &writer);
261  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
298  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
299 
300  /// Mapping from value to its corresponding memory index.
301  DenseMap<Value, ByteCodeField> valueToMemIndex;
302 
303  /// Mapping from a range value to its corresponding range storage index.
304  DenseMap<Value, ByteCodeField> valueToRangeIndex;
305 
306  /// Mapping from the name of an externally registered rewrite to its index in
307  /// the bytecode registry.
308  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
309 
310  /// Mapping from the name of an externally registered constraint to its index
311  /// in the bytecode registry.
312  llvm::StringMap<ByteCodeField> constraintToMemIndex;
313 
314  /// Mapping from rewriter function name to the bytecode address of the
315  /// rewriter function in byte.
316  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
317 
318  /// Mapping from a uniqued storage object to its memory index within
319  /// `uniquedData`.
320  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
321 
322  /// The current level of the foreach loop.
323  ByteCodeField curLoopLevel = 0;
324 
325  /// The current MLIR context.
326  MLIRContext *ctx;
327 
328  /// Mapping from block to its address.
330 
331  /// Data of the ByteCode class to be populated.
332  std::vector<const void *> &uniquedData;
333  SmallVectorImpl<ByteCodeField> &matcherByteCode;
334  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
336  ByteCodeField &maxValueMemoryIndex;
337  ByteCodeField &maxOpRangeMemoryIndex;
338  ByteCodeField &maxTypeRangeMemoryIndex;
339  ByteCodeField &maxValueRangeMemoryIndex;
340  ByteCodeField &maxLoopLevel;
341 
342  /// A map of pattern configurations.
344 };
345 
346 /// This class provides utilities for writing a bytecode stream.
347 struct ByteCodeWriter {
348  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
349  : bytecode(bytecode), generator(generator) {}
350 
351  /// Append a field to the bytecode.
352  void append(ByteCodeField field) { bytecode.push_back(field); }
353  void append(OpCode opCode) { bytecode.push_back(opCode); }
354 
355  /// Append an address to the bytecode.
356  void append(ByteCodeAddr field) {
357  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
358  "unexpected ByteCode address size");
359 
360  ByteCodeField fieldParts[2];
361  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
362  bytecode.append({fieldParts[0], fieldParts[1]});
363  }
364 
365  /// Append a single successor to the bytecode, the exact address will need to
366  /// be resolved later.
367  void append(Block *successor) {
368  // Add back a reference to the successor so that the address can be resolved
369  // later.
370  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
371  append(ByteCodeAddr(0));
372  }
373 
374  /// Append a successor range to the bytecode, the exact address will need to
375  /// be resolved later.
376  void append(SuccessorRange successors) {
377  for (Block *successor : successors)
378  append(successor);
379  }
380 
381  /// Append a range of values that will be read as generic PDLValues.
382  void appendPDLValueList(OperandRange values) {
383  bytecode.push_back(values.size());
384  for (Value value : values)
385  appendPDLValue(value);
386  }
387 
388  /// Append a value as a PDLValue.
389  void appendPDLValue(Value value) {
390  appendPDLValueKind(value);
391  append(value);
392  }
393 
394  /// Append the PDLValue::Kind of the given value.
395  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
396 
397  /// Append the PDLValue::Kind of the given type.
398  void appendPDLValueKind(Type type) {
401  .Case<pdl::AttributeType>(
402  [](Type) { return PDLValue::Kind::Attribute; })
403  .Case<pdl::OperationType>(
404  [](Type) { return PDLValue::Kind::Operation; })
405  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
406  if (isa<pdl::TypeType>(rangeTy.getElementType()))
407  return PDLValue::Kind::TypeRange;
408  return PDLValue::Kind::ValueRange;
409  })
410  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
411  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
412  bytecode.push_back(static_cast<ByteCodeField>(kind));
413  }
414 
415  /// Append a value that will be stored in a memory slot and not inline within
416  /// the bytecode.
417  template <typename T>
418  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
419  std::is_pointer<T>::value>
420  append(T value) {
421  bytecode.push_back(generator.getMemIndex(value));
422  }
423 
424  /// Append a range of values.
425  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
426  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
427  append(T range) {
428  bytecode.push_back(llvm::size(range));
429  for (auto it : range)
430  append(it);
431  }
432 
433  /// Append a variadic number of fields to the bytecode.
434  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
435  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
436  append(field);
437  append(field2, fields...);
438  }
439 
440  /// Appends a value as a pointer, stored inline within the bytecode.
441  template <typename T>
442  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
443  appendInline(T value) {
444  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
445  const void *pointer = value.getAsOpaquePointer();
446  ByteCodeField fieldParts[numParts];
447  std::memcpy(fieldParts, &pointer, sizeof(const void *));
448  bytecode.append(fieldParts, fieldParts + numParts);
449  }
450 
451  /// Successor references in the bytecode that have yet to be resolved.
452  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
453 
454  /// The underlying bytecode buffer.
456 
457  /// The main generator producing PDL.
458  Generator &generator;
459 };
460 
461 /// This class represents a live range of PDL Interpreter values, containing
462 /// information about when values are live within a match/rewrite.
463 struct ByteCodeLiveRange {
464  using Set = llvm::IntervalMap<uint64_t, char, 16>;
465  using Allocator = Set::Allocator;
466 
467  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
468 
469  /// Union this live range with the one provided.
470  void unionWith(const ByteCodeLiveRange &rhs) {
471  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
472  ++it)
473  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
474  }
475 
476  /// Returns true if this range overlaps with the one provided.
477  bool overlaps(const ByteCodeLiveRange &rhs) const {
478  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
479  .valid();
480  }
481 
482  /// A map representing the ranges of the match/rewrite that a value is live in
483  /// the interpreter.
484  ///
485  /// We use std::unique_ptr here, because IntervalMap does not provide a
486  /// correct copy or move constructor. We can eliminate the pointer once
487  /// https://reviews.llvm.org/D113240 lands.
488  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
489 
490  /// The operation range storage index for this range.
491  std::optional<unsigned> opRangeIndex;
492 
493  /// The type range storage index for this range.
494  std::optional<unsigned> typeRangeIndex;
495 
496  /// The value range storage index for this range.
497  std::optional<unsigned> valueRangeIndex;
498 };
499 } // namespace
500 
501 void Generator::generate(ModuleOp module) {
502  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
503  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
504  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
505  pdl_interp::PDLInterpDialect::getRewriterModuleName());
506  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
507 
508  // Allocate memory indices for the results of operations within the matcher
509  // and rewriters.
510  allocateMemoryIndices(matcherFunc, rewriterModule);
511 
512  // Generate code for the rewriter functions.
513  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
514  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
515  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
516  for (Operation &op : rewriterFunc.getOps())
517  generate(&op, rewriterByteCodeWriter);
518  }
519  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
520  "unexpected branches in rewriter function");
521 
522  // Generate code for the matcher function.
523  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
524  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
525 
526  // Resolve successor references in the matcher.
527  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528  ByteCodeAddr addr = blockToAddr[it.first];
529  for (unsigned offsetToFix : it.second)
530  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
531  }
532 }
533 
534 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
535  ModuleOp rewriterModule) {
536  // Rewriters use simplistic allocation scheme that simply assigns an index to
537  // each result.
538  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
539  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
540  auto processRewriterValue = [&](Value val) {
541  valueToMemIndex.try_emplace(val, index++);
542  if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
543  Type elementTy = rangeType.getElementType();
544  if (isa<pdl::TypeType>(elementTy))
545  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
546  else if (isa<pdl::ValueType>(elementTy))
547  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
548  }
549  };
550 
551  for (BlockArgument arg : rewriterFunc.getArguments())
552  processRewriterValue(arg);
553  rewriterFunc.getBody().walk([&](Operation *op) {
554  for (Value result : op->getResults())
555  processRewriterValue(result);
556  });
557  if (index > maxValueMemoryIndex)
558  maxValueMemoryIndex = index;
559  if (typeRangeIndex > maxTypeRangeMemoryIndex)
560  maxTypeRangeMemoryIndex = typeRangeIndex;
561  if (valueRangeIndex > maxValueRangeMemoryIndex)
562  maxValueRangeMemoryIndex = valueRangeIndex;
563  }
564 
565  // The matcher function uses a more sophisticated numbering that tries to
566  // minimize the number of memory indices assigned. This is done by determining
567  // a live range of the values within the matcher, then the allocation is just
568  // finding the minimal number of overlapping live ranges. This is essentially
569  // a simplified form of register allocation where we don't necessarily have a
570  // limited number of registers, but we still want to minimize the number used.
571  DenseMap<Operation *, unsigned> opToFirstIndex;
572  DenseMap<Operation *, unsigned> opToLastIndex;
573 
574  // A custom walk that marks the first and the last index of each operation.
575  // The entry marks the beginning of the liveness range for this operation,
576  // followed by nested operations, followed by the end of the liveness range.
577  unsigned index = 0;
578  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
579  opToFirstIndex.try_emplace(op, index++);
580  for (Region &region : op->getRegions())
581  for (Block &block : region.getBlocks())
582  for (Operation &nested : block)
583  walk(&nested);
584  opToLastIndex.try_emplace(op, index++);
585  };
586  walk(matcherFunc);
587 
588  // Liveness info for each of the defs within the matcher.
589  ByteCodeLiveRange::Allocator allocator;
590  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
591 
592  // Assign the root operation being matched to slot 0.
593  BlockArgument rootOpArg = matcherFunc.getArgument(0);
594  valueToMemIndex[rootOpArg] = 0;
595 
596  // Walk each of the blocks, computing the def interval that the value is used.
597  Liveness matcherLiveness(matcherFunc);
598  matcherFunc->walk([&](Block *block) {
599  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
600  assert(info && "expected liveness info for block");
601  auto processValue = [&](Value value, Operation *firstUseOrDef) {
602  // We don't need to process the root op argument, this value is always
603  // assigned to the first memory slot.
604  if (value == rootOpArg)
605  return;
606 
607  // Set indices for the range of this block that the value is used.
608  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
609  defRangeIt->second.liveness->insert(
610  opToFirstIndex[firstUseOrDef],
611  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
612  /*dummyValue*/ 0);
613 
614  // Check to see if this value is a range type.
615  if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
616  Type eleType = rangeTy.getElementType();
617  if (isa<pdl::OperationType>(eleType))
618  defRangeIt->second.opRangeIndex = 0;
619  else if (isa<pdl::TypeType>(eleType))
620  defRangeIt->second.typeRangeIndex = 0;
621  else if (isa<pdl::ValueType>(eleType))
622  defRangeIt->second.valueRangeIndex = 0;
623  }
624  };
625 
626  // Process the live-ins of this block.
627  for (Value liveIn : info->in()) {
628  // Only process the value if it has been defined in the current region.
629  // Other values that span across pdl_interp.foreach will be added higher
630  // up. This ensures that the we keep them alive for the entire duration
631  // of the loop.
632  if (liveIn.getParentRegion() == block->getParent())
633  processValue(liveIn, &block->front());
634  }
635 
636  // Process the block arguments for the entry block (those are not live-in).
637  if (block->isEntryBlock()) {
638  for (Value argument : block->getArguments())
639  processValue(argument, &block->front());
640  }
641 
642  // Process any new defs within this block.
643  for (Operation &op : *block)
644  for (Value result : op.getResults())
645  processValue(result, &op);
646  });
647 
648  // Greedily allocate memory slots using the computed def live ranges.
649  std::vector<ByteCodeLiveRange> allocatedIndices;
650 
651  // The number of memory indices currently allocated (and its next value).
652  // Recall that the root gets allocated memory index 0.
653  ByteCodeField numIndices = 1;
654 
655  // The number of memory ranges of various types (and their next values).
656  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
657 
658  for (auto &defIt : valueDefRanges) {
659  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
660  ByteCodeLiveRange &defRange = defIt.second;
661 
662  // Try to allocate to an existing index.
663  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
664  ByteCodeLiveRange &existingRange = existingIndexIt.value();
665  if (!defRange.overlaps(existingRange)) {
666  existingRange.unionWith(defRange);
667  memIndex = existingIndexIt.index() + 1;
668 
669  if (defRange.opRangeIndex) {
670  if (!existingRange.opRangeIndex)
671  existingRange.opRangeIndex = numOpRanges++;
672  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
673  } else if (defRange.typeRangeIndex) {
674  if (!existingRange.typeRangeIndex)
675  existingRange.typeRangeIndex = numTypeRanges++;
676  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
677  } else if (defRange.valueRangeIndex) {
678  if (!existingRange.valueRangeIndex)
679  existingRange.valueRangeIndex = numValueRanges++;
680  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
681  }
682  break;
683  }
684  }
685 
686  // If no existing index could be used, add a new one.
687  if (memIndex == 0) {
688  allocatedIndices.emplace_back(allocator);
689  ByteCodeLiveRange &newRange = allocatedIndices.back();
690  newRange.unionWith(defRange);
691 
692  // Allocate an index for op/type/value ranges.
693  if (defRange.opRangeIndex) {
694  newRange.opRangeIndex = numOpRanges;
695  valueToRangeIndex[defIt.first] = numOpRanges++;
696  } else if (defRange.typeRangeIndex) {
697  newRange.typeRangeIndex = numTypeRanges;
698  valueToRangeIndex[defIt.first] = numTypeRanges++;
699  } else if (defRange.valueRangeIndex) {
700  newRange.valueRangeIndex = numValueRanges;
701  valueToRangeIndex[defIt.first] = numValueRanges++;
702  }
703 
704  memIndex = allocatedIndices.size();
705  ++numIndices;
706  }
707  }
708 
709  // Print the index usage and ensure that we did not run out of index space.
710  LLVM_DEBUG({
711  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
712  << "(down from initial " << valueDefRanges.size() << ").\n";
713  });
714  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715  "Ran out of memory for allocated indices");
716 
717  // Update the max number of indices.
718  if (numIndices > maxValueMemoryIndex)
719  maxValueMemoryIndex = numIndices;
720  if (numOpRanges > maxOpRangeMemoryIndex)
721  maxOpRangeMemoryIndex = numOpRanges;
722  if (numTypeRanges > maxTypeRangeMemoryIndex)
723  maxTypeRangeMemoryIndex = numTypeRanges;
724  if (numValueRanges > maxValueRangeMemoryIndex)
725  maxValueRangeMemoryIndex = numValueRanges;
726 }
727 
728 void Generator::generate(Region *region, ByteCodeWriter &writer) {
729  llvm::ReversePostOrderTraversal<Region *> rpot(region);
730  for (Block *block : rpot) {
731  // Keep track of where this block begins within the matcher function.
732  blockToAddr.try_emplace(block, matcherByteCode.size());
733  for (Operation &op : *block)
734  generate(&op, writer);
735  }
736 }
737 
738 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739  LLVM_DEBUG({
740  // The following list must contain all the operations that do not
741  // produce any bytecode.
742  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
743  writer.appendInline(op->getLoc());
744  });
746  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
747  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
748  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
749  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
750  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
751  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
752  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
753  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
754  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
755  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
756  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
757  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
758  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
759  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
760  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
761  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
762  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
763  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
764  pdl_interp::SwitchResultCountOp>(
765  [&](auto interpOp) { this->generate(interpOp, writer); })
766  .Default([](Operation *) {
767  llvm_unreachable("unknown `pdl_interp` operation");
768  });
769 }
770 
771 void Generator::generate(pdl_interp::ApplyConstraintOp op,
772  ByteCodeWriter &writer) {
773  // Constraints that should return a value have to be registered as rewrites.
774  // If a constraint and a rewrite of similar name are registered the
775  // constraint takes precedence
776  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
777  writer.appendPDLValueList(op.getArgs());
778  writer.append(ByteCodeField(op.getIsNegated()));
779  ResultRange results = op.getResults();
780  writer.append(ByteCodeField(results.size()));
781  for (Value result : results) {
782  // We record the expected kind of the result, so that we can provide extra
783  // verification of the native rewrite function and handle the failure case
784  // of constraints accordingly.
785  writer.appendPDLValueKind(result);
786 
787  // Range results also need to append the range storage index.
788  if (isa<pdl::RangeType>(result.getType()))
789  writer.append(getRangeStorageIndex(result));
790  writer.append(result);
791  }
792  writer.append(op.getSuccessors());
793 }
794 void Generator::generate(pdl_interp::ApplyRewriteOp op,
795  ByteCodeWriter &writer) {
796  assert(externalRewriterToMemIndex.count(op.getName()) &&
797  "expected index for rewrite function");
798  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
799  writer.appendPDLValueList(op.getArgs());
800 
801  ResultRange results = op.getResults();
802  writer.append(ByteCodeField(results.size()));
803  for (Value result : results) {
804  // We record the expected kind of the result, so that we
805  // can provide extra verification of the native rewrite function.
806  writer.appendPDLValueKind(result);
807 
808  // Range results also need to append the range storage index.
809  if (isa<pdl::RangeType>(result.getType()))
810  writer.append(getRangeStorageIndex(result));
811  writer.append(result);
812  }
813 }
814 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
815  Value lhs = op.getLhs();
816  if (isa<pdl::RangeType>(lhs.getType())) {
817  writer.append(OpCode::AreRangesEqual);
818  writer.appendPDLValueKind(lhs);
819  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
820  return;
821  }
822 
823  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
824 }
825 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
826  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
827 }
828 void Generator::generate(pdl_interp::CheckAttributeOp op,
829  ByteCodeWriter &writer) {
830  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
831  op.getSuccessors());
832 }
833 void Generator::generate(pdl_interp::CheckOperandCountOp op,
834  ByteCodeWriter &writer) {
835  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
836  static_cast<ByteCodeField>(op.getCompareAtLeast()),
837  op.getSuccessors());
838 }
839 void Generator::generate(pdl_interp::CheckOperationNameOp op,
840  ByteCodeWriter &writer) {
841  writer.append(OpCode::CheckOperationName, op.getInputOp(),
842  OperationName(op.getName(), ctx), op.getSuccessors());
843 }
844 void Generator::generate(pdl_interp::CheckResultCountOp op,
845  ByteCodeWriter &writer) {
846  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
847  static_cast<ByteCodeField>(op.getCompareAtLeast()),
848  op.getSuccessors());
849 }
850 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
851  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
852  op.getSuccessors());
853 }
854 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
855  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
856  op.getSuccessors());
857 }
858 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
859  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
860  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
861 }
862 void Generator::generate(pdl_interp::CreateAttributeOp op,
863  ByteCodeWriter &writer) {
864  // Simply repoint the memory index of the result to the constant.
865  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
866 }
867 void Generator::generate(pdl_interp::CreateOperationOp op,
868  ByteCodeWriter &writer) {
869  writer.append(OpCode::CreateOperation, op.getResultOp(),
870  OperationName(op.getName(), ctx));
871  writer.appendPDLValueList(op.getInputOperands());
872 
873  // Add the attributes.
874  OperandRange attributes = op.getInputAttributes();
875  writer.append(static_cast<ByteCodeField>(attributes.size()));
876  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
877  writer.append(std::get<0>(it), std::get<1>(it));
878 
879  // Add the result types. If the operation has inferred results, we use a
880  // marker "size" value. Otherwise, we add the list of explicit result types.
881  if (op.getInferredResultTypes())
882  writer.append(kInferTypesMarker);
883  else
884  writer.appendPDLValueList(op.getInputResultTypes());
885 }
886 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
887  // Append the correct opcode for the range type.
888  TypeSwitch<Type>(op.getType().getElementType())
889  .Case(
890  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
891  .Case([&](pdl::ValueType) {
892  writer.append(OpCode::CreateDynamicValueRange);
893  });
894 
895  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
896  writer.appendPDLValueList(op->getOperands());
897 }
898 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
899  // Simply repoint the memory index of the result to the constant.
900  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
901 }
902 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
903  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
904  getRangeStorageIndex(op.getResult()), op.getValue());
905 }
906 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
907  writer.append(OpCode::EraseOp, op.getInputOp());
908 }
909 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
910  OpCode opCode =
911  TypeSwitch<Type, OpCode>(op.getResult().getType())
912  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
913  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
914  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
915  .Default([](Type) -> OpCode {
916  llvm_unreachable("unsupported element type");
917  });
918  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
919 }
920 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
921  writer.append(OpCode::Finalize);
922 }
923 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
924  BlockArgument arg = op.getLoopVariable();
925  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
926  writer.appendPDLValueKind(arg.getType());
927  writer.append(curLoopLevel, op.getSuccessor());
928  ++curLoopLevel;
929  if (curLoopLevel > maxLoopLevel)
930  maxLoopLevel = curLoopLevel;
931  generate(&op.getRegion(), writer);
932  --curLoopLevel;
933 }
934 void Generator::generate(pdl_interp::GetAttributeOp op,
935  ByteCodeWriter &writer) {
936  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
937  op.getNameAttr());
938 }
939 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
940  ByteCodeWriter &writer) {
941  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
942 }
943 void Generator::generate(pdl_interp::GetDefiningOpOp op,
944  ByteCodeWriter &writer) {
945  writer.append(OpCode::GetDefiningOp, op.getInputOp());
946  writer.appendPDLValue(op.getValue());
947 }
948 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
949  uint32_t index = op.getIndex();
950  if (index < 4)
951  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
952  else
953  writer.append(OpCode::GetOperandN, index);
954  writer.append(op.getInputOp(), op.getValue());
955 }
956 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
957  Value result = op.getValue();
958  std::optional<uint32_t> index = op.getIndex();
959  writer.append(OpCode::GetOperands,
960  index.value_or(std::numeric_limits<uint32_t>::max()),
961  op.getInputOp());
962  if (isa<pdl::RangeType>(result.getType()))
963  writer.append(getRangeStorageIndex(result));
964  else
966  writer.append(result);
967 }
968 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
969  uint32_t index = op.getIndex();
970  if (index < 4)
971  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
972  else
973  writer.append(OpCode::GetResultN, index);
974  writer.append(op.getInputOp(), op.getValue());
975 }
976 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
977  Value result = op.getValue();
978  std::optional<uint32_t> index = op.getIndex();
979  writer.append(OpCode::GetResults,
980  index.value_or(std::numeric_limits<uint32_t>::max()),
981  op.getInputOp());
982  if (isa<pdl::RangeType>(result.getType()))
983  writer.append(getRangeStorageIndex(result));
984  else
986  writer.append(result);
987 }
988 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
989  Value operations = op.getOperations();
990  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
991  writer.append(OpCode::GetUsers, operations, rangeIndex);
992  writer.appendPDLValue(op.getValue());
993 }
994 void Generator::generate(pdl_interp::GetValueTypeOp op,
995  ByteCodeWriter &writer) {
996  if (isa<pdl::RangeType>(op.getType())) {
997  Value result = op.getResult();
998  writer.append(OpCode::GetValueRangeTypes, result,
999  getRangeStorageIndex(result), op.getValue());
1000  } else {
1001  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1002  }
1003 }
1004 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1005  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1006 }
1007 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1008  ByteCodeField patternIndex = patterns.size();
1009  patterns.emplace_back(PDLByteCodePattern::create(
1010  op, configMap.lookup(op),
1011  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1012  writer.append(OpCode::RecordMatch, patternIndex,
1013  SuccessorRange(op.getOperation()), op.getMatchedOps());
1014  writer.appendPDLValueList(op.getInputs());
1015 }
1016 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1017  writer.append(OpCode::ReplaceOp, op.getInputOp());
1018  writer.appendPDLValueList(op.getReplValues());
1019 }
1020 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1021  ByteCodeWriter &writer) {
1022  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1023  op.getCaseValuesAttr(), op.getSuccessors());
1024 }
1025 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1026  ByteCodeWriter &writer) {
1027  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1028  op.getCaseValuesAttr(), op.getSuccessors());
1029 }
1030 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1031  ByteCodeWriter &writer) {
1032  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1033  return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1034  });
1035  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1036  op.getSuccessors());
1037 }
1038 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1039  ByteCodeWriter &writer) {
1040  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1041  op.getCaseValuesAttr(), op.getSuccessors());
1042 }
1043 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1044  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1045  op.getSuccessors());
1046 }
1047 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1048  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1049  op.getSuccessors());
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // PDLByteCode
1054 //===----------------------------------------------------------------------===//
1055 
1056 PDLByteCode::PDLByteCode(
1057  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1059  llvm::StringMap<PDLConstraintFunction> constraintFns,
1060  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1061  : configs(std::move(configs)) {
1062  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1063  rewriterByteCode, patterns, maxValueMemoryIndex,
1064  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1065  maxLoopLevel, constraintFns, rewriteFns, configMap);
1066  generator.generate(module);
1067 
1068  // Initialize the external functions.
1069  for (auto &it : constraintFns)
1070  constraintFunctions.push_back(std::move(it.second));
1071  for (auto &it : rewriteFns)
1072  rewriteFunctions.push_back(std::move(it.second));
1073 }
1074 
1075 /// Initialize the given state such that it can be used to execute the current
1076 /// bytecode.
1077 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1078  state.memory.resize(maxValueMemoryIndex, nullptr);
1079  state.opRangeMemory.resize(maxOpRangeCount);
1080  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1081  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1082  state.loopIndex.resize(maxLoopLevel, 0);
1083  state.currentPatternBenefits.reserve(patterns.size());
1084  for (const PDLByteCodePattern &pattern : patterns)
1085  state.currentPatternBenefits.push_back(pattern.getBenefit());
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // ByteCode Execution
1090 //===----------------------------------------------------------------------===//
1091 
1092 namespace {
1093 /// This class is an instantiation of the PDLResultList that provides access to
1094 /// the returned results. This API is not on `PDLResultList` to avoid
1095 /// overexposing access to information specific solely to the ByteCode.
1096 class ByteCodeRewriteResultList : public PDLResultList {
1097 public:
1098  ByteCodeRewriteResultList(unsigned maxNumResults)
1099  : PDLResultList(maxNumResults) {}
1100 
1101  /// Return the list of PDL results.
1102  MutableArrayRef<PDLValue> getResults() { return results; }
1103 
1104  /// Return the type ranges allocated by this list.
1105  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1106  return allocatedTypeRanges;
1107  }
1108 
1109  /// Return the value ranges allocated by this list.
1110  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1111  return allocatedValueRanges;
1112  }
1113 };
1114 
1115 /// This class provides support for executing a bytecode stream.
1116 class ByteCodeExecutor {
1117 public:
1118  ByteCodeExecutor(
1119  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1120  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1121  MutableArrayRef<TypeRange> typeRangeMemory,
1122  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1123  MutableArrayRef<ValueRange> valueRangeMemory,
1124  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1125  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1127  ArrayRef<PatternBenefit> currentPatternBenefits,
1129  ArrayRef<PDLConstraintFunction> constraintFunctions,
1130  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1131  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1132  typeRangeMemory(typeRangeMemory),
1133  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1134  valueRangeMemory(valueRangeMemory),
1135  allocatedValueRangeMemory(allocatedValueRangeMemory),
1136  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1137  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1138  constraintFunctions(constraintFunctions),
1139  rewriteFunctions(rewriteFunctions) {}
1140 
1141  /// Start executing the code at the current bytecode index. `matches` is an
1142  /// optional field provided when this function is executed in a matching
1143  /// context.
1144  LogicalResult
1145  execute(PatternRewriter &rewriter,
1146  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1147  std::optional<Location> mainRewriteLoc = {});
1148 
1149 private:
1150  /// Internal implementation of executing each of the bytecode commands.
1151  void executeApplyConstraint(PatternRewriter &rewriter);
1152  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1153  void executeAreEqual();
1154  void executeAreRangesEqual();
1155  void executeBranch();
1156  void executeCheckOperandCount();
1157  void executeCheckOperationName();
1158  void executeCheckResultCount();
1159  void executeCheckTypes();
1160  void executeContinue();
1161  void executeCreateConstantTypeRange();
1162  void executeCreateOperation(PatternRewriter &rewriter,
1163  Location mainRewriteLoc);
1164  template <typename T>
1165  void executeDynamicCreateRange(StringRef type);
1166  void executeEraseOp(PatternRewriter &rewriter);
1167  template <typename T, typename Range, PDLValue::Kind kind>
1168  void executeExtract();
1169  void executeFinalize();
1170  void executeForEach();
1171  void executeGetAttribute();
1172  void executeGetAttributeType();
1173  void executeGetDefiningOp();
1174  void executeGetOperand(unsigned index);
1175  void executeGetOperands();
1176  void executeGetResult(unsigned index);
1177  void executeGetResults();
1178  void executeGetUsers();
1179  void executeGetValueType();
1180  void executeGetValueRangeTypes();
1181  void executeIsNotNull();
1182  void executeRecordMatch(PatternRewriter &rewriter,
1184  void executeReplaceOp(PatternRewriter &rewriter);
1185  void executeSwitchAttribute();
1186  void executeSwitchOperandCount();
1187  void executeSwitchOperationName();
1188  void executeSwitchResultCount();
1189  void executeSwitchType();
1190  void executeSwitchTypes();
1191  void processNativeFunResults(ByteCodeRewriteResultList &results,
1192  unsigned numResults,
1193  LogicalResult &rewriteResult);
1194 
1195  /// Pushes a code iterator to the stack.
1196  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1197 
1198  /// Pops a code iterator from the stack, returning true on success.
1199  void popCodeIt() {
1200  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1201  curCodeIt = resumeCodeIt.pop_back_val();
1202  }
1203 
1204  /// Return the bytecode iterator at the start of the current op code.
1205  const ByteCodeField *getPrevCodeIt() const {
1206  LLVM_DEBUG({
1207  // Account for the op code and the Location stored inline.
1208  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1209  });
1210 
1211  // Account for the op code only.
1212  return curCodeIt - 1;
1213  }
1214 
1215  /// Read a value from the bytecode buffer, optionally skipping a certain
1216  /// number of prefix values. These methods always update the buffer to point
1217  /// to the next field after the read data.
1218  template <typename T = ByteCodeField>
1219  T read(size_t skipN = 0) {
1220  curCodeIt += skipN;
1221  return readImpl<T>();
1222  }
1223  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1224 
1225  /// Read a list of values from the bytecode buffer.
1226  template <typename ValueT, typename T>
1227  void readList(SmallVectorImpl<T> &list) {
1228  list.clear();
1229  for (unsigned i = 0, e = read(); i != e; ++i)
1230  list.push_back(read<ValueT>());
1231  }
1232 
1233  /// Read a list of values from the bytecode buffer. The values may be encoded
1234  /// either as a single element or a range of elements.
1235  void readList(SmallVectorImpl<Type> &list) {
1236  for (unsigned i = 0, e = read(); i != e; ++i) {
1237  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1238  list.push_back(read<Type>());
1239  } else {
1240  TypeRange *values = read<TypeRange *>();
1241  list.append(values->begin(), values->end());
1242  }
1243  }
1244  }
1245  void readList(SmallVectorImpl<Value> &list) {
1246  for (unsigned i = 0, e = read(); i != e; ++i) {
1247  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1248  list.push_back(read<Value>());
1249  } else {
1250  ValueRange *values = read<ValueRange *>();
1251  list.append(values->begin(), values->end());
1252  }
1253  }
1254  }
1255 
1256  /// Read a value stored inline as a pointer.
1257  template <typename T>
1258  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1259  readInline() {
1260  const void *pointer;
1261  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1262  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1263  return T::getFromOpaquePointer(pointer);
1264  }
1265 
1266  void skip(size_t skipN) { curCodeIt += skipN; }
1267 
1268  /// Jump to a specific successor based on a predicate value.
1269  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1270  /// Jump to a specific successor based on a destination index.
1271  void selectJump(size_t destIndex) {
1272  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1273  }
1274 
1275  /// Handle a switch operation with the provided value and cases.
1276  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1277  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1278  LLVM_DEBUG({
1279  llvm::dbgs() << " * Value: " << value << "\n"
1280  << " * Cases: ";
1281  llvm::interleaveComma(cases, llvm::dbgs());
1282  llvm::dbgs() << "\n";
1283  });
1284 
1285  // Check to see if the attribute value is within the case list. Jump to
1286  // the correct successor index based on the result.
1287  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1288  if (cmp(*it, value))
1289  return selectJump(size_t((it - cases.begin()) + 1));
1290  selectJump(size_t(0));
1291  }
1292 
1293  /// Store a pointer to memory.
1294  void storeToMemory(unsigned index, const void *value) {
1295  memory[index] = value;
1296  }
1297 
1298  /// Store a value to memory as an opaque pointer.
1299  template <typename T>
1300  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1301  storeToMemory(unsigned index, T value) {
1302  memory[index] = value.getAsOpaquePointer();
1303  }
1304 
1305  /// Internal implementation of reading various data types from the bytecode
1306  /// stream.
1307  template <typename T>
1308  const void *readFromMemory() {
1309  size_t index = *curCodeIt++;
1310 
1311  // If this type is an SSA value, it can only be stored in non-const memory.
1312  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1313  Value>::value ||
1314  index < memory.size())
1315  return memory[index];
1316 
1317  // Otherwise, if this index is not inbounds it is uniqued.
1318  return uniquedMemory[index - memory.size()];
1319  }
1320  template <typename T>
1321  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1322  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1323  }
1324  template <typename T>
1325  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1326  T>
1327  readImpl() {
1328  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1329  }
1330  template <typename T>
1331  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1332  switch (read<PDLValue::Kind>()) {
1333  case PDLValue::Kind::Attribute:
1334  return read<Attribute>();
1335  case PDLValue::Kind::Operation:
1336  return read<Operation *>();
1337  case PDLValue::Kind::Type:
1338  return read<Type>();
1339  case PDLValue::Kind::Value:
1340  return read<Value>();
1341  case PDLValue::Kind::TypeRange:
1342  return read<TypeRange *>();
1343  case PDLValue::Kind::ValueRange:
1344  return read<ValueRange *>();
1345  }
1346  llvm_unreachable("unhandled PDLValue::Kind");
1347  }
1348  template <typename T>
1349  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1350  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1351  "unexpected ByteCode address size");
1352  ByteCodeAddr result;
1353  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1354  curCodeIt += 2;
1355  return result;
1356  }
1357  template <typename T>
1358  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1359  return *curCodeIt++;
1360  }
1361  template <typename T>
1362  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1363  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1364  }
1365 
1366  /// Assign the given range to the given memory index. This allocates a new
1367  /// range object if necessary.
1368  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1369  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1370  unsigned rangeIndex) {
1371  // Utility functor used to type-erase the assignment.
1372  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1373  // If the input range is empty, we don't need to allocate anything.
1374  if (range.empty()) {
1375  rangeMemory[rangeIndex] = {};
1376  } else {
1377  // Allocate a buffer for this type range.
1378  llvm::OwningArrayRef<T> storage(llvm::size(range));
1379  llvm::copy(range, storage.begin());
1380 
1381  // Assign this to the range slot and use the range as the value for the
1382  // memory index.
1383  allocatedRangeMemory.emplace_back(std::move(storage));
1384  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1385  }
1386  memory[memIndex] = &rangeMemory[rangeIndex];
1387  };
1388 
1389  // Dispatch based on the concrete range type.
1390  if constexpr (std::is_same_v<T, Type>) {
1391  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1392  } else if constexpr (std::is_same_v<T, Value>) {
1393  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1394  } else {
1395  llvm_unreachable("unhandled range type");
1396  }
1397  }
1398 
1399  /// The underlying bytecode buffer.
1400  const ByteCodeField *curCodeIt;
1401 
1402  /// The stack of bytecode positions at which to resume operation.
1404 
1405  /// The current execution memory.
1407  MutableArrayRef<OwningOpRange> opRangeMemory;
1408  MutableArrayRef<TypeRange> typeRangeMemory;
1409  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1410  MutableArrayRef<ValueRange> valueRangeMemory;
1411  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1412 
1413  /// The current loop indices.
1414  MutableArrayRef<unsigned> loopIndex;
1415 
1416  /// References to ByteCode data necessary for execution.
1417  ArrayRef<const void *> uniquedMemory;
1419  ArrayRef<PatternBenefit> currentPatternBenefits;
1421  ArrayRef<PDLConstraintFunction> constraintFunctions;
1422  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1423 };
1424 } // namespace
1425 
1426 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1427  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1428  ByteCodeField fun_idx = read();
1430  readList<PDLValue>(args);
1431 
1432  LLVM_DEBUG({
1433  llvm::dbgs() << " * Arguments: ";
1434  llvm::interleaveComma(args, llvm::dbgs());
1435  llvm::dbgs() << "\n";
1436  });
1437 
1438  ByteCodeField isNegated = read();
1439  LLVM_DEBUG({
1440  llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1441  llvm::interleaveComma(args, llvm::dbgs());
1442  });
1443 
1444  ByteCodeField numResults = read();
1445  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1446  ByteCodeRewriteResultList results(numResults);
1447  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1448  [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1449  LLVM_DEBUG({
1450  if (succeeded(rewriteResult)) {
1451  llvm::dbgs() << " * Constraint succeeded\n";
1452  llvm::dbgs() << " * Results: ";
1453  llvm::interleaveComma(constraintResults, llvm::dbgs());
1454  llvm::dbgs() << "\n";
1455  } else {
1456  llvm::dbgs() << " * Constraint failed\n";
1457  }
1458  });
1459  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1460  "native PDL rewrite function succeeded but returned "
1461  "unexpected number of results");
1462  processNativeFunResults(results, numResults, rewriteResult);
1463 
1464  // Depending on the constraint jump to the proper destination.
1465  selectJump(isNegated != succeeded(rewriteResult));
1466 }
1467 
1468 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1469  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1470  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1472  readList<PDLValue>(args);
1473 
1474  LLVM_DEBUG({
1475  llvm::dbgs() << " * Arguments: ";
1476  llvm::interleaveComma(args, llvm::dbgs());
1477  });
1478 
1479  // Execute the rewrite function.
1480  ByteCodeField numResults = read();
1481  ByteCodeRewriteResultList results(numResults);
1482  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1483 
1484  assert(results.getResults().size() == numResults &&
1485  "native PDL rewrite function returned unexpected number of results");
1486 
1487  processNativeFunResults(results, numResults, rewriteResult);
1488 
1489  if (failed(rewriteResult)) {
1490  LLVM_DEBUG(llvm::dbgs() << " - Failed");
1491  return failure();
1492  }
1493  return success();
1494 }
1495 
1496 void ByteCodeExecutor::processNativeFunResults(
1497  ByteCodeRewriteResultList &results, unsigned numResults,
1498  LogicalResult &rewriteResult) {
1499  // Store the results in the bytecode memory or handle missing results on
1500  // failure.
1501  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1502  PDLValue::Kind resultKind = read<PDLValue::Kind>();
1503 
1504  // Skip the according number of values on the buffer on failure and exit
1505  // early as there are no results to process.
1506  if (failed(rewriteResult)) {
1507  if (resultKind == PDLValue::Kind::TypeRange ||
1508  resultKind == PDLValue::Kind::ValueRange) {
1509  skip(2);
1510  } else {
1511  skip(1);
1512  }
1513  return;
1514  }
1515  PDLValue result = results.getResults()[resultIdx];
1516  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1517  assert(result.getKind() == resultKind &&
1518  "native PDL rewrite function returned an unexpected type of "
1519  "result");
1520  // If the result is a range, we need to copy it over to the bytecodes
1521  // range memory.
1522  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1523  unsigned rangeIndex = read();
1524  typeRangeMemory[rangeIndex] = *typeRange;
1525  memory[read()] = &typeRangeMemory[rangeIndex];
1526  } else if (std::optional<ValueRange> valueRange =
1527  result.dyn_cast<ValueRange>()) {
1528  unsigned rangeIndex = read();
1529  valueRangeMemory[rangeIndex] = *valueRange;
1530  memory[read()] = &valueRangeMemory[rangeIndex];
1531  } else {
1532  memory[read()] = result.getAsOpaquePointer();
1533  }
1534  }
1535 
1536  // Copy over any underlying storage allocated for result ranges.
1537  for (auto &it : results.getAllocatedTypeRanges())
1538  allocatedTypeRangeMemory.push_back(std::move(it));
1539  for (auto &it : results.getAllocatedValueRanges())
1540  allocatedValueRangeMemory.push_back(std::move(it));
1541 }
1542 
1543 void ByteCodeExecutor::executeAreEqual() {
1544  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1545  const void *lhs = read<const void *>();
1546  const void *rhs = read<const void *>();
1547 
1548  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1549  selectJump(lhs == rhs);
1550 }
1551 
1552 void ByteCodeExecutor::executeAreRangesEqual() {
1553  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1554  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1555  const void *lhs = read<const void *>();
1556  const void *rhs = read<const void *>();
1557 
1558  switch (valueKind) {
1559  case PDLValue::Kind::TypeRange: {
1560  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1561  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1562  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1563  selectJump(*lhsRange == *rhsRange);
1564  break;
1565  }
1566  case PDLValue::Kind::ValueRange: {
1567  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1568  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1569  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1570  selectJump(*lhsRange == *rhsRange);
1571  break;
1572  }
1573  default:
1574  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1575  }
1576 }
1577 
1578 void ByteCodeExecutor::executeBranch() {
1579  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1580  curCodeIt = &code[read<ByteCodeAddr>()];
1581 }
1582 
1583 void ByteCodeExecutor::executeCheckOperandCount() {
1584  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1585  Operation *op = read<Operation *>();
1586  uint32_t expectedCount = read<uint32_t>();
1587  bool compareAtLeast = read();
1588 
1589  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1590  << " * Expected: " << expectedCount << "\n"
1591  << " * Comparator: "
1592  << (compareAtLeast ? ">=" : "==") << "\n");
1593  if (compareAtLeast)
1594  selectJump(op->getNumOperands() >= expectedCount);
1595  else
1596  selectJump(op->getNumOperands() == expectedCount);
1597 }
1598 
1599 void ByteCodeExecutor::executeCheckOperationName() {
1600  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1601  Operation *op = read<Operation *>();
1602  OperationName expectedName = read<OperationName>();
1603 
1604  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1605  << " * Expected: \"" << expectedName << "\"\n");
1606  selectJump(op->getName() == expectedName);
1607 }
1608 
1609 void ByteCodeExecutor::executeCheckResultCount() {
1610  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1611  Operation *op = read<Operation *>();
1612  uint32_t expectedCount = read<uint32_t>();
1613  bool compareAtLeast = read();
1614 
1615  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1616  << " * Expected: " << expectedCount << "\n"
1617  << " * Comparator: "
1618  << (compareAtLeast ? ">=" : "==") << "\n");
1619  if (compareAtLeast)
1620  selectJump(op->getNumResults() >= expectedCount);
1621  else
1622  selectJump(op->getNumResults() == expectedCount);
1623 }
1624 
1625 void ByteCodeExecutor::executeCheckTypes() {
1626  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1627  TypeRange *lhs = read<TypeRange *>();
1628  Attribute rhs = read<Attribute>();
1629  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1630 
1631  selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1632 }
1633 
1634 void ByteCodeExecutor::executeContinue() {
1635  ByteCodeField level = read();
1636  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1637  << " * Level: " << level << "\n");
1638  ++loopIndex[level];
1639  popCodeIt();
1640 }
1641 
1642 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1643  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1644  unsigned memIndex = read();
1645  unsigned rangeIndex = read();
1646  ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1647 
1648  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1649  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1650  rangeIndex);
1651 }
1652 
1653 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1654  Location mainRewriteLoc) {
1655  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1656 
1657  unsigned memIndex = read();
1658  OperationState state(mainRewriteLoc, read<OperationName>());
1659  readList(state.operands);
1660  for (unsigned i = 0, e = read(); i != e; ++i) {
1661  StringAttr name = read<StringAttr>();
1662  if (Attribute attr = read<Attribute>())
1663  state.addAttribute(name, attr);
1664  }
1665 
1666  // Read in the result types. If the "size" is the sentinel value, this
1667  // indicates that the result types should be inferred.
1668  unsigned numResults = read();
1669  if (numResults == kInferTypesMarker) {
1670  InferTypeOpInterface::Concept *inferInterface =
1671  state.name.getInterface<InferTypeOpInterface>();
1672  assert(inferInterface &&
1673  "expected operation to provide InferTypeOpInterface");
1674 
1675  // TODO: Handle failure.
1676  if (failed(inferInterface->inferReturnTypes(
1677  state.getContext(), state.location, state.operands,
1678  state.attributes.getDictionary(state.getContext()),
1679  state.getRawProperties(), state.regions, state.types)))
1680  return;
1681  } else {
1682  // Otherwise, this is a fixed number of results.
1683  for (unsigned i = 0; i != numResults; ++i) {
1684  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1685  state.types.push_back(read<Type>());
1686  } else {
1687  TypeRange *resultTypes = read<TypeRange *>();
1688  state.types.append(resultTypes->begin(), resultTypes->end());
1689  }
1690  }
1691  }
1692 
1693  Operation *resultOp = rewriter.create(state);
1694  memory[memIndex] = resultOp;
1695 
1696  LLVM_DEBUG({
1697  llvm::dbgs() << " * Attributes: "
1698  << state.attributes.getDictionary(state.getContext())
1699  << "\n * Operands: ";
1700  llvm::interleaveComma(state.operands, llvm::dbgs());
1701  llvm::dbgs() << "\n * Result Types: ";
1702  llvm::interleaveComma(state.types, llvm::dbgs());
1703  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1704  });
1705 }
1706 
1707 template <typename T>
1708 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1709  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1710  unsigned memIndex = read();
1711  unsigned rangeIndex = read();
1712  SmallVector<T> values;
1713  readList(values);
1714 
1715  LLVM_DEBUG({
1716  llvm::dbgs() << "\n * " << type << "s: ";
1717  llvm::interleaveComma(values, llvm::dbgs());
1718  llvm::dbgs() << "\n";
1719  });
1720 
1721  assignRangeToMemory(values, memIndex, rangeIndex);
1722 }
1723 
1724 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1725  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1726  Operation *op = read<Operation *>();
1727 
1728  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1729  rewriter.eraseOp(op);
1730 }
1731 
1732 template <typename T, typename Range, PDLValue::Kind kind>
1733 void ByteCodeExecutor::executeExtract() {
1734  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1735  Range *range = read<Range *>();
1736  unsigned index = read<uint32_t>();
1737  unsigned memIndex = read();
1738 
1739  if (!range) {
1740  memory[memIndex] = nullptr;
1741  return;
1742  }
1743 
1744  T result = index < range->size() ? (*range)[index] : T();
1745  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1746  << " * Index: " << index << "\n"
1747  << " * Result: " << result << "\n");
1748  storeToMemory(memIndex, result);
1749 }
1750 
1751 void ByteCodeExecutor::executeFinalize() {
1752  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1753 }
1754 
1755 void ByteCodeExecutor::executeForEach() {
1756  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1757  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1758  unsigned rangeIndex = read();
1759  unsigned memIndex = read();
1760  const void *value = nullptr;
1761 
1762  switch (read<PDLValue::Kind>()) {
1763  case PDLValue::Kind::Operation: {
1764  unsigned &index = loopIndex[read()];
1765  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1766  assert(index <= array.size() && "iterated past the end");
1767  if (index < array.size()) {
1768  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1769  value = array[index];
1770  break;
1771  }
1772 
1773  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1774  index = 0;
1775  selectJump(size_t(0));
1776  return;
1777  }
1778  default:
1779  llvm_unreachable("unexpected `ForEach` value kind");
1780  }
1781 
1782  // Store the iterate value and the stack address.
1783  memory[memIndex] = value;
1784  pushCodeIt(prevCodeIt);
1785 
1786  // Skip over the successor (we will enter the body of the loop).
1787  read<ByteCodeAddr>();
1788 }
1789 
1790 void ByteCodeExecutor::executeGetAttribute() {
1791  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1792  unsigned memIndex = read();
1793  Operation *op = read<Operation *>();
1794  StringAttr attrName = read<StringAttr>();
1795  Attribute attr = op->getAttr(attrName);
1796 
1797  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1798  << " * Attribute: " << attrName << "\n"
1799  << " * Result: " << attr << "\n");
1800  memory[memIndex] = attr.getAsOpaquePointer();
1801 }
1802 
1803 void ByteCodeExecutor::executeGetAttributeType() {
1804  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1805  unsigned memIndex = read();
1806  Attribute attr = read<Attribute>();
1807  Type type;
1808  if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1809  type = typedAttr.getType();
1810 
1811  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1812  << " * Result: " << type << "\n");
1813  memory[memIndex] = type.getAsOpaquePointer();
1814 }
1815 
1816 void ByteCodeExecutor::executeGetDefiningOp() {
1817  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1818  unsigned memIndex = read();
1819  Operation *op = nullptr;
1820  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1821  Value value = read<Value>();
1822  if (value)
1823  op = value.getDefiningOp();
1824  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1825  } else {
1826  ValueRange *values = read<ValueRange *>();
1827  if (values && !values->empty()) {
1828  op = values->front().getDefiningOp();
1829  }
1830  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1831  }
1832 
1833  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1834  memory[memIndex] = op;
1835 }
1836 
1837 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1838  Operation *op = read<Operation *>();
1839  unsigned memIndex = read();
1840  Value operand =
1841  index < op->getNumOperands() ? op->getOperand(index) : Value();
1842 
1843  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1844  << " * Index: " << index << "\n"
1845  << " * Result: " << operand << "\n");
1846  memory[memIndex] = operand.getAsOpaquePointer();
1847 }
1848 
1849 /// This function is the internal implementation of `GetResults` and
1850 /// `GetOperands` that provides support for extracting a value range from the
1851 /// given operation.
1852 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1853 static void *
1854 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1855  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1856  MutableArrayRef<ValueRange> valueRangeMemory) {
1857  // Check for the sentinel index that signals that all values should be
1858  // returned.
1859  if (index == std::numeric_limits<uint32_t>::max()) {
1860  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1861  // `values` is already the full value range.
1862 
1863  // Otherwise, check to see if this operation uses AttrSizedSegments.
1864  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1865  LLVM_DEBUG(llvm::dbgs()
1866  << " * Extracting values from `" << attrSizedSegments << "`\n");
1867 
1868  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1869  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1870  return nullptr;
1871 
1872  ArrayRef<int32_t> segments = segmentAttr;
1873  unsigned startIndex =
1874  std::accumulate(segments.begin(), segments.begin() + index, 0);
1875  values = values.slice(startIndex, *std::next(segments.begin(), index));
1876 
1877  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1878  << *std::next(segments.begin(), index) << "]\n");
1879 
1880  // Otherwise, assume this is the last operand group of the operation.
1881  // FIXME: We currently don't support operations with
1882  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1883  // have a way to detect it's presence.
1884  } else if (values.size() >= index) {
1885  LLVM_DEBUG(llvm::dbgs()
1886  << " * Treating values as trailing variadic range\n");
1887  values = values.drop_front(index);
1888 
1889  // If we couldn't detect a way to compute the values, bail out.
1890  } else {
1891  return nullptr;
1892  }
1893 
1894  // If the range index is valid, we are returning a range.
1895  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1896  valueRangeMemory[rangeIndex] = values;
1897  return &valueRangeMemory[rangeIndex];
1898  }
1899 
1900  // If a range index wasn't provided, the range is required to be non-variadic.
1901  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1902 }
1903 
1904 void ByteCodeExecutor::executeGetOperands() {
1905  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1906  unsigned index = read<uint32_t>();
1907  Operation *op = read<Operation *>();
1908  ByteCodeField rangeIndex = read();
1909 
1910  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1911  op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1912  valueRangeMemory);
1913  if (!result)
1914  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1915  memory[read()] = result;
1916 }
1917 
1918 void ByteCodeExecutor::executeGetResult(unsigned index) {
1919  Operation *op = read<Operation *>();
1920  unsigned memIndex = read();
1921  OpResult result =
1922  index < op->getNumResults() ? op->getResult(index) : OpResult();
1923 
1924  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1925  << " * Index: " << index << "\n"
1926  << " * Result: " << result << "\n");
1927  memory[memIndex] = result.getAsOpaquePointer();
1928 }
1929 
1930 void ByteCodeExecutor::executeGetResults() {
1931  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1932  unsigned index = read<uint32_t>();
1933  Operation *op = read<Operation *>();
1934  ByteCodeField rangeIndex = read();
1935 
1936  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1937  op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1938  valueRangeMemory);
1939  if (!result)
1940  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1941  memory[read()] = result;
1942 }
1943 
1944 void ByteCodeExecutor::executeGetUsers() {
1945  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1946  unsigned memIndex = read();
1947  unsigned rangeIndex = read();
1948  OwningOpRange &range = opRangeMemory[rangeIndex];
1949  memory[memIndex] = &range;
1950 
1951  range = OwningOpRange();
1952  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1953  // Read the value.
1954  Value value = read<Value>();
1955  if (!value)
1956  return;
1957  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1958 
1959  // Extract the users of a single value.
1960  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1961  llvm::copy(value.getUsers(), range.begin());
1962  } else {
1963  // Read a range of values.
1964  ValueRange *values = read<ValueRange *>();
1965  if (!values)
1966  return;
1967  LLVM_DEBUG({
1968  llvm::dbgs() << " * Values (" << values->size() << "): ";
1969  llvm::interleaveComma(*values, llvm::dbgs());
1970  llvm::dbgs() << "\n";
1971  });
1972 
1973  // Extract all the users of a range of values.
1975  for (Value value : *values)
1976  users.append(value.user_begin(), value.user_end());
1977  range = OwningOpRange(users.size());
1978  llvm::copy(users, range.begin());
1979  }
1980 
1981  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1982 }
1983 
1984 void ByteCodeExecutor::executeGetValueType() {
1985  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1986  unsigned memIndex = read();
1987  Value value = read<Value>();
1988  Type type = value ? value.getType() : Type();
1989 
1990  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1991  << " * Result: " << type << "\n");
1992  memory[memIndex] = type.getAsOpaquePointer();
1993 }
1994 
1995 void ByteCodeExecutor::executeGetValueRangeTypes() {
1996  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1997  unsigned memIndex = read();
1998  unsigned rangeIndex = read();
1999  ValueRange *values = read<ValueRange *>();
2000  if (!values) {
2001  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
2002  memory[memIndex] = nullptr;
2003  return;
2004  }
2005 
2006  LLVM_DEBUG({
2007  llvm::dbgs() << " * Values (" << values->size() << "): ";
2008  llvm::interleaveComma(*values, llvm::dbgs());
2009  llvm::dbgs() << "\n * Result: ";
2010  llvm::interleaveComma(values->getType(), llvm::dbgs());
2011  llvm::dbgs() << "\n";
2012  });
2013  typeRangeMemory[rangeIndex] = values->getType();
2014  memory[memIndex] = &typeRangeMemory[rangeIndex];
2015 }
2016 
2017 void ByteCodeExecutor::executeIsNotNull() {
2018  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2019  const void *value = read<const void *>();
2020 
2021  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
2022  selectJump(value != nullptr);
2023 }
2024 
2025 void ByteCodeExecutor::executeRecordMatch(
2026  PatternRewriter &rewriter,
2028  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2029  unsigned patternIndex = read();
2030  PatternBenefit benefit = currentPatternBenefits[patternIndex];
2031  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2032 
2033  // If the benefit of the pattern is impossible, skip the processing of the
2034  // rest of the pattern.
2035  if (benefit.isImpossibleToMatch()) {
2036  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2037  curCodeIt = dest;
2038  return;
2039  }
2040 
2041  // Create a fused location containing the locations of each of the
2042  // operations used in the match. This will be used as the location for
2043  // created operations during the rewrite that don't already have an
2044  // explicit location set.
2045  unsigned numMatchLocs = read();
2046  SmallVector<Location, 4> matchLocs;
2047  matchLocs.reserve(numMatchLocs);
2048  for (unsigned i = 0; i != numMatchLocs; ++i)
2049  matchLocs.push_back(read<Operation *>()->getLoc());
2050  Location matchLoc = rewriter.getFusedLoc(matchLocs);
2051 
2052  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
2053  << " * Location: " << matchLoc << "\n");
2054  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2055  PDLByteCode::MatchResult &match = matches.back();
2056 
2057  // Record all of the inputs to the match. If any of the inputs are ranges, we
2058  // will also need to remap the range pointer to memory stored in the match
2059  // state.
2060  unsigned numInputs = read();
2061  match.values.reserve(numInputs);
2062  match.typeRangeValues.reserve(numInputs);
2063  match.valueRangeValues.reserve(numInputs);
2064  for (unsigned i = 0; i < numInputs; ++i) {
2065  switch (read<PDLValue::Kind>()) {
2066  case PDLValue::Kind::TypeRange:
2067  match.typeRangeValues.push_back(*read<TypeRange *>());
2068  match.values.push_back(&match.typeRangeValues.back());
2069  break;
2070  case PDLValue::Kind::ValueRange:
2071  match.valueRangeValues.push_back(*read<ValueRange *>());
2072  match.values.push_back(&match.valueRangeValues.back());
2073  break;
2074  default:
2075  match.values.push_back(read<const void *>());
2076  break;
2077  }
2078  }
2079  curCodeIt = dest;
2080 }
2081 
2082 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2083  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2084  Operation *op = read<Operation *>();
2086  readList(args);
2087 
2088  LLVM_DEBUG({
2089  llvm::dbgs() << " * Operation: " << *op << "\n"
2090  << " * Values: ";
2091  llvm::interleaveComma(args, llvm::dbgs());
2092  llvm::dbgs() << "\n";
2093  });
2094  rewriter.replaceOp(op, args);
2095 }
2096 
2097 void ByteCodeExecutor::executeSwitchAttribute() {
2098  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2099  Attribute value = read<Attribute>();
2100  ArrayAttr cases = read<ArrayAttr>();
2101  handleSwitch(value, cases);
2102 }
2103 
2104 void ByteCodeExecutor::executeSwitchOperandCount() {
2105  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2106  Operation *op = read<Operation *>();
2107  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2108 
2109  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2110  handleSwitch(op->getNumOperands(), cases);
2111 }
2112 
2113 void ByteCodeExecutor::executeSwitchOperationName() {
2114  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2115  OperationName value = read<Operation *>()->getName();
2116  size_t caseCount = read();
2117 
2118  // The operation names are stored in-line, so to print them out for
2119  // debugging purposes we need to read the array before executing the
2120  // switch so that we can display all of the possible values.
2121  LLVM_DEBUG({
2122  const ByteCodeField *prevCodeIt = curCodeIt;
2123  llvm::dbgs() << " * Value: " << value << "\n"
2124  << " * Cases: ";
2125  llvm::interleaveComma(
2126  llvm::map_range(llvm::seq<size_t>(0, caseCount),
2127  [&](size_t) { return read<OperationName>(); }),
2128  llvm::dbgs());
2129  llvm::dbgs() << "\n";
2130  curCodeIt = prevCodeIt;
2131  });
2132 
2133  // Try to find the switch value within any of the cases.
2134  for (size_t i = 0; i != caseCount; ++i) {
2135  if (read<OperationName>() == value) {
2136  curCodeIt += (caseCount - i - 1);
2137  return selectJump(i + 1);
2138  }
2139  }
2140  selectJump(size_t(0));
2141 }
2142 
2143 void ByteCodeExecutor::executeSwitchResultCount() {
2144  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2145  Operation *op = read<Operation *>();
2146  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2147 
2148  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2149  handleSwitch(op->getNumResults(), cases);
2150 }
2151 
2152 void ByteCodeExecutor::executeSwitchType() {
2153  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2154  Type value = read<Type>();
2155  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2156  handleSwitch(value, cases);
2157 }
2158 
2159 void ByteCodeExecutor::executeSwitchTypes() {
2160  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2161  TypeRange *value = read<TypeRange *>();
2162  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2163  if (!value) {
2164  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2165  return selectJump(size_t(0));
2166  }
2167  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2168  return value == caseValue.getAsValueRange<TypeAttr>();
2169  });
2170 }
2171 
2172 LogicalResult
2173 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2175  std::optional<Location> mainRewriteLoc) {
2176  while (true) {
2177  // Print the location of the operation being executed.
2178  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2179 
2180  OpCode opCode = static_cast<OpCode>(read());
2181  switch (opCode) {
2182  case ApplyConstraint:
2183  executeApplyConstraint(rewriter);
2184  break;
2185  case ApplyRewrite:
2186  if (failed(executeApplyRewrite(rewriter)))
2187  return failure();
2188  break;
2189  case AreEqual:
2190  executeAreEqual();
2191  break;
2192  case AreRangesEqual:
2193  executeAreRangesEqual();
2194  break;
2195  case Branch:
2196  executeBranch();
2197  break;
2198  case CheckOperandCount:
2199  executeCheckOperandCount();
2200  break;
2201  case CheckOperationName:
2202  executeCheckOperationName();
2203  break;
2204  case CheckResultCount:
2205  executeCheckResultCount();
2206  break;
2207  case CheckTypes:
2208  executeCheckTypes();
2209  break;
2210  case Continue:
2211  executeContinue();
2212  break;
2213  case CreateConstantTypeRange:
2214  executeCreateConstantTypeRange();
2215  break;
2216  case CreateOperation:
2217  executeCreateOperation(rewriter, *mainRewriteLoc);
2218  break;
2219  case CreateDynamicTypeRange:
2220  executeDynamicCreateRange<Type>("Type");
2221  break;
2222  case CreateDynamicValueRange:
2223  executeDynamicCreateRange<Value>("Value");
2224  break;
2225  case EraseOp:
2226  executeEraseOp(rewriter);
2227  break;
2228  case ExtractOp:
2229  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2230  break;
2231  case ExtractType:
2232  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2233  break;
2234  case ExtractValue:
2235  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2236  break;
2237  case Finalize:
2238  executeFinalize();
2239  LLVM_DEBUG(llvm::dbgs() << "\n");
2240  return success();
2241  case ForEach:
2242  executeForEach();
2243  break;
2244  case GetAttribute:
2245  executeGetAttribute();
2246  break;
2247  case GetAttributeType:
2248  executeGetAttributeType();
2249  break;
2250  case GetDefiningOp:
2251  executeGetDefiningOp();
2252  break;
2253  case GetOperand0:
2254  case GetOperand1:
2255  case GetOperand2:
2256  case GetOperand3: {
2257  unsigned index = opCode - GetOperand0;
2258  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2259  executeGetOperand(index);
2260  break;
2261  }
2262  case GetOperandN:
2263  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2264  executeGetOperand(read<uint32_t>());
2265  break;
2266  case GetOperands:
2267  executeGetOperands();
2268  break;
2269  case GetResult0:
2270  case GetResult1:
2271  case GetResult2:
2272  case GetResult3: {
2273  unsigned index = opCode - GetResult0;
2274  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2275  executeGetResult(index);
2276  break;
2277  }
2278  case GetResultN:
2279  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2280  executeGetResult(read<uint32_t>());
2281  break;
2282  case GetResults:
2283  executeGetResults();
2284  break;
2285  case GetUsers:
2286  executeGetUsers();
2287  break;
2288  case GetValueType:
2289  executeGetValueType();
2290  break;
2291  case GetValueRangeTypes:
2292  executeGetValueRangeTypes();
2293  break;
2294  case IsNotNull:
2295  executeIsNotNull();
2296  break;
2297  case RecordMatch:
2298  assert(matches &&
2299  "expected matches to be provided when executing the matcher");
2300  executeRecordMatch(rewriter, *matches);
2301  break;
2302  case ReplaceOp:
2303  executeReplaceOp(rewriter);
2304  break;
2305  case SwitchAttribute:
2306  executeSwitchAttribute();
2307  break;
2308  case SwitchOperandCount:
2309  executeSwitchOperandCount();
2310  break;
2311  case SwitchOperationName:
2312  executeSwitchOperationName();
2313  break;
2314  case SwitchResultCount:
2315  executeSwitchResultCount();
2316  break;
2317  case SwitchType:
2318  executeSwitchType();
2319  break;
2320  case SwitchTypes:
2321  executeSwitchTypes();
2322  break;
2323  }
2324  LLVM_DEBUG(llvm::dbgs() << "\n");
2325  }
2326 }
2327 
2328 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2330  PDLByteCodeMutableState &state) const {
2331  // The first memory slot is always the root operation.
2332  state.memory[0] = op;
2333 
2334  // The matcher function always starts at code address 0.
2335  ByteCodeExecutor executor(
2336  matcherByteCode.data(), state.memory, state.opRangeMemory,
2337  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2338  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2339  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2340  constraintFunctions, rewriteFunctions);
2341  LogicalResult executeResult = executor.execute(rewriter, &matches);
2342  (void)executeResult;
2343  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2344 
2345  // Order the found matches by benefit.
2346  std::stable_sort(matches.begin(), matches.end(),
2347  [](const MatchResult &lhs, const MatchResult &rhs) {
2348  return lhs.benefit > rhs.benefit;
2349  });
2350 }
2351 
2352 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2353  const MatchResult &match,
2354  PDLByteCodeMutableState &state) const {
2355  auto *configSet = match.pattern->getConfigSet();
2356  if (configSet)
2357  configSet->notifyRewriteBegin(rewriter);
2358 
2359  // The arguments of the rewrite function are stored at the start of the
2360  // memory buffer.
2361  llvm::copy(match.values, state.memory.begin());
2362 
2363  ByteCodeExecutor executor(
2364  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2365  state.opRangeMemory, state.typeRangeMemory,
2366  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2367  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2368  rewriterByteCode, state.currentPatternBenefits, patterns,
2369  constraintFunctions, rewriteFunctions);
2370  LogicalResult result =
2371  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2372 
2373  if (configSet)
2374  configSet->notifyRewriteEnd(rewriter);
2375 
2376  // If the rewrite failed, check if the pattern rewriter can recover. If it
2377  // can, we can signal to the pattern applicator to keep trying patterns. If it
2378  // doesn't, we need to bail. Bailing here should be fine, given that we have
2379  // no means to propagate such a failure to the user, and it also indicates a
2380  // bug in the user code (i.e. failable rewrites should not be used with
2381  // pattern rewriters that don't support it).
2382  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2383  LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2384  llvm::report_fatal_error(
2385  "Native PDL Rewrite failed, but the pattern "
2386  "rewriter doesn't support recovery. Failable pattern rewrites should "
2387  "not be used with pattern rewriters that do not support them.");
2388  }
2389  return result;
2390 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1854
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:173
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1193::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:73
This class represents an argument of a Block.
Definition: Value.h:295
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This is a value defined by a result of an operation.
Definition: Value.h:433
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:758
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:164
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:215
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:202
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:219
user_iterator user_end() const
Definition: Value.h:203
user_range getUsers() const
Definition: Value.h:204
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition: ByteCode.h:235
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size