MLIR  20.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38  PDLPatternConfigSet *configSet,
39  ByteCodeAddr rewriterAddr) {
40  PatternBenefit benefit = matchOp.getBenefit();
41  MLIRContext *ctx = matchOp.getContext();
42 
43  // Collect the set of generated operations.
44  SmallVector<StringRef, 8> generatedOps;
45  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46  generatedOps =
47  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49  // Check to see if this is pattern matches a specific operation type.
50  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52  generatedOps);
53  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54  benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65  PatternBenefit benefit) {
66  currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
73  allocatedTypeRangeMemory.clear();
74  allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83  /// Apply an externally registered constraint.
84  ApplyConstraint,
85  /// Apply an externally registered rewrite.
86  ApplyRewrite,
87  /// Check if two generic values are equal.
88  AreEqual,
89  /// Check if two ranges are equal.
90  AreRangesEqual,
91  /// Unconditional branch.
92  Branch,
93  /// Compare the operand count of an operation with a constant.
94  CheckOperandCount,
95  /// Compare the name of an operation with a constant.
96  CheckOperationName,
97  /// Compare the result count of an operation with a constant.
98  CheckResultCount,
99  /// Compare a range of types to a constant range of types.
100  CheckTypes,
101  /// Continue to the next iteration of a loop.
102  Continue,
103  /// Create a type range from a list of constant types.
104  CreateConstantTypeRange,
105  /// Create an operation.
106  CreateOperation,
107  /// Create a type range from a list of dynamic types.
108  CreateDynamicTypeRange,
109  /// Create a value range.
110  CreateDynamicValueRange,
111  /// Erase an operation.
112  EraseOp,
113  /// Extract the op from a range at the specified index.
114  ExtractOp,
115  /// Extract the type from a range at the specified index.
116  ExtractType,
117  /// Extract the value from a range at the specified index.
118  ExtractValue,
119  /// Terminate a matcher or rewrite sequence.
120  Finalize,
121  /// Iterate over a range of values.
122  ForEach,
123  /// Get a specific attribute of an operation.
124  GetAttribute,
125  /// Get the type of an attribute.
126  GetAttributeType,
127  /// Get the defining operation of a value.
128  GetDefiningOp,
129  /// Get a specific operand of an operation.
130  GetOperand0,
131  GetOperand1,
132  GetOperand2,
133  GetOperand3,
134  GetOperandN,
135  /// Get a specific operand group of an operation.
136  GetOperands,
137  /// Get a specific result of an operation.
138  GetResult0,
139  GetResult1,
140  GetResult2,
141  GetResult3,
142  GetResultN,
143  /// Get a specific result group of an operation.
144  GetResults,
145  /// Get the users of a value or a range of values.
146  GetUsers,
147  /// Get the type of a value.
148  GetValueType,
149  /// Get the types of a value range.
150  GetValueRangeTypes,
151  /// Check if a generic value is not null.
152  IsNotNull,
153  /// Record a successful pattern match.
154  RecordMatch,
155  /// Replace an operation.
156  ReplaceOp,
157  /// Compare an attribute with a set of constants.
158  SwitchAttribute,
159  /// Compare the operand count of an operation with a set of constants.
160  SwitchOperandCount,
161  /// Compare the name of an operation with a set of constants.
162  SwitchOperationName,
163  /// Compare the result count of an operation with a set of constants.
164  SwitchResultCount,
165  /// Compare a type with a set of constants.
166  SwitchType,
167  /// Compare a range of types with a set of constants.
168  SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker =
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 
183 namespace {
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
186 
187 /// Check if the given class `T` can be converted to an opaque pointer.
188 template <typename T, typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
190 
191 /// This class represents the main generator for the pattern bytecode.
192 class Generator {
193 public:
194  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
195  SmallVectorImpl<ByteCodeField> &matcherByteCode,
196  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
198  ByteCodeField &maxValueMemoryIndex,
199  ByteCodeField &maxOpRangeMemoryIndex,
200  ByteCodeField &maxTypeRangeMemoryIndex,
201  ByteCodeField &maxValueRangeMemoryIndex,
202  ByteCodeField &maxLoopLevel,
203  llvm::StringMap<PDLConstraintFunction> &constraintFns,
204  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207  rewriterByteCode(rewriterByteCode), patterns(patterns),
208  maxValueMemoryIndex(maxValueMemoryIndex),
209  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212  maxLoopLevel(maxLoopLevel), configMap(configMap) {
213  for (const auto &it : llvm::enumerate(constraintFns))
214  constraintToMemIndex.try_emplace(it.value().first(), it.index());
215  for (const auto &it : llvm::enumerate(rewriteFns))
216  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
217  }
218 
219  /// Generate the bytecode for the given PDL interpreter module.
220  void generate(ModuleOp module);
221 
222  /// Return the memory index to use for the given value.
223  ByteCodeField &getMemIndex(Value value) {
224  assert(valueToMemIndex.count(value) &&
225  "expected memory index to be assigned");
226  return valueToMemIndex[value];
227  }
228 
229  /// Return the range memory index used to store the given range value.
230  ByteCodeField &getRangeStorageIndex(Value value) {
231  assert(valueToRangeIndex.count(value) &&
232  "expected range index to be assigned");
233  return valueToRangeIndex[value];
234  }
235 
236  /// Return an index to use when referring to the given data that is uniqued in
237  /// the MLIR context.
238  template <typename T>
239  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
240  getMemIndex(T val) {
241  const void *opaqueVal = val.getAsOpaquePointer();
242 
243  // Get or insert a reference to this value.
244  auto it = uniquedDataToMemIndex.try_emplace(
245  opaqueVal, maxValueMemoryIndex + uniquedData.size());
246  if (it.second)
247  uniquedData.push_back(opaqueVal);
248  return it.first->second;
249  }
250 
251 private:
252  /// Allocate memory indices for the results of operations within the matcher
253  /// and rewriters.
254  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255  ModuleOp rewriterModule);
256 
257  /// Generate the bytecode for the given operation.
258  void generate(Region *region, ByteCodeWriter &writer);
259  void generate(Operation *op, ByteCodeWriter &writer);
260  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
298 
299  /// Mapping from value to its corresponding memory index.
300  DenseMap<Value, ByteCodeField> valueToMemIndex;
301 
302  /// Mapping from a range value to its corresponding range storage index.
303  DenseMap<Value, ByteCodeField> valueToRangeIndex;
304 
305  /// Mapping from the name of an externally registered rewrite to its index in
306  /// the bytecode registry.
307  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
308 
309  /// Mapping from the name of an externally registered constraint to its index
310  /// in the bytecode registry.
311  llvm::StringMap<ByteCodeField> constraintToMemIndex;
312 
313  /// Mapping from rewriter function name to the bytecode address of the
314  /// rewriter function in byte.
315  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
316 
317  /// Mapping from a uniqued storage object to its memory index within
318  /// `uniquedData`.
319  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
320 
321  /// The current level of the foreach loop.
322  ByteCodeField curLoopLevel = 0;
323 
324  /// The current MLIR context.
325  MLIRContext *ctx;
326 
327  /// Mapping from block to its address.
329 
330  /// Data of the ByteCode class to be populated.
331  std::vector<const void *> &uniquedData;
332  SmallVectorImpl<ByteCodeField> &matcherByteCode;
333  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
335  ByteCodeField &maxValueMemoryIndex;
336  ByteCodeField &maxOpRangeMemoryIndex;
337  ByteCodeField &maxTypeRangeMemoryIndex;
338  ByteCodeField &maxValueRangeMemoryIndex;
339  ByteCodeField &maxLoopLevel;
340 
341  /// A map of pattern configurations.
343 };
344 
345 /// This class provides utilities for writing a bytecode stream.
346 struct ByteCodeWriter {
347  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
348  : bytecode(bytecode), generator(generator) {}
349 
350  /// Append a field to the bytecode.
351  void append(ByteCodeField field) { bytecode.push_back(field); }
352  void append(OpCode opCode) { bytecode.push_back(opCode); }
353 
354  /// Append an address to the bytecode.
355  void append(ByteCodeAddr field) {
356  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
357  "unexpected ByteCode address size");
358 
359  ByteCodeField fieldParts[2];
360  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
361  bytecode.append({fieldParts[0], fieldParts[1]});
362  }
363 
364  /// Append a single successor to the bytecode, the exact address will need to
365  /// be resolved later.
366  void append(Block *successor) {
367  // Add back a reference to the successor so that the address can be resolved
368  // later.
369  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370  append(ByteCodeAddr(0));
371  }
372 
373  /// Append a successor range to the bytecode, the exact address will need to
374  /// be resolved later.
375  void append(SuccessorRange successors) {
376  for (Block *successor : successors)
377  append(successor);
378  }
379 
380  /// Append a range of values that will be read as generic PDLValues.
381  void appendPDLValueList(OperandRange values) {
382  bytecode.push_back(values.size());
383  for (Value value : values)
384  appendPDLValue(value);
385  }
386 
387  /// Append a value as a PDLValue.
388  void appendPDLValue(Value value) {
389  appendPDLValueKind(value);
390  append(value);
391  }
392 
393  /// Append the PDLValue::Kind of the given value.
394  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
395 
396  /// Append the PDLValue::Kind of the given type.
397  void appendPDLValueKind(Type type) {
398  PDLValue::Kind kind =
400  .Case<pdl::AttributeType>(
401  [](Type) { return PDLValue::Kind::Attribute; })
402  .Case<pdl::OperationType>(
403  [](Type) { return PDLValue::Kind::Operation; })
404  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405  if (isa<pdl::TypeType>(rangeTy.getElementType()))
406  return PDLValue::Kind::TypeRange;
407  return PDLValue::Kind::ValueRange;
408  })
409  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
410  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
411  bytecode.push_back(static_cast<ByteCodeField>(kind));
412  }
413 
414  /// Append a value that will be stored in a memory slot and not inline within
415  /// the bytecode.
416  template <typename T>
417  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418  std::is_pointer<T>::value>
419  append(T value) {
420  bytecode.push_back(generator.getMemIndex(value));
421  }
422 
423  /// Append a range of values.
424  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
425  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
426  append(T range) {
427  bytecode.push_back(llvm::size(range));
428  for (auto it : range)
429  append(it);
430  }
431 
432  /// Append a variadic number of fields to the bytecode.
433  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
434  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
435  append(field);
436  append(field2, fields...);
437  }
438 
439  /// Appends a value as a pointer, stored inline within the bytecode.
440  template <typename T>
441  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442  appendInline(T value) {
443  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
444  const void *pointer = value.getAsOpaquePointer();
445  ByteCodeField fieldParts[numParts];
446  std::memcpy(fieldParts, &pointer, sizeof(const void *));
447  bytecode.append(fieldParts, fieldParts + numParts);
448  }
449 
450  /// Successor references in the bytecode that have yet to be resolved.
451  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
452 
453  /// The underlying bytecode buffer.
455 
456  /// The main generator producing PDL.
457  Generator &generator;
458 };
459 
460 /// This class represents a live range of PDL Interpreter values, containing
461 /// information about when values are live within a match/rewrite.
462 struct ByteCodeLiveRange {
463  using Set = llvm::IntervalMap<uint64_t, char, 16>;
464  using Allocator = Set::Allocator;
465 
466  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
467 
468  /// Union this live range with the one provided.
469  void unionWith(const ByteCodeLiveRange &rhs) {
470  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
471  ++it)
472  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
473  }
474 
475  /// Returns true if this range overlaps with the one provided.
476  bool overlaps(const ByteCodeLiveRange &rhs) const {
477  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
478  .valid();
479  }
480 
481  /// A map representing the ranges of the match/rewrite that a value is live in
482  /// the interpreter.
483  ///
484  /// We use std::unique_ptr here, because IntervalMap does not provide a
485  /// correct copy or move constructor. We can eliminate the pointer once
486  /// https://reviews.llvm.org/D113240 lands.
487  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
488 
489  /// The operation range storage index for this range.
490  std::optional<unsigned> opRangeIndex;
491 
492  /// The type range storage index for this range.
493  std::optional<unsigned> typeRangeIndex;
494 
495  /// The value range storage index for this range.
496  std::optional<unsigned> valueRangeIndex;
497 };
498 } // namespace
499 
500 void Generator::generate(ModuleOp module) {
501  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504  pdl_interp::PDLInterpDialect::getRewriterModuleName());
505  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
506 
507  // Allocate memory indices for the results of operations within the matcher
508  // and rewriters.
509  allocateMemoryIndices(matcherFunc, rewriterModule);
510 
511  // Generate code for the rewriter functions.
512  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
513  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515  for (Operation &op : rewriterFunc.getOps())
516  generate(&op, rewriterByteCodeWriter);
517  }
518  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519  "unexpected branches in rewriter function");
520 
521  // Generate code for the matcher function.
522  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
523  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
524 
525  // Resolve successor references in the matcher.
526  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527  ByteCodeAddr addr = blockToAddr[it.first];
528  for (unsigned offsetToFix : it.second)
529  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
530  }
531 }
532 
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534  ModuleOp rewriterModule) {
535  // Rewriters use simplistic allocation scheme that simply assigns an index to
536  // each result.
537  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539  auto processRewriterValue = [&](Value val) {
540  valueToMemIndex.try_emplace(val, index++);
541  if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542  Type elementTy = rangeType.getElementType();
543  if (isa<pdl::TypeType>(elementTy))
544  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545  else if (isa<pdl::ValueType>(elementTy))
546  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
547  }
548  };
549 
550  for (BlockArgument arg : rewriterFunc.getArguments())
551  processRewriterValue(arg);
552  rewriterFunc.getBody().walk([&](Operation *op) {
553  for (Value result : op->getResults())
554  processRewriterValue(result);
555  });
556  if (index > maxValueMemoryIndex)
557  maxValueMemoryIndex = index;
558  if (typeRangeIndex > maxTypeRangeMemoryIndex)
559  maxTypeRangeMemoryIndex = typeRangeIndex;
560  if (valueRangeIndex > maxValueRangeMemoryIndex)
561  maxValueRangeMemoryIndex = valueRangeIndex;
562  }
563 
564  // The matcher function uses a more sophisticated numbering that tries to
565  // minimize the number of memory indices assigned. This is done by determining
566  // a live range of the values within the matcher, then the allocation is just
567  // finding the minimal number of overlapping live ranges. This is essentially
568  // a simplified form of register allocation where we don't necessarily have a
569  // limited number of registers, but we still want to minimize the number used.
570  DenseMap<Operation *, unsigned> opToFirstIndex;
571  DenseMap<Operation *, unsigned> opToLastIndex;
572 
573  // A custom walk that marks the first and the last index of each operation.
574  // The entry marks the beginning of the liveness range for this operation,
575  // followed by nested operations, followed by the end of the liveness range.
576  unsigned index = 0;
577  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
578  opToFirstIndex.try_emplace(op, index++);
579  for (Region &region : op->getRegions())
580  for (Block &block : region.getBlocks())
581  for (Operation &nested : block)
582  walk(&nested);
583  opToLastIndex.try_emplace(op, index++);
584  };
585  walk(matcherFunc);
586 
587  // Liveness info for each of the defs within the matcher.
588  ByteCodeLiveRange::Allocator allocator;
589  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
590 
591  // Assign the root operation being matched to slot 0.
592  BlockArgument rootOpArg = matcherFunc.getArgument(0);
593  valueToMemIndex[rootOpArg] = 0;
594 
595  // Walk each of the blocks, computing the def interval that the value is used.
596  Liveness matcherLiveness(matcherFunc);
597  matcherFunc->walk([&](Block *block) {
598  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
599  assert(info && "expected liveness info for block");
600  auto processValue = [&](Value value, Operation *firstUseOrDef) {
601  // We don't need to process the root op argument, this value is always
602  // assigned to the first memory slot.
603  if (value == rootOpArg)
604  return;
605 
606  // Set indices for the range of this block that the value is used.
607  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608  defRangeIt->second.liveness->insert(
609  opToFirstIndex[firstUseOrDef],
610  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
611  /*dummyValue*/ 0);
612 
613  // Check to see if this value is a range type.
614  if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
615  Type eleType = rangeTy.getElementType();
616  if (isa<pdl::OperationType>(eleType))
617  defRangeIt->second.opRangeIndex = 0;
618  else if (isa<pdl::TypeType>(eleType))
619  defRangeIt->second.typeRangeIndex = 0;
620  else if (isa<pdl::ValueType>(eleType))
621  defRangeIt->second.valueRangeIndex = 0;
622  }
623  };
624 
625  // Process the live-ins of this block.
626  for (Value liveIn : info->in()) {
627  // Only process the value if it has been defined in the current region.
628  // Other values that span across pdl_interp.foreach will be added higher
629  // up. This ensures that the we keep them alive for the entire duration
630  // of the loop.
631  if (liveIn.getParentRegion() == block->getParent())
632  processValue(liveIn, &block->front());
633  }
634 
635  // Process the block arguments for the entry block (those are not live-in).
636  if (block->isEntryBlock()) {
637  for (Value argument : block->getArguments())
638  processValue(argument, &block->front());
639  }
640 
641  // Process any new defs within this block.
642  for (Operation &op : *block)
643  for (Value result : op.getResults())
644  processValue(result, &op);
645  });
646 
647  // Greedily allocate memory slots using the computed def live ranges.
648  std::vector<ByteCodeLiveRange> allocatedIndices;
649 
650  // The number of memory indices currently allocated (and its next value).
651  // Recall that the root gets allocated memory index 0.
652  ByteCodeField numIndices = 1;
653 
654  // The number of memory ranges of various types (and their next values).
655  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
656 
657  for (auto &defIt : valueDefRanges) {
658  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659  ByteCodeLiveRange &defRange = defIt.second;
660 
661  // Try to allocate to an existing index.
662  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
663  ByteCodeLiveRange &existingRange = existingIndexIt.value();
664  if (!defRange.overlaps(existingRange)) {
665  existingRange.unionWith(defRange);
666  memIndex = existingIndexIt.index() + 1;
667 
668  if (defRange.opRangeIndex) {
669  if (!existingRange.opRangeIndex)
670  existingRange.opRangeIndex = numOpRanges++;
671  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672  } else if (defRange.typeRangeIndex) {
673  if (!existingRange.typeRangeIndex)
674  existingRange.typeRangeIndex = numTypeRanges++;
675  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676  } else if (defRange.valueRangeIndex) {
677  if (!existingRange.valueRangeIndex)
678  existingRange.valueRangeIndex = numValueRanges++;
679  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
680  }
681  break;
682  }
683  }
684 
685  // If no existing index could be used, add a new one.
686  if (memIndex == 0) {
687  allocatedIndices.emplace_back(allocator);
688  ByteCodeLiveRange &newRange = allocatedIndices.back();
689  newRange.unionWith(defRange);
690 
691  // Allocate an index for op/type/value ranges.
692  if (defRange.opRangeIndex) {
693  newRange.opRangeIndex = numOpRanges;
694  valueToRangeIndex[defIt.first] = numOpRanges++;
695  } else if (defRange.typeRangeIndex) {
696  newRange.typeRangeIndex = numTypeRanges;
697  valueToRangeIndex[defIt.first] = numTypeRanges++;
698  } else if (defRange.valueRangeIndex) {
699  newRange.valueRangeIndex = numValueRanges;
700  valueToRangeIndex[defIt.first] = numValueRanges++;
701  }
702 
703  memIndex = allocatedIndices.size();
704  ++numIndices;
705  }
706  }
707 
708  // Print the index usage and ensure that we did not run out of index space.
709  LLVM_DEBUG({
710  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
711  << "(down from initial " << valueDefRanges.size() << ").\n";
712  });
713  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
714  "Ran out of memory for allocated indices");
715 
716  // Update the max number of indices.
717  if (numIndices > maxValueMemoryIndex)
718  maxValueMemoryIndex = numIndices;
719  if (numOpRanges > maxOpRangeMemoryIndex)
720  maxOpRangeMemoryIndex = numOpRanges;
721  if (numTypeRanges > maxTypeRangeMemoryIndex)
722  maxTypeRangeMemoryIndex = numTypeRanges;
723  if (numValueRanges > maxValueRangeMemoryIndex)
724  maxValueRangeMemoryIndex = numValueRanges;
725 }
726 
727 void Generator::generate(Region *region, ByteCodeWriter &writer) {
728  llvm::ReversePostOrderTraversal<Region *> rpot(region);
729  for (Block *block : rpot) {
730  // Keep track of where this block begins within the matcher function.
731  blockToAddr.try_emplace(block, matcherByteCode.size());
732  for (Operation &op : *block)
733  generate(&op, writer);
734  }
735 }
736 
737 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
738  LLVM_DEBUG({
739  // The following list must contain all the operations that do not
740  // produce any bytecode.
741  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742  writer.appendInline(op->getLoc());
743  });
745  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763  pdl_interp::SwitchResultCountOp>(
764  [&](auto interpOp) { this->generate(interpOp, writer); })
765  .Default([](Operation *) {
766  llvm_unreachable("unknown `pdl_interp` operation");
767  });
768 }
769 
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771  ByteCodeWriter &writer) {
772  // Constraints that should return a value have to be registered as rewrites.
773  // If a constraint and a rewrite of similar name are registered the
774  // constraint takes precedence
775  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776  writer.appendPDLValueList(op.getArgs());
777  writer.append(ByteCodeField(op.getIsNegated()));
778  ResultRange results = op.getResults();
779  writer.append(ByteCodeField(results.size()));
780  for (Value result : results) {
781  // We record the expected kind of the result, so that we can provide extra
782  // verification of the native rewrite function and handle the failure case
783  // of constraints accordingly.
784  writer.appendPDLValueKind(result);
785 
786  // Range results also need to append the range storage index.
787  if (isa<pdl::RangeType>(result.getType()))
788  writer.append(getRangeStorageIndex(result));
789  writer.append(result);
790  }
791  writer.append(op.getSuccessors());
792 }
793 void Generator::generate(pdl_interp::ApplyRewriteOp op,
794  ByteCodeWriter &writer) {
795  assert(externalRewriterToMemIndex.count(op.getName()) &&
796  "expected index for rewrite function");
797  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798  writer.appendPDLValueList(op.getArgs());
799 
800  ResultRange results = op.getResults();
801  writer.append(ByteCodeField(results.size()));
802  for (Value result : results) {
803  // We record the expected kind of the result, so that we
804  // can provide extra verification of the native rewrite function.
805  writer.appendPDLValueKind(result);
806 
807  // Range results also need to append the range storage index.
808  if (isa<pdl::RangeType>(result.getType()))
809  writer.append(getRangeStorageIndex(result));
810  writer.append(result);
811  }
812 }
813 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814  Value lhs = op.getLhs();
815  if (isa<pdl::RangeType>(lhs.getType())) {
816  writer.append(OpCode::AreRangesEqual);
817  writer.appendPDLValueKind(lhs);
818  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819  return;
820  }
821 
822  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826 }
827 void Generator::generate(pdl_interp::CheckAttributeOp op,
828  ByteCodeWriter &writer) {
829  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830  op.getSuccessors());
831 }
832 void Generator::generate(pdl_interp::CheckOperandCountOp op,
833  ByteCodeWriter &writer) {
834  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835  static_cast<ByteCodeField>(op.getCompareAtLeast()),
836  op.getSuccessors());
837 }
838 void Generator::generate(pdl_interp::CheckOperationNameOp op,
839  ByteCodeWriter &writer) {
840  writer.append(OpCode::CheckOperationName, op.getInputOp(),
841  OperationName(op.getName(), ctx), op.getSuccessors());
842 }
843 void Generator::generate(pdl_interp::CheckResultCountOp op,
844  ByteCodeWriter &writer) {
845  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846  static_cast<ByteCodeField>(op.getCompareAtLeast()),
847  op.getSuccessors());
848 }
849 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851  op.getSuccessors());
852 }
853 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855  op.getSuccessors());
856 }
857 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
860 }
861 void Generator::generate(pdl_interp::CreateAttributeOp op,
862  ByteCodeWriter &writer) {
863  // Simply repoint the memory index of the result to the constant.
864  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865 }
866 void Generator::generate(pdl_interp::CreateOperationOp op,
867  ByteCodeWriter &writer) {
868  writer.append(OpCode::CreateOperation, op.getResultOp(),
869  OperationName(op.getName(), ctx));
870  writer.appendPDLValueList(op.getInputOperands());
871 
872  // Add the attributes.
873  OperandRange attributes = op.getInputAttributes();
874  writer.append(static_cast<ByteCodeField>(attributes.size()));
875  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876  writer.append(std::get<0>(it), std::get<1>(it));
877 
878  // Add the result types. If the operation has inferred results, we use a
879  // marker "size" value. Otherwise, we add the list of explicit result types.
880  if (op.getInferredResultTypes())
881  writer.append(kInferTypesMarker);
882  else
883  writer.appendPDLValueList(op.getInputResultTypes());
884 }
885 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886  // Append the correct opcode for the range type.
887  TypeSwitch<Type>(op.getType().getElementType())
888  .Case(
889  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890  .Case([&](pdl::ValueType) {
891  writer.append(OpCode::CreateDynamicValueRange);
892  });
893 
894  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895  writer.appendPDLValueList(op->getOperands());
896 }
897 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898  // Simply repoint the memory index of the result to the constant.
899  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900 }
901 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903  getRangeStorageIndex(op.getResult()), op.getValue());
904 }
905 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906  writer.append(OpCode::EraseOp, op.getInputOp());
907 }
908 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909  OpCode opCode =
910  TypeSwitch<Type, OpCode>(op.getResult().getType())
911  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
912  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
913  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
914  .Default([](Type) -> OpCode {
915  llvm_unreachable("unsupported element type");
916  });
917  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
918 }
919 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
920  writer.append(OpCode::Finalize);
921 }
922 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
923  BlockArgument arg = op.getLoopVariable();
924  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
925  writer.appendPDLValueKind(arg.getType());
926  writer.append(curLoopLevel, op.getSuccessor());
927  ++curLoopLevel;
928  if (curLoopLevel > maxLoopLevel)
929  maxLoopLevel = curLoopLevel;
930  generate(&op.getRegion(), writer);
931  --curLoopLevel;
932 }
933 void Generator::generate(pdl_interp::GetAttributeOp op,
934  ByteCodeWriter &writer) {
935  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
936  op.getNameAttr());
937 }
938 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
939  ByteCodeWriter &writer) {
940  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
941 }
942 void Generator::generate(pdl_interp::GetDefiningOpOp op,
943  ByteCodeWriter &writer) {
944  writer.append(OpCode::GetDefiningOp, op.getInputOp());
945  writer.appendPDLValue(op.getValue());
946 }
947 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
948  uint32_t index = op.getIndex();
949  if (index < 4)
950  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
951  else
952  writer.append(OpCode::GetOperandN, index);
953  writer.append(op.getInputOp(), op.getValue());
954 }
955 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
956  Value result = op.getValue();
957  std::optional<uint32_t> index = op.getIndex();
958  writer.append(OpCode::GetOperands,
959  index.value_or(std::numeric_limits<uint32_t>::max()),
960  op.getInputOp());
961  if (isa<pdl::RangeType>(result.getType()))
962  writer.append(getRangeStorageIndex(result));
963  else
965  writer.append(result);
966 }
967 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
968  uint32_t index = op.getIndex();
969  if (index < 4)
970  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
971  else
972  writer.append(OpCode::GetResultN, index);
973  writer.append(op.getInputOp(), op.getValue());
974 }
975 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
976  Value result = op.getValue();
977  std::optional<uint32_t> index = op.getIndex();
978  writer.append(OpCode::GetResults,
979  index.value_or(std::numeric_limits<uint32_t>::max()),
980  op.getInputOp());
981  if (isa<pdl::RangeType>(result.getType()))
982  writer.append(getRangeStorageIndex(result));
983  else
985  writer.append(result);
986 }
987 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
988  Value operations = op.getOperations();
989  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
990  writer.append(OpCode::GetUsers, operations, rangeIndex);
991  writer.appendPDLValue(op.getValue());
992 }
993 void Generator::generate(pdl_interp::GetValueTypeOp op,
994  ByteCodeWriter &writer) {
995  if (isa<pdl::RangeType>(op.getType())) {
996  Value result = op.getResult();
997  writer.append(OpCode::GetValueRangeTypes, result,
998  getRangeStorageIndex(result), op.getValue());
999  } else {
1000  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1001  }
1002 }
1003 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1004  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1005 }
1006 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1007  ByteCodeField patternIndex = patterns.size();
1008  patterns.emplace_back(PDLByteCodePattern::create(
1009  op, configMap.lookup(op),
1010  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1011  writer.append(OpCode::RecordMatch, patternIndex,
1012  SuccessorRange(op.getOperation()), op.getMatchedOps());
1013  writer.appendPDLValueList(op.getInputs());
1014 }
1015 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1016  writer.append(OpCode::ReplaceOp, op.getInputOp());
1017  writer.appendPDLValueList(op.getReplValues());
1018 }
1019 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1020  ByteCodeWriter &writer) {
1021  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1022  op.getCaseValuesAttr(), op.getSuccessors());
1023 }
1024 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1025  ByteCodeWriter &writer) {
1026  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1027  op.getCaseValuesAttr(), op.getSuccessors());
1028 }
1029 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1030  ByteCodeWriter &writer) {
1031  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1032  return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1033  });
1034  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1035  op.getSuccessors());
1036 }
1037 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1038  ByteCodeWriter &writer) {
1039  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1040  op.getCaseValuesAttr(), op.getSuccessors());
1041 }
1042 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1043  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1044  op.getSuccessors());
1045 }
1046 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1047  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1048  op.getSuccessors());
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // PDLByteCode
1053 //===----------------------------------------------------------------------===//
1054 
1055 PDLByteCode::PDLByteCode(
1056  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1058  llvm::StringMap<PDLConstraintFunction> constraintFns,
1059  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1060  : configs(std::move(configs)) {
1061  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1062  rewriterByteCode, patterns, maxValueMemoryIndex,
1063  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1064  maxLoopLevel, constraintFns, rewriteFns, configMap);
1065  generator.generate(module);
1066 
1067  // Initialize the external functions.
1068  for (auto &it : constraintFns)
1069  constraintFunctions.push_back(std::move(it.second));
1070  for (auto &it : rewriteFns)
1071  rewriteFunctions.push_back(std::move(it.second));
1072 }
1073 
1074 /// Initialize the given state such that it can be used to execute the current
1075 /// bytecode.
1076 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1077  state.memory.resize(maxValueMemoryIndex, nullptr);
1078  state.opRangeMemory.resize(maxOpRangeCount);
1079  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1080  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1081  state.loopIndex.resize(maxLoopLevel, 0);
1082  state.currentPatternBenefits.reserve(patterns.size());
1083  for (const PDLByteCodePattern &pattern : patterns)
1084  state.currentPatternBenefits.push_back(pattern.getBenefit());
1085 }
1086 
1087 //===----------------------------------------------------------------------===//
1088 // ByteCode Execution
1089 
1090 namespace {
1091 /// This class is an instantiation of the PDLResultList that provides access to
1092 /// the returned results. This API is not on `PDLResultList` to avoid
1093 /// overexposing access to information specific solely to the ByteCode.
1094 class ByteCodeRewriteResultList : public PDLResultList {
1095 public:
1096  ByteCodeRewriteResultList(unsigned maxNumResults)
1097  : PDLResultList(maxNumResults) {}
1098 
1099  /// Return the list of PDL results.
1100  MutableArrayRef<PDLValue> getResults() { return results; }
1101 
1102  /// Return the type ranges allocated by this list.
1103  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1104  return allocatedTypeRanges;
1105  }
1106 
1107  /// Return the value ranges allocated by this list.
1108  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1109  return allocatedValueRanges;
1110  }
1111 };
1112 
1113 /// This class provides support for executing a bytecode stream.
1114 class ByteCodeExecutor {
1115 public:
1116  ByteCodeExecutor(
1117  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1118  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1119  MutableArrayRef<TypeRange> typeRangeMemory,
1120  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1121  MutableArrayRef<ValueRange> valueRangeMemory,
1122  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1123  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1125  ArrayRef<PatternBenefit> currentPatternBenefits,
1127  ArrayRef<PDLConstraintFunction> constraintFunctions,
1128  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1129  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130  typeRangeMemory(typeRangeMemory),
1131  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132  valueRangeMemory(valueRangeMemory),
1133  allocatedValueRangeMemory(allocatedValueRangeMemory),
1134  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1135  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1136  constraintFunctions(constraintFunctions),
1137  rewriteFunctions(rewriteFunctions) {}
1138 
1139  /// Start executing the code at the current bytecode index. `matches` is an
1140  /// optional field provided when this function is executed in a matching
1141  /// context.
1142  LogicalResult
1143  execute(PatternRewriter &rewriter,
1144  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1145  std::optional<Location> mainRewriteLoc = {});
1146 
1147 private:
1148  /// Internal implementation of executing each of the bytecode commands.
1149  void executeApplyConstraint(PatternRewriter &rewriter);
1150  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1151  void executeAreEqual();
1152  void executeAreRangesEqual();
1153  void executeBranch();
1154  void executeCheckOperandCount();
1155  void executeCheckOperationName();
1156  void executeCheckResultCount();
1157  void executeCheckTypes();
1158  void executeContinue();
1159  void executeCreateConstantTypeRange();
1160  void executeCreateOperation(PatternRewriter &rewriter,
1161  Location mainRewriteLoc);
1162  template <typename T>
1163  void executeDynamicCreateRange(StringRef type);
1164  void executeEraseOp(PatternRewriter &rewriter);
1165  template <typename T, typename Range, PDLValue::Kind kind>
1166  void executeExtract();
1167  void executeFinalize();
1168  void executeForEach();
1169  void executeGetAttribute();
1170  void executeGetAttributeType();
1171  void executeGetDefiningOp();
1172  void executeGetOperand(unsigned index);
1173  void executeGetOperands();
1174  void executeGetResult(unsigned index);
1175  void executeGetResults();
1176  void executeGetUsers();
1177  void executeGetValueType();
1178  void executeGetValueRangeTypes();
1179  void executeIsNotNull();
1180  void executeRecordMatch(PatternRewriter &rewriter,
1182  void executeReplaceOp(PatternRewriter &rewriter);
1183  void executeSwitchAttribute();
1184  void executeSwitchOperandCount();
1185  void executeSwitchOperationName();
1186  void executeSwitchResultCount();
1187  void executeSwitchType();
1188  void executeSwitchTypes();
1189  void processNativeFunResults(ByteCodeRewriteResultList &results,
1190  unsigned numResults,
1191  LogicalResult &rewriteResult);
1192 
1193  /// Pushes a code iterator to the stack.
1194  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1195 
1196  /// Pops a code iterator from the stack, returning true on success.
1197  void popCodeIt() {
1198  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1199  curCodeIt = resumeCodeIt.back();
1200  resumeCodeIt.pop_back();
1201  }
1202 
1203  /// Return the bytecode iterator at the start of the current op code.
1204  const ByteCodeField *getPrevCodeIt() const {
1205  LLVM_DEBUG({
1206  // Account for the op code and the Location stored inline.
1207  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1208  });
1209 
1210  // Account for the op code only.
1211  return curCodeIt - 1;
1212  }
1213 
1214  /// Read a value from the bytecode buffer, optionally skipping a certain
1215  /// number of prefix values. These methods always update the buffer to point
1216  /// to the next field after the read data.
1217  template <typename T = ByteCodeField>
1218  T read(size_t skipN = 0) {
1219  curCodeIt += skipN;
1220  return readImpl<T>();
1221  }
1222  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1223 
1224  /// Read a list of values from the bytecode buffer.
1225  template <typename ValueT, typename T>
1226  void readList(SmallVectorImpl<T> &list) {
1227  list.clear();
1228  for (unsigned i = 0, e = read(); i != e; ++i)
1229  list.push_back(read<ValueT>());
1230  }
1231 
1232  /// Read a list of values from the bytecode buffer. The values may be encoded
1233  /// either as a single element or a range of elements.
1234  void readList(SmallVectorImpl<Type> &list) {
1235  for (unsigned i = 0, e = read(); i != e; ++i) {
1236  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1237  list.push_back(read<Type>());
1238  } else {
1239  TypeRange *values = read<TypeRange *>();
1240  list.append(values->begin(), values->end());
1241  }
1242  }
1243  }
1244  void readList(SmallVectorImpl<Value> &list) {
1245  for (unsigned i = 0, e = read(); i != e; ++i) {
1246  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1247  list.push_back(read<Value>());
1248  } else {
1249  ValueRange *values = read<ValueRange *>();
1250  list.append(values->begin(), values->end());
1251  }
1252  }
1253  }
1254 
1255  /// Read a value stored inline as a pointer.
1256  template <typename T>
1257  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1258  readInline() {
1259  const void *pointer;
1260  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1261  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1262  return T::getFromOpaquePointer(pointer);
1263  }
1264 
1265  void skip(size_t skipN) { curCodeIt += skipN; }
1266 
1267  /// Jump to a specific successor based on a predicate value.
1268  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1269  /// Jump to a specific successor based on a destination index.
1270  void selectJump(size_t destIndex) {
1271  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1272  }
1273 
1274  /// Handle a switch operation with the provided value and cases.
1275  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1276  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1277  LLVM_DEBUG({
1278  llvm::dbgs() << " * Value: " << value << "\n"
1279  << " * Cases: ";
1280  llvm::interleaveComma(cases, llvm::dbgs());
1281  llvm::dbgs() << "\n";
1282  });
1283 
1284  // Check to see if the attribute value is within the case list. Jump to
1285  // the correct successor index based on the result.
1286  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1287  if (cmp(*it, value))
1288  return selectJump(size_t((it - cases.begin()) + 1));
1289  selectJump(size_t(0));
1290  }
1291 
1292  /// Store a pointer to memory.
1293  void storeToMemory(unsigned index, const void *value) {
1294  memory[index] = value;
1295  }
1296 
1297  /// Store a value to memory as an opaque pointer.
1298  template <typename T>
1299  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1300  storeToMemory(unsigned index, T value) {
1301  memory[index] = value.getAsOpaquePointer();
1302  }
1303 
1304  /// Internal implementation of reading various data types from the bytecode
1305  /// stream.
1306  template <typename T>
1307  const void *readFromMemory() {
1308  size_t index = *curCodeIt++;
1309 
1310  // If this type is an SSA value, it can only be stored in non-const memory.
1311  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1312  Value>::value ||
1313  index < memory.size())
1314  return memory[index];
1315 
1316  // Otherwise, if this index is not inbounds it is uniqued.
1317  return uniquedMemory[index - memory.size()];
1318  }
1319  template <typename T>
1320  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1321  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1322  }
1323  template <typename T>
1324  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1325  T>
1326  readImpl() {
1327  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1328  }
1329  template <typename T>
1330  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1331  switch (read<PDLValue::Kind>()) {
1332  case PDLValue::Kind::Attribute:
1333  return read<Attribute>();
1334  case PDLValue::Kind::Operation:
1335  return read<Operation *>();
1336  case PDLValue::Kind::Type:
1337  return read<Type>();
1338  case PDLValue::Kind::Value:
1339  return read<Value>();
1340  case PDLValue::Kind::TypeRange:
1341  return read<TypeRange *>();
1342  case PDLValue::Kind::ValueRange:
1343  return read<ValueRange *>();
1344  }
1345  llvm_unreachable("unhandled PDLValue::Kind");
1346  }
1347  template <typename T>
1348  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1349  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1350  "unexpected ByteCode address size");
1351  ByteCodeAddr result;
1352  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1353  curCodeIt += 2;
1354  return result;
1355  }
1356  template <typename T>
1357  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1358  return *curCodeIt++;
1359  }
1360  template <typename T>
1361  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1362  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1363  }
1364 
1365  /// Assign the given range to the given memory index. This allocates a new
1366  /// range object if necessary.
1367  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1368  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1369  unsigned rangeIndex) {
1370  // Utility functor used to type-erase the assignment.
1371  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1372  // If the input range is empty, we don't need to allocate anything.
1373  if (range.empty()) {
1374  rangeMemory[rangeIndex] = {};
1375  } else {
1376  // Allocate a buffer for this type range.
1377  llvm::OwningArrayRef<T> storage(llvm::size(range));
1378  llvm::copy(range, storage.begin());
1379 
1380  // Assign this to the range slot and use the range as the value for the
1381  // memory index.
1382  allocatedRangeMemory.emplace_back(std::move(storage));
1383  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1384  }
1385  memory[memIndex] = &rangeMemory[rangeIndex];
1386  };
1387 
1388  // Dispatch based on the concrete range type.
1389  if constexpr (std::is_same_v<T, Type>) {
1390  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1391  } else if constexpr (std::is_same_v<T, Value>) {
1392  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1393  } else {
1394  llvm_unreachable("unhandled range type");
1395  }
1396  }
1397 
1398  /// The underlying bytecode buffer.
1399  const ByteCodeField *curCodeIt;
1400 
1401  /// The stack of bytecode positions at which to resume operation.
1403 
1404  /// The current execution memory.
1406  MutableArrayRef<OwningOpRange> opRangeMemory;
1407  MutableArrayRef<TypeRange> typeRangeMemory;
1408  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1409  MutableArrayRef<ValueRange> valueRangeMemory;
1410  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1411 
1412  /// The current loop indices.
1413  MutableArrayRef<unsigned> loopIndex;
1414 
1415  /// References to ByteCode data necessary for execution.
1416  ArrayRef<const void *> uniquedMemory;
1418  ArrayRef<PatternBenefit> currentPatternBenefits;
1420  ArrayRef<PDLConstraintFunction> constraintFunctions;
1421  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1422 };
1423 } // namespace
1424 
1425 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1426  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1427  ByteCodeField fun_idx = read();
1429  readList<PDLValue>(args);
1430 
1431  LLVM_DEBUG({
1432  llvm::dbgs() << " * Arguments: ";
1433  llvm::interleaveComma(args, llvm::dbgs());
1434  llvm::dbgs() << "\n";
1435  });
1436 
1437  ByteCodeField isNegated = read();
1438  LLVM_DEBUG({
1439  llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1440  llvm::interleaveComma(args, llvm::dbgs());
1441  });
1442 
1443  ByteCodeField numResults = read();
1444  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1445  ByteCodeRewriteResultList results(numResults);
1446  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1447  [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1448  LLVM_DEBUG({
1449  if (succeeded(rewriteResult)) {
1450  llvm::dbgs() << " * Constraint succeeded\n";
1451  llvm::dbgs() << " * Results: ";
1452  llvm::interleaveComma(constraintResults, llvm::dbgs());
1453  llvm::dbgs() << "\n";
1454  } else {
1455  llvm::dbgs() << " * Constraint failed\n";
1456  }
1457  });
1458  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1459  "native PDL rewrite function succeeded but returned "
1460  "unexpected number of results");
1461  processNativeFunResults(results, numResults, rewriteResult);
1462 
1463  // Depending on the constraint jump to the proper destination.
1464  selectJump(isNegated != succeeded(rewriteResult));
1465 }
1466 
1467 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1468  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1469  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1471  readList<PDLValue>(args);
1472 
1473  LLVM_DEBUG({
1474  llvm::dbgs() << " * Arguments: ";
1475  llvm::interleaveComma(args, llvm::dbgs());
1476  });
1477 
1478  // Execute the rewrite function.
1479  ByteCodeField numResults = read();
1480  ByteCodeRewriteResultList results(numResults);
1481  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1482 
1483  assert(results.getResults().size() == numResults &&
1484  "native PDL rewrite function returned unexpected number of results");
1485 
1486  processNativeFunResults(results, numResults, rewriteResult);
1487 
1488  if (failed(rewriteResult)) {
1489  LLVM_DEBUG(llvm::dbgs() << " - Failed");
1490  return failure();
1491  }
1492  return success();
1493 }
1494 
1495 void ByteCodeExecutor::processNativeFunResults(
1496  ByteCodeRewriteResultList &results, unsigned numResults,
1497  LogicalResult &rewriteResult) {
1498  // Store the results in the bytecode memory or handle missing results on
1499  // failure.
1500  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1501  PDLValue::Kind resultKind = read<PDLValue::Kind>();
1502 
1503  // Skip the according number of values on the buffer on failure and exit
1504  // early as there are no results to process.
1505  if (failed(rewriteResult)) {
1506  if (resultKind == PDLValue::Kind::TypeRange ||
1507  resultKind == PDLValue::Kind::ValueRange) {
1508  skip(2);
1509  } else {
1510  skip(1);
1511  }
1512  return;
1513  }
1514  PDLValue result = results.getResults()[resultIdx];
1515  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1516  assert(result.getKind() == resultKind &&
1517  "native PDL rewrite function returned an unexpected type of "
1518  "result");
1519  // If the result is a range, we need to copy it over to the bytecodes
1520  // range memory.
1521  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1522  unsigned rangeIndex = read();
1523  typeRangeMemory[rangeIndex] = *typeRange;
1524  memory[read()] = &typeRangeMemory[rangeIndex];
1525  } else if (std::optional<ValueRange> valueRange =
1526  result.dyn_cast<ValueRange>()) {
1527  unsigned rangeIndex = read();
1528  valueRangeMemory[rangeIndex] = *valueRange;
1529  memory[read()] = &valueRangeMemory[rangeIndex];
1530  } else {
1531  memory[read()] = result.getAsOpaquePointer();
1532  }
1533  }
1534 
1535  // Copy over any underlying storage allocated for result ranges.
1536  for (auto &it : results.getAllocatedTypeRanges())
1537  allocatedTypeRangeMemory.push_back(std::move(it));
1538  for (auto &it : results.getAllocatedValueRanges())
1539  allocatedValueRangeMemory.push_back(std::move(it));
1540 }
1541 
1542 void ByteCodeExecutor::executeAreEqual() {
1543  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1544  const void *lhs = read<const void *>();
1545  const void *rhs = read<const void *>();
1546 
1547  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1548  selectJump(lhs == rhs);
1549 }
1550 
1551 void ByteCodeExecutor::executeAreRangesEqual() {
1552  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1553  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1554  const void *lhs = read<const void *>();
1555  const void *rhs = read<const void *>();
1556 
1557  switch (valueKind) {
1558  case PDLValue::Kind::TypeRange: {
1559  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1560  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1561  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1562  selectJump(*lhsRange == *rhsRange);
1563  break;
1564  }
1565  case PDLValue::Kind::ValueRange: {
1566  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1567  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1568  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1569  selectJump(*lhsRange == *rhsRange);
1570  break;
1571  }
1572  default:
1573  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1574  }
1575 }
1576 
1577 void ByteCodeExecutor::executeBranch() {
1578  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1579  curCodeIt = &code[read<ByteCodeAddr>()];
1580 }
1581 
1582 void ByteCodeExecutor::executeCheckOperandCount() {
1583  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1584  Operation *op = read<Operation *>();
1585  uint32_t expectedCount = read<uint32_t>();
1586  bool compareAtLeast = read();
1587 
1588  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1589  << " * Expected: " << expectedCount << "\n"
1590  << " * Comparator: "
1591  << (compareAtLeast ? ">=" : "==") << "\n");
1592  if (compareAtLeast)
1593  selectJump(op->getNumOperands() >= expectedCount);
1594  else
1595  selectJump(op->getNumOperands() == expectedCount);
1596 }
1597 
1598 void ByteCodeExecutor::executeCheckOperationName() {
1599  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1600  Operation *op = read<Operation *>();
1601  OperationName expectedName = read<OperationName>();
1602 
1603  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1604  << " * Expected: \"" << expectedName << "\"\n");
1605  selectJump(op->getName() == expectedName);
1606 }
1607 
1608 void ByteCodeExecutor::executeCheckResultCount() {
1609  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1610  Operation *op = read<Operation *>();
1611  uint32_t expectedCount = read<uint32_t>();
1612  bool compareAtLeast = read();
1613 
1614  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1615  << " * Expected: " << expectedCount << "\n"
1616  << " * Comparator: "
1617  << (compareAtLeast ? ">=" : "==") << "\n");
1618  if (compareAtLeast)
1619  selectJump(op->getNumResults() >= expectedCount);
1620  else
1621  selectJump(op->getNumResults() == expectedCount);
1622 }
1623 
1624 void ByteCodeExecutor::executeCheckTypes() {
1625  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1626  TypeRange *lhs = read<TypeRange *>();
1627  Attribute rhs = read<Attribute>();
1628  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1629 
1630  selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1631 }
1632 
1633 void ByteCodeExecutor::executeContinue() {
1634  ByteCodeField level = read();
1635  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1636  << " * Level: " << level << "\n");
1637  ++loopIndex[level];
1638  popCodeIt();
1639 }
1640 
1641 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1642  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1643  unsigned memIndex = read();
1644  unsigned rangeIndex = read();
1645  ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1646 
1647  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1648  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1649  rangeIndex);
1650 }
1651 
1652 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1653  Location mainRewriteLoc) {
1654  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1655 
1656  unsigned memIndex = read();
1657  OperationState state(mainRewriteLoc, read<OperationName>());
1658  readList(state.operands);
1659  for (unsigned i = 0, e = read(); i != e; ++i) {
1660  StringAttr name = read<StringAttr>();
1661  if (Attribute attr = read<Attribute>())
1662  state.addAttribute(name, attr);
1663  }
1664 
1665  // Read in the result types. If the "size" is the sentinel value, this
1666  // indicates that the result types should be inferred.
1667  unsigned numResults = read();
1668  if (numResults == kInferTypesMarker) {
1669  InferTypeOpInterface::Concept *inferInterface =
1670  state.name.getInterface<InferTypeOpInterface>();
1671  assert(inferInterface &&
1672  "expected operation to provide InferTypeOpInterface");
1673 
1674  // TODO: Handle failure.
1675  if (failed(inferInterface->inferReturnTypes(
1676  state.getContext(), state.location, state.operands,
1677  state.attributes.getDictionary(state.getContext()),
1678  state.getRawProperties(), state.regions, state.types)))
1679  return;
1680  } else {
1681  // Otherwise, this is a fixed number of results.
1682  for (unsigned i = 0; i != numResults; ++i) {
1683  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1684  state.types.push_back(read<Type>());
1685  } else {
1686  TypeRange *resultTypes = read<TypeRange *>();
1687  state.types.append(resultTypes->begin(), resultTypes->end());
1688  }
1689  }
1690  }
1691 
1692  Operation *resultOp = rewriter.create(state);
1693  memory[memIndex] = resultOp;
1694 
1695  LLVM_DEBUG({
1696  llvm::dbgs() << " * Attributes: "
1697  << state.attributes.getDictionary(state.getContext())
1698  << "\n * Operands: ";
1699  llvm::interleaveComma(state.operands, llvm::dbgs());
1700  llvm::dbgs() << "\n * Result Types: ";
1701  llvm::interleaveComma(state.types, llvm::dbgs());
1702  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1703  });
1704 }
1705 
1706 template <typename T>
1707 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1708  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1709  unsigned memIndex = read();
1710  unsigned rangeIndex = read();
1711  SmallVector<T> values;
1712  readList(values);
1713 
1714  LLVM_DEBUG({
1715  llvm::dbgs() << "\n * " << type << "s: ";
1716  llvm::interleaveComma(values, llvm::dbgs());
1717  llvm::dbgs() << "\n";
1718  });
1719 
1720  assignRangeToMemory(values, memIndex, rangeIndex);
1721 }
1722 
1723 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1724  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1725  Operation *op = read<Operation *>();
1726 
1727  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1728  rewriter.eraseOp(op);
1729 }
1730 
1731 template <typename T, typename Range, PDLValue::Kind kind>
1732 void ByteCodeExecutor::executeExtract() {
1733  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1734  Range *range = read<Range *>();
1735  unsigned index = read<uint32_t>();
1736  unsigned memIndex = read();
1737 
1738  if (!range) {
1739  memory[memIndex] = nullptr;
1740  return;
1741  }
1742 
1743  T result = index < range->size() ? (*range)[index] : T();
1744  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1745  << " * Index: " << index << "\n"
1746  << " * Result: " << result << "\n");
1747  storeToMemory(memIndex, result);
1748 }
1749 
1750 void ByteCodeExecutor::executeFinalize() {
1751  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1752 }
1753 
1754 void ByteCodeExecutor::executeForEach() {
1755  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1756  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1757  unsigned rangeIndex = read();
1758  unsigned memIndex = read();
1759  const void *value = nullptr;
1760 
1761  switch (read<PDLValue::Kind>()) {
1762  case PDLValue::Kind::Operation: {
1763  unsigned &index = loopIndex[read()];
1764  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1765  assert(index <= array.size() && "iterated past the end");
1766  if (index < array.size()) {
1767  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1768  value = array[index];
1769  break;
1770  }
1771 
1772  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1773  index = 0;
1774  selectJump(size_t(0));
1775  return;
1776  }
1777  default:
1778  llvm_unreachable("unexpected `ForEach` value kind");
1779  }
1780 
1781  // Store the iterate value and the stack address.
1782  memory[memIndex] = value;
1783  pushCodeIt(prevCodeIt);
1784 
1785  // Skip over the successor (we will enter the body of the loop).
1786  read<ByteCodeAddr>();
1787 }
1788 
1789 void ByteCodeExecutor::executeGetAttribute() {
1790  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1791  unsigned memIndex = read();
1792  Operation *op = read<Operation *>();
1793  StringAttr attrName = read<StringAttr>();
1794  Attribute attr = op->getAttr(attrName);
1795 
1796  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1797  << " * Attribute: " << attrName << "\n"
1798  << " * Result: " << attr << "\n");
1799  memory[memIndex] = attr.getAsOpaquePointer();
1800 }
1801 
1802 void ByteCodeExecutor::executeGetAttributeType() {
1803  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1804  unsigned memIndex = read();
1805  Attribute attr = read<Attribute>();
1806  Type type;
1807  if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1808  type = typedAttr.getType();
1809 
1810  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1811  << " * Result: " << type << "\n");
1812  memory[memIndex] = type.getAsOpaquePointer();
1813 }
1814 
1815 void ByteCodeExecutor::executeGetDefiningOp() {
1816  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1817  unsigned memIndex = read();
1818  Operation *op = nullptr;
1819  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1820  Value value = read<Value>();
1821  if (value)
1822  op = value.getDefiningOp();
1823  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1824  } else {
1825  ValueRange *values = read<ValueRange *>();
1826  if (values && !values->empty()) {
1827  op = values->front().getDefiningOp();
1828  }
1829  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1830  }
1831 
1832  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1833  memory[memIndex] = op;
1834 }
1835 
1836 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1837  Operation *op = read<Operation *>();
1838  unsigned memIndex = read();
1839  Value operand =
1840  index < op->getNumOperands() ? op->getOperand(index) : Value();
1841 
1842  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1843  << " * Index: " << index << "\n"
1844  << " * Result: " << operand << "\n");
1845  memory[memIndex] = operand.getAsOpaquePointer();
1846 }
1847 
1848 /// This function is the internal implementation of `GetResults` and
1849 /// `GetOperands` that provides support for extracting a value range from the
1850 /// given operation.
1851 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1852 static void *
1853 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1854  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1855  MutableArrayRef<ValueRange> valueRangeMemory) {
1856  // Check for the sentinel index that signals that all values should be
1857  // returned.
1858  if (index == std::numeric_limits<uint32_t>::max()) {
1859  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1860  // `values` is already the full value range.
1861 
1862  // Otherwise, check to see if this operation uses AttrSizedSegments.
1863  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1864  LLVM_DEBUG(llvm::dbgs()
1865  << " * Extracting values from `" << attrSizedSegments << "`\n");
1866 
1867  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1868  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1869  return nullptr;
1870 
1871  ArrayRef<int32_t> segments = segmentAttr;
1872  unsigned startIndex =
1873  std::accumulate(segments.begin(), segments.begin() + index, 0);
1874  values = values.slice(startIndex, *std::next(segments.begin(), index));
1875 
1876  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1877  << *std::next(segments.begin(), index) << "]\n");
1878 
1879  // Otherwise, assume this is the last operand group of the operation.
1880  // FIXME: We currently don't support operations with
1881  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1882  // have a way to detect it's presence.
1883  } else if (values.size() >= index) {
1884  LLVM_DEBUG(llvm::dbgs()
1885  << " * Treating values as trailing variadic range\n");
1886  values = values.drop_front(index);
1887 
1888  // If we couldn't detect a way to compute the values, bail out.
1889  } else {
1890  return nullptr;
1891  }
1892 
1893  // If the range index is valid, we are returning a range.
1894  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1895  valueRangeMemory[rangeIndex] = values;
1896  return &valueRangeMemory[rangeIndex];
1897  }
1898 
1899  // If a range index wasn't provided, the range is required to be non-variadic.
1900  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1901 }
1902 
1903 void ByteCodeExecutor::executeGetOperands() {
1904  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1905  unsigned index = read<uint32_t>();
1906  Operation *op = read<Operation *>();
1907  ByteCodeField rangeIndex = read();
1908 
1909  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1910  op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1911  valueRangeMemory);
1912  if (!result)
1913  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1914  memory[read()] = result;
1915 }
1916 
1917 void ByteCodeExecutor::executeGetResult(unsigned index) {
1918  Operation *op = read<Operation *>();
1919  unsigned memIndex = read();
1920  OpResult result =
1921  index < op->getNumResults() ? op->getResult(index) : OpResult();
1922 
1923  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1924  << " * Index: " << index << "\n"
1925  << " * Result: " << result << "\n");
1926  memory[memIndex] = result.getAsOpaquePointer();
1927 }
1928 
1929 void ByteCodeExecutor::executeGetResults() {
1930  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1931  unsigned index = read<uint32_t>();
1932  Operation *op = read<Operation *>();
1933  ByteCodeField rangeIndex = read();
1934 
1935  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1936  op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1937  valueRangeMemory);
1938  if (!result)
1939  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1940  memory[read()] = result;
1941 }
1942 
1943 void ByteCodeExecutor::executeGetUsers() {
1944  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1945  unsigned memIndex = read();
1946  unsigned rangeIndex = read();
1947  OwningOpRange &range = opRangeMemory[rangeIndex];
1948  memory[memIndex] = &range;
1949 
1950  range = OwningOpRange();
1951  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1952  // Read the value.
1953  Value value = read<Value>();
1954  if (!value)
1955  return;
1956  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1957 
1958  // Extract the users of a single value.
1959  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1960  llvm::copy(value.getUsers(), range.begin());
1961  } else {
1962  // Read a range of values.
1963  ValueRange *values = read<ValueRange *>();
1964  if (!values)
1965  return;
1966  LLVM_DEBUG({
1967  llvm::dbgs() << " * Values (" << values->size() << "): ";
1968  llvm::interleaveComma(*values, llvm::dbgs());
1969  llvm::dbgs() << "\n";
1970  });
1971 
1972  // Extract all the users of a range of values.
1974  for (Value value : *values)
1975  users.append(value.user_begin(), value.user_end());
1976  range = OwningOpRange(users.size());
1977  llvm::copy(users, range.begin());
1978  }
1979 
1980  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1981 }
1982 
1983 void ByteCodeExecutor::executeGetValueType() {
1984  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1985  unsigned memIndex = read();
1986  Value value = read<Value>();
1987  Type type = value ? value.getType() : Type();
1988 
1989  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1990  << " * Result: " << type << "\n");
1991  memory[memIndex] = type.getAsOpaquePointer();
1992 }
1993 
1994 void ByteCodeExecutor::executeGetValueRangeTypes() {
1995  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1996  unsigned memIndex = read();
1997  unsigned rangeIndex = read();
1998  ValueRange *values = read<ValueRange *>();
1999  if (!values) {
2000  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
2001  memory[memIndex] = nullptr;
2002  return;
2003  }
2004 
2005  LLVM_DEBUG({
2006  llvm::dbgs() << " * Values (" << values->size() << "): ";
2007  llvm::interleaveComma(*values, llvm::dbgs());
2008  llvm::dbgs() << "\n * Result: ";
2009  llvm::interleaveComma(values->getType(), llvm::dbgs());
2010  llvm::dbgs() << "\n";
2011  });
2012  typeRangeMemory[rangeIndex] = values->getType();
2013  memory[memIndex] = &typeRangeMemory[rangeIndex];
2014 }
2015 
2016 void ByteCodeExecutor::executeIsNotNull() {
2017  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2018  const void *value = read<const void *>();
2019 
2020  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
2021  selectJump(value != nullptr);
2022 }
2023 
2024 void ByteCodeExecutor::executeRecordMatch(
2025  PatternRewriter &rewriter,
2027  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2028  unsigned patternIndex = read();
2029  PatternBenefit benefit = currentPatternBenefits[patternIndex];
2030  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2031 
2032  // If the benefit of the pattern is impossible, skip the processing of the
2033  // rest of the pattern.
2034  if (benefit.isImpossibleToMatch()) {
2035  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2036  curCodeIt = dest;
2037  return;
2038  }
2039 
2040  // Create a fused location containing the locations of each of the
2041  // operations used in the match. This will be used as the location for
2042  // created operations during the rewrite that don't already have an
2043  // explicit location set.
2044  unsigned numMatchLocs = read();
2045  SmallVector<Location, 4> matchLocs;
2046  matchLocs.reserve(numMatchLocs);
2047  for (unsigned i = 0; i != numMatchLocs; ++i)
2048  matchLocs.push_back(read<Operation *>()->getLoc());
2049  Location matchLoc = rewriter.getFusedLoc(matchLocs);
2050 
2051  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
2052  << " * Location: " << matchLoc << "\n");
2053  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2054  PDLByteCode::MatchResult &match = matches.back();
2055 
2056  // Record all of the inputs to the match. If any of the inputs are ranges, we
2057  // will also need to remap the range pointer to memory stored in the match
2058  // state.
2059  unsigned numInputs = read();
2060  match.values.reserve(numInputs);
2061  match.typeRangeValues.reserve(numInputs);
2062  match.valueRangeValues.reserve(numInputs);
2063  for (unsigned i = 0; i < numInputs; ++i) {
2064  switch (read<PDLValue::Kind>()) {
2065  case PDLValue::Kind::TypeRange:
2066  match.typeRangeValues.push_back(*read<TypeRange *>());
2067  match.values.push_back(&match.typeRangeValues.back());
2068  break;
2069  case PDLValue::Kind::ValueRange:
2070  match.valueRangeValues.push_back(*read<ValueRange *>());
2071  match.values.push_back(&match.valueRangeValues.back());
2072  break;
2073  default:
2074  match.values.push_back(read<const void *>());
2075  break;
2076  }
2077  }
2078  curCodeIt = dest;
2079 }
2080 
2081 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2082  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2083  Operation *op = read<Operation *>();
2085  readList(args);
2086 
2087  LLVM_DEBUG({
2088  llvm::dbgs() << " * Operation: " << *op << "\n"
2089  << " * Values: ";
2090  llvm::interleaveComma(args, llvm::dbgs());
2091  llvm::dbgs() << "\n";
2092  });
2093  rewriter.replaceOp(op, args);
2094 }
2095 
2096 void ByteCodeExecutor::executeSwitchAttribute() {
2097  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2098  Attribute value = read<Attribute>();
2099  ArrayAttr cases = read<ArrayAttr>();
2100  handleSwitch(value, cases);
2101 }
2102 
2103 void ByteCodeExecutor::executeSwitchOperandCount() {
2104  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2105  Operation *op = read<Operation *>();
2106  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2107 
2108  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2109  handleSwitch(op->getNumOperands(), cases);
2110 }
2111 
2112 void ByteCodeExecutor::executeSwitchOperationName() {
2113  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2114  OperationName value = read<Operation *>()->getName();
2115  size_t caseCount = read();
2116 
2117  // The operation names are stored in-line, so to print them out for
2118  // debugging purposes we need to read the array before executing the
2119  // switch so that we can display all of the possible values.
2120  LLVM_DEBUG({
2121  const ByteCodeField *prevCodeIt = curCodeIt;
2122  llvm::dbgs() << " * Value: " << value << "\n"
2123  << " * Cases: ";
2124  llvm::interleaveComma(
2125  llvm::map_range(llvm::seq<size_t>(0, caseCount),
2126  [&](size_t) { return read<OperationName>(); }),
2127  llvm::dbgs());
2128  llvm::dbgs() << "\n";
2129  curCodeIt = prevCodeIt;
2130  });
2131 
2132  // Try to find the switch value within any of the cases.
2133  for (size_t i = 0; i != caseCount; ++i) {
2134  if (read<OperationName>() == value) {
2135  curCodeIt += (caseCount - i - 1);
2136  return selectJump(i + 1);
2137  }
2138  }
2139  selectJump(size_t(0));
2140 }
2141 
2142 void ByteCodeExecutor::executeSwitchResultCount() {
2143  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2144  Operation *op = read<Operation *>();
2145  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2146 
2147  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2148  handleSwitch(op->getNumResults(), cases);
2149 }
2150 
2151 void ByteCodeExecutor::executeSwitchType() {
2152  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2153  Type value = read<Type>();
2154  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2155  handleSwitch(value, cases);
2156 }
2157 
2158 void ByteCodeExecutor::executeSwitchTypes() {
2159  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2160  TypeRange *value = read<TypeRange *>();
2161  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2162  if (!value) {
2163  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2164  return selectJump(size_t(0));
2165  }
2166  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2167  return value == caseValue.getAsValueRange<TypeAttr>();
2168  });
2169 }
2170 
2171 LogicalResult
2172 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2174  std::optional<Location> mainRewriteLoc) {
2175  while (true) {
2176  // Print the location of the operation being executed.
2177  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2178 
2179  OpCode opCode = static_cast<OpCode>(read());
2180  switch (opCode) {
2181  case ApplyConstraint:
2182  executeApplyConstraint(rewriter);
2183  break;
2184  case ApplyRewrite:
2185  if (failed(executeApplyRewrite(rewriter)))
2186  return failure();
2187  break;
2188  case AreEqual:
2189  executeAreEqual();
2190  break;
2191  case AreRangesEqual:
2192  executeAreRangesEqual();
2193  break;
2194  case Branch:
2195  executeBranch();
2196  break;
2197  case CheckOperandCount:
2198  executeCheckOperandCount();
2199  break;
2200  case CheckOperationName:
2201  executeCheckOperationName();
2202  break;
2203  case CheckResultCount:
2204  executeCheckResultCount();
2205  break;
2206  case CheckTypes:
2207  executeCheckTypes();
2208  break;
2209  case Continue:
2210  executeContinue();
2211  break;
2212  case CreateConstantTypeRange:
2213  executeCreateConstantTypeRange();
2214  break;
2215  case CreateOperation:
2216  executeCreateOperation(rewriter, *mainRewriteLoc);
2217  break;
2218  case CreateDynamicTypeRange:
2219  executeDynamicCreateRange<Type>("Type");
2220  break;
2221  case CreateDynamicValueRange:
2222  executeDynamicCreateRange<Value>("Value");
2223  break;
2224  case EraseOp:
2225  executeEraseOp(rewriter);
2226  break;
2227  case ExtractOp:
2228  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2229  break;
2230  case ExtractType:
2231  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2232  break;
2233  case ExtractValue:
2234  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2235  break;
2236  case Finalize:
2237  executeFinalize();
2238  LLVM_DEBUG(llvm::dbgs() << "\n");
2239  return success();
2240  case ForEach:
2241  executeForEach();
2242  break;
2243  case GetAttribute:
2244  executeGetAttribute();
2245  break;
2246  case GetAttributeType:
2247  executeGetAttributeType();
2248  break;
2249  case GetDefiningOp:
2250  executeGetDefiningOp();
2251  break;
2252  case GetOperand0:
2253  case GetOperand1:
2254  case GetOperand2:
2255  case GetOperand3: {
2256  unsigned index = opCode - GetOperand0;
2257  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2258  executeGetOperand(index);
2259  break;
2260  }
2261  case GetOperandN:
2262  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2263  executeGetOperand(read<uint32_t>());
2264  break;
2265  case GetOperands:
2266  executeGetOperands();
2267  break;
2268  case GetResult0:
2269  case GetResult1:
2270  case GetResult2:
2271  case GetResult3: {
2272  unsigned index = opCode - GetResult0;
2273  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2274  executeGetResult(index);
2275  break;
2276  }
2277  case GetResultN:
2278  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2279  executeGetResult(read<uint32_t>());
2280  break;
2281  case GetResults:
2282  executeGetResults();
2283  break;
2284  case GetUsers:
2285  executeGetUsers();
2286  break;
2287  case GetValueType:
2288  executeGetValueType();
2289  break;
2290  case GetValueRangeTypes:
2291  executeGetValueRangeTypes();
2292  break;
2293  case IsNotNull:
2294  executeIsNotNull();
2295  break;
2296  case RecordMatch:
2297  assert(matches &&
2298  "expected matches to be provided when executing the matcher");
2299  executeRecordMatch(rewriter, *matches);
2300  break;
2301  case ReplaceOp:
2302  executeReplaceOp(rewriter);
2303  break;
2304  case SwitchAttribute:
2305  executeSwitchAttribute();
2306  break;
2307  case SwitchOperandCount:
2308  executeSwitchOperandCount();
2309  break;
2310  case SwitchOperationName:
2311  executeSwitchOperationName();
2312  break;
2313  case SwitchResultCount:
2314  executeSwitchResultCount();
2315  break;
2316  case SwitchType:
2317  executeSwitchType();
2318  break;
2319  case SwitchTypes:
2320  executeSwitchTypes();
2321  break;
2322  }
2323  LLVM_DEBUG(llvm::dbgs() << "\n");
2324  }
2325 }
2326 
2327 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2329  PDLByteCodeMutableState &state) const {
2330  // The first memory slot is always the root operation.
2331  state.memory[0] = op;
2332 
2333  // The matcher function always starts at code address 0.
2334  ByteCodeExecutor executor(
2335  matcherByteCode.data(), state.memory, state.opRangeMemory,
2336  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2337  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2338  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2339  constraintFunctions, rewriteFunctions);
2340  LogicalResult executeResult = executor.execute(rewriter, &matches);
2341  (void)executeResult;
2342  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2343 
2344  // Order the found matches by benefit.
2345  std::stable_sort(matches.begin(), matches.end(),
2346  [](const MatchResult &lhs, const MatchResult &rhs) {
2347  return lhs.benefit > rhs.benefit;
2348  });
2349 }
2350 
2351 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2352  const MatchResult &match,
2353  PDLByteCodeMutableState &state) const {
2354  auto *configSet = match.pattern->getConfigSet();
2355  if (configSet)
2356  configSet->notifyRewriteBegin(rewriter);
2357 
2358  // The arguments of the rewrite function are stored at the start of the
2359  // memory buffer.
2360  llvm::copy(match.values, state.memory.begin());
2361 
2362  ByteCodeExecutor executor(
2363  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2364  state.opRangeMemory, state.typeRangeMemory,
2365  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2366  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2367  rewriterByteCode, state.currentPatternBenefits, patterns,
2368  constraintFunctions, rewriteFunctions);
2369  LogicalResult result =
2370  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2371 
2372  if (configSet)
2373  configSet->notifyRewriteEnd(rewriter);
2374 
2375  // If the rewrite failed, check if the pattern rewriter can recover. If it
2376  // can, we can signal to the pattern applicator to keep trying patterns. If it
2377  // doesn't, we need to bail. Bailing here should be fine, given that we have
2378  // no means to propagate such a failure to the user, and it also indicates a
2379  // bug in the user code (i.e. failable rewrites should not be used with
2380  // pattern rewriters that don't support it).
2381  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2382  LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2383  llvm::report_fatal_error(
2384  "Native PDL Rewrite failed, but the pattern "
2385  "rewriter doesn't support recovery. Failable pattern rewrites should "
2386  "not be used with pattern rewriters that do not support them.");
2387  }
2388  return result;
2389 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1853
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:173
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:91
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:379
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This is a value defined by a result of an operation.
Definition: Value.h:457
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:42
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:800
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:242
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:188
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:239
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_iterator user_begin() const
Definition: Value.h:226
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:243
user_iterator user_end() const
Definition: Value.h:227
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition: ByteCode.h:235
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size