MLIR  22.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/Format.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/InterleavedRange.h"
27 #include <numeric>
28 #include <optional>
29 
30 #define DEBUG_TYPE "pdl-bytecode"
31 
32 using namespace mlir;
33 using namespace mlir::detail;
34 
35 //===----------------------------------------------------------------------===//
36 // PDLByteCodePattern
37 //===----------------------------------------------------------------------===//
38 
39 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
40  PDLPatternConfigSet *configSet,
41  ByteCodeAddr rewriterAddr) {
42  PatternBenefit benefit = matchOp.getBenefit();
43  MLIRContext *ctx = matchOp.getContext();
44 
45  // Collect the set of generated operations.
46  SmallVector<StringRef, 8> generatedOps;
47  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
48  generatedOps =
49  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 
51  // Check to see if this is pattern matches a specific operation type.
52  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
53  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
54  generatedOps);
55  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
56  benefit, ctx, generatedOps);
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // PDLByteCodeMutableState
61 //===----------------------------------------------------------------------===//
62 
63 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
64 /// to the position of the pattern within the range returned by
65 /// `PDLByteCode::getPatterns`.
66 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
67  PatternBenefit benefit) {
68  currentPatternBenefits[patternIndex] = benefit;
69 }
70 
71 /// Cleanup any allocated state after a full match/rewrite has been completed.
72 /// This method should be called irregardless of whether the match+rewrite was a
73 /// success or not.
75  allocatedTypeRangeMemory.clear();
76  allocatedValueRangeMemory.clear();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // Bytecode OpCodes
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 enum OpCode : ByteCodeField {
85  /// Apply an externally registered constraint.
86  ApplyConstraint,
87  /// Apply an externally registered rewrite.
88  ApplyRewrite,
89  /// Check if two generic values are equal.
90  AreEqual,
91  /// Check if two ranges are equal.
92  AreRangesEqual,
93  /// Unconditional branch.
94  Branch,
95  /// Compare the operand count of an operation with a constant.
96  CheckOperandCount,
97  /// Compare the name of an operation with a constant.
98  CheckOperationName,
99  /// Compare the result count of an operation with a constant.
100  CheckResultCount,
101  /// Compare a range of types to a constant range of types.
102  CheckTypes,
103  /// Continue to the next iteration of a loop.
104  Continue,
105  /// Create a type range from a list of constant types.
106  CreateConstantTypeRange,
107  /// Create an operation.
108  CreateOperation,
109  /// Create a type range from a list of dynamic types.
110  CreateDynamicTypeRange,
111  /// Create a value range.
112  CreateDynamicValueRange,
113  /// Erase an operation.
114  EraseOp,
115  /// Extract the op from a range at the specified index.
116  ExtractOp,
117  /// Extract the type from a range at the specified index.
118  ExtractType,
119  /// Extract the value from a range at the specified index.
120  ExtractValue,
121  /// Terminate a matcher or rewrite sequence.
122  Finalize,
123  /// Iterate over a range of values.
124  ForEach,
125  /// Get a specific attribute of an operation.
126  GetAttribute,
127  /// Get the type of an attribute.
128  GetAttributeType,
129  /// Get the defining operation of a value.
130  GetDefiningOp,
131  /// Get a specific operand of an operation.
132  GetOperand0,
133  GetOperand1,
134  GetOperand2,
135  GetOperand3,
136  GetOperandN,
137  /// Get a specific operand group of an operation.
138  GetOperands,
139  /// Get a specific result of an operation.
140  GetResult0,
141  GetResult1,
142  GetResult2,
143  GetResult3,
144  GetResultN,
145  /// Get a specific result group of an operation.
146  GetResults,
147  /// Get the users of a value or a range of values.
148  GetUsers,
149  /// Get the type of a value.
150  GetValueType,
151  /// Get the types of a value range.
152  GetValueRangeTypes,
153  /// Check if a generic value is not null.
154  IsNotNull,
155  /// Record a successful pattern match.
156  RecordMatch,
157  /// Replace an operation.
158  ReplaceOp,
159  /// Compare an attribute with a set of constants.
160  SwitchAttribute,
161  /// Compare the operand count of an operation with a set of constants.
162  SwitchOperandCount,
163  /// Compare the name of an operation with a set of constants.
164  SwitchOperationName,
165  /// Compare the result count of an operation with a set of constants.
166  SwitchResultCount,
167  /// Compare a type with a set of constants.
168  SwitchType,
169  /// Compare a range of types with a set of constants.
170  SwitchTypes,
171 };
172 } // namespace
173 
174 /// A marker used to indicate if an operation should infer types.
175 static constexpr ByteCodeField kInferTypesMarker =
177 
178 //===----------------------------------------------------------------------===//
179 // ByteCode Generation
180 //===----------------------------------------------------------------------===//
181 
182 //===----------------------------------------------------------------------===//
183 // Generator
184 //===----------------------------------------------------------------------===//
185 
186 namespace {
187 struct ByteCodeLiveRange;
188 struct ByteCodeWriter;
189 
190 /// Check if the given class `T` can be converted to an opaque pointer.
191 template <typename T, typename... Args>
192 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
193 
194 /// This class represents the main generator for the pattern bytecode.
195 class Generator {
196 public:
197  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
198  SmallVectorImpl<ByteCodeField> &matcherByteCode,
199  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
201  ByteCodeField &maxValueMemoryIndex,
202  ByteCodeField &maxOpRangeMemoryIndex,
203  ByteCodeField &maxTypeRangeMemoryIndex,
204  ByteCodeField &maxValueRangeMemoryIndex,
205  ByteCodeField &maxLoopLevel,
206  llvm::StringMap<PDLConstraintFunction> &constraintFns,
207  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
210  rewriterByteCode(rewriterByteCode), patterns(patterns),
211  maxValueMemoryIndex(maxValueMemoryIndex),
212  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215  maxLoopLevel(maxLoopLevel), configMap(configMap) {
216  for (const auto &it : llvm::enumerate(constraintFns))
217  constraintToMemIndex.try_emplace(it.value().first(), it.index());
218  for (const auto &it : llvm::enumerate(rewriteFns))
219  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220  }
221 
222  /// Generate the bytecode for the given PDL interpreter module.
223  void generate(ModuleOp module);
224 
225  /// Return the memory index to use for the given value.
226  ByteCodeField &getMemIndex(Value value) {
227  assert(valueToMemIndex.count(value) &&
228  "expected memory index to be assigned");
229  return valueToMemIndex[value];
230  }
231 
232  /// Return the range memory index used to store the given range value.
233  ByteCodeField &getRangeStorageIndex(Value value) {
234  assert(valueToRangeIndex.count(value) &&
235  "expected range index to be assigned");
236  return valueToRangeIndex[value];
237  }
238 
239  /// Return an index to use when referring to the given data that is uniqued in
240  /// the MLIR context.
241  template <typename T>
242  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
243  getMemIndex(T val) {
244  const void *opaqueVal = val.getAsOpaquePointer();
245 
246  // Get or insert a reference to this value.
247  auto it = uniquedDataToMemIndex.try_emplace(
248  opaqueVal, maxValueMemoryIndex + uniquedData.size());
249  if (it.second)
250  uniquedData.push_back(opaqueVal);
251  return it.first->second;
252  }
253 
254 private:
255  /// Allocate memory indices for the results of operations within the matcher
256  /// and rewriters.
257  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258  ModuleOp rewriterModule);
259 
260  /// Generate the bytecode for the given operation.
261  void generate(Region *region, ByteCodeWriter &writer);
262  void generate(Operation *op, ByteCodeWriter &writer);
263  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
301 
302  /// Mapping from value to its corresponding memory index.
303  DenseMap<Value, ByteCodeField> valueToMemIndex;
304 
305  /// Mapping from a range value to its corresponding range storage index.
306  DenseMap<Value, ByteCodeField> valueToRangeIndex;
307 
308  /// Mapping from the name of an externally registered rewrite to its index in
309  /// the bytecode registry.
310  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311 
312  /// Mapping from the name of an externally registered constraint to its index
313  /// in the bytecode registry.
314  llvm::StringMap<ByteCodeField> constraintToMemIndex;
315 
316  /// Mapping from rewriter function name to the bytecode address of the
317  /// rewriter function in byte.
318  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
319 
320  /// Mapping from a uniqued storage object to its memory index within
321  /// `uniquedData`.
322  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
323 
324  /// The current level of the foreach loop.
325  ByteCodeField curLoopLevel = 0;
326 
327  /// The current MLIR context.
328  MLIRContext *ctx;
329 
330  /// Mapping from block to its address.
332 
333  /// Data of the ByteCode class to be populated.
334  std::vector<const void *> &uniquedData;
335  SmallVectorImpl<ByteCodeField> &matcherByteCode;
336  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
338  ByteCodeField &maxValueMemoryIndex;
339  ByteCodeField &maxOpRangeMemoryIndex;
340  ByteCodeField &maxTypeRangeMemoryIndex;
341  ByteCodeField &maxValueRangeMemoryIndex;
342  ByteCodeField &maxLoopLevel;
343 
344  /// A map of pattern configurations.
346 };
347 
348 /// This class provides utilities for writing a bytecode stream.
349 struct ByteCodeWriter {
350  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
351  : bytecode(bytecode), generator(generator) {}
352 
353  /// Append a field to the bytecode.
354  void append(ByteCodeField field) { bytecode.push_back(field); }
355  void append(OpCode opCode) { bytecode.push_back(opCode); }
356 
357  /// Append an address to the bytecode.
358  void append(ByteCodeAddr field) {
359  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
360  "unexpected ByteCode address size");
361 
362  ByteCodeField fieldParts[2];
363  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
364  bytecode.append({fieldParts[0], fieldParts[1]});
365  }
366 
367  /// Append a single successor to the bytecode, the exact address will need to
368  /// be resolved later.
369  void append(Block *successor) {
370  // Add back a reference to the successor so that the address can be resolved
371  // later.
372  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373  append(ByteCodeAddr(0));
374  }
375 
376  /// Append a successor range to the bytecode, the exact address will need to
377  /// be resolved later.
378  void append(SuccessorRange successors) {
379  for (Block *successor : successors)
380  append(successor);
381  }
382 
383  /// Append a range of values that will be read as generic PDLValues.
384  void appendPDLValueList(OperandRange values) {
385  bytecode.push_back(values.size());
386  for (Value value : values)
387  appendPDLValue(value);
388  }
389 
390  /// Append a value as a PDLValue.
391  void appendPDLValue(Value value) {
392  appendPDLValueKind(value);
393  append(value);
394  }
395 
396  /// Append the PDLValue::Kind of the given value.
397  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
398 
399  /// Append the PDLValue::Kind of the given type.
400  void appendPDLValueKind(Type type) {
403  .Case<pdl::AttributeType>(
404  [](Type) { return PDLValue::Kind::Attribute; })
405  .Case<pdl::OperationType>(
406  [](Type) { return PDLValue::Kind::Operation; })
407  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408  if (isa<pdl::TypeType>(rangeTy.getElementType()))
409  return PDLValue::Kind::TypeRange;
410  return PDLValue::Kind::ValueRange;
411  })
412  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
413  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
414  bytecode.push_back(static_cast<ByteCodeField>(kind));
415  }
416 
417  /// Append a value that will be stored in a memory slot and not inline within
418  /// the bytecode.
419  template <typename T>
420  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421  std::is_pointer<T>::value>
422  append(T value) {
423  bytecode.push_back(generator.getMemIndex(value));
424  }
425 
426  /// Append a range of values.
427  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
428  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
429  append(T range) {
430  bytecode.push_back(llvm::size(range));
431  for (auto it : range)
432  append(it);
433  }
434 
435  /// Append a variadic number of fields to the bytecode.
436  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
437  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
438  append(field);
439  append(field2, fields...);
440  }
441 
442  /// Appends a value as a pointer, stored inline within the bytecode.
443  template <typename T>
444  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445  appendInline(T value) {
446  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
447  const void *pointer = value.getAsOpaquePointer();
448  ByteCodeField fieldParts[numParts];
449  std::memcpy(fieldParts, &pointer, sizeof(const void *));
450  bytecode.append(fieldParts, fieldParts + numParts);
451  }
452 
453  /// Successor references in the bytecode that have yet to be resolved.
454  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
455 
456  /// The underlying bytecode buffer.
458 
459  /// The main generator producing PDL.
460  Generator &generator;
461 };
462 
463 /// This class represents a live range of PDL Interpreter values, containing
464 /// information about when values are live within a match/rewrite.
465 struct ByteCodeLiveRange {
466  using Set = llvm::IntervalMap<uint64_t, char, 16>;
467  using Allocator = Set::Allocator;
468 
469  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470 
471  /// Union this live range with the one provided.
472  void unionWith(const ByteCodeLiveRange &rhs) {
473  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
474  ++it)
475  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
476  }
477 
478  /// Returns true if this range overlaps with the one provided.
479  bool overlaps(const ByteCodeLiveRange &rhs) const {
480  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
481  .valid();
482  }
483 
484  /// A map representing the ranges of the match/rewrite that a value is live in
485  /// the interpreter.
486  ///
487  /// We use std::unique_ptr here, because IntervalMap does not provide a
488  /// correct copy or move constructor. We can eliminate the pointer once
489  /// https://reviews.llvm.org/D113240 lands.
490  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491 
492  /// The operation range storage index for this range.
493  std::optional<unsigned> opRangeIndex;
494 
495  /// The type range storage index for this range.
496  std::optional<unsigned> typeRangeIndex;
497 
498  /// The value range storage index for this range.
499  std::optional<unsigned> valueRangeIndex;
500 };
501 } // namespace
502 
503 void Generator::generate(ModuleOp module) {
504  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507  pdl_interp::PDLInterpDialect::getRewriterModuleName());
508  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
509 
510  // Allocate memory indices for the results of operations within the matcher
511  // and rewriters.
512  allocateMemoryIndices(matcherFunc, rewriterModule);
513 
514  // Generate code for the rewriter functions.
515  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
516  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518  for (Operation &op : rewriterFunc.getOps())
519  generate(&op, rewriterByteCodeWriter);
520  }
521  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522  "unexpected branches in rewriter function");
523 
524  // Generate code for the matcher function.
525  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
526  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527 
528  // Resolve successor references in the matcher.
529  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530  ByteCodeAddr addr = blockToAddr[it.first];
531  for (unsigned offsetToFix : it.second)
532  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
533  }
534 }
535 
536 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537  ModuleOp rewriterModule) {
538  // Rewriters use simplistic allocation scheme that simply assigns an index to
539  // each result.
540  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542  auto processRewriterValue = [&](Value val) {
543  valueToMemIndex.try_emplace(val, index++);
544  if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545  Type elementTy = rangeType.getElementType();
546  if (isa<pdl::TypeType>(elementTy))
547  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548  else if (isa<pdl::ValueType>(elementTy))
549  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
550  }
551  };
552 
553  for (BlockArgument arg : rewriterFunc.getArguments())
554  processRewriterValue(arg);
555  rewriterFunc.getBody().walk([&](Operation *op) {
556  for (Value result : op->getResults())
557  processRewriterValue(result);
558  });
559  if (index > maxValueMemoryIndex)
560  maxValueMemoryIndex = index;
561  if (typeRangeIndex > maxTypeRangeMemoryIndex)
562  maxTypeRangeMemoryIndex = typeRangeIndex;
563  if (valueRangeIndex > maxValueRangeMemoryIndex)
564  maxValueRangeMemoryIndex = valueRangeIndex;
565  }
566 
567  // The matcher function uses a more sophisticated numbering that tries to
568  // minimize the number of memory indices assigned. This is done by determining
569  // a live range of the values within the matcher, then the allocation is just
570  // finding the minimal number of overlapping live ranges. This is essentially
571  // a simplified form of register allocation where we don't necessarily have a
572  // limited number of registers, but we still want to minimize the number used.
573  DenseMap<Operation *, unsigned> opToFirstIndex;
574  DenseMap<Operation *, unsigned> opToLastIndex;
575 
576  // A custom walk that marks the first and the last index of each operation.
577  // The entry marks the beginning of the liveness range for this operation,
578  // followed by nested operations, followed by the end of the liveness range.
579  unsigned index = 0;
580  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
581  opToFirstIndex.try_emplace(op, index++);
582  for (Region &region : op->getRegions())
583  for (Block &block : region.getBlocks())
584  for (Operation &nested : block)
585  walk(&nested);
586  opToLastIndex.try_emplace(op, index++);
587  };
588  walk(matcherFunc);
589 
590  // Liveness info for each of the defs within the matcher.
591  ByteCodeLiveRange::Allocator allocator;
592  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
593 
594  // Assign the root operation being matched to slot 0.
595  BlockArgument rootOpArg = matcherFunc.getArgument(0);
596  valueToMemIndex[rootOpArg] = 0;
597 
598  // Walk each of the blocks, computing the def interval that the value is used.
599  Liveness matcherLiveness(matcherFunc);
600  matcherFunc->walk([&](Block *block) {
601  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
602  assert(info && "expected liveness info for block");
603  auto processValue = [&](Value value, Operation *firstUseOrDef) {
604  // We don't need to process the root op argument, this value is always
605  // assigned to the first memory slot.
606  if (value == rootOpArg)
607  return;
608 
609  // Set indices for the range of this block that the value is used.
610  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611  defRangeIt->second.liveness->insert(
612  opToFirstIndex[firstUseOrDef],
613  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
614  /*dummyValue*/ 0);
615 
616  // Check to see if this value is a range type.
617  if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
618  Type eleType = rangeTy.getElementType();
619  if (isa<pdl::OperationType>(eleType))
620  defRangeIt->second.opRangeIndex = 0;
621  else if (isa<pdl::TypeType>(eleType))
622  defRangeIt->second.typeRangeIndex = 0;
623  else if (isa<pdl::ValueType>(eleType))
624  defRangeIt->second.valueRangeIndex = 0;
625  }
626  };
627 
628  // Process the live-ins of this block.
629  for (Value liveIn : info->in()) {
630  // Only process the value if it has been defined in the current region.
631  // Other values that span across pdl_interp.foreach will be added higher
632  // up. This ensures that the we keep them alive for the entire duration
633  // of the loop.
634  if (liveIn.getParentRegion() == block->getParent())
635  processValue(liveIn, &block->front());
636  }
637 
638  // Process the block arguments for the entry block (those are not live-in).
639  if (block->isEntryBlock()) {
640  for (Value argument : block->getArguments())
641  processValue(argument, &block->front());
642  }
643 
644  // Process any new defs within this block.
645  for (Operation &op : *block)
646  for (Value result : op.getResults())
647  processValue(result, &op);
648  });
649 
650  // Greedily allocate memory slots using the computed def live ranges.
651  std::vector<ByteCodeLiveRange> allocatedIndices;
652 
653  // The number of memory indices currently allocated (and its next value).
654  // Recall that the root gets allocated memory index 0.
655  ByteCodeField numIndices = 1;
656 
657  // The number of memory ranges of various types (and their next values).
658  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
659 
660  for (auto &defIt : valueDefRanges) {
661  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662  ByteCodeLiveRange &defRange = defIt.second;
663 
664  // Try to allocate to an existing index.
665  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
666  ByteCodeLiveRange &existingRange = existingIndexIt.value();
667  if (!defRange.overlaps(existingRange)) {
668  existingRange.unionWith(defRange);
669  memIndex = existingIndexIt.index() + 1;
670 
671  if (defRange.opRangeIndex) {
672  if (!existingRange.opRangeIndex)
673  existingRange.opRangeIndex = numOpRanges++;
674  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675  } else if (defRange.typeRangeIndex) {
676  if (!existingRange.typeRangeIndex)
677  existingRange.typeRangeIndex = numTypeRanges++;
678  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679  } else if (defRange.valueRangeIndex) {
680  if (!existingRange.valueRangeIndex)
681  existingRange.valueRangeIndex = numValueRanges++;
682  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
683  }
684  break;
685  }
686  }
687 
688  // If no existing index could be used, add a new one.
689  if (memIndex == 0) {
690  allocatedIndices.emplace_back(allocator);
691  ByteCodeLiveRange &newRange = allocatedIndices.back();
692  newRange.unionWith(defRange);
693 
694  // Allocate an index for op/type/value ranges.
695  if (defRange.opRangeIndex) {
696  newRange.opRangeIndex = numOpRanges;
697  valueToRangeIndex[defIt.first] = numOpRanges++;
698  } else if (defRange.typeRangeIndex) {
699  newRange.typeRangeIndex = numTypeRanges;
700  valueToRangeIndex[defIt.first] = numTypeRanges++;
701  } else if (defRange.valueRangeIndex) {
702  newRange.valueRangeIndex = numValueRanges;
703  valueToRangeIndex[defIt.first] = numValueRanges++;
704  }
705 
706  memIndex = allocatedIndices.size();
707  ++numIndices;
708  }
709  }
710 
711  // Print the index usage and ensure that we did not run out of index space.
712  LDBG() << "Allocated " << allocatedIndices.size() << " indices "
713  << "(down from initial " << valueDefRanges.size() << ").";
714  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715  "Ran out of memory for allocated indices");
716 
717  // Update the max number of indices.
718  if (numIndices > maxValueMemoryIndex)
719  maxValueMemoryIndex = numIndices;
720  if (numOpRanges > maxOpRangeMemoryIndex)
721  maxOpRangeMemoryIndex = numOpRanges;
722  if (numTypeRanges > maxTypeRangeMemoryIndex)
723  maxTypeRangeMemoryIndex = numTypeRanges;
724  if (numValueRanges > maxValueRangeMemoryIndex)
725  maxValueRangeMemoryIndex = numValueRanges;
726 }
727 
728 void Generator::generate(Region *region, ByteCodeWriter &writer) {
729  llvm::ReversePostOrderTraversal<Region *> rpot(region);
730  for (Block *block : rpot) {
731  // Keep track of where this block begins within the matcher function.
732  blockToAddr.try_emplace(block, matcherByteCode.size());
733  for (Operation &op : *block)
734  generate(&op, writer);
735  }
736 }
737 
738 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739  LDBG() << "Generating bytecode for operation: " << op->getName();
740  LLVM_DEBUG({
741  // The following list must contain all the operations that do not
742  // produce any bytecode.
743  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744  writer.appendInline(op->getLoc());
745  });
747  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765  pdl_interp::SwitchResultCountOp>(
766  [&](auto interpOp) { this->generate(interpOp, writer); })
767  .Default([](Operation *) {
768  llvm_unreachable("unknown `pdl_interp` operation");
769  });
770 }
771 
772 void Generator::generate(pdl_interp::ApplyConstraintOp op,
773  ByteCodeWriter &writer) {
774  // Constraints that should return a value have to be registered as rewrites.
775  // If a constraint and a rewrite of similar name are registered the
776  // constraint takes precedence
777  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
778  writer.appendPDLValueList(op.getArgs());
779  writer.append(ByteCodeField(op.getIsNegated()));
780  ResultRange results = op.getResults();
781  writer.append(ByteCodeField(results.size()));
782  for (Value result : results) {
783  // We record the expected kind of the result, so that we can provide extra
784  // verification of the native rewrite function and handle the failure case
785  // of constraints accordingly.
786  writer.appendPDLValueKind(result);
787 
788  // Range results also need to append the range storage index.
789  if (isa<pdl::RangeType>(result.getType()))
790  writer.append(getRangeStorageIndex(result));
791  writer.append(result);
792  }
793  writer.append(op.getSuccessors());
794 }
795 void Generator::generate(pdl_interp::ApplyRewriteOp op,
796  ByteCodeWriter &writer) {
797  assert(externalRewriterToMemIndex.count(op.getName()) &&
798  "expected index for rewrite function");
799  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
800  writer.appendPDLValueList(op.getArgs());
801 
802  ResultRange results = op.getResults();
803  writer.append(ByteCodeField(results.size()));
804  for (Value result : results) {
805  // We record the expected kind of the result, so that we
806  // can provide extra verification of the native rewrite function.
807  writer.appendPDLValueKind(result);
808 
809  // Range results also need to append the range storage index.
810  if (isa<pdl::RangeType>(result.getType()))
811  writer.append(getRangeStorageIndex(result));
812  writer.append(result);
813  }
814 }
815 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
816  Value lhs = op.getLhs();
817  if (isa<pdl::RangeType>(lhs.getType())) {
818  writer.append(OpCode::AreRangesEqual);
819  writer.appendPDLValueKind(lhs);
820  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
821  return;
822  }
823 
824  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
825 }
826 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
827  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
828 }
829 void Generator::generate(pdl_interp::CheckAttributeOp op,
830  ByteCodeWriter &writer) {
831  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
832  op.getSuccessors());
833 }
834 void Generator::generate(pdl_interp::CheckOperandCountOp op,
835  ByteCodeWriter &writer) {
836  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
837  static_cast<ByteCodeField>(op.getCompareAtLeast()),
838  op.getSuccessors());
839 }
840 void Generator::generate(pdl_interp::CheckOperationNameOp op,
841  ByteCodeWriter &writer) {
842  writer.append(OpCode::CheckOperationName, op.getInputOp(),
843  OperationName(op.getName(), ctx), op.getSuccessors());
844 }
845 void Generator::generate(pdl_interp::CheckResultCountOp op,
846  ByteCodeWriter &writer) {
847  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
848  static_cast<ByteCodeField>(op.getCompareAtLeast()),
849  op.getSuccessors());
850 }
851 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
852  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
853  op.getSuccessors());
854 }
855 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
856  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
857  op.getSuccessors());
858 }
859 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
860  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
861  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
862 }
863 void Generator::generate(pdl_interp::CreateAttributeOp op,
864  ByteCodeWriter &writer) {
865  // Simply repoint the memory index of the result to the constant.
866  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
867 }
868 void Generator::generate(pdl_interp::CreateOperationOp op,
869  ByteCodeWriter &writer) {
870  writer.append(OpCode::CreateOperation, op.getResultOp(),
871  OperationName(op.getName(), ctx));
872  writer.appendPDLValueList(op.getInputOperands());
873 
874  // Add the attributes.
875  OperandRange attributes = op.getInputAttributes();
876  writer.append(static_cast<ByteCodeField>(attributes.size()));
877  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
878  writer.append(std::get<0>(it), std::get<1>(it));
879 
880  // Add the result types. If the operation has inferred results, we use a
881  // marker "size" value. Otherwise, we add the list of explicit result types.
882  if (op.getInferredResultTypes())
883  writer.append(kInferTypesMarker);
884  else
885  writer.appendPDLValueList(op.getInputResultTypes());
886 }
887 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
888  // Append the correct opcode for the range type.
889  TypeSwitch<Type>(op.getType().getElementType())
890  .Case(
891  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
892  .Case([&](pdl::ValueType) {
893  writer.append(OpCode::CreateDynamicValueRange);
894  });
895 
896  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
897  writer.appendPDLValueList(op->getOperands());
898 }
899 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
900  // Simply repoint the memory index of the result to the constant.
901  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
902 }
903 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
904  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
905  getRangeStorageIndex(op.getResult()), op.getValue());
906 }
907 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
908  writer.append(OpCode::EraseOp, op.getInputOp());
909 }
910 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
911  OpCode opCode =
912  TypeSwitch<Type, OpCode>(op.getResult().getType())
913  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
914  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
915  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
916  .Default([](Type) -> OpCode {
917  llvm_unreachable("unsupported element type");
918  });
919  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
920 }
921 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
922  writer.append(OpCode::Finalize);
923 }
924 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
925  BlockArgument arg = op.getLoopVariable();
926  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
927  writer.appendPDLValueKind(arg.getType());
928  writer.append(curLoopLevel, op.getSuccessor());
929  ++curLoopLevel;
930  if (curLoopLevel > maxLoopLevel)
931  maxLoopLevel = curLoopLevel;
932  generate(&op.getRegion(), writer);
933  --curLoopLevel;
934 }
935 void Generator::generate(pdl_interp::GetAttributeOp op,
936  ByteCodeWriter &writer) {
937  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
938  op.getNameAttr());
939 }
940 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
941  ByteCodeWriter &writer) {
942  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
943 }
944 void Generator::generate(pdl_interp::GetDefiningOpOp op,
945  ByteCodeWriter &writer) {
946  writer.append(OpCode::GetDefiningOp, op.getInputOp());
947  writer.appendPDLValue(op.getValue());
948 }
949 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
950  uint32_t index = op.getIndex();
951  if (index < 4)
952  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
953  else
954  writer.append(OpCode::GetOperandN, index);
955  writer.append(op.getInputOp(), op.getValue());
956 }
957 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
958  Value result = op.getValue();
959  std::optional<uint32_t> index = op.getIndex();
960  writer.append(OpCode::GetOperands,
961  index.value_or(std::numeric_limits<uint32_t>::max()),
962  op.getInputOp());
963  if (isa<pdl::RangeType>(result.getType()))
964  writer.append(getRangeStorageIndex(result));
965  else
967  writer.append(result);
968 }
969 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
970  uint32_t index = op.getIndex();
971  if (index < 4)
972  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
973  else
974  writer.append(OpCode::GetResultN, index);
975  writer.append(op.getInputOp(), op.getValue());
976 }
977 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
978  Value result = op.getValue();
979  std::optional<uint32_t> index = op.getIndex();
980  writer.append(OpCode::GetResults,
981  index.value_or(std::numeric_limits<uint32_t>::max()),
982  op.getInputOp());
983  if (isa<pdl::RangeType>(result.getType()))
984  writer.append(getRangeStorageIndex(result));
985  else
987  writer.append(result);
988 }
989 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
990  Value operations = op.getOperations();
991  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
992  writer.append(OpCode::GetUsers, operations, rangeIndex);
993  writer.appendPDLValue(op.getValue());
994 }
995 void Generator::generate(pdl_interp::GetValueTypeOp op,
996  ByteCodeWriter &writer) {
997  if (isa<pdl::RangeType>(op.getType())) {
998  Value result = op.getResult();
999  writer.append(OpCode::GetValueRangeTypes, result,
1000  getRangeStorageIndex(result), op.getValue());
1001  } else {
1002  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1003  }
1004 }
1005 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1006  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1007 }
1008 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1009  ByteCodeField patternIndex = patterns.size();
1010  patterns.emplace_back(PDLByteCodePattern::create(
1011  op, configMap.lookup(op),
1012  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1013  writer.append(OpCode::RecordMatch, patternIndex,
1014  SuccessorRange(op.getOperation()), op.getMatchedOps());
1015  writer.appendPDLValueList(op.getInputs());
1016 }
1017 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1018  writer.append(OpCode::ReplaceOp, op.getInputOp());
1019  writer.appendPDLValueList(op.getReplValues());
1020 }
1021 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1022  ByteCodeWriter &writer) {
1023  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1024  op.getCaseValuesAttr(), op.getSuccessors());
1025 }
1026 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1027  ByteCodeWriter &writer) {
1028  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1029  op.getCaseValuesAttr(), op.getSuccessors());
1030 }
1031 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1032  ByteCodeWriter &writer) {
1033  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1034  return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1035  });
1036  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1037  op.getSuccessors());
1038 }
1039 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1040  ByteCodeWriter &writer) {
1041  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1042  op.getCaseValuesAttr(), op.getSuccessors());
1043 }
1044 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1045  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1046  op.getSuccessors());
1047 }
1048 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1049  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1050  op.getSuccessors());
1051 }
1052 
1053 //===----------------------------------------------------------------------===//
1054 // PDLByteCode
1055 //===----------------------------------------------------------------------===//
1056 
1057 PDLByteCode::PDLByteCode(
1058  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1060  llvm::StringMap<PDLConstraintFunction> constraintFns,
1061  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1062  : configs(std::move(configs)) {
1063  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1064  rewriterByteCode, patterns, maxValueMemoryIndex,
1065  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1066  maxLoopLevel, constraintFns, rewriteFns, configMap);
1067  generator.generate(module);
1068 
1069  // Initialize the external functions.
1070  for (auto &it : constraintFns)
1071  constraintFunctions.push_back(std::move(it.second));
1072  for (auto &it : rewriteFns)
1073  rewriteFunctions.push_back(std::move(it.second));
1074 }
1075 
1076 /// Initialize the given state such that it can be used to execute the current
1077 /// bytecode.
1078 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1079  state.memory.resize(maxValueMemoryIndex, nullptr);
1080  state.opRangeMemory.resize(maxOpRangeCount);
1081  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1082  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1083  state.loopIndex.resize(maxLoopLevel, 0);
1084  state.currentPatternBenefits.reserve(patterns.size());
1085  for (const PDLByteCodePattern &pattern : patterns)
1086  state.currentPatternBenefits.push_back(pattern.getBenefit());
1087 }
1088 
1089 //===----------------------------------------------------------------------===//
1090 // ByteCode Execution
1091 //===----------------------------------------------------------------------===//
1092 
1093 namespace {
1094 /// This class is an instantiation of the PDLResultList that provides access to
1095 /// the returned results. This API is not on `PDLResultList` to avoid
1096 /// overexposing access to information specific solely to the ByteCode.
1097 class ByteCodeRewriteResultList : public PDLResultList {
1098 public:
1099  ByteCodeRewriteResultList(unsigned maxNumResults)
1100  : PDLResultList(maxNumResults) {}
1101 
1102  /// Return the list of PDL results.
1103  MutableArrayRef<PDLValue> getResults() { return results; }
1104 
1105  /// Return the type ranges allocated by this list.
1106  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1107  return allocatedTypeRanges;
1108  }
1109 
1110  /// Return the value ranges allocated by this list.
1111  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1112  return allocatedValueRanges;
1113  }
1114 };
1115 
1116 /// This class provides support for executing a bytecode stream.
1117 class ByteCodeExecutor {
1118 public:
1119  ByteCodeExecutor(
1120  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1121  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1122  MutableArrayRef<TypeRange> typeRangeMemory,
1123  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1124  MutableArrayRef<ValueRange> valueRangeMemory,
1125  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1126  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1128  ArrayRef<PatternBenefit> currentPatternBenefits,
1130  ArrayRef<PDLConstraintFunction> constraintFunctions,
1131  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1132  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1133  typeRangeMemory(typeRangeMemory),
1134  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1135  valueRangeMemory(valueRangeMemory),
1136  allocatedValueRangeMemory(allocatedValueRangeMemory),
1137  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1138  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1139  constraintFunctions(constraintFunctions),
1140  rewriteFunctions(rewriteFunctions) {}
1141 
1142  /// Start executing the code at the current bytecode index. `matches` is an
1143  /// optional field provided when this function is executed in a matching
1144  /// context.
1145  LogicalResult
1146  execute(PatternRewriter &rewriter,
1147  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1148  std::optional<Location> mainRewriteLoc = {});
1149 
1150 private:
1151  /// Internal implementation of executing each of the bytecode commands.
1152  void executeApplyConstraint(PatternRewriter &rewriter);
1153  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1154  void executeAreEqual();
1155  void executeAreRangesEqual();
1156  void executeBranch();
1157  void executeCheckOperandCount();
1158  void executeCheckOperationName();
1159  void executeCheckResultCount();
1160  void executeCheckTypes();
1161  void executeContinue();
1162  void executeCreateConstantTypeRange();
1163  void executeCreateOperation(PatternRewriter &rewriter,
1164  Location mainRewriteLoc);
1165  template <typename T>
1166  void executeDynamicCreateRange(StringRef type);
1167  void executeEraseOp(PatternRewriter &rewriter);
1168  template <typename T, typename Range, PDLValue::Kind kind>
1169  void executeExtract();
1170  void executeFinalize();
1171  void executeForEach();
1172  void executeGetAttribute();
1173  void executeGetAttributeType();
1174  void executeGetDefiningOp();
1175  void executeGetOperand(unsigned index);
1176  void executeGetOperands();
1177  void executeGetResult(unsigned index);
1178  void executeGetResults();
1179  void executeGetUsers();
1180  void executeGetValueType();
1181  void executeGetValueRangeTypes();
1182  void executeIsNotNull();
1183  void executeRecordMatch(PatternRewriter &rewriter,
1185  void executeReplaceOp(PatternRewriter &rewriter);
1186  void executeSwitchAttribute();
1187  void executeSwitchOperandCount();
1188  void executeSwitchOperationName();
1189  void executeSwitchResultCount();
1190  void executeSwitchType();
1191  void executeSwitchTypes();
1192  void processNativeFunResults(ByteCodeRewriteResultList &results,
1193  unsigned numResults,
1194  LogicalResult &rewriteResult);
1195 
1196  /// Pushes a code iterator to the stack.
1197  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1198 
1199  /// Pops a code iterator from the stack, returning true on success.
1200  void popCodeIt() {
1201  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1202  curCodeIt = resumeCodeIt.pop_back_val();
1203  }
1204 
1205  /// Return the bytecode iterator at the start of the current op code.
1206  const ByteCodeField *getPrevCodeIt() const {
1207  LLVM_DEBUG({
1208  // Account for the op code and the Location stored inline.
1209  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1210  });
1211 
1212  // Account for the op code only.
1213  return curCodeIt - 1;
1214  }
1215 
1216  /// Read a value from the bytecode buffer, optionally skipping a certain
1217  /// number of prefix values. These methods always update the buffer to point
1218  /// to the next field after the read data.
1219  template <typename T = ByteCodeField>
1220  T read(size_t skipN = 0) {
1221  curCodeIt += skipN;
1222  return readImpl<T>();
1223  }
1224  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1225 
1226  /// Read a list of values from the bytecode buffer.
1227  template <typename ValueT, typename T>
1228  void readList(SmallVectorImpl<T> &list) {
1229  list.clear();
1230  for (unsigned i = 0, e = read(); i != e; ++i)
1231  list.push_back(read<ValueT>());
1232  }
1233 
1234  /// Read a list of values from the bytecode buffer. The values may be encoded
1235  /// either as a single element or a range of elements.
1236  void readList(SmallVectorImpl<Type> &list) {
1237  for (unsigned i = 0, e = read(); i != e; ++i) {
1238  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1239  list.push_back(read<Type>());
1240  } else {
1241  TypeRange *values = read<TypeRange *>();
1242  list.append(values->begin(), values->end());
1243  }
1244  }
1245  }
1246  void readList(SmallVectorImpl<Value> &list) {
1247  for (unsigned i = 0, e = read(); i != e; ++i) {
1248  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1249  list.push_back(read<Value>());
1250  } else {
1251  ValueRange *values = read<ValueRange *>();
1252  list.append(values->begin(), values->end());
1253  }
1254  }
1255  }
1256 
1257  /// Read a value stored inline as a pointer.
1258  template <typename T>
1259  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1260  readInline() {
1261  const void *pointer;
1262  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1263  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1264  return T::getFromOpaquePointer(pointer);
1265  }
1266 
1267  void skip(size_t skipN) { curCodeIt += skipN; }
1268 
1269  /// Jump to a specific successor based on a predicate value.
1270  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1271  /// Jump to a specific successor based on a destination index.
1272  void selectJump(size_t destIndex) {
1273  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1274  }
1275 
1276  /// Handle a switch operation with the provided value and cases.
1277  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1278  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1279  LDBG() << "Switch operation:\n * Value: " << value
1280  << "\n * Cases: " << llvm::interleaved(cases);
1281 
1282  // Check to see if the attribute value is within the case list. Jump to
1283  // the correct successor index based on the result.
1284  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1285  if (cmp(*it, value))
1286  return selectJump(size_t((it - cases.begin()) + 1));
1287  selectJump(size_t(0));
1288  }
1289 
1290  /// Store a pointer to memory.
1291  void storeToMemory(unsigned index, const void *value) {
1292  memory[index] = value;
1293  }
1294 
1295  /// Store a value to memory as an opaque pointer.
1296  template <typename T>
1297  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1298  storeToMemory(unsigned index, T value) {
1299  memory[index] = value.getAsOpaquePointer();
1300  }
1301 
1302  /// Internal implementation of reading various data types from the bytecode
1303  /// stream.
1304  template <typename T>
1305  const void *readFromMemory() {
1306  size_t index = *curCodeIt++;
1307 
1308  // If this type is an SSA value, it can only be stored in non-const memory.
1309  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1310  Value>::value ||
1311  index < memory.size())
1312  return memory[index];
1313 
1314  // Otherwise, if this index is not inbounds it is uniqued.
1315  return uniquedMemory[index - memory.size()];
1316  }
1317  template <typename T>
1318  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1319  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1320  }
1321  template <typename T>
1322  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1323  T>
1324  readImpl() {
1325  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1326  }
1327  template <typename T>
1328  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1329  switch (read<PDLValue::Kind>()) {
1330  case PDLValue::Kind::Attribute:
1331  return read<Attribute>();
1332  case PDLValue::Kind::Operation:
1333  return read<Operation *>();
1334  case PDLValue::Kind::Type:
1335  return read<Type>();
1336  case PDLValue::Kind::Value:
1337  return read<Value>();
1338  case PDLValue::Kind::TypeRange:
1339  return read<TypeRange *>();
1340  case PDLValue::Kind::ValueRange:
1341  return read<ValueRange *>();
1342  }
1343  llvm_unreachable("unhandled PDLValue::Kind");
1344  }
1345  template <typename T>
1346  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1347  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1348  "unexpected ByteCode address size");
1349  ByteCodeAddr result;
1350  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1351  curCodeIt += 2;
1352  return result;
1353  }
1354  template <typename T>
1355  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1356  return *curCodeIt++;
1357  }
1358  template <typename T>
1359  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1360  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1361  }
1362 
1363  /// Assign the given range to the given memory index. This allocates a new
1364  /// range object if necessary.
1365  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1366  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1367  unsigned rangeIndex) {
1368  // Utility functor used to type-erase the assignment.
1369  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1370  // If the input range is empty, we don't need to allocate anything.
1371  if (range.empty()) {
1372  rangeMemory[rangeIndex] = {};
1373  } else {
1374  // Allocate a buffer for this type range.
1375  llvm::OwningArrayRef<T> storage(llvm::size(range));
1376  llvm::copy(range, storage.begin());
1377 
1378  // Assign this to the range slot and use the range as the value for the
1379  // memory index.
1380  allocatedRangeMemory.emplace_back(std::move(storage));
1381  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1382  }
1383  memory[memIndex] = &rangeMemory[rangeIndex];
1384  };
1385 
1386  // Dispatch based on the concrete range type.
1387  if constexpr (std::is_same_v<T, Type>) {
1388  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1389  } else if constexpr (std::is_same_v<T, Value>) {
1390  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1391  } else {
1392  llvm_unreachable("unhandled range type");
1393  }
1394  }
1395 
1396  /// The underlying bytecode buffer.
1397  const ByteCodeField *curCodeIt;
1398 
1399  /// The stack of bytecode positions at which to resume operation.
1401 
1402  /// The current execution memory.
1404  MutableArrayRef<OwningOpRange> opRangeMemory;
1405  MutableArrayRef<TypeRange> typeRangeMemory;
1406  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1407  MutableArrayRef<ValueRange> valueRangeMemory;
1408  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1409 
1410  /// The current loop indices.
1411  MutableArrayRef<unsigned> loopIndex;
1412 
1413  /// References to ByteCode data necessary for execution.
1414  ArrayRef<const void *> uniquedMemory;
1416  ArrayRef<PatternBenefit> currentPatternBenefits;
1418  ArrayRef<PDLConstraintFunction> constraintFunctions;
1419  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1420 };
1421 } // namespace
1422 
1423 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1424  LDBG() << "Executing ApplyConstraint:";
1425  ByteCodeField fun_idx = read();
1427  readList<PDLValue>(args);
1428 
1429  LDBG() << " * Arguments: " << llvm::interleaved(args);
1430 
1431  ByteCodeField isNegated = read();
1432  LDBG() << " * isNegated: " << isNegated;
1433 
1434  ByteCodeField numResults = read();
1435  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1436  ByteCodeRewriteResultList results(numResults);
1437  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1438  [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1439  if (succeeded(rewriteResult)) {
1440  LDBG() << " * Constraint succeeded, results: "
1441  << llvm::interleaved(constraintResults);
1442  } else {
1443  LDBG() << " * Constraint failed";
1444  }
1445  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1446  "native PDL rewrite function succeeded but returned "
1447  "unexpected number of results");
1448  processNativeFunResults(results, numResults, rewriteResult);
1449 
1450  // Depending on the constraint jump to the proper destination.
1451  selectJump(isNegated != succeeded(rewriteResult));
1452 }
1453 
1454 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1455  LDBG() << "Executing ApplyRewrite:";
1456  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1458  readList<PDLValue>(args);
1459 
1460  LDBG() << " * Arguments: " << llvm::interleaved(args);
1461 
1462  // Execute the rewrite function.
1463  ByteCodeField numResults = read();
1464  ByteCodeRewriteResultList results(numResults);
1465  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1466 
1467  assert(results.getResults().size() == numResults &&
1468  "native PDL rewrite function returned unexpected number of results");
1469 
1470  processNativeFunResults(results, numResults, rewriteResult);
1471 
1472  if (failed(rewriteResult)) {
1473  LDBG() << " - Failed";
1474  return failure();
1475  }
1476  return success();
1477 }
1478 
1479 void ByteCodeExecutor::processNativeFunResults(
1480  ByteCodeRewriteResultList &results, unsigned numResults,
1481  LogicalResult &rewriteResult) {
1482  if (failed(rewriteResult)) {
1483  // Skip the according number of values on the buffer on failure and exit
1484  // early as there are no results to process.
1485  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1486  const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1487  if (resultKind == PDLValue::Kind::TypeRange ||
1488  resultKind == PDLValue::Kind::ValueRange) {
1489  skip(2);
1490  } else {
1491  skip(1);
1492  }
1493  }
1494  return;
1495  }
1496 
1497  // Store the results in the bytecode memory
1498  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1499  PDLValue::Kind resultKind = read<PDLValue::Kind>();
1500  (void)resultKind;
1501  PDLValue result = results.getResults()[resultIdx];
1502  LDBG() << " * Result: " << result;
1503  assert(result.getKind() == resultKind &&
1504  "native PDL rewrite function returned an unexpected type of "
1505  "result");
1506  // If the result is a range, we need to copy it over to the bytecodes
1507  // range memory.
1508  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1509  unsigned rangeIndex = read();
1510  typeRangeMemory[rangeIndex] = *typeRange;
1511  memory[read()] = &typeRangeMemory[rangeIndex];
1512  } else if (std::optional<ValueRange> valueRange =
1513  result.dyn_cast<ValueRange>()) {
1514  unsigned rangeIndex = read();
1515  valueRangeMemory[rangeIndex] = *valueRange;
1516  memory[read()] = &valueRangeMemory[rangeIndex];
1517  } else {
1518  memory[read()] = result.getAsOpaquePointer();
1519  }
1520  }
1521 
1522  // Copy over any underlying storage allocated for result ranges.
1523  for (auto &it : results.getAllocatedTypeRanges())
1524  allocatedTypeRangeMemory.push_back(std::move(it));
1525  for (auto &it : results.getAllocatedValueRanges())
1526  allocatedValueRangeMemory.push_back(std::move(it));
1527 }
1528 
1529 void ByteCodeExecutor::executeAreEqual() {
1530  LDBG() << "Executing AreEqual:";
1531  const void *lhs = read<const void *>();
1532  const void *rhs = read<const void *>();
1533 
1534  LDBG() << " * " << lhs << " == " << rhs;
1535  selectJump(lhs == rhs);
1536 }
1537 
1538 void ByteCodeExecutor::executeAreRangesEqual() {
1539  LDBG() << "Executing AreRangesEqual:";
1540  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1541  const void *lhs = read<const void *>();
1542  const void *rhs = read<const void *>();
1543 
1544  switch (valueKind) {
1545  case PDLValue::Kind::TypeRange: {
1546  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1547  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1548  LDBG() << " * " << lhs << " == " << rhs;
1549  selectJump(*lhsRange == *rhsRange);
1550  break;
1551  }
1552  case PDLValue::Kind::ValueRange: {
1553  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1554  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1555  LDBG() << " * " << lhs << " == " << rhs;
1556  selectJump(*lhsRange == *rhsRange);
1557  break;
1558  }
1559  default:
1560  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1561  }
1562 }
1563 
1564 void ByteCodeExecutor::executeBranch() {
1565  LDBG() << "Executing Branch";
1566  curCodeIt = &code[read<ByteCodeAddr>()];
1567 }
1568 
1569 void ByteCodeExecutor::executeCheckOperandCount() {
1570  LDBG() << "Executing CheckOperandCount:";
1571  Operation *op = read<Operation *>();
1572  uint32_t expectedCount = read<uint32_t>();
1573  bool compareAtLeast = read();
1574 
1575  LDBG() << " * Found: " << op->getNumOperands()
1576  << "\n * Expected: " << expectedCount
1577  << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1578  if (compareAtLeast)
1579  selectJump(op->getNumOperands() >= expectedCount);
1580  else
1581  selectJump(op->getNumOperands() == expectedCount);
1582 }
1583 
1584 void ByteCodeExecutor::executeCheckOperationName() {
1585  LDBG() << "Executing CheckOperationName:";
1586  Operation *op = read<Operation *>();
1587  OperationName expectedName = read<OperationName>();
1588 
1589  LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
1590  << expectedName << "\"";
1591  selectJump(op->getName() == expectedName);
1592 }
1593 
1594 void ByteCodeExecutor::executeCheckResultCount() {
1595  LDBG() << "Executing CheckResultCount:";
1596  Operation *op = read<Operation *>();
1597  uint32_t expectedCount = read<uint32_t>();
1598  bool compareAtLeast = read();
1599 
1600  LDBG() << " * Found: " << op->getNumResults()
1601  << "\n * Expected: " << expectedCount
1602  << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1603  if (compareAtLeast)
1604  selectJump(op->getNumResults() >= expectedCount);
1605  else
1606  selectJump(op->getNumResults() == expectedCount);
1607 }
1608 
1609 void ByteCodeExecutor::executeCheckTypes() {
1610  LDBG() << "Executing AreEqual:";
1611  TypeRange *lhs = read<TypeRange *>();
1612  Attribute rhs = read<Attribute>();
1613  LDBG() << " * " << lhs << " == " << rhs;
1614 
1615  selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1616 }
1617 
1618 void ByteCodeExecutor::executeContinue() {
1619  ByteCodeField level = read();
1620  LDBG() << "Executing Continue\n * Level: " << level;
1621  ++loopIndex[level];
1622  popCodeIt();
1623 }
1624 
1625 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1626  LDBG() << "Executing CreateConstantTypeRange:";
1627  unsigned memIndex = read();
1628  unsigned rangeIndex = read();
1629  ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1630 
1631  LDBG() << " * Types: " << typesAttr;
1632  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1633  rangeIndex);
1634 }
1635 
1636 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1637  Location mainRewriteLoc) {
1638  LDBG() << "Executing CreateOperation:";
1639 
1640  unsigned memIndex = read();
1641  OperationState state(mainRewriteLoc, read<OperationName>());
1642  readList(state.operands);
1643  for (unsigned i = 0, e = read(); i != e; ++i) {
1644  StringAttr name = read<StringAttr>();
1645  if (Attribute attr = read<Attribute>())
1646  state.addAttribute(name, attr);
1647  }
1648 
1649  // Read in the result types. If the "size" is the sentinel value, this
1650  // indicates that the result types should be inferred.
1651  unsigned numResults = read();
1652  if (numResults == kInferTypesMarker) {
1653  InferTypeOpInterface::Concept *inferInterface =
1654  state.name.getInterface<InferTypeOpInterface>();
1655  assert(inferInterface &&
1656  "expected operation to provide InferTypeOpInterface");
1657 
1658  // TODO: Handle failure.
1659  if (failed(inferInterface->inferReturnTypes(
1660  state.getContext(), state.location, state.operands,
1661  state.attributes.getDictionary(state.getContext()),
1662  state.getRawProperties(), state.regions, state.types)))
1663  return;
1664  } else {
1665  // Otherwise, this is a fixed number of results.
1666  for (unsigned i = 0; i != numResults; ++i) {
1667  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1668  state.types.push_back(read<Type>());
1669  } else {
1670  TypeRange *resultTypes = read<TypeRange *>();
1671  state.types.append(resultTypes->begin(), resultTypes->end());
1672  }
1673  }
1674  }
1675 
1676  Operation *resultOp = rewriter.create(state);
1677  memory[memIndex] = resultOp;
1678 
1679  LDBG() << " * Attributes: "
1680  << state.attributes.getDictionary(state.getContext())
1681  << "\n * Operands: " << llvm::interleaved(state.operands)
1682  << "\n * Result Types: " << llvm::interleaved(state.types)
1683  << "\n * Result: " << *resultOp;
1684 }
1685 
1686 template <typename T>
1687 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1688  LDBG() << "Executing CreateDynamic" << type << "Range:";
1689  unsigned memIndex = read();
1690  unsigned rangeIndex = read();
1691  SmallVector<T> values;
1692  readList(values);
1693 
1694  LDBG() << " * " << type << "s: " << llvm::interleaved(values);
1695 
1696  assignRangeToMemory(values, memIndex, rangeIndex);
1697 }
1698 
1699 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1700  LDBG() << "Executing EraseOp:";
1701  Operation *op = read<Operation *>();
1702 
1703  LDBG() << " * Operation: " << *op;
1704  rewriter.eraseOp(op);
1705 }
1706 
1707 template <typename T, typename Range, PDLValue::Kind kind>
1708 void ByteCodeExecutor::executeExtract() {
1709  LDBG() << "Executing Extract" << kind << ":";
1710  Range *range = read<Range *>();
1711  unsigned index = read<uint32_t>();
1712  unsigned memIndex = read();
1713 
1714  if (!range) {
1715  memory[memIndex] = nullptr;
1716  return;
1717  }
1718 
1719  T result = index < range->size() ? (*range)[index] : T();
1720  LDBG() << " * " << kind << "s(" << range->size() << ")";
1721  LDBG() << " * Index: " << index;
1722  LDBG() << " * Result: " << result;
1723  storeToMemory(memIndex, result);
1724 }
1725 
1726 void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
1727 
1728 void ByteCodeExecutor::executeForEach() {
1729  LDBG() << "Executing ForEach:";
1730  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1731  unsigned rangeIndex = read();
1732  unsigned memIndex = read();
1733  const void *value = nullptr;
1734 
1735  switch (read<PDLValue::Kind>()) {
1736  case PDLValue::Kind::Operation: {
1737  unsigned &index = loopIndex[read()];
1738  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1739  assert(index <= array.size() && "iterated past the end");
1740  if (index < array.size()) {
1741  LDBG() << " * Result: " << array[index];
1742  value = array[index];
1743  break;
1744  }
1745 
1746  LDBG() << " * Done";
1747  index = 0;
1748  selectJump(size_t(0));
1749  return;
1750  }
1751  default:
1752  llvm_unreachable("unexpected `ForEach` value kind");
1753  }
1754 
1755  // Store the iterate value and the stack address.
1756  memory[memIndex] = value;
1757  pushCodeIt(prevCodeIt);
1758 
1759  // Skip over the successor (we will enter the body of the loop).
1760  read<ByteCodeAddr>();
1761 }
1762 
1763 void ByteCodeExecutor::executeGetAttribute() {
1764  LDBG() << "Executing GetAttribute:";
1765  unsigned memIndex = read();
1766  Operation *op = read<Operation *>();
1767  StringAttr attrName = read<StringAttr>();
1768  Attribute attr = op->getAttr(attrName);
1769 
1770  LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
1771  << "\n * Result: " << attr;
1772  memory[memIndex] = attr.getAsOpaquePointer();
1773 }
1774 
1775 void ByteCodeExecutor::executeGetAttributeType() {
1776  LDBG() << "Executing GetAttributeType:";
1777  unsigned memIndex = read();
1778  Attribute attr = read<Attribute>();
1779  Type type;
1780  if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1781  type = typedAttr.getType();
1782 
1783  LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
1784  memory[memIndex] = type.getAsOpaquePointer();
1785 }
1786 
1787 void ByteCodeExecutor::executeGetDefiningOp() {
1788  LDBG() << "Executing GetDefiningOp:";
1789  unsigned memIndex = read();
1790  Operation *op = nullptr;
1791  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1792  Value value = read<Value>();
1793  if (value)
1794  op = value.getDefiningOp();
1795  LDBG() << " * Value: " << value;
1796  } else {
1797  ValueRange *values = read<ValueRange *>();
1798  if (values && !values->empty()) {
1799  op = values->front().getDefiningOp();
1800  }
1801  LDBG() << " * Values: " << values;
1802  }
1803 
1804  LDBG() << " * Result: " << op;
1805  memory[memIndex] = op;
1806 }
1807 
1808 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1809  Operation *op = read<Operation *>();
1810  unsigned memIndex = read();
1811  Value operand =
1812  index < op->getNumOperands() ? op->getOperand(index) : Value();
1813 
1814  LDBG() << " * Operation: " << *op << "\n * Index: " << index
1815  << "\n * Result: " << operand;
1816  memory[memIndex] = operand.getAsOpaquePointer();
1817 }
1818 
1819 /// This function is the internal implementation of `GetResults` and
1820 /// `GetOperands` that provides support for extracting a value range from the
1821 /// given operation.
1822 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1823 static void *
1824 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1825  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1826  MutableArrayRef<ValueRange> valueRangeMemory) {
1827  // Check for the sentinel index that signals that all values should be
1828  // returned.
1829  if (index == std::numeric_limits<uint32_t>::max()) {
1830  LDBG() << " * Getting all values";
1831  // `values` is already the full value range.
1832 
1833  // Otherwise, check to see if this operation uses AttrSizedSegments.
1834  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1835  LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
1836 
1837  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1838  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1839  return nullptr;
1840 
1841  ArrayRef<int32_t> segments = segmentAttr;
1842  unsigned startIndex =
1843  std::accumulate(segments.begin(), segments.begin() + index, 0);
1844  values = values.slice(startIndex, *std::next(segments.begin(), index));
1845 
1846  LDBG() << " * Extracting range[" << startIndex << ", "
1847  << *std::next(segments.begin(), index) << "]";
1848 
1849  // Otherwise, assume this is the last operand group of the operation.
1850  // FIXME: We currently don't support operations with
1851  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1852  // have a way to detect it's presence.
1853  } else if (values.size() >= index) {
1854  LDBG() << " * Treating values as trailing variadic range";
1855  values = values.drop_front(index);
1856 
1857  // If we couldn't detect a way to compute the values, bail out.
1858  } else {
1859  return nullptr;
1860  }
1861 
1862  // If the range index is valid, we are returning a range.
1863  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1864  valueRangeMemory[rangeIndex] = values;
1865  return &valueRangeMemory[rangeIndex];
1866  }
1867 
1868  // If a range index wasn't provided, the range is required to be non-variadic.
1869  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1870 }
1871 
1872 void ByteCodeExecutor::executeGetOperands() {
1873  LDBG() << "Executing GetOperands:";
1874  unsigned index = read<uint32_t>();
1875  Operation *op = read<Operation *>();
1876  ByteCodeField rangeIndex = read();
1877 
1878  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1879  op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1880  valueRangeMemory);
1881  if (!result)
1882  LDBG() << " * Invalid operand range";
1883  memory[read()] = result;
1884 }
1885 
1886 void ByteCodeExecutor::executeGetResult(unsigned index) {
1887  Operation *op = read<Operation *>();
1888  unsigned memIndex = read();
1889  OpResult result =
1890  index < op->getNumResults() ? op->getResult(index) : OpResult();
1891 
1892  LDBG() << " * Operation: " << *op << "\n * Index: " << index
1893  << "\n * Result: " << result;
1894  memory[memIndex] = result.getAsOpaquePointer();
1895 }
1896 
1897 void ByteCodeExecutor::executeGetResults() {
1898  LDBG() << "Executing GetResults:";
1899  unsigned index = read<uint32_t>();
1900  Operation *op = read<Operation *>();
1901  ByteCodeField rangeIndex = read();
1902 
1903  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1904  op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1905  valueRangeMemory);
1906  if (!result)
1907  LDBG() << " * Invalid result range";
1908  memory[read()] = result;
1909 }
1910 
1911 void ByteCodeExecutor::executeGetUsers() {
1912  LDBG() << "Executing GetUsers:";
1913  unsigned memIndex = read();
1914  unsigned rangeIndex = read();
1915  OwningOpRange &range = opRangeMemory[rangeIndex];
1916  memory[memIndex] = &range;
1917 
1918  range = OwningOpRange();
1919  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1920  // Read the value.
1921  Value value = read<Value>();
1922  if (!value)
1923  return;
1924  LDBG() << " * Value: " << value;
1925 
1926  // Extract the users of a single value.
1927  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1928  llvm::copy(value.getUsers(), range.begin());
1929  } else {
1930  // Read a range of values.
1931  ValueRange *values = read<ValueRange *>();
1932  if (!values)
1933  return;
1934  LDBG() << " * Values (" << values->size()
1935  << "): " << llvm::interleaved(*values);
1936 
1937  // Extract all the users of a range of values.
1939  for (Value value : *values)
1940  users.append(value.user_begin(), value.user_end());
1941  range = OwningOpRange(users.size());
1942  llvm::copy(users, range.begin());
1943  }
1944 
1945  LDBG() << " * Result: " << range.size() << " operations";
1946 }
1947 
1948 void ByteCodeExecutor::executeGetValueType() {
1949  LDBG() << "Executing GetValueType:";
1950  unsigned memIndex = read();
1951  Value value = read<Value>();
1952  Type type = value ? value.getType() : Type();
1953 
1954  LDBG() << " * Value: " << value << "\n * Result: " << type;
1955  memory[memIndex] = type.getAsOpaquePointer();
1956 }
1957 
1958 void ByteCodeExecutor::executeGetValueRangeTypes() {
1959  LDBG() << "Executing GetValueRangeTypes:";
1960  unsigned memIndex = read();
1961  unsigned rangeIndex = read();
1962  ValueRange *values = read<ValueRange *>();
1963  if (!values) {
1964  LDBG() << " * Values: <NULL>";
1965  memory[memIndex] = nullptr;
1966  return;
1967  }
1968 
1969  LDBG() << " * Values (" << values->size()
1970  << "): " << llvm::interleaved(*values)
1971  << "\n * Result: " << llvm::interleaved(values->getType());
1972  typeRangeMemory[rangeIndex] = values->getType();
1973  memory[memIndex] = &typeRangeMemory[rangeIndex];
1974 }
1975 
1976 void ByteCodeExecutor::executeIsNotNull() {
1977  LDBG() << "Executing IsNotNull:";
1978  const void *value = read<const void *>();
1979 
1980  LDBG() << " * Value: " << value;
1981  selectJump(value != nullptr);
1982 }
1983 
1984 void ByteCodeExecutor::executeRecordMatch(
1985  PatternRewriter &rewriter,
1987  LDBG() << "Executing RecordMatch:";
1988  unsigned patternIndex = read();
1989  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1990  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1991 
1992  // If the benefit of the pattern is impossible, skip the processing of the
1993  // rest of the pattern.
1994  if (benefit.isImpossibleToMatch()) {
1995  LDBG() << " * Benefit: Impossible To Match";
1996  curCodeIt = dest;
1997  return;
1998  }
1999 
2000  // Create a fused location containing the locations of each of the
2001  // operations used in the match. This will be used as the location for
2002  // created operations during the rewrite that don't already have an
2003  // explicit location set.
2004  unsigned numMatchLocs = read();
2005  SmallVector<Location, 4> matchLocs;
2006  matchLocs.reserve(numMatchLocs);
2007  for (unsigned i = 0; i != numMatchLocs; ++i)
2008  matchLocs.push_back(read<Operation *>()->getLoc());
2009  Location matchLoc = rewriter.getFusedLoc(matchLocs);
2010 
2011  LDBG() << " * Benefit: " << benefit.getBenefit();
2012  LDBG() << " * Location: " << matchLoc;
2013  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2014  PDLByteCode::MatchResult &match = matches.back();
2015 
2016  // Record all of the inputs to the match. If any of the inputs are ranges, we
2017  // will also need to remap the range pointer to memory stored in the match
2018  // state.
2019  unsigned numInputs = read();
2020  match.values.reserve(numInputs);
2021  match.typeRangeValues.reserve(numInputs);
2022  match.valueRangeValues.reserve(numInputs);
2023  for (unsigned i = 0; i < numInputs; ++i) {
2024  switch (read<PDLValue::Kind>()) {
2025  case PDLValue::Kind::TypeRange:
2026  match.typeRangeValues.push_back(*read<TypeRange *>());
2027  match.values.push_back(&match.typeRangeValues.back());
2028  break;
2029  case PDLValue::Kind::ValueRange:
2030  match.valueRangeValues.push_back(*read<ValueRange *>());
2031  match.values.push_back(&match.valueRangeValues.back());
2032  break;
2033  default:
2034  match.values.push_back(read<const void *>());
2035  break;
2036  }
2037  }
2038  curCodeIt = dest;
2039 }
2040 
2041 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2042  LDBG() << "Executing ReplaceOp:";
2043  Operation *op = read<Operation *>();
2045  readList(args);
2046 
2047  LDBG() << " * Operation: " << *op
2048  << "\n * Values: " << llvm::interleaved(args);
2049  rewriter.replaceOp(op, args);
2050 }
2051 
2052 void ByteCodeExecutor::executeSwitchAttribute() {
2053  LDBG() << "Executing SwitchAttribute:";
2054  Attribute value = read<Attribute>();
2055  ArrayAttr cases = read<ArrayAttr>();
2056  handleSwitch(value, cases);
2057 }
2058 
2059 void ByteCodeExecutor::executeSwitchOperandCount() {
2060  LDBG() << "Executing SwitchOperandCount:";
2061  Operation *op = read<Operation *>();
2062  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2063 
2064  LDBG() << " * Operation: " << *op;
2065  handleSwitch(op->getNumOperands(), cases);
2066 }
2067 
2068 void ByteCodeExecutor::executeSwitchOperationName() {
2069  LDBG() << "Executing SwitchOperationName:";
2070  OperationName value = read<Operation *>()->getName();
2071  size_t caseCount = read();
2072 
2073  // The operation names are stored in-line, so to print them out for
2074  // debugging purposes we need to read the array before executing the
2075  // switch so that we can display all of the possible values.
2076  LLVM_DEBUG({
2077  const ByteCodeField *prevCodeIt = curCodeIt;
2078  LDBG() << " * Value: " << value << "\n * Cases: "
2079  << llvm::interleaved(
2080  llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
2081  return read<OperationName>();
2082  }));
2083  curCodeIt = prevCodeIt;
2084  });
2085 
2086  // Try to find the switch value within any of the cases.
2087  for (size_t i = 0; i != caseCount; ++i) {
2088  if (read<OperationName>() == value) {
2089  curCodeIt += (caseCount - i - 1);
2090  return selectJump(i + 1);
2091  }
2092  }
2093  selectJump(size_t(0));
2094 }
2095 
2096 void ByteCodeExecutor::executeSwitchResultCount() {
2097  LDBG() << "Executing SwitchResultCount:";
2098  Operation *op = read<Operation *>();
2099  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2100 
2101  LDBG() << " * Operation: " << *op;
2102  handleSwitch(op->getNumResults(), cases);
2103 }
2104 
2105 void ByteCodeExecutor::executeSwitchType() {
2106  LDBG() << "Executing SwitchType:";
2107  Type value = read<Type>();
2108  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2109  handleSwitch(value, cases);
2110 }
2111 
2112 void ByteCodeExecutor::executeSwitchTypes() {
2113  LDBG() << "Executing SwitchTypes:";
2114  TypeRange *value = read<TypeRange *>();
2115  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2116  if (!value) {
2117  LDBG() << "Types: <NULL>";
2118  return selectJump(size_t(0));
2119  }
2120  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2121  return value == caseValue.getAsValueRange<TypeAttr>();
2122  });
2123 }
2124 
2125 LogicalResult
2126 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2128  std::optional<Location> mainRewriteLoc) {
2129  while (true) {
2130  // Print the location of the operation being executed.
2131  LDBG() << readInline<Location>();
2132 
2133  OpCode opCode = static_cast<OpCode>(read());
2134  switch (opCode) {
2135  case ApplyConstraint:
2136  executeApplyConstraint(rewriter);
2137  break;
2138  case ApplyRewrite:
2139  if (failed(executeApplyRewrite(rewriter)))
2140  return failure();
2141  break;
2142  case AreEqual:
2143  executeAreEqual();
2144  break;
2145  case AreRangesEqual:
2146  executeAreRangesEqual();
2147  break;
2148  case Branch:
2149  executeBranch();
2150  break;
2151  case CheckOperandCount:
2152  executeCheckOperandCount();
2153  break;
2154  case CheckOperationName:
2155  executeCheckOperationName();
2156  break;
2157  case CheckResultCount:
2158  executeCheckResultCount();
2159  break;
2160  case CheckTypes:
2161  executeCheckTypes();
2162  break;
2163  case Continue:
2164  executeContinue();
2165  break;
2166  case CreateConstantTypeRange:
2167  executeCreateConstantTypeRange();
2168  break;
2169  case CreateOperation:
2170  executeCreateOperation(rewriter, *mainRewriteLoc);
2171  break;
2172  case CreateDynamicTypeRange:
2173  executeDynamicCreateRange<Type>("Type");
2174  break;
2175  case CreateDynamicValueRange:
2176  executeDynamicCreateRange<Value>("Value");
2177  break;
2178  case EraseOp:
2179  executeEraseOp(rewriter);
2180  break;
2181  case ExtractOp:
2182  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2183  break;
2184  case ExtractType:
2185  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2186  break;
2187  case ExtractValue:
2188  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2189  break;
2190  case Finalize:
2191  executeFinalize();
2192  LDBG() << "";
2193  return success();
2194  case ForEach:
2195  executeForEach();
2196  break;
2197  case GetAttribute:
2198  executeGetAttribute();
2199  break;
2200  case GetAttributeType:
2201  executeGetAttributeType();
2202  break;
2203  case GetDefiningOp:
2204  executeGetDefiningOp();
2205  break;
2206  case GetOperand0:
2207  case GetOperand1:
2208  case GetOperand2:
2209  case GetOperand3: {
2210  unsigned index = opCode - GetOperand0;
2211  LDBG() << "Executing GetOperand" << index << ":";
2212  executeGetOperand(index);
2213  break;
2214  }
2215  case GetOperandN:
2216  LDBG() << "Executing GetOperandN:";
2217  executeGetOperand(read<uint32_t>());
2218  break;
2219  case GetOperands:
2220  executeGetOperands();
2221  break;
2222  case GetResult0:
2223  case GetResult1:
2224  case GetResult2:
2225  case GetResult3: {
2226  unsigned index = opCode - GetResult0;
2227  LDBG() << "Executing GetResult" << index << ":";
2228  executeGetResult(index);
2229  break;
2230  }
2231  case GetResultN:
2232  LDBG() << "Executing GetResultN:";
2233  executeGetResult(read<uint32_t>());
2234  break;
2235  case GetResults:
2236  executeGetResults();
2237  break;
2238  case GetUsers:
2239  executeGetUsers();
2240  break;
2241  case GetValueType:
2242  executeGetValueType();
2243  break;
2244  case GetValueRangeTypes:
2245  executeGetValueRangeTypes();
2246  break;
2247  case IsNotNull:
2248  executeIsNotNull();
2249  break;
2250  case RecordMatch:
2251  assert(matches &&
2252  "expected matches to be provided when executing the matcher");
2253  executeRecordMatch(rewriter, *matches);
2254  break;
2255  case ReplaceOp:
2256  executeReplaceOp(rewriter);
2257  break;
2258  case SwitchAttribute:
2259  executeSwitchAttribute();
2260  break;
2261  case SwitchOperandCount:
2262  executeSwitchOperandCount();
2263  break;
2264  case SwitchOperationName:
2265  executeSwitchOperationName();
2266  break;
2267  case SwitchResultCount:
2268  executeSwitchResultCount();
2269  break;
2270  case SwitchType:
2271  executeSwitchType();
2272  break;
2273  case SwitchTypes:
2274  executeSwitchTypes();
2275  break;
2276  }
2277  LDBG() << "";
2278  }
2279 }
2280 
2281 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2283  PDLByteCodeMutableState &state) const {
2284  // The first memory slot is always the root operation.
2285  state.memory[0] = op;
2286 
2287  // The matcher function always starts at code address 0.
2288  ByteCodeExecutor executor(
2289  matcherByteCode.data(), state.memory, state.opRangeMemory,
2290  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2291  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2292  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2293  constraintFunctions, rewriteFunctions);
2294  LogicalResult executeResult = executor.execute(rewriter, &matches);
2295  (void)executeResult;
2296  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2297 
2298  // Order the found matches by benefit.
2299  llvm::stable_sort(matches,
2300  [](const MatchResult &lhs, const MatchResult &rhs) {
2301  return lhs.benefit > rhs.benefit;
2302  });
2303 }
2304 
2305 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2306  const MatchResult &match,
2307  PDLByteCodeMutableState &state) const {
2308  auto *configSet = match.pattern->getConfigSet();
2309  if (configSet)
2310  configSet->notifyRewriteBegin(rewriter);
2311 
2312  // The arguments of the rewrite function are stored at the start of the
2313  // memory buffer.
2314  llvm::copy(match.values, state.memory.begin());
2315 
2316  ByteCodeExecutor executor(
2317  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2318  state.opRangeMemory, state.typeRangeMemory,
2319  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2320  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2321  rewriterByteCode, state.currentPatternBenefits, patterns,
2322  constraintFunctions, rewriteFunctions);
2323  LogicalResult result =
2324  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2325 
2326  if (configSet)
2327  configSet->notifyRewriteEnd(rewriter);
2328 
2329  // If the rewrite failed, check if the pattern rewriter can recover. If it
2330  // can, we can signal to the pattern applicator to keep trying patterns. If it
2331  // doesn't, we need to bail. Bailing here should be fine, given that we have
2332  // no means to propagate such a failure to the user, and it also indicates a
2333  // bug in the user code (i.e. failable rewrites should not be used with
2334  // pattern rewriters that don't support it).
2335  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2336  LDBG() << " and rollback is not supported - aborting";
2337  llvm::report_fatal_error(
2338  "Native PDL Rewrite failed, but the pattern "
2339  "rewriter doesn't support recovery. Failable pattern rewrites should "
2340  "not be used with pattern rewriters that do not support them.");
2341  }
2342  return result;
2343 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1824
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:175
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1242::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:73
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:26
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This is a value defined by a result of an operation.
Definition: Value.h:447
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:792
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:164
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:215
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:216
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
user_iterator user_end() const
Definition: Value.h:217
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition: ByteCode.h:235
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size