MLIR  22.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/Format.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/InterleavedRange.h"
27 #include <numeric>
28 #include <optional>
29 
30 #define DEBUG_TYPE "pdl-bytecode"
31 
32 using namespace mlir;
33 using namespace mlir::detail;
34 
35 //===----------------------------------------------------------------------===//
36 // PDLByteCodePattern
37 //===----------------------------------------------------------------------===//
38 
39 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
40  PDLPatternConfigSet *configSet,
41  ByteCodeAddr rewriterAddr) {
42  PatternBenefit benefit = matchOp.getBenefit();
43  MLIRContext *ctx = matchOp.getContext();
44 
45  // Collect the set of generated operations.
46  SmallVector<StringRef, 8> generatedOps;
47  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
48  generatedOps =
49  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 
51  // Check to see if this is pattern matches a specific operation type.
52  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
53  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
54  generatedOps);
55  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
56  benefit, ctx, generatedOps);
57 }
58 
59 //===----------------------------------------------------------------------===//
60 // PDLByteCodeMutableState
61 //===----------------------------------------------------------------------===//
62 
63 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
64 /// to the position of the pattern within the range returned by
65 /// `PDLByteCode::getPatterns`.
66 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
67  PatternBenefit benefit) {
68  currentPatternBenefits[patternIndex] = benefit;
69 }
70 
71 /// Cleanup any allocated state after a full match/rewrite has been completed.
72 /// This method should be called irregardless of whether the match+rewrite was a
73 /// success or not.
75  allocatedTypeRangeMemory.clear();
76  allocatedValueRangeMemory.clear();
77 }
78 
79 //===----------------------------------------------------------------------===//
80 // Bytecode OpCodes
81 //===----------------------------------------------------------------------===//
82 
83 namespace {
84 enum OpCode : ByteCodeField {
85  /// Apply an externally registered constraint.
86  ApplyConstraint,
87  /// Apply an externally registered rewrite.
88  ApplyRewrite,
89  /// Check if two generic values are equal.
90  AreEqual,
91  /// Check if two ranges are equal.
92  AreRangesEqual,
93  /// Unconditional branch.
94  Branch,
95  /// Compare the operand count of an operation with a constant.
96  CheckOperandCount,
97  /// Compare the name of an operation with a constant.
98  CheckOperationName,
99  /// Compare the result count of an operation with a constant.
100  CheckResultCount,
101  /// Compare a range of types to a constant range of types.
102  CheckTypes,
103  /// Continue to the next iteration of a loop.
104  Continue,
105  /// Create a type range from a list of constant types.
106  CreateConstantTypeRange,
107  /// Create an operation.
108  CreateOperation,
109  /// Create a type range from a list of dynamic types.
110  CreateDynamicTypeRange,
111  /// Create a value range.
112  CreateDynamicValueRange,
113  /// Erase an operation.
114  EraseOp,
115  /// Extract the op from a range at the specified index.
116  ExtractOp,
117  /// Extract the type from a range at the specified index.
118  ExtractType,
119  /// Extract the value from a range at the specified index.
120  ExtractValue,
121  /// Terminate a matcher or rewrite sequence.
122  Finalize,
123  /// Iterate over a range of values.
124  ForEach,
125  /// Get a specific attribute of an operation.
126  GetAttribute,
127  /// Get the type of an attribute.
128  GetAttributeType,
129  /// Get the defining operation of a value.
130  GetDefiningOp,
131  /// Get a specific operand of an operation.
132  GetOperand0,
133  GetOperand1,
134  GetOperand2,
135  GetOperand3,
136  GetOperandN,
137  /// Get a specific operand group of an operation.
138  GetOperands,
139  /// Get a specific result of an operation.
140  GetResult0,
141  GetResult1,
142  GetResult2,
143  GetResult3,
144  GetResultN,
145  /// Get a specific result group of an operation.
146  GetResults,
147  /// Get the users of a value or a range of values.
148  GetUsers,
149  /// Get the type of a value.
150  GetValueType,
151  /// Get the types of a value range.
152  GetValueRangeTypes,
153  /// Check if a generic value is not null.
154  IsNotNull,
155  /// Record a successful pattern match.
156  RecordMatch,
157  /// Replace an operation.
158  ReplaceOp,
159  /// Compare an attribute with a set of constants.
160  SwitchAttribute,
161  /// Compare the operand count of an operation with a set of constants.
162  SwitchOperandCount,
163  /// Compare the name of an operation with a set of constants.
164  SwitchOperationName,
165  /// Compare the result count of an operation with a set of constants.
166  SwitchResultCount,
167  /// Compare a type with a set of constants.
168  SwitchType,
169  /// Compare a range of types with a set of constants.
170  SwitchTypes,
171 };
172 } // namespace
173 
174 /// A marker used to indicate if an operation should infer types.
175 static constexpr ByteCodeField kInferTypesMarker =
177 
178 //===----------------------------------------------------------------------===//
179 // ByteCode Generation
180 //===----------------------------------------------------------------------===//
181 
182 //===----------------------------------------------------------------------===//
183 // Generator
184 //===----------------------------------------------------------------------===//
185 
186 namespace {
187 struct ByteCodeLiveRange;
188 struct ByteCodeWriter;
189 
190 /// Check if the given class `T` can be converted to an opaque pointer.
191 template <typename T, typename... Args>
192 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
193 
194 /// This class represents the main generator for the pattern bytecode.
195 class Generator {
196 public:
197  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
198  SmallVectorImpl<ByteCodeField> &matcherByteCode,
199  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
201  ByteCodeField &maxValueMemoryIndex,
202  ByteCodeField &maxOpRangeMemoryIndex,
203  ByteCodeField &maxTypeRangeMemoryIndex,
204  ByteCodeField &maxValueRangeMemoryIndex,
205  ByteCodeField &maxLoopLevel,
206  llvm::StringMap<PDLConstraintFunction> &constraintFns,
207  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
210  rewriterByteCode(rewriterByteCode), patterns(patterns),
211  maxValueMemoryIndex(maxValueMemoryIndex),
212  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215  maxLoopLevel(maxLoopLevel), configMap(configMap) {
216  for (const auto &it : llvm::enumerate(constraintFns))
217  constraintToMemIndex.try_emplace(it.value().first(), it.index());
218  for (const auto &it : llvm::enumerate(rewriteFns))
219  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220  }
221 
222  /// Generate the bytecode for the given PDL interpreter module.
223  void generate(ModuleOp module);
224 
225  /// Return the memory index to use for the given value.
226  ByteCodeField &getMemIndex(Value value) {
227  assert(valueToMemIndex.count(value) &&
228  "expected memory index to be assigned");
229  return valueToMemIndex[value];
230  }
231 
232  /// Return the range memory index used to store the given range value.
233  ByteCodeField &getRangeStorageIndex(Value value) {
234  assert(valueToRangeIndex.count(value) &&
235  "expected range index to be assigned");
236  return valueToRangeIndex[value];
237  }
238 
239  /// Return an index to use when referring to the given data that is uniqued in
240  /// the MLIR context.
241  template <typename T>
242  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
243  getMemIndex(T val) {
244  const void *opaqueVal = val.getAsOpaquePointer();
245 
246  // Get or insert a reference to this value.
247  auto it = uniquedDataToMemIndex.try_emplace(
248  opaqueVal, maxValueMemoryIndex + uniquedData.size());
249  if (it.second)
250  uniquedData.push_back(opaqueVal);
251  return it.first->second;
252  }
253 
254 private:
255  /// Allocate memory indices for the results of operations within the matcher
256  /// and rewriters.
257  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258  ModuleOp rewriterModule);
259 
260  /// Generate the bytecode for the given operation.
261  void generate(Region *region, ByteCodeWriter &writer);
262  void generate(Operation *op, ByteCodeWriter &writer);
263  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
301 
302  /// Mapping from value to its corresponding memory index.
303  DenseMap<Value, ByteCodeField> valueToMemIndex;
304 
305  /// Mapping from a range value to its corresponding range storage index.
306  DenseMap<Value, ByteCodeField> valueToRangeIndex;
307 
308  /// Mapping from the name of an externally registered rewrite to its index in
309  /// the bytecode registry.
310  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311 
312  /// Mapping from the name of an externally registered constraint to its index
313  /// in the bytecode registry.
314  llvm::StringMap<ByteCodeField> constraintToMemIndex;
315 
316  /// Mapping from rewriter function name to the bytecode address of the
317  /// rewriter function in byte.
318  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
319 
320  /// Mapping from a uniqued storage object to its memory index within
321  /// `uniquedData`.
322  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
323 
324  /// The current level of the foreach loop.
325  ByteCodeField curLoopLevel = 0;
326 
327  /// The current MLIR context.
328  MLIRContext *ctx;
329 
330  /// Mapping from block to its address.
332 
333  /// Data of the ByteCode class to be populated.
334  std::vector<const void *> &uniquedData;
335  SmallVectorImpl<ByteCodeField> &matcherByteCode;
336  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
338  ByteCodeField &maxValueMemoryIndex;
339  ByteCodeField &maxOpRangeMemoryIndex;
340  ByteCodeField &maxTypeRangeMemoryIndex;
341  ByteCodeField &maxValueRangeMemoryIndex;
342  ByteCodeField &maxLoopLevel;
343 
344  /// A map of pattern configurations.
346 };
347 
348 /// This class provides utilities for writing a bytecode stream.
349 struct ByteCodeWriter {
350  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
351  : bytecode(bytecode), generator(generator) {}
352 
353  /// Append a field to the bytecode.
354  void append(ByteCodeField field) { bytecode.push_back(field); }
355  void append(OpCode opCode) { bytecode.push_back(opCode); }
356 
357  /// Append an address to the bytecode.
358  void append(ByteCodeAddr field) {
359  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
360  "unexpected ByteCode address size");
361 
362  ByteCodeField fieldParts[2];
363  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
364  bytecode.append({fieldParts[0], fieldParts[1]});
365  }
366 
367  /// Append a single successor to the bytecode, the exact address will need to
368  /// be resolved later.
369  void append(Block *successor) {
370  // Add back a reference to the successor so that the address can be resolved
371  // later.
372  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373  append(ByteCodeAddr(0));
374  }
375 
376  /// Append a successor range to the bytecode, the exact address will need to
377  /// be resolved later.
378  void append(SuccessorRange successors) {
379  for (Block *successor : successors)
380  append(successor);
381  }
382 
383  /// Append a range of values that will be read as generic PDLValues.
384  void appendPDLValueList(OperandRange values) {
385  bytecode.push_back(values.size());
386  for (Value value : values)
387  appendPDLValue(value);
388  }
389 
390  /// Append a value as a PDLValue.
391  void appendPDLValue(Value value) {
392  appendPDLValueKind(value);
393  append(value);
394  }
395 
396  /// Append the PDLValue::Kind of the given value.
397  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
398 
399  /// Append the PDLValue::Kind of the given type.
400  void appendPDLValueKind(Type type) {
403  .Case<pdl::AttributeType>(
404  [](Type) { return PDLValue::Kind::Attribute; })
405  .Case<pdl::OperationType>(
406  [](Type) { return PDLValue::Kind::Operation; })
407  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408  if (isa<pdl::TypeType>(rangeTy.getElementType()))
409  return PDLValue::Kind::TypeRange;
410  return PDLValue::Kind::ValueRange;
411  })
412  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
413  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
414  bytecode.push_back(static_cast<ByteCodeField>(kind));
415  }
416 
417  /// Append a value that will be stored in a memory slot and not inline within
418  /// the bytecode.
419  template <typename T>
420  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421  std::is_pointer<T>::value>
422  append(T value) {
423  bytecode.push_back(generator.getMemIndex(value));
424  }
425 
426  /// Append a range of values.
427  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
428  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
429  append(T range) {
430  bytecode.push_back(llvm::size(range));
431  for (auto it : range)
432  append(it);
433  }
434 
435  /// Append a variadic number of fields to the bytecode.
436  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
437  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
438  append(field);
439  append(field2, fields...);
440  }
441 
442  /// Appends a value as a pointer, stored inline within the bytecode.
443  template <typename T>
444  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445  appendInline(T value) {
446  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
447  const void *pointer = value.getAsOpaquePointer();
448  ByteCodeField fieldParts[numParts];
449  std::memcpy(fieldParts, &pointer, sizeof(const void *));
450  bytecode.append(fieldParts, fieldParts + numParts);
451  }
452 
453  /// Successor references in the bytecode that have yet to be resolved.
454  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
455 
456  /// The underlying bytecode buffer.
458 
459  /// The main generator producing PDL.
460  Generator &generator;
461 };
462 
463 /// This class represents a live range of PDL Interpreter values, containing
464 /// information about when values are live within a match/rewrite.
465 struct ByteCodeLiveRange {
466  using Set = llvm::IntervalMap<uint64_t, char, 16>;
467  using Allocator = Set::Allocator;
468 
469  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
470 
471  /// Union this live range with the one provided.
472  void unionWith(const ByteCodeLiveRange &rhs) {
473  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
474  ++it)
475  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
476  }
477 
478  /// Returns true if this range overlaps with the one provided.
479  bool overlaps(const ByteCodeLiveRange &rhs) const {
480  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
481  .valid();
482  }
483 
484  /// A map representing the ranges of the match/rewrite that a value is live in
485  /// the interpreter.
486  ///
487  /// We use std::unique_ptr here, because IntervalMap does not provide a
488  /// correct copy or move constructor. We can eliminate the pointer once
489  /// https://reviews.llvm.org/D113240 lands.
490  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
491 
492  /// The operation range storage index for this range.
493  std::optional<unsigned> opRangeIndex;
494 
495  /// The type range storage index for this range.
496  std::optional<unsigned> typeRangeIndex;
497 
498  /// The value range storage index for this range.
499  std::optional<unsigned> valueRangeIndex;
500 };
501 } // namespace
502 
503 void Generator::generate(ModuleOp module) {
504  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507  pdl_interp::PDLInterpDialect::getRewriterModuleName());
508  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
509 
510  // Allocate memory indices for the results of operations within the matcher
511  // and rewriters.
512  allocateMemoryIndices(matcherFunc, rewriterModule);
513 
514  // Generate code for the rewriter functions.
515  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
516  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518  for (Operation &op : rewriterFunc.getOps())
519  generate(&op, rewriterByteCodeWriter);
520  }
521  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522  "unexpected branches in rewriter function");
523 
524  // Generate code for the matcher function.
525  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
526  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
527 
528  // Resolve successor references in the matcher.
529  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530  ByteCodeAddr addr = blockToAddr[it.first];
531  for (unsigned offsetToFix : it.second)
532  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
533  }
534 }
535 
536 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537  ModuleOp rewriterModule) {
538  // Rewriters use simplistic allocation scheme that simply assigns an index to
539  // each result.
540  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542  auto processRewriterValue = [&](Value val) {
543  valueToMemIndex.try_emplace(val, index++);
544  if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545  Type elementTy = rangeType.getElementType();
546  if (isa<pdl::TypeType>(elementTy))
547  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548  else if (isa<pdl::ValueType>(elementTy))
549  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
550  }
551  };
552 
553  for (BlockArgument arg : rewriterFunc.getArguments())
554  processRewriterValue(arg);
555  rewriterFunc.getBody().walk([&](Operation *op) {
556  for (Value result : op->getResults())
557  processRewriterValue(result);
558  });
559  if (index > maxValueMemoryIndex)
560  maxValueMemoryIndex = index;
561  if (typeRangeIndex > maxTypeRangeMemoryIndex)
562  maxTypeRangeMemoryIndex = typeRangeIndex;
563  if (valueRangeIndex > maxValueRangeMemoryIndex)
564  maxValueRangeMemoryIndex = valueRangeIndex;
565  }
566 
567  // The matcher function uses a more sophisticated numbering that tries to
568  // minimize the number of memory indices assigned. This is done by determining
569  // a live range of the values within the matcher, then the allocation is just
570  // finding the minimal number of overlapping live ranges. This is essentially
571  // a simplified form of register allocation where we don't necessarily have a
572  // limited number of registers, but we still want to minimize the number used.
573  DenseMap<Operation *, unsigned> opToFirstIndex;
574  DenseMap<Operation *, unsigned> opToLastIndex;
575 
576  // A custom walk that marks the first and the last index of each operation.
577  // The entry marks the beginning of the liveness range for this operation,
578  // followed by nested operations, followed by the end of the liveness range.
579  unsigned index = 0;
580  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
581  opToFirstIndex.try_emplace(op, index++);
582  for (Region &region : op->getRegions())
583  for (Block &block : region.getBlocks())
584  for (Operation &nested : block)
585  walk(&nested);
586  opToLastIndex.try_emplace(op, index++);
587  };
588  walk(matcherFunc);
589 
590  // Liveness info for each of the defs within the matcher.
591  ByteCodeLiveRange::Allocator allocator;
592  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
593 
594  // Assign the root operation being matched to slot 0.
595  BlockArgument rootOpArg = matcherFunc.getArgument(0);
596  valueToMemIndex[rootOpArg] = 0;
597 
598  // Walk each of the blocks, computing the def interval that the value is used.
599  Liveness matcherLiveness(matcherFunc);
600  matcherFunc->walk([&](Block *block) {
601  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
602  assert(info && "expected liveness info for block");
603  auto processValue = [&](Value value, Operation *firstUseOrDef) {
604  // We don't need to process the root op argument, this value is always
605  // assigned to the first memory slot.
606  if (value == rootOpArg)
607  return;
608 
609  // Set indices for the range of this block that the value is used.
610  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611  defRangeIt->second.liveness->insert(
612  opToFirstIndex[firstUseOrDef],
613  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
614  /*dummyValue*/ 0);
615 
616  // Check to see if this value is a range type.
617  if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
618  Type eleType = rangeTy.getElementType();
619  if (isa<pdl::OperationType>(eleType))
620  defRangeIt->second.opRangeIndex = 0;
621  else if (isa<pdl::TypeType>(eleType))
622  defRangeIt->second.typeRangeIndex = 0;
623  else if (isa<pdl::ValueType>(eleType))
624  defRangeIt->second.valueRangeIndex = 0;
625  }
626  };
627 
628  // Process the live-ins of this block.
629  for (Value liveIn : info->in()) {
630  // Only process the value if it has been defined in the current region.
631  // Other values that span across pdl_interp.foreach will be added higher
632  // up. This ensures that the we keep them alive for the entire duration
633  // of the loop.
634  if (liveIn.getParentRegion() == block->getParent())
635  processValue(liveIn, &block->front());
636  }
637 
638  // Process the block arguments for the entry block (those are not live-in).
639  if (block->isEntryBlock()) {
640  for (Value argument : block->getArguments())
641  processValue(argument, &block->front());
642  }
643 
644  // Process any new defs within this block.
645  for (Operation &op : *block)
646  for (Value result : op.getResults())
647  processValue(result, &op);
648  });
649 
650  // Greedily allocate memory slots using the computed def live ranges.
651  std::vector<ByteCodeLiveRange> allocatedIndices;
652 
653  // The number of memory indices currently allocated (and its next value).
654  // Recall that the root gets allocated memory index 0.
655  ByteCodeField numIndices = 1;
656 
657  // The number of memory ranges of various types (and their next values).
658  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
659 
660  for (auto &defIt : valueDefRanges) {
661  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662  ByteCodeLiveRange &defRange = defIt.second;
663 
664  // Try to allocate to an existing index.
665  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
666  ByteCodeLiveRange &existingRange = existingIndexIt.value();
667  if (!defRange.overlaps(existingRange)) {
668  existingRange.unionWith(defRange);
669  memIndex = existingIndexIt.index() + 1;
670 
671  if (defRange.opRangeIndex) {
672  if (!existingRange.opRangeIndex)
673  existingRange.opRangeIndex = numOpRanges++;
674  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675  } else if (defRange.typeRangeIndex) {
676  if (!existingRange.typeRangeIndex)
677  existingRange.typeRangeIndex = numTypeRanges++;
678  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679  } else if (defRange.valueRangeIndex) {
680  if (!existingRange.valueRangeIndex)
681  existingRange.valueRangeIndex = numValueRanges++;
682  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
683  }
684  break;
685  }
686  }
687 
688  // If no existing index could be used, add a new one.
689  if (memIndex == 0) {
690  allocatedIndices.emplace_back(allocator);
691  ByteCodeLiveRange &newRange = allocatedIndices.back();
692  newRange.unionWith(defRange);
693 
694  // Allocate an index for op/type/value ranges.
695  if (defRange.opRangeIndex) {
696  newRange.opRangeIndex = numOpRanges;
697  valueToRangeIndex[defIt.first] = numOpRanges++;
698  } else if (defRange.typeRangeIndex) {
699  newRange.typeRangeIndex = numTypeRanges;
700  valueToRangeIndex[defIt.first] = numTypeRanges++;
701  } else if (defRange.valueRangeIndex) {
702  newRange.valueRangeIndex = numValueRanges;
703  valueToRangeIndex[defIt.first] = numValueRanges++;
704  }
705 
706  memIndex = allocatedIndices.size();
707  ++numIndices;
708  }
709  }
710 
711  // Print the index usage and ensure that we did not run out of index space.
712  LDBG() << "Allocated " << allocatedIndices.size() << " indices "
713  << "(down from initial " << valueDefRanges.size() << ").";
714  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715  "Ran out of memory for allocated indices");
716 
717  // Update the max number of indices.
718  if (numIndices > maxValueMemoryIndex)
719  maxValueMemoryIndex = numIndices;
720  if (numOpRanges > maxOpRangeMemoryIndex)
721  maxOpRangeMemoryIndex = numOpRanges;
722  if (numTypeRanges > maxTypeRangeMemoryIndex)
723  maxTypeRangeMemoryIndex = numTypeRanges;
724  if (numValueRanges > maxValueRangeMemoryIndex)
725  maxValueRangeMemoryIndex = numValueRanges;
726 }
727 
728 void Generator::generate(Region *region, ByteCodeWriter &writer) {
729  llvm::ReversePostOrderTraversal<Region *> rpot(region);
730  for (Block *block : rpot) {
731  // Keep track of where this block begins within the matcher function.
732  blockToAddr.try_emplace(block, matcherByteCode.size());
733  for (Operation &op : *block)
734  generate(&op, writer);
735  }
736 }
737 
738 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739  LDBG() << "Generating bytecode for operation: " << op->getName();
740  LLVM_DEBUG({
741  // The following list must contain all the operations that do not
742  // produce any bytecode.
743  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744  writer.appendInline(op->getLoc());
745  });
747  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765  pdl_interp::SwitchResultCountOp>(
766  [&](auto interpOp) { this->generate(interpOp, writer); })
767  .DefaultUnreachable("unknown `pdl_interp` operation");
768 }
769 
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771  ByteCodeWriter &writer) {
772  // Constraints that should return a value have to be registered as rewrites.
773  // If a constraint and a rewrite of similar name are registered the
774  // constraint takes precedence
775  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776  writer.appendPDLValueList(op.getArgs());
777  writer.append(ByteCodeField(op.getIsNegated()));
778  ResultRange results = op.getResults();
779  writer.append(ByteCodeField(results.size()));
780  for (Value result : results) {
781  // We record the expected kind of the result, so that we can provide extra
782  // verification of the native rewrite function and handle the failure case
783  // of constraints accordingly.
784  writer.appendPDLValueKind(result);
785 
786  // Range results also need to append the range storage index.
787  if (isa<pdl::RangeType>(result.getType()))
788  writer.append(getRangeStorageIndex(result));
789  writer.append(result);
790  }
791  writer.append(op.getSuccessors());
792 }
793 void Generator::generate(pdl_interp::ApplyRewriteOp op,
794  ByteCodeWriter &writer) {
795  assert(externalRewriterToMemIndex.count(op.getName()) &&
796  "expected index for rewrite function");
797  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798  writer.appendPDLValueList(op.getArgs());
799 
800  ResultRange results = op.getResults();
801  writer.append(ByteCodeField(results.size()));
802  for (Value result : results) {
803  // We record the expected kind of the result, so that we
804  // can provide extra verification of the native rewrite function.
805  writer.appendPDLValueKind(result);
806 
807  // Range results also need to append the range storage index.
808  if (isa<pdl::RangeType>(result.getType()))
809  writer.append(getRangeStorageIndex(result));
810  writer.append(result);
811  }
812 }
813 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814  Value lhs = op.getLhs();
815  if (isa<pdl::RangeType>(lhs.getType())) {
816  writer.append(OpCode::AreRangesEqual);
817  writer.appendPDLValueKind(lhs);
818  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
819  return;
820  }
821 
822  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
825  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
826 }
827 void Generator::generate(pdl_interp::CheckAttributeOp op,
828  ByteCodeWriter &writer) {
829  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
830  op.getSuccessors());
831 }
832 void Generator::generate(pdl_interp::CheckOperandCountOp op,
833  ByteCodeWriter &writer) {
834  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835  static_cast<ByteCodeField>(op.getCompareAtLeast()),
836  op.getSuccessors());
837 }
838 void Generator::generate(pdl_interp::CheckOperationNameOp op,
839  ByteCodeWriter &writer) {
840  writer.append(OpCode::CheckOperationName, op.getInputOp(),
841  OperationName(op.getName(), ctx), op.getSuccessors());
842 }
843 void Generator::generate(pdl_interp::CheckResultCountOp op,
844  ByteCodeWriter &writer) {
845  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846  static_cast<ByteCodeField>(op.getCompareAtLeast()),
847  op.getSuccessors());
848 }
849 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
851  op.getSuccessors());
852 }
853 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
855  op.getSuccessors());
856 }
857 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
859  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
860 }
861 void Generator::generate(pdl_interp::CreateAttributeOp op,
862  ByteCodeWriter &writer) {
863  // Simply repoint the memory index of the result to the constant.
864  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
865 }
866 void Generator::generate(pdl_interp::CreateOperationOp op,
867  ByteCodeWriter &writer) {
868  writer.append(OpCode::CreateOperation, op.getResultOp(),
869  OperationName(op.getName(), ctx));
870  writer.appendPDLValueList(op.getInputOperands());
871 
872  // Add the attributes.
873  OperandRange attributes = op.getInputAttributes();
874  writer.append(static_cast<ByteCodeField>(attributes.size()));
875  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876  writer.append(std::get<0>(it), std::get<1>(it));
877 
878  // Add the result types. If the operation has inferred results, we use a
879  // marker "size" value. Otherwise, we add the list of explicit result types.
880  if (op.getInferredResultTypes())
881  writer.append(kInferTypesMarker);
882  else
883  writer.appendPDLValueList(op.getInputResultTypes());
884 }
885 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
886  // Append the correct opcode for the range type.
887  TypeSwitch<Type>(op.getType().getElementType())
888  .Case(
889  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890  .Case([&](pdl::ValueType) {
891  writer.append(OpCode::CreateDynamicValueRange);
892  });
893 
894  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895  writer.appendPDLValueList(op->getOperands());
896 }
897 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
898  // Simply repoint the memory index of the result to the constant.
899  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
900 }
901 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903  getRangeStorageIndex(op.getResult()), op.getValue());
904 }
905 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906  writer.append(OpCode::EraseOp, op.getInputOp());
907 }
908 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
909  OpCode opCode =
910  TypeSwitch<Type, OpCode>(op.getResult().getType())
911  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
912  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
913  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
914  .DefaultUnreachable("unsupported element type");
915  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
916 }
917 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
918  writer.append(OpCode::Finalize);
919 }
920 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
921  BlockArgument arg = op.getLoopVariable();
922  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
923  writer.appendPDLValueKind(arg.getType());
924  writer.append(curLoopLevel, op.getSuccessor());
925  ++curLoopLevel;
926  if (curLoopLevel > maxLoopLevel)
927  maxLoopLevel = curLoopLevel;
928  generate(&op.getRegion(), writer);
929  --curLoopLevel;
930 }
931 void Generator::generate(pdl_interp::GetAttributeOp op,
932  ByteCodeWriter &writer) {
933  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
934  op.getNameAttr());
935 }
936 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
937  ByteCodeWriter &writer) {
938  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
939 }
940 void Generator::generate(pdl_interp::GetDefiningOpOp op,
941  ByteCodeWriter &writer) {
942  writer.append(OpCode::GetDefiningOp, op.getInputOp());
943  writer.appendPDLValue(op.getValue());
944 }
945 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
946  uint32_t index = op.getIndex();
947  if (index < 4)
948  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
949  else
950  writer.append(OpCode::GetOperandN, index);
951  writer.append(op.getInputOp(), op.getValue());
952 }
953 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
954  Value result = op.getValue();
955  std::optional<uint32_t> index = op.getIndex();
956  writer.append(OpCode::GetOperands,
957  index.value_or(std::numeric_limits<uint32_t>::max()),
958  op.getInputOp());
959  if (isa<pdl::RangeType>(result.getType()))
960  writer.append(getRangeStorageIndex(result));
961  else
963  writer.append(result);
964 }
965 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
966  uint32_t index = op.getIndex();
967  if (index < 4)
968  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
969  else
970  writer.append(OpCode::GetResultN, index);
971  writer.append(op.getInputOp(), op.getValue());
972 }
973 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
974  Value result = op.getValue();
975  std::optional<uint32_t> index = op.getIndex();
976  writer.append(OpCode::GetResults,
977  index.value_or(std::numeric_limits<uint32_t>::max()),
978  op.getInputOp());
979  if (isa<pdl::RangeType>(result.getType()))
980  writer.append(getRangeStorageIndex(result));
981  else
983  writer.append(result);
984 }
985 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
986  Value operations = op.getOperations();
987  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
988  writer.append(OpCode::GetUsers, operations, rangeIndex);
989  writer.appendPDLValue(op.getValue());
990 }
991 void Generator::generate(pdl_interp::GetValueTypeOp op,
992  ByteCodeWriter &writer) {
993  if (isa<pdl::RangeType>(op.getType())) {
994  Value result = op.getResult();
995  writer.append(OpCode::GetValueRangeTypes, result,
996  getRangeStorageIndex(result), op.getValue());
997  } else {
998  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
999  }
1000 }
1001 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1002  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1003 }
1004 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1005  ByteCodeField patternIndex = patterns.size();
1006  patterns.emplace_back(PDLByteCodePattern::create(
1007  op, configMap.lookup(op),
1008  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1009  writer.append(OpCode::RecordMatch, patternIndex,
1010  SuccessorRange(op.getOperation()), op.getMatchedOps());
1011  writer.appendPDLValueList(op.getInputs());
1012 }
1013 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1014  writer.append(OpCode::ReplaceOp, op.getInputOp());
1015  writer.appendPDLValueList(op.getReplValues());
1016 }
1017 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1018  ByteCodeWriter &writer) {
1019  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1020  op.getCaseValuesAttr(), op.getSuccessors());
1021 }
1022 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1023  ByteCodeWriter &writer) {
1024  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1025  op.getCaseValuesAttr(), op.getSuccessors());
1026 }
1027 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1028  ByteCodeWriter &writer) {
1029  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1030  return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1031  });
1032  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1033  op.getSuccessors());
1034 }
1035 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1036  ByteCodeWriter &writer) {
1037  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1038  op.getCaseValuesAttr(), op.getSuccessors());
1039 }
1040 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1041  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1042  op.getSuccessors());
1043 }
1044 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1045  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1046  op.getSuccessors());
1047 }
1048 
1049 //===----------------------------------------------------------------------===//
1050 // PDLByteCode
1051 //===----------------------------------------------------------------------===//
1052 
1053 PDLByteCode::PDLByteCode(
1054  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1056  llvm::StringMap<PDLConstraintFunction> constraintFns,
1057  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1058  : configs(std::move(configs)) {
1059  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1060  rewriterByteCode, patterns, maxValueMemoryIndex,
1061  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1062  maxLoopLevel, constraintFns, rewriteFns, configMap);
1063  generator.generate(module);
1064 
1065  // Initialize the external functions.
1066  for (auto &it : constraintFns)
1067  constraintFunctions.push_back(std::move(it.second));
1068  for (auto &it : rewriteFns)
1069  rewriteFunctions.push_back(std::move(it.second));
1070 }
1071 
1072 /// Initialize the given state such that it can be used to execute the current
1073 /// bytecode.
1074 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1075  state.memory.resize(maxValueMemoryIndex, nullptr);
1076  state.opRangeMemory.resize(maxOpRangeCount);
1077  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1078  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1079  state.loopIndex.resize(maxLoopLevel, 0);
1080  state.currentPatternBenefits.reserve(patterns.size());
1081  for (const PDLByteCodePattern &pattern : patterns)
1082  state.currentPatternBenefits.push_back(pattern.getBenefit());
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // ByteCode Execution
1087 //===----------------------------------------------------------------------===//
1088 
1089 namespace {
1090 /// This class is an instantiation of the PDLResultList that provides access to
1091 /// the returned results. This API is not on `PDLResultList` to avoid
1092 /// overexposing access to information specific solely to the ByteCode.
1093 class ByteCodeRewriteResultList : public PDLResultList {
1094 public:
1095  ByteCodeRewriteResultList(unsigned maxNumResults)
1096  : PDLResultList(maxNumResults) {}
1097 
1098  /// Return the list of PDL results.
1099  MutableArrayRef<PDLValue> getResults() { return results; }
1100 
1101  /// Return the type ranges allocated by this list.
1102  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1103  return allocatedTypeRanges;
1104  }
1105 
1106  /// Return the value ranges allocated by this list.
1107  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1108  return allocatedValueRanges;
1109  }
1110 };
1111 
1112 /// This class provides support for executing a bytecode stream.
1113 class ByteCodeExecutor {
1114 public:
1115  ByteCodeExecutor(
1116  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1117  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1118  MutableArrayRef<TypeRange> typeRangeMemory,
1119  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1120  MutableArrayRef<ValueRange> valueRangeMemory,
1121  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1122  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1124  ArrayRef<PatternBenefit> currentPatternBenefits,
1126  ArrayRef<PDLConstraintFunction> constraintFunctions,
1127  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1128  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1129  typeRangeMemory(typeRangeMemory),
1130  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1131  valueRangeMemory(valueRangeMemory),
1132  allocatedValueRangeMemory(allocatedValueRangeMemory),
1133  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1134  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1135  constraintFunctions(constraintFunctions),
1136  rewriteFunctions(rewriteFunctions) {}
1137 
1138  /// Start executing the code at the current bytecode index. `matches` is an
1139  /// optional field provided when this function is executed in a matching
1140  /// context.
1141  LogicalResult
1142  execute(PatternRewriter &rewriter,
1143  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1144  std::optional<Location> mainRewriteLoc = {});
1145 
1146 private:
1147  /// Internal implementation of executing each of the bytecode commands.
1148  void executeApplyConstraint(PatternRewriter &rewriter);
1149  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1150  void executeAreEqual();
1151  void executeAreRangesEqual();
1152  void executeBranch();
1153  void executeCheckOperandCount();
1154  void executeCheckOperationName();
1155  void executeCheckResultCount();
1156  void executeCheckTypes();
1157  void executeContinue();
1158  void executeCreateConstantTypeRange();
1159  void executeCreateOperation(PatternRewriter &rewriter,
1160  Location mainRewriteLoc);
1161  template <typename T>
1162  void executeDynamicCreateRange(StringRef type);
1163  void executeEraseOp(PatternRewriter &rewriter);
1164  template <typename T, typename Range, PDLValue::Kind kind>
1165  void executeExtract();
1166  void executeFinalize();
1167  void executeForEach();
1168  void executeGetAttribute();
1169  void executeGetAttributeType();
1170  void executeGetDefiningOp();
1171  void executeGetOperand(unsigned index);
1172  void executeGetOperands();
1173  void executeGetResult(unsigned index);
1174  void executeGetResults();
1175  void executeGetUsers();
1176  void executeGetValueType();
1177  void executeGetValueRangeTypes();
1178  void executeIsNotNull();
1179  void executeRecordMatch(PatternRewriter &rewriter,
1181  void executeReplaceOp(PatternRewriter &rewriter);
1182  void executeSwitchAttribute();
1183  void executeSwitchOperandCount();
1184  void executeSwitchOperationName();
1185  void executeSwitchResultCount();
1186  void executeSwitchType();
1187  void executeSwitchTypes();
1188  void processNativeFunResults(ByteCodeRewriteResultList &results,
1189  unsigned numResults,
1190  LogicalResult &rewriteResult);
1191 
1192  /// Pushes a code iterator to the stack.
1193  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1194 
1195  /// Pops a code iterator from the stack, returning true on success.
1196  void popCodeIt() {
1197  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1198  curCodeIt = resumeCodeIt.pop_back_val();
1199  }
1200 
1201  /// Return the bytecode iterator at the start of the current op code.
1202  const ByteCodeField *getPrevCodeIt() const {
1203  LLVM_DEBUG({
1204  // Account for the op code and the Location stored inline.
1205  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1206  });
1207 
1208  // Account for the op code only.
1209  return curCodeIt - 1;
1210  }
1211 
1212  /// Read a value from the bytecode buffer, optionally skipping a certain
1213  /// number of prefix values. These methods always update the buffer to point
1214  /// to the next field after the read data.
1215  template <typename T = ByteCodeField>
1216  T read(size_t skipN = 0) {
1217  curCodeIt += skipN;
1218  return readImpl<T>();
1219  }
1220  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1221 
1222  /// Read a list of values from the bytecode buffer.
1223  template <typename ValueT, typename T>
1224  void readList(SmallVectorImpl<T> &list) {
1225  list.clear();
1226  for (unsigned i = 0, e = read(); i != e; ++i)
1227  list.push_back(read<ValueT>());
1228  }
1229 
1230  /// Read a list of values from the bytecode buffer. The values may be encoded
1231  /// either as a single element or a range of elements.
1232  void readList(SmallVectorImpl<Type> &list) {
1233  for (unsigned i = 0, e = read(); i != e; ++i) {
1234  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1235  list.push_back(read<Type>());
1236  } else {
1237  TypeRange *values = read<TypeRange *>();
1238  list.append(values->begin(), values->end());
1239  }
1240  }
1241  }
1242  void readList(SmallVectorImpl<Value> &list) {
1243  for (unsigned i = 0, e = read(); i != e; ++i) {
1244  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1245  list.push_back(read<Value>());
1246  } else {
1247  ValueRange *values = read<ValueRange *>();
1248  list.append(values->begin(), values->end());
1249  }
1250  }
1251  }
1252 
1253  /// Read a value stored inline as a pointer.
1254  template <typename T>
1255  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1256  readInline() {
1257  const void *pointer;
1258  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1259  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1260  return T::getFromOpaquePointer(pointer);
1261  }
1262 
1263  void skip(size_t skipN) { curCodeIt += skipN; }
1264 
1265  /// Jump to a specific successor based on a predicate value.
1266  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1267  /// Jump to a specific successor based on a destination index.
1268  void selectJump(size_t destIndex) {
1269  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1270  }
1271 
1272  /// Handle a switch operation with the provided value and cases.
1273  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1274  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1275  LDBG() << "Switch operation:\n * Value: " << value
1276  << "\n * Cases: " << llvm::interleaved(cases);
1277 
1278  // Check to see if the attribute value is within the case list. Jump to
1279  // the correct successor index based on the result.
1280  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1281  if (cmp(*it, value))
1282  return selectJump(size_t((it - cases.begin()) + 1));
1283  selectJump(size_t(0));
1284  }
1285 
1286  /// Store a pointer to memory.
1287  void storeToMemory(unsigned index, const void *value) {
1288  memory[index] = value;
1289  }
1290 
1291  /// Store a value to memory as an opaque pointer.
1292  template <typename T>
1293  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1294  storeToMemory(unsigned index, T value) {
1295  memory[index] = value.getAsOpaquePointer();
1296  }
1297 
1298  /// Internal implementation of reading various data types from the bytecode
1299  /// stream.
1300  template <typename T>
1301  const void *readFromMemory() {
1302  size_t index = *curCodeIt++;
1303 
1304  // If this type is an SSA value, it can only be stored in non-const memory.
1305  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1306  Value>::value ||
1307  index < memory.size())
1308  return memory[index];
1309 
1310  // Otherwise, if this index is not inbounds it is uniqued.
1311  return uniquedMemory[index - memory.size()];
1312  }
1313  template <typename T>
1314  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1315  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1316  }
1317  template <typename T>
1318  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1319  T>
1320  readImpl() {
1321  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1322  }
1323  template <typename T>
1324  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1325  switch (read<PDLValue::Kind>()) {
1326  case PDLValue::Kind::Attribute:
1327  return read<Attribute>();
1328  case PDLValue::Kind::Operation:
1329  return read<Operation *>();
1330  case PDLValue::Kind::Type:
1331  return read<Type>();
1332  case PDLValue::Kind::Value:
1333  return read<Value>();
1334  case PDLValue::Kind::TypeRange:
1335  return read<TypeRange *>();
1336  case PDLValue::Kind::ValueRange:
1337  return read<ValueRange *>();
1338  }
1339  llvm_unreachable("unhandled PDLValue::Kind");
1340  }
1341  template <typename T>
1342  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1343  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1344  "unexpected ByteCode address size");
1345  ByteCodeAddr result;
1346  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1347  curCodeIt += 2;
1348  return result;
1349  }
1350  template <typename T>
1351  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1352  return *curCodeIt++;
1353  }
1354  template <typename T>
1355  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1356  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1357  }
1358 
1359  /// Assign the given range to the given memory index. This allocates a new
1360  /// range object if necessary.
1361  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1362  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1363  unsigned rangeIndex) {
1364  // Utility functor used to type-erase the assignment.
1365  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1366  // If the input range is empty, we don't need to allocate anything.
1367  if (range.empty()) {
1368  rangeMemory[rangeIndex] = {};
1369  } else {
1370  // Allocate a buffer for this type range.
1371  llvm::OwningArrayRef<T> storage(llvm::size(range));
1372  llvm::copy(range, storage.begin());
1373 
1374  // Assign this to the range slot and use the range as the value for the
1375  // memory index.
1376  allocatedRangeMemory.emplace_back(std::move(storage));
1377  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1378  }
1379  memory[memIndex] = &rangeMemory[rangeIndex];
1380  };
1381 
1382  // Dispatch based on the concrete range type.
1383  if constexpr (std::is_same_v<T, Type>) {
1384  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1385  } else if constexpr (std::is_same_v<T, Value>) {
1386  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1387  } else {
1388  llvm_unreachable("unhandled range type");
1389  }
1390  }
1391 
1392  /// The underlying bytecode buffer.
1393  const ByteCodeField *curCodeIt;
1394 
1395  /// The stack of bytecode positions at which to resume operation.
1397 
1398  /// The current execution memory.
1400  MutableArrayRef<OwningOpRange> opRangeMemory;
1401  MutableArrayRef<TypeRange> typeRangeMemory;
1402  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1403  MutableArrayRef<ValueRange> valueRangeMemory;
1404  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1405 
1406  /// The current loop indices.
1407  MutableArrayRef<unsigned> loopIndex;
1408 
1409  /// References to ByteCode data necessary for execution.
1410  ArrayRef<const void *> uniquedMemory;
1412  ArrayRef<PatternBenefit> currentPatternBenefits;
1414  ArrayRef<PDLConstraintFunction> constraintFunctions;
1415  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1416 };
1417 } // namespace
1418 
1419 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1420  LDBG() << "Executing ApplyConstraint:";
1421  ByteCodeField fun_idx = read();
1423  readList<PDLValue>(args);
1424 
1425  LDBG() << " * Arguments: " << llvm::interleaved(args);
1426 
1427  ByteCodeField isNegated = read();
1428  LDBG() << " * isNegated: " << isNegated;
1429 
1430  ByteCodeField numResults = read();
1431  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1432  ByteCodeRewriteResultList results(numResults);
1433  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1434  [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1435  if (succeeded(rewriteResult)) {
1436  LDBG() << " * Constraint succeeded, results: "
1437  << llvm::interleaved(constraintResults);
1438  } else {
1439  LDBG() << " * Constraint failed";
1440  }
1441  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1442  "native PDL rewrite function succeeded but returned "
1443  "unexpected number of results");
1444  processNativeFunResults(results, numResults, rewriteResult);
1445 
1446  // Depending on the constraint jump to the proper destination.
1447  selectJump(isNegated != succeeded(rewriteResult));
1448 }
1449 
1450 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1451  LDBG() << "Executing ApplyRewrite:";
1452  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1454  readList<PDLValue>(args);
1455 
1456  LDBG() << " * Arguments: " << llvm::interleaved(args);
1457 
1458  // Execute the rewrite function.
1459  ByteCodeField numResults = read();
1460  ByteCodeRewriteResultList results(numResults);
1461  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1462 
1463  assert(results.getResults().size() == numResults &&
1464  "native PDL rewrite function returned unexpected number of results");
1465 
1466  processNativeFunResults(results, numResults, rewriteResult);
1467 
1468  if (failed(rewriteResult)) {
1469  LDBG() << " - Failed";
1470  return failure();
1471  }
1472  return success();
1473 }
1474 
1475 void ByteCodeExecutor::processNativeFunResults(
1476  ByteCodeRewriteResultList &results, unsigned numResults,
1477  LogicalResult &rewriteResult) {
1478  if (failed(rewriteResult)) {
1479  // Skip the according number of values on the buffer on failure and exit
1480  // early as there are no results to process.
1481  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1482  const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1483  if (resultKind == PDLValue::Kind::TypeRange ||
1484  resultKind == PDLValue::Kind::ValueRange) {
1485  skip(2);
1486  } else {
1487  skip(1);
1488  }
1489  }
1490  return;
1491  }
1492 
1493  // Store the results in the bytecode memory
1494  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1495  PDLValue::Kind resultKind = read<PDLValue::Kind>();
1496  (void)resultKind;
1497  PDLValue result = results.getResults()[resultIdx];
1498  LDBG() << " * Result: " << result;
1499  assert(result.getKind() == resultKind &&
1500  "native PDL rewrite function returned an unexpected type of "
1501  "result");
1502  // If the result is a range, we need to copy it over to the bytecodes
1503  // range memory.
1504  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1505  unsigned rangeIndex = read();
1506  typeRangeMemory[rangeIndex] = *typeRange;
1507  memory[read()] = &typeRangeMemory[rangeIndex];
1508  } else if (std::optional<ValueRange> valueRange =
1509  result.dyn_cast<ValueRange>()) {
1510  unsigned rangeIndex = read();
1511  valueRangeMemory[rangeIndex] = *valueRange;
1512  memory[read()] = &valueRangeMemory[rangeIndex];
1513  } else {
1514  memory[read()] = result.getAsOpaquePointer();
1515  }
1516  }
1517 
1518  // Copy over any underlying storage allocated for result ranges.
1519  for (auto &it : results.getAllocatedTypeRanges())
1520  allocatedTypeRangeMemory.push_back(std::move(it));
1521  for (auto &it : results.getAllocatedValueRanges())
1522  allocatedValueRangeMemory.push_back(std::move(it));
1523 }
1524 
1525 void ByteCodeExecutor::executeAreEqual() {
1526  LDBG() << "Executing AreEqual:";
1527  const void *lhs = read<const void *>();
1528  const void *rhs = read<const void *>();
1529 
1530  LDBG() << " * " << lhs << " == " << rhs;
1531  selectJump(lhs == rhs);
1532 }
1533 
1534 void ByteCodeExecutor::executeAreRangesEqual() {
1535  LDBG() << "Executing AreRangesEqual:";
1536  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1537  const void *lhs = read<const void *>();
1538  const void *rhs = read<const void *>();
1539 
1540  switch (valueKind) {
1541  case PDLValue::Kind::TypeRange: {
1542  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1543  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1544  LDBG() << " * " << lhs << " == " << rhs;
1545  selectJump(*lhsRange == *rhsRange);
1546  break;
1547  }
1548  case PDLValue::Kind::ValueRange: {
1549  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1550  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1551  LDBG() << " * " << lhs << " == " << rhs;
1552  selectJump(*lhsRange == *rhsRange);
1553  break;
1554  }
1555  default:
1556  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1557  }
1558 }
1559 
1560 void ByteCodeExecutor::executeBranch() {
1561  LDBG() << "Executing Branch";
1562  curCodeIt = &code[read<ByteCodeAddr>()];
1563 }
1564 
1565 void ByteCodeExecutor::executeCheckOperandCount() {
1566  LDBG() << "Executing CheckOperandCount:";
1567  Operation *op = read<Operation *>();
1568  uint32_t expectedCount = read<uint32_t>();
1569  bool compareAtLeast = read();
1570 
1571  LDBG() << " * Found: " << op->getNumOperands()
1572  << "\n * Expected: " << expectedCount
1573  << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1574  if (compareAtLeast)
1575  selectJump(op->getNumOperands() >= expectedCount);
1576  else
1577  selectJump(op->getNumOperands() == expectedCount);
1578 }
1579 
1580 void ByteCodeExecutor::executeCheckOperationName() {
1581  LDBG() << "Executing CheckOperationName:";
1582  Operation *op = read<Operation *>();
1583  OperationName expectedName = read<OperationName>();
1584 
1585  LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \""
1586  << expectedName << "\"";
1587  selectJump(op->getName() == expectedName);
1588 }
1589 
1590 void ByteCodeExecutor::executeCheckResultCount() {
1591  LDBG() << "Executing CheckResultCount:";
1592  Operation *op = read<Operation *>();
1593  uint32_t expectedCount = read<uint32_t>();
1594  bool compareAtLeast = read();
1595 
1596  LDBG() << " * Found: " << op->getNumResults()
1597  << "\n * Expected: " << expectedCount
1598  << "\n * Comparator: " << (compareAtLeast ? ">=" : "==");
1599  if (compareAtLeast)
1600  selectJump(op->getNumResults() >= expectedCount);
1601  else
1602  selectJump(op->getNumResults() == expectedCount);
1603 }
1604 
1605 void ByteCodeExecutor::executeCheckTypes() {
1606  LDBG() << "Executing AreEqual:";
1607  TypeRange *lhs = read<TypeRange *>();
1608  Attribute rhs = read<Attribute>();
1609  LDBG() << " * " << lhs << " == " << rhs;
1610 
1611  selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1612 }
1613 
1614 void ByteCodeExecutor::executeContinue() {
1615  ByteCodeField level = read();
1616  LDBG() << "Executing Continue\n * Level: " << level;
1617  ++loopIndex[level];
1618  popCodeIt();
1619 }
1620 
1621 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1622  LDBG() << "Executing CreateConstantTypeRange:";
1623  unsigned memIndex = read();
1624  unsigned rangeIndex = read();
1625  ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1626 
1627  LDBG() << " * Types: " << typesAttr;
1628  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1629  rangeIndex);
1630 }
1631 
1632 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1633  Location mainRewriteLoc) {
1634  LDBG() << "Executing CreateOperation:";
1635 
1636  unsigned memIndex = read();
1637  OperationState state(mainRewriteLoc, read<OperationName>());
1638  readList(state.operands);
1639  for (unsigned i = 0, e = read(); i != e; ++i) {
1640  StringAttr name = read<StringAttr>();
1641  if (Attribute attr = read<Attribute>())
1642  state.addAttribute(name, attr);
1643  }
1644 
1645  // Read in the result types. If the "size" is the sentinel value, this
1646  // indicates that the result types should be inferred.
1647  unsigned numResults = read();
1648  if (numResults == kInferTypesMarker) {
1649  InferTypeOpInterface::Concept *inferInterface =
1650  state.name.getInterface<InferTypeOpInterface>();
1651  assert(inferInterface &&
1652  "expected operation to provide InferTypeOpInterface");
1653 
1654  // TODO: Handle failure.
1655  if (failed(inferInterface->inferReturnTypes(
1656  state.getContext(), state.location, state.operands,
1657  state.attributes.getDictionary(state.getContext()),
1658  state.getRawProperties(), state.regions, state.types)))
1659  return;
1660  } else {
1661  // Otherwise, this is a fixed number of results.
1662  for (unsigned i = 0; i != numResults; ++i) {
1663  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1664  state.types.push_back(read<Type>());
1665  } else {
1666  TypeRange *resultTypes = read<TypeRange *>();
1667  state.types.append(resultTypes->begin(), resultTypes->end());
1668  }
1669  }
1670  }
1671 
1672  Operation *resultOp = rewriter.create(state);
1673  memory[memIndex] = resultOp;
1674 
1675  LDBG() << " * Attributes: "
1676  << state.attributes.getDictionary(state.getContext())
1677  << "\n * Operands: " << llvm::interleaved(state.operands)
1678  << "\n * Result Types: " << llvm::interleaved(state.types)
1679  << "\n * Result: " << *resultOp;
1680 }
1681 
1682 template <typename T>
1683 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1684  LDBG() << "Executing CreateDynamic" << type << "Range:";
1685  unsigned memIndex = read();
1686  unsigned rangeIndex = read();
1687  SmallVector<T> values;
1688  readList(values);
1689 
1690  LDBG() << " * " << type << "s: " << llvm::interleaved(values);
1691 
1692  assignRangeToMemory(values, memIndex, rangeIndex);
1693 }
1694 
1695 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1696  LDBG() << "Executing EraseOp:";
1697  Operation *op = read<Operation *>();
1698 
1699  LDBG() << " * Operation: " << *op;
1700  rewriter.eraseOp(op);
1701 }
1702 
1703 template <typename T, typename Range, PDLValue::Kind kind>
1704 void ByteCodeExecutor::executeExtract() {
1705  LDBG() << "Executing Extract" << kind << ":";
1706  Range *range = read<Range *>();
1707  unsigned index = read<uint32_t>();
1708  unsigned memIndex = read();
1709 
1710  if (!range) {
1711  memory[memIndex] = nullptr;
1712  return;
1713  }
1714 
1715  T result = index < range->size() ? (*range)[index] : T();
1716  LDBG() << " * " << kind << "s(" << range->size() << ")";
1717  LDBG() << " * Index: " << index;
1718  LDBG() << " * Result: " << result;
1719  storeToMemory(memIndex, result);
1720 }
1721 
1722 void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; }
1723 
1724 void ByteCodeExecutor::executeForEach() {
1725  LDBG() << "Executing ForEach:";
1726  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1727  unsigned rangeIndex = read();
1728  unsigned memIndex = read();
1729  const void *value = nullptr;
1730 
1731  switch (read<PDLValue::Kind>()) {
1732  case PDLValue::Kind::Operation: {
1733  unsigned &index = loopIndex[read()];
1734  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1735  assert(index <= array.size() && "iterated past the end");
1736  if (index < array.size()) {
1737  LDBG() << " * Result: " << array[index];
1738  value = array[index];
1739  break;
1740  }
1741 
1742  LDBG() << " * Done";
1743  index = 0;
1744  selectJump(size_t(0));
1745  return;
1746  }
1747  default:
1748  llvm_unreachable("unexpected `ForEach` value kind");
1749  }
1750 
1751  // Store the iterate value and the stack address.
1752  memory[memIndex] = value;
1753  pushCodeIt(prevCodeIt);
1754 
1755  // Skip over the successor (we will enter the body of the loop).
1756  read<ByteCodeAddr>();
1757 }
1758 
1759 void ByteCodeExecutor::executeGetAttribute() {
1760  LDBG() << "Executing GetAttribute:";
1761  unsigned memIndex = read();
1762  Operation *op = read<Operation *>();
1763  StringAttr attrName = read<StringAttr>();
1764  Attribute attr = op->getAttr(attrName);
1765 
1766  LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName
1767  << "\n * Result: " << attr;
1768  memory[memIndex] = attr.getAsOpaquePointer();
1769 }
1770 
1771 void ByteCodeExecutor::executeGetAttributeType() {
1772  LDBG() << "Executing GetAttributeType:";
1773  unsigned memIndex = read();
1774  Attribute attr = read<Attribute>();
1775  Type type;
1776  if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1777  type = typedAttr.getType();
1778 
1779  LDBG() << " * Attribute: " << attr << "\n * Result: " << type;
1780  memory[memIndex] = type.getAsOpaquePointer();
1781 }
1782 
1783 void ByteCodeExecutor::executeGetDefiningOp() {
1784  LDBG() << "Executing GetDefiningOp:";
1785  unsigned memIndex = read();
1786  Operation *op = nullptr;
1787  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1788  Value value = read<Value>();
1789  if (value)
1790  op = value.getDefiningOp();
1791  LDBG() << " * Value: " << value;
1792  } else {
1793  ValueRange *values = read<ValueRange *>();
1794  if (values && !values->empty()) {
1795  op = values->front().getDefiningOp();
1796  }
1797  LDBG() << " * Values: " << values;
1798  }
1799 
1800  LDBG() << " * Result: " << op;
1801  memory[memIndex] = op;
1802 }
1803 
1804 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1805  Operation *op = read<Operation *>();
1806  unsigned memIndex = read();
1807  Value operand =
1808  index < op->getNumOperands() ? op->getOperand(index) : Value();
1809 
1810  LDBG() << " * Operation: " << *op << "\n * Index: " << index
1811  << "\n * Result: " << operand;
1812  memory[memIndex] = operand.getAsOpaquePointer();
1813 }
1814 
1815 /// This function is the internal implementation of `GetResults` and
1816 /// `GetOperands` that provides support for extracting a value range from the
1817 /// given operation.
1818 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1819 static void *
1820 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1821  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1822  MutableArrayRef<ValueRange> valueRangeMemory) {
1823  // Check for the sentinel index that signals that all values should be
1824  // returned.
1825  if (index == std::numeric_limits<uint32_t>::max()) {
1826  LDBG() << " * Getting all values";
1827  // `values` is already the full value range.
1828 
1829  // Otherwise, check to see if this operation uses AttrSizedSegments.
1830  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1831  LDBG() << " * Extracting values from `" << attrSizedSegments << "`";
1832 
1833  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1834  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1835  return nullptr;
1836 
1837  ArrayRef<int32_t> segments = segmentAttr;
1838  unsigned startIndex = llvm::sum_of(segments.take_front(index));
1839  values = values.slice(startIndex, *std::next(segments.begin(), index));
1840 
1841  LDBG() << " * Extracting range[" << startIndex << ", "
1842  << *std::next(segments.begin(), index) << "]";
1843 
1844  // Otherwise, assume this is the last operand group of the operation.
1845  // FIXME: We currently don't support operations with
1846  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1847  // have a way to detect it's presence.
1848  } else if (values.size() >= index) {
1849  LDBG() << " * Treating values as trailing variadic range";
1850  values = values.drop_front(index);
1851 
1852  // If we couldn't detect a way to compute the values, bail out.
1853  } else {
1854  return nullptr;
1855  }
1856 
1857  // If the range index is valid, we are returning a range.
1858  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1859  valueRangeMemory[rangeIndex] = values;
1860  return &valueRangeMemory[rangeIndex];
1861  }
1862 
1863  // If a range index wasn't provided, the range is required to be non-variadic.
1864  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1865 }
1866 
1867 void ByteCodeExecutor::executeGetOperands() {
1868  LDBG() << "Executing GetOperands:";
1869  unsigned index = read<uint32_t>();
1870  Operation *op = read<Operation *>();
1871  ByteCodeField rangeIndex = read();
1872 
1873  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1874  op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1875  valueRangeMemory);
1876  if (!result)
1877  LDBG() << " * Invalid operand range";
1878  memory[read()] = result;
1879 }
1880 
1881 void ByteCodeExecutor::executeGetResult(unsigned index) {
1882  Operation *op = read<Operation *>();
1883  unsigned memIndex = read();
1884  OpResult result =
1885  index < op->getNumResults() ? op->getResult(index) : OpResult();
1886 
1887  LDBG() << " * Operation: " << *op << "\n * Index: " << index
1888  << "\n * Result: " << result;
1889  memory[memIndex] = result.getAsOpaquePointer();
1890 }
1891 
1892 void ByteCodeExecutor::executeGetResults() {
1893  LDBG() << "Executing GetResults:";
1894  unsigned index = read<uint32_t>();
1895  Operation *op = read<Operation *>();
1896  ByteCodeField rangeIndex = read();
1897 
1898  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1899  op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1900  valueRangeMemory);
1901  if (!result)
1902  LDBG() << " * Invalid result range";
1903  memory[read()] = result;
1904 }
1905 
1906 void ByteCodeExecutor::executeGetUsers() {
1907  LDBG() << "Executing GetUsers:";
1908  unsigned memIndex = read();
1909  unsigned rangeIndex = read();
1910  OwningOpRange &range = opRangeMemory[rangeIndex];
1911  memory[memIndex] = &range;
1912 
1913  range = OwningOpRange();
1914  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1915  // Read the value.
1916  Value value = read<Value>();
1917  if (!value)
1918  return;
1919  LDBG() << " * Value: " << value;
1920 
1921  // Extract the users of a single value.
1922  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1923  llvm::copy(value.getUsers(), range.begin());
1924  } else {
1925  // Read a range of values.
1926  ValueRange *values = read<ValueRange *>();
1927  if (!values)
1928  return;
1929  LDBG() << " * Values (" << values->size()
1930  << "): " << llvm::interleaved(*values);
1931 
1932  // Extract all the users of a range of values.
1934  for (Value value : *values)
1935  users.append(value.user_begin(), value.user_end());
1936  range = OwningOpRange(users.size());
1937  llvm::copy(users, range.begin());
1938  }
1939 
1940  LDBG() << " * Result: " << range.size() << " operations";
1941 }
1942 
1943 void ByteCodeExecutor::executeGetValueType() {
1944  LDBG() << "Executing GetValueType:";
1945  unsigned memIndex = read();
1946  Value value = read<Value>();
1947  Type type = value ? value.getType() : Type();
1948 
1949  LDBG() << " * Value: " << value << "\n * Result: " << type;
1950  memory[memIndex] = type.getAsOpaquePointer();
1951 }
1952 
1953 void ByteCodeExecutor::executeGetValueRangeTypes() {
1954  LDBG() << "Executing GetValueRangeTypes:";
1955  unsigned memIndex = read();
1956  unsigned rangeIndex = read();
1957  ValueRange *values = read<ValueRange *>();
1958  if (!values) {
1959  LDBG() << " * Values: <NULL>";
1960  memory[memIndex] = nullptr;
1961  return;
1962  }
1963 
1964  LDBG() << " * Values (" << values->size()
1965  << "): " << llvm::interleaved(*values)
1966  << "\n * Result: " << llvm::interleaved(values->getType());
1967  typeRangeMemory[rangeIndex] = values->getType();
1968  memory[memIndex] = &typeRangeMemory[rangeIndex];
1969 }
1970 
1971 void ByteCodeExecutor::executeIsNotNull() {
1972  LDBG() << "Executing IsNotNull:";
1973  const void *value = read<const void *>();
1974 
1975  LDBG() << " * Value: " << value;
1976  selectJump(value != nullptr);
1977 }
1978 
1979 void ByteCodeExecutor::executeRecordMatch(
1980  PatternRewriter &rewriter,
1982  LDBG() << "Executing RecordMatch:";
1983  unsigned patternIndex = read();
1984  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1985  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1986 
1987  // If the benefit of the pattern is impossible, skip the processing of the
1988  // rest of the pattern.
1989  if (benefit.isImpossibleToMatch()) {
1990  LDBG() << " * Benefit: Impossible To Match";
1991  curCodeIt = dest;
1992  return;
1993  }
1994 
1995  // Create a fused location containing the locations of each of the
1996  // operations used in the match. This will be used as the location for
1997  // created operations during the rewrite that don't already have an
1998  // explicit location set.
1999  unsigned numMatchLocs = read();
2000  SmallVector<Location, 4> matchLocs;
2001  matchLocs.reserve(numMatchLocs);
2002  for (unsigned i = 0; i != numMatchLocs; ++i)
2003  matchLocs.push_back(read<Operation *>()->getLoc());
2004  Location matchLoc = rewriter.getFusedLoc(matchLocs);
2005 
2006  LDBG() << " * Benefit: " << benefit.getBenefit();
2007  LDBG() << " * Location: " << matchLoc;
2008  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2009  PDLByteCode::MatchResult &match = matches.back();
2010 
2011  // Record all of the inputs to the match. If any of the inputs are ranges, we
2012  // will also need to remap the range pointer to memory stored in the match
2013  // state.
2014  unsigned numInputs = read();
2015  match.values.reserve(numInputs);
2016  match.typeRangeValues.reserve(numInputs);
2017  match.valueRangeValues.reserve(numInputs);
2018  for (unsigned i = 0; i < numInputs; ++i) {
2019  switch (read<PDLValue::Kind>()) {
2020  case PDLValue::Kind::TypeRange:
2021  match.typeRangeValues.push_back(*read<TypeRange *>());
2022  match.values.push_back(&match.typeRangeValues.back());
2023  break;
2024  case PDLValue::Kind::ValueRange:
2025  match.valueRangeValues.push_back(*read<ValueRange *>());
2026  match.values.push_back(&match.valueRangeValues.back());
2027  break;
2028  default:
2029  match.values.push_back(read<const void *>());
2030  break;
2031  }
2032  }
2033  curCodeIt = dest;
2034 }
2035 
2036 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2037  LDBG() << "Executing ReplaceOp:";
2038  Operation *op = read<Operation *>();
2040  readList(args);
2041 
2042  LDBG() << " * Operation: " << *op
2043  << "\n * Values: " << llvm::interleaved(args);
2044  rewriter.replaceOp(op, args);
2045 }
2046 
2047 void ByteCodeExecutor::executeSwitchAttribute() {
2048  LDBG() << "Executing SwitchAttribute:";
2049  Attribute value = read<Attribute>();
2050  ArrayAttr cases = read<ArrayAttr>();
2051  handleSwitch(value, cases);
2052 }
2053 
2054 void ByteCodeExecutor::executeSwitchOperandCount() {
2055  LDBG() << "Executing SwitchOperandCount:";
2056  Operation *op = read<Operation *>();
2057  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2058 
2059  LDBG() << " * Operation: " << *op;
2060  handleSwitch(op->getNumOperands(), cases);
2061 }
2062 
2063 void ByteCodeExecutor::executeSwitchOperationName() {
2064  LDBG() << "Executing SwitchOperationName:";
2065  OperationName value = read<Operation *>()->getName();
2066  size_t caseCount = read();
2067 
2068  // The operation names are stored in-line, so to print them out for
2069  // debugging purposes we need to read the array before executing the
2070  // switch so that we can display all of the possible values.
2071  LLVM_DEBUG({
2072  const ByteCodeField *prevCodeIt = curCodeIt;
2073  LDBG() << " * Value: " << value << "\n * Cases: "
2074  << llvm::interleaved(
2075  llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) {
2076  return read<OperationName>();
2077  }));
2078  curCodeIt = prevCodeIt;
2079  });
2080 
2081  // Try to find the switch value within any of the cases.
2082  for (size_t i = 0; i != caseCount; ++i) {
2083  if (read<OperationName>() == value) {
2084  curCodeIt += (caseCount - i - 1);
2085  return selectJump(i + 1);
2086  }
2087  }
2088  selectJump(size_t(0));
2089 }
2090 
2091 void ByteCodeExecutor::executeSwitchResultCount() {
2092  LDBG() << "Executing SwitchResultCount:";
2093  Operation *op = read<Operation *>();
2094  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2095 
2096  LDBG() << " * Operation: " << *op;
2097  handleSwitch(op->getNumResults(), cases);
2098 }
2099 
2100 void ByteCodeExecutor::executeSwitchType() {
2101  LDBG() << "Executing SwitchType:";
2102  Type value = read<Type>();
2103  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2104  handleSwitch(value, cases);
2105 }
2106 
2107 void ByteCodeExecutor::executeSwitchTypes() {
2108  LDBG() << "Executing SwitchTypes:";
2109  TypeRange *value = read<TypeRange *>();
2110  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2111  if (!value) {
2112  LDBG() << "Types: <NULL>";
2113  return selectJump(size_t(0));
2114  }
2115  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2116  return value == caseValue.getAsValueRange<TypeAttr>();
2117  });
2118 }
2119 
2120 LogicalResult
2121 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2123  std::optional<Location> mainRewriteLoc) {
2124  while (true) {
2125  // Print the location of the operation being executed.
2126  LDBG() << readInline<Location>();
2127 
2128  OpCode opCode = static_cast<OpCode>(read());
2129  switch (opCode) {
2130  case ApplyConstraint:
2131  executeApplyConstraint(rewriter);
2132  break;
2133  case ApplyRewrite:
2134  if (failed(executeApplyRewrite(rewriter)))
2135  return failure();
2136  break;
2137  case AreEqual:
2138  executeAreEqual();
2139  break;
2140  case AreRangesEqual:
2141  executeAreRangesEqual();
2142  break;
2143  case Branch:
2144  executeBranch();
2145  break;
2146  case CheckOperandCount:
2147  executeCheckOperandCount();
2148  break;
2149  case CheckOperationName:
2150  executeCheckOperationName();
2151  break;
2152  case CheckResultCount:
2153  executeCheckResultCount();
2154  break;
2155  case CheckTypes:
2156  executeCheckTypes();
2157  break;
2158  case Continue:
2159  executeContinue();
2160  break;
2161  case CreateConstantTypeRange:
2162  executeCreateConstantTypeRange();
2163  break;
2164  case CreateOperation:
2165  executeCreateOperation(rewriter, *mainRewriteLoc);
2166  break;
2167  case CreateDynamicTypeRange:
2168  executeDynamicCreateRange<Type>("Type");
2169  break;
2170  case CreateDynamicValueRange:
2171  executeDynamicCreateRange<Value>("Value");
2172  break;
2173  case EraseOp:
2174  executeEraseOp(rewriter);
2175  break;
2176  case ExtractOp:
2177  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2178  break;
2179  case ExtractType:
2180  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2181  break;
2182  case ExtractValue:
2183  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2184  break;
2185  case Finalize:
2186  executeFinalize();
2187  LDBG() << "";
2188  return success();
2189  case ForEach:
2190  executeForEach();
2191  break;
2192  case GetAttribute:
2193  executeGetAttribute();
2194  break;
2195  case GetAttributeType:
2196  executeGetAttributeType();
2197  break;
2198  case GetDefiningOp:
2199  executeGetDefiningOp();
2200  break;
2201  case GetOperand0:
2202  case GetOperand1:
2203  case GetOperand2:
2204  case GetOperand3: {
2205  unsigned index = opCode - GetOperand0;
2206  LDBG() << "Executing GetOperand" << index << ":";
2207  executeGetOperand(index);
2208  break;
2209  }
2210  case GetOperandN:
2211  LDBG() << "Executing GetOperandN:";
2212  executeGetOperand(read<uint32_t>());
2213  break;
2214  case GetOperands:
2215  executeGetOperands();
2216  break;
2217  case GetResult0:
2218  case GetResult1:
2219  case GetResult2:
2220  case GetResult3: {
2221  unsigned index = opCode - GetResult0;
2222  LDBG() << "Executing GetResult" << index << ":";
2223  executeGetResult(index);
2224  break;
2225  }
2226  case GetResultN:
2227  LDBG() << "Executing GetResultN:";
2228  executeGetResult(read<uint32_t>());
2229  break;
2230  case GetResults:
2231  executeGetResults();
2232  break;
2233  case GetUsers:
2234  executeGetUsers();
2235  break;
2236  case GetValueType:
2237  executeGetValueType();
2238  break;
2239  case GetValueRangeTypes:
2240  executeGetValueRangeTypes();
2241  break;
2242  case IsNotNull:
2243  executeIsNotNull();
2244  break;
2245  case RecordMatch:
2246  assert(matches &&
2247  "expected matches to be provided when executing the matcher");
2248  executeRecordMatch(rewriter, *matches);
2249  break;
2250  case ReplaceOp:
2251  executeReplaceOp(rewriter);
2252  break;
2253  case SwitchAttribute:
2254  executeSwitchAttribute();
2255  break;
2256  case SwitchOperandCount:
2257  executeSwitchOperandCount();
2258  break;
2259  case SwitchOperationName:
2260  executeSwitchOperationName();
2261  break;
2262  case SwitchResultCount:
2263  executeSwitchResultCount();
2264  break;
2265  case SwitchType:
2266  executeSwitchType();
2267  break;
2268  case SwitchTypes:
2269  executeSwitchTypes();
2270  break;
2271  }
2272  LDBG() << "";
2273  }
2274 }
2275 
2276 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2278  PDLByteCodeMutableState &state) const {
2279  // The first memory slot is always the root operation.
2280  state.memory[0] = op;
2281 
2282  // The matcher function always starts at code address 0.
2283  ByteCodeExecutor executor(
2284  matcherByteCode.data(), state.memory, state.opRangeMemory,
2285  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2286  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2287  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2288  constraintFunctions, rewriteFunctions);
2289  LogicalResult executeResult = executor.execute(rewriter, &matches);
2290  (void)executeResult;
2291  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2292 
2293  // Order the found matches by benefit.
2294  llvm::stable_sort(matches,
2295  [](const MatchResult &lhs, const MatchResult &rhs) {
2296  return lhs.benefit > rhs.benefit;
2297  });
2298 }
2299 
2300 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2301  const MatchResult &match,
2302  PDLByteCodeMutableState &state) const {
2303  auto *configSet = match.pattern->getConfigSet();
2304  if (configSet)
2305  configSet->notifyRewriteBegin(rewriter);
2306 
2307  // The arguments of the rewrite function are stored at the start of the
2308  // memory buffer.
2309  llvm::copy(match.values, state.memory.begin());
2310 
2311  ByteCodeExecutor executor(
2312  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2313  state.opRangeMemory, state.typeRangeMemory,
2314  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2315  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2316  rewriterByteCode, state.currentPatternBenefits, patterns,
2317  constraintFunctions, rewriteFunctions);
2318  LogicalResult result =
2319  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2320 
2321  if (configSet)
2322  configSet->notifyRewriteEnd(rewriter);
2323 
2324  // If the rewrite failed, check if the pattern rewriter can recover. If it
2325  // can, we can signal to the pattern applicator to keep trying patterns. If it
2326  // doesn't, we need to bail. Bailing here should be fine, given that we have
2327  // no means to propagate such a failure to the user, and it also indicates a
2328  // bug in the user code (i.e. failable rewrites should not be used with
2329  // pattern rewriters that don't support it).
2330  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2331  LDBG() << " and rollback is not supported - aborting";
2332  llvm::report_fatal_error(
2333  "Native PDL Rewrite failed, but the pattern "
2334  "rewriter doesn't support recovery. Failable pattern rewrites should "
2335  "not be used with pattern rewriters that do not support them.");
2336  }
2337  return result;
2338 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1820
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:175
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1253::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:73
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:27
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:36
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:27
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:457
This is a value defined by a result of an operation.
Definition: Value.h:447
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:802
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:164
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:215
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:216
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
user_iterator user_end() const
Definition: Value.h:217
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition: ByteCode.h:235
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:102
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size