MLIR  21.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 #include <optional>
27 
28 #define DEBUG_TYPE "pdl-bytecode"
29 
30 using namespace mlir;
31 using namespace mlir::detail;
32 
33 //===----------------------------------------------------------------------===//
34 // PDLByteCodePattern
35 //===----------------------------------------------------------------------===//
36 
37 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
38  PDLPatternConfigSet *configSet,
39  ByteCodeAddr rewriterAddr) {
40  PatternBenefit benefit = matchOp.getBenefit();
41  MLIRContext *ctx = matchOp.getContext();
42 
43  // Collect the set of generated operations.
44  SmallVector<StringRef, 8> generatedOps;
45  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
46  generatedOps =
47  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
48 
49  // Check to see if this is pattern matches a specific operation type.
50  if (std::optional<StringRef> rootKind = matchOp.getRootKind())
51  return PDLByteCodePattern(rewriterAddr, configSet, *rootKind, benefit, ctx,
52  generatedOps);
53  return PDLByteCodePattern(rewriterAddr, configSet, MatchAnyOpTypeTag(),
54  benefit, ctx, generatedOps);
55 }
56 
57 //===----------------------------------------------------------------------===//
58 // PDLByteCodeMutableState
59 //===----------------------------------------------------------------------===//
60 
61 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
62 /// to the position of the pattern within the range returned by
63 /// `PDLByteCode::getPatterns`.
64 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
65  PatternBenefit benefit) {
66  currentPatternBenefits[patternIndex] = benefit;
67 }
68 
69 /// Cleanup any allocated state after a full match/rewrite has been completed.
70 /// This method should be called irregardless of whether the match+rewrite was a
71 /// success or not.
73  allocatedTypeRangeMemory.clear();
74  allocatedValueRangeMemory.clear();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Bytecode OpCodes
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 enum OpCode : ByteCodeField {
83  /// Apply an externally registered constraint.
84  ApplyConstraint,
85  /// Apply an externally registered rewrite.
86  ApplyRewrite,
87  /// Check if two generic values are equal.
88  AreEqual,
89  /// Check if two ranges are equal.
90  AreRangesEqual,
91  /// Unconditional branch.
92  Branch,
93  /// Compare the operand count of an operation with a constant.
94  CheckOperandCount,
95  /// Compare the name of an operation with a constant.
96  CheckOperationName,
97  /// Compare the result count of an operation with a constant.
98  CheckResultCount,
99  /// Compare a range of types to a constant range of types.
100  CheckTypes,
101  /// Continue to the next iteration of a loop.
102  Continue,
103  /// Create a type range from a list of constant types.
104  CreateConstantTypeRange,
105  /// Create an operation.
106  CreateOperation,
107  /// Create a type range from a list of dynamic types.
108  CreateDynamicTypeRange,
109  /// Create a value range.
110  CreateDynamicValueRange,
111  /// Erase an operation.
112  EraseOp,
113  /// Extract the op from a range at the specified index.
114  ExtractOp,
115  /// Extract the type from a range at the specified index.
116  ExtractType,
117  /// Extract the value from a range at the specified index.
118  ExtractValue,
119  /// Terminate a matcher or rewrite sequence.
120  Finalize,
121  /// Iterate over a range of values.
122  ForEach,
123  /// Get a specific attribute of an operation.
124  GetAttribute,
125  /// Get the type of an attribute.
126  GetAttributeType,
127  /// Get the defining operation of a value.
128  GetDefiningOp,
129  /// Get a specific operand of an operation.
130  GetOperand0,
131  GetOperand1,
132  GetOperand2,
133  GetOperand3,
134  GetOperandN,
135  /// Get a specific operand group of an operation.
136  GetOperands,
137  /// Get a specific result of an operation.
138  GetResult0,
139  GetResult1,
140  GetResult2,
141  GetResult3,
142  GetResultN,
143  /// Get a specific result group of an operation.
144  GetResults,
145  /// Get the users of a value or a range of values.
146  GetUsers,
147  /// Get the type of a value.
148  GetValueType,
149  /// Get the types of a value range.
150  GetValueRangeTypes,
151  /// Check if a generic value is not null.
152  IsNotNull,
153  /// Record a successful pattern match.
154  RecordMatch,
155  /// Replace an operation.
156  ReplaceOp,
157  /// Compare an attribute with a set of constants.
158  SwitchAttribute,
159  /// Compare the operand count of an operation with a set of constants.
160  SwitchOperandCount,
161  /// Compare the name of an operation with a set of constants.
162  SwitchOperationName,
163  /// Compare the result count of an operation with a set of constants.
164  SwitchResultCount,
165  /// Compare a type with a set of constants.
166  SwitchType,
167  /// Compare a range of types with a set of constants.
168  SwitchTypes,
169 };
170 } // namespace
171 
172 /// A marker used to indicate if an operation should infer types.
173 static constexpr ByteCodeField kInferTypesMarker =
175 
176 //===----------------------------------------------------------------------===//
177 // ByteCode Generation
178 //===----------------------------------------------------------------------===//
179 
180 //===----------------------------------------------------------------------===//
181 // Generator
182 //===----------------------------------------------------------------------===//
183 
184 namespace {
185 struct ByteCodeLiveRange;
186 struct ByteCodeWriter;
187 
188 /// Check if the given class `T` can be converted to an opaque pointer.
189 template <typename T, typename... Args>
190 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
191 
192 /// This class represents the main generator for the pattern bytecode.
193 class Generator {
194 public:
195  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
196  SmallVectorImpl<ByteCodeField> &matcherByteCode,
197  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
199  ByteCodeField &maxValueMemoryIndex,
200  ByteCodeField &maxOpRangeMemoryIndex,
201  ByteCodeField &maxTypeRangeMemoryIndex,
202  ByteCodeField &maxValueRangeMemoryIndex,
203  ByteCodeField &maxLoopLevel,
204  llvm::StringMap<PDLConstraintFunction> &constraintFns,
205  llvm::StringMap<PDLRewriteFunction> &rewriteFns,
207  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
208  rewriterByteCode(rewriterByteCode), patterns(patterns),
209  maxValueMemoryIndex(maxValueMemoryIndex),
210  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
211  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
212  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
213  maxLoopLevel(maxLoopLevel), configMap(configMap) {
214  for (const auto &it : llvm::enumerate(constraintFns))
215  constraintToMemIndex.try_emplace(it.value().first(), it.index());
216  for (const auto &it : llvm::enumerate(rewriteFns))
217  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
218  }
219 
220  /// Generate the bytecode for the given PDL interpreter module.
221  void generate(ModuleOp module);
222 
223  /// Return the memory index to use for the given value.
224  ByteCodeField &getMemIndex(Value value) {
225  assert(valueToMemIndex.count(value) &&
226  "expected memory index to be assigned");
227  return valueToMemIndex[value];
228  }
229 
230  /// Return the range memory index used to store the given range value.
231  ByteCodeField &getRangeStorageIndex(Value value) {
232  assert(valueToRangeIndex.count(value) &&
233  "expected range index to be assigned");
234  return valueToRangeIndex[value];
235  }
236 
237  /// Return an index to use when referring to the given data that is uniqued in
238  /// the MLIR context.
239  template <typename T>
240  std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
241  getMemIndex(T val) {
242  const void *opaqueVal = val.getAsOpaquePointer();
243 
244  // Get or insert a reference to this value.
245  auto it = uniquedDataToMemIndex.try_emplace(
246  opaqueVal, maxValueMemoryIndex + uniquedData.size());
247  if (it.second)
248  uniquedData.push_back(opaqueVal);
249  return it.first->second;
250  }
251 
252 private:
253  /// Allocate memory indices for the results of operations within the matcher
254  /// and rewriters.
255  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
256  ModuleOp rewriterModule);
257 
258  /// Generate the bytecode for the given operation.
259  void generate(Region *region, ByteCodeWriter &writer);
260  void generate(Operation *op, ByteCodeWriter &writer);
261  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
289  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
290  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
291  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
292  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
293  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
294  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
295  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
296  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
297  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
298  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
299 
300  /// Mapping from value to its corresponding memory index.
301  DenseMap<Value, ByteCodeField> valueToMemIndex;
302 
303  /// Mapping from a range value to its corresponding range storage index.
304  DenseMap<Value, ByteCodeField> valueToRangeIndex;
305 
306  /// Mapping from the name of an externally registered rewrite to its index in
307  /// the bytecode registry.
308  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
309 
310  /// Mapping from the name of an externally registered constraint to its index
311  /// in the bytecode registry.
312  llvm::StringMap<ByteCodeField> constraintToMemIndex;
313 
314  /// Mapping from rewriter function name to the bytecode address of the
315  /// rewriter function in byte.
316  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
317 
318  /// Mapping from a uniqued storage object to its memory index within
319  /// `uniquedData`.
320  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
321 
322  /// The current level of the foreach loop.
323  ByteCodeField curLoopLevel = 0;
324 
325  /// The current MLIR context.
326  MLIRContext *ctx;
327 
328  /// Mapping from block to its address.
330 
331  /// Data of the ByteCode class to be populated.
332  std::vector<const void *> &uniquedData;
333  SmallVectorImpl<ByteCodeField> &matcherByteCode;
334  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
336  ByteCodeField &maxValueMemoryIndex;
337  ByteCodeField &maxOpRangeMemoryIndex;
338  ByteCodeField &maxTypeRangeMemoryIndex;
339  ByteCodeField &maxValueRangeMemoryIndex;
340  ByteCodeField &maxLoopLevel;
341 
342  /// A map of pattern configurations.
344 };
345 
346 /// This class provides utilities for writing a bytecode stream.
347 struct ByteCodeWriter {
348  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
349  : bytecode(bytecode), generator(generator) {}
350 
351  /// Append a field to the bytecode.
352  void append(ByteCodeField field) { bytecode.push_back(field); }
353  void append(OpCode opCode) { bytecode.push_back(opCode); }
354 
355  /// Append an address to the bytecode.
356  void append(ByteCodeAddr field) {
357  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
358  "unexpected ByteCode address size");
359 
360  ByteCodeField fieldParts[2];
361  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
362  bytecode.append({fieldParts[0], fieldParts[1]});
363  }
364 
365  /// Append a single successor to the bytecode, the exact address will need to
366  /// be resolved later.
367  void append(Block *successor) {
368  // Add back a reference to the successor so that the address can be resolved
369  // later.
370  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
371  append(ByteCodeAddr(0));
372  }
373 
374  /// Append a successor range to the bytecode, the exact address will need to
375  /// be resolved later.
376  void append(SuccessorRange successors) {
377  for (Block *successor : successors)
378  append(successor);
379  }
380 
381  /// Append a range of values that will be read as generic PDLValues.
382  void appendPDLValueList(OperandRange values) {
383  bytecode.push_back(values.size());
384  for (Value value : values)
385  appendPDLValue(value);
386  }
387 
388  /// Append a value as a PDLValue.
389  void appendPDLValue(Value value) {
390  appendPDLValueKind(value);
391  append(value);
392  }
393 
394  /// Append the PDLValue::Kind of the given value.
395  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
396 
397  /// Append the PDLValue::Kind of the given type.
398  void appendPDLValueKind(Type type) {
401  .Case<pdl::AttributeType>(
402  [](Type) { return PDLValue::Kind::Attribute; })
403  .Case<pdl::OperationType>(
404  [](Type) { return PDLValue::Kind::Operation; })
405  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
406  if (isa<pdl::TypeType>(rangeTy.getElementType()))
407  return PDLValue::Kind::TypeRange;
408  return PDLValue::Kind::ValueRange;
409  })
410  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
411  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
412  bytecode.push_back(static_cast<ByteCodeField>(kind));
413  }
414 
415  /// Append a value that will be stored in a memory slot and not inline within
416  /// the bytecode.
417  template <typename T>
418  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
419  std::is_pointer<T>::value>
420  append(T value) {
421  bytecode.push_back(generator.getMemIndex(value));
422  }
423 
424  /// Append a range of values.
425  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
426  std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
427  append(T range) {
428  bytecode.push_back(llvm::size(range));
429  for (auto it : range)
430  append(it);
431  }
432 
433  /// Append a variadic number of fields to the bytecode.
434  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
435  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
436  append(field);
437  append(field2, fields...);
438  }
439 
440  /// Appends a value as a pointer, stored inline within the bytecode.
441  template <typename T>
442  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
443  appendInline(T value) {
444  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
445  const void *pointer = value.getAsOpaquePointer();
446  ByteCodeField fieldParts[numParts];
447  std::memcpy(fieldParts, &pointer, sizeof(const void *));
448  bytecode.append(fieldParts, fieldParts + numParts);
449  }
450 
451  /// Successor references in the bytecode that have yet to be resolved.
452  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
453 
454  /// The underlying bytecode buffer.
456 
457  /// The main generator producing PDL.
458  Generator &generator;
459 };
460 
461 /// This class represents a live range of PDL Interpreter values, containing
462 /// information about when values are live within a match/rewrite.
463 struct ByteCodeLiveRange {
464  using Set = llvm::IntervalMap<uint64_t, char, 16>;
465  using Allocator = Set::Allocator;
466 
467  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
468 
469  /// Union this live range with the one provided.
470  void unionWith(const ByteCodeLiveRange &rhs) {
471  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
472  ++it)
473  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
474  }
475 
476  /// Returns true if this range overlaps with the one provided.
477  bool overlaps(const ByteCodeLiveRange &rhs) const {
478  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
479  .valid();
480  }
481 
482  /// A map representing the ranges of the match/rewrite that a value is live in
483  /// the interpreter.
484  ///
485  /// We use std::unique_ptr here, because IntervalMap does not provide a
486  /// correct copy or move constructor. We can eliminate the pointer once
487  /// https://reviews.llvm.org/D113240 lands.
488  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
489 
490  /// The operation range storage index for this range.
491  std::optional<unsigned> opRangeIndex;
492 
493  /// The type range storage index for this range.
494  std::optional<unsigned> typeRangeIndex;
495 
496  /// The value range storage index for this range.
497  std::optional<unsigned> valueRangeIndex;
498 };
499 } // namespace
500 
501 void Generator::generate(ModuleOp module) {
502  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
503  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
504  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
505  pdl_interp::PDLInterpDialect::getRewriterModuleName());
506  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
507 
508  // Allocate memory indices for the results of operations within the matcher
509  // and rewriters.
510  allocateMemoryIndices(matcherFunc, rewriterModule);
511 
512  // Generate code for the rewriter functions.
513  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
514  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
515  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
516  for (Operation &op : rewriterFunc.getOps())
517  generate(&op, rewriterByteCodeWriter);
518  }
519  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
520  "unexpected branches in rewriter function");
521 
522  // Generate code for the matcher function.
523  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
524  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
525 
526  // Resolve successor references in the matcher.
527  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528  ByteCodeAddr addr = blockToAddr[it.first];
529  for (unsigned offsetToFix : it.second)
530  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
531  }
532 }
533 
534 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
535  ModuleOp rewriterModule) {
536  // Rewriters use simplistic allocation scheme that simply assigns an index to
537  // each result.
538  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
539  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
540  auto processRewriterValue = [&](Value val) {
541  valueToMemIndex.try_emplace(val, index++);
542  if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
543  Type elementTy = rangeType.getElementType();
544  if (isa<pdl::TypeType>(elementTy))
545  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
546  else if (isa<pdl::ValueType>(elementTy))
547  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
548  }
549  };
550 
551  for (BlockArgument arg : rewriterFunc.getArguments())
552  processRewriterValue(arg);
553  rewriterFunc.getBody().walk([&](Operation *op) {
554  for (Value result : op->getResults())
555  processRewriterValue(result);
556  });
557  if (index > maxValueMemoryIndex)
558  maxValueMemoryIndex = index;
559  if (typeRangeIndex > maxTypeRangeMemoryIndex)
560  maxTypeRangeMemoryIndex = typeRangeIndex;
561  if (valueRangeIndex > maxValueRangeMemoryIndex)
562  maxValueRangeMemoryIndex = valueRangeIndex;
563  }
564 
565  // The matcher function uses a more sophisticated numbering that tries to
566  // minimize the number of memory indices assigned. This is done by determining
567  // a live range of the values within the matcher, then the allocation is just
568  // finding the minimal number of overlapping live ranges. This is essentially
569  // a simplified form of register allocation where we don't necessarily have a
570  // limited number of registers, but we still want to minimize the number used.
571  DenseMap<Operation *, unsigned> opToFirstIndex;
572  DenseMap<Operation *, unsigned> opToLastIndex;
573 
574  // A custom walk that marks the first and the last index of each operation.
575  // The entry marks the beginning of the liveness range for this operation,
576  // followed by nested operations, followed by the end of the liveness range.
577  unsigned index = 0;
578  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
579  opToFirstIndex.try_emplace(op, index++);
580  for (Region &region : op->getRegions())
581  for (Block &block : region.getBlocks())
582  for (Operation &nested : block)
583  walk(&nested);
584  opToLastIndex.try_emplace(op, index++);
585  };
586  walk(matcherFunc);
587 
588  // Liveness info for each of the defs within the matcher.
589  ByteCodeLiveRange::Allocator allocator;
590  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
591 
592  // Assign the root operation being matched to slot 0.
593  BlockArgument rootOpArg = matcherFunc.getArgument(0);
594  valueToMemIndex[rootOpArg] = 0;
595 
596  // Walk each of the blocks, computing the def interval that the value is used.
597  Liveness matcherLiveness(matcherFunc);
598  matcherFunc->walk([&](Block *block) {
599  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
600  assert(info && "expected liveness info for block");
601  auto processValue = [&](Value value, Operation *firstUseOrDef) {
602  // We don't need to process the root op argument, this value is always
603  // assigned to the first memory slot.
604  if (value == rootOpArg)
605  return;
606 
607  // Set indices for the range of this block that the value is used.
608  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
609  defRangeIt->second.liveness->insert(
610  opToFirstIndex[firstUseOrDef],
611  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
612  /*dummyValue*/ 0);
613 
614  // Check to see if this value is a range type.
615  if (auto rangeTy = dyn_cast<pdl::RangeType>(value.getType())) {
616  Type eleType = rangeTy.getElementType();
617  if (isa<pdl::OperationType>(eleType))
618  defRangeIt->second.opRangeIndex = 0;
619  else if (isa<pdl::TypeType>(eleType))
620  defRangeIt->second.typeRangeIndex = 0;
621  else if (isa<pdl::ValueType>(eleType))
622  defRangeIt->second.valueRangeIndex = 0;
623  }
624  };
625 
626  // Process the live-ins of this block.
627  for (Value liveIn : info->in()) {
628  // Only process the value if it has been defined in the current region.
629  // Other values that span across pdl_interp.foreach will be added higher
630  // up. This ensures that the we keep them alive for the entire duration
631  // of the loop.
632  if (liveIn.getParentRegion() == block->getParent())
633  processValue(liveIn, &block->front());
634  }
635 
636  // Process the block arguments for the entry block (those are not live-in).
637  if (block->isEntryBlock()) {
638  for (Value argument : block->getArguments())
639  processValue(argument, &block->front());
640  }
641 
642  // Process any new defs within this block.
643  for (Operation &op : *block)
644  for (Value result : op.getResults())
645  processValue(result, &op);
646  });
647 
648  // Greedily allocate memory slots using the computed def live ranges.
649  std::vector<ByteCodeLiveRange> allocatedIndices;
650 
651  // The number of memory indices currently allocated (and its next value).
652  // Recall that the root gets allocated memory index 0.
653  ByteCodeField numIndices = 1;
654 
655  // The number of memory ranges of various types (and their next values).
656  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
657 
658  for (auto &defIt : valueDefRanges) {
659  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
660  ByteCodeLiveRange &defRange = defIt.second;
661 
662  // Try to allocate to an existing index.
663  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
664  ByteCodeLiveRange &existingRange = existingIndexIt.value();
665  if (!defRange.overlaps(existingRange)) {
666  existingRange.unionWith(defRange);
667  memIndex = existingIndexIt.index() + 1;
668 
669  if (defRange.opRangeIndex) {
670  if (!existingRange.opRangeIndex)
671  existingRange.opRangeIndex = numOpRanges++;
672  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
673  } else if (defRange.typeRangeIndex) {
674  if (!existingRange.typeRangeIndex)
675  existingRange.typeRangeIndex = numTypeRanges++;
676  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
677  } else if (defRange.valueRangeIndex) {
678  if (!existingRange.valueRangeIndex)
679  existingRange.valueRangeIndex = numValueRanges++;
680  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
681  }
682  break;
683  }
684  }
685 
686  // If no existing index could be used, add a new one.
687  if (memIndex == 0) {
688  allocatedIndices.emplace_back(allocator);
689  ByteCodeLiveRange &newRange = allocatedIndices.back();
690  newRange.unionWith(defRange);
691 
692  // Allocate an index for op/type/value ranges.
693  if (defRange.opRangeIndex) {
694  newRange.opRangeIndex = numOpRanges;
695  valueToRangeIndex[defIt.first] = numOpRanges++;
696  } else if (defRange.typeRangeIndex) {
697  newRange.typeRangeIndex = numTypeRanges;
698  valueToRangeIndex[defIt.first] = numTypeRanges++;
699  } else if (defRange.valueRangeIndex) {
700  newRange.valueRangeIndex = numValueRanges;
701  valueToRangeIndex[defIt.first] = numValueRanges++;
702  }
703 
704  memIndex = allocatedIndices.size();
705  ++numIndices;
706  }
707  }
708 
709  // Print the index usage and ensure that we did not run out of index space.
710  LLVM_DEBUG({
711  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
712  << "(down from initial " << valueDefRanges.size() << ").\n";
713  });
714  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
715  "Ran out of memory for allocated indices");
716 
717  // Update the max number of indices.
718  if (numIndices > maxValueMemoryIndex)
719  maxValueMemoryIndex = numIndices;
720  if (numOpRanges > maxOpRangeMemoryIndex)
721  maxOpRangeMemoryIndex = numOpRanges;
722  if (numTypeRanges > maxTypeRangeMemoryIndex)
723  maxTypeRangeMemoryIndex = numTypeRanges;
724  if (numValueRanges > maxValueRangeMemoryIndex)
725  maxValueRangeMemoryIndex = numValueRanges;
726 }
727 
728 void Generator::generate(Region *region, ByteCodeWriter &writer) {
729  llvm::ReversePostOrderTraversal<Region *> rpot(region);
730  for (Block *block : rpot) {
731  // Keep track of where this block begins within the matcher function.
732  blockToAddr.try_emplace(block, matcherByteCode.size());
733  for (Operation &op : *block)
734  generate(&op, writer);
735  }
736 }
737 
738 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
739  LLVM_DEBUG({
740  // The following list must contain all the operations that do not
741  // produce any bytecode.
742  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
743  writer.appendInline(op->getLoc());
744  });
746  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
747  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
748  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
749  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
750  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
751  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
752  pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
753  pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
754  pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
755  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
756  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
757  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
758  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
759  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
760  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
761  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
762  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
763  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
764  pdl_interp::SwitchResultCountOp>(
765  [&](auto interpOp) { this->generate(interpOp, writer); })
766  .Default([](Operation *) {
767  llvm_unreachable("unknown `pdl_interp` operation");
768  });
769 }
770 
771 void Generator::generate(pdl_interp::ApplyConstraintOp op,
772  ByteCodeWriter &writer) {
773  // Constraints that should return a value have to be registered as rewrites.
774  // If a constraint and a rewrite of similar name are registered the
775  // constraint takes precedence
776  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
777  writer.appendPDLValueList(op.getArgs());
778  writer.append(ByteCodeField(op.getIsNegated()));
779  ResultRange results = op.getResults();
780  writer.append(ByteCodeField(results.size()));
781  for (Value result : results) {
782  // We record the expected kind of the result, so that we can provide extra
783  // verification of the native rewrite function and handle the failure case
784  // of constraints accordingly.
785  writer.appendPDLValueKind(result);
786 
787  // Range results also need to append the range storage index.
788  if (isa<pdl::RangeType>(result.getType()))
789  writer.append(getRangeStorageIndex(result));
790  writer.append(result);
791  }
792  writer.append(op.getSuccessors());
793 }
794 void Generator::generate(pdl_interp::ApplyRewriteOp op,
795  ByteCodeWriter &writer) {
796  assert(externalRewriterToMemIndex.count(op.getName()) &&
797  "expected index for rewrite function");
798  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
799  writer.appendPDLValueList(op.getArgs());
800 
801  ResultRange results = op.getResults();
802  writer.append(ByteCodeField(results.size()));
803  for (Value result : results) {
804  // We record the expected kind of the result, so that we
805  // can provide extra verification of the native rewrite function.
806  writer.appendPDLValueKind(result);
807 
808  // Range results also need to append the range storage index.
809  if (isa<pdl::RangeType>(result.getType()))
810  writer.append(getRangeStorageIndex(result));
811  writer.append(result);
812  }
813 }
814 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
815  Value lhs = op.getLhs();
816  if (isa<pdl::RangeType>(lhs.getType())) {
817  writer.append(OpCode::AreRangesEqual);
818  writer.appendPDLValueKind(lhs);
819  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
820  return;
821  }
822 
823  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
824 }
825 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
826  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
827 }
828 void Generator::generate(pdl_interp::CheckAttributeOp op,
829  ByteCodeWriter &writer) {
830  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
831  op.getSuccessors());
832 }
833 void Generator::generate(pdl_interp::CheckOperandCountOp op,
834  ByteCodeWriter &writer) {
835  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
836  static_cast<ByteCodeField>(op.getCompareAtLeast()),
837  op.getSuccessors());
838 }
839 void Generator::generate(pdl_interp::CheckOperationNameOp op,
840  ByteCodeWriter &writer) {
841  writer.append(OpCode::CheckOperationName, op.getInputOp(),
842  OperationName(op.getName(), ctx), op.getSuccessors());
843 }
844 void Generator::generate(pdl_interp::CheckResultCountOp op,
845  ByteCodeWriter &writer) {
846  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
847  static_cast<ByteCodeField>(op.getCompareAtLeast()),
848  op.getSuccessors());
849 }
850 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
851  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
852  op.getSuccessors());
853 }
854 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
855  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
856  op.getSuccessors());
857 }
858 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
859  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
860  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
861 }
862 void Generator::generate(pdl_interp::CreateAttributeOp op,
863  ByteCodeWriter &writer) {
864  // Simply repoint the memory index of the result to the constant.
865  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
866 }
867 void Generator::generate(pdl_interp::CreateOperationOp op,
868  ByteCodeWriter &writer) {
869  writer.append(OpCode::CreateOperation, op.getResultOp(),
870  OperationName(op.getName(), ctx));
871  writer.appendPDLValueList(op.getInputOperands());
872 
873  // Add the attributes.
874  OperandRange attributes = op.getInputAttributes();
875  writer.append(static_cast<ByteCodeField>(attributes.size()));
876  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
877  writer.append(std::get<0>(it), std::get<1>(it));
878 
879  // Add the result types. If the operation has inferred results, we use a
880  // marker "size" value. Otherwise, we add the list of explicit result types.
881  if (op.getInferredResultTypes())
882  writer.append(kInferTypesMarker);
883  else
884  writer.appendPDLValueList(op.getInputResultTypes());
885 }
886 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
887  // Append the correct opcode for the range type.
888  TypeSwitch<Type>(op.getType().getElementType())
889  .Case(
890  [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
891  .Case([&](pdl::ValueType) {
892  writer.append(OpCode::CreateDynamicValueRange);
893  });
894 
895  writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
896  writer.appendPDLValueList(op->getOperands());
897 }
898 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
899  // Simply repoint the memory index of the result to the constant.
900  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
901 }
902 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
903  writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
904  getRangeStorageIndex(op.getResult()), op.getValue());
905 }
906 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
907  writer.append(OpCode::EraseOp, op.getInputOp());
908 }
909 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
910  OpCode opCode =
911  TypeSwitch<Type, OpCode>(op.getResult().getType())
912  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
913  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
914  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
915  .Default([](Type) -> OpCode {
916  llvm_unreachable("unsupported element type");
917  });
918  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
919 }
920 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
921  writer.append(OpCode::Finalize);
922 }
923 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
924  BlockArgument arg = op.getLoopVariable();
925  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
926  writer.appendPDLValueKind(arg.getType());
927  writer.append(curLoopLevel, op.getSuccessor());
928  ++curLoopLevel;
929  if (curLoopLevel > maxLoopLevel)
930  maxLoopLevel = curLoopLevel;
931  generate(&op.getRegion(), writer);
932  --curLoopLevel;
933 }
934 void Generator::generate(pdl_interp::GetAttributeOp op,
935  ByteCodeWriter &writer) {
936  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
937  op.getNameAttr());
938 }
939 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
940  ByteCodeWriter &writer) {
941  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
942 }
943 void Generator::generate(pdl_interp::GetDefiningOpOp op,
944  ByteCodeWriter &writer) {
945  writer.append(OpCode::GetDefiningOp, op.getInputOp());
946  writer.appendPDLValue(op.getValue());
947 }
948 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
949  uint32_t index = op.getIndex();
950  if (index < 4)
951  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
952  else
953  writer.append(OpCode::GetOperandN, index);
954  writer.append(op.getInputOp(), op.getValue());
955 }
956 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
957  Value result = op.getValue();
958  std::optional<uint32_t> index = op.getIndex();
959  writer.append(OpCode::GetOperands,
960  index.value_or(std::numeric_limits<uint32_t>::max()),
961  op.getInputOp());
962  if (isa<pdl::RangeType>(result.getType()))
963  writer.append(getRangeStorageIndex(result));
964  else
966  writer.append(result);
967 }
968 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
969  uint32_t index = op.getIndex();
970  if (index < 4)
971  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
972  else
973  writer.append(OpCode::GetResultN, index);
974  writer.append(op.getInputOp(), op.getValue());
975 }
976 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
977  Value result = op.getValue();
978  std::optional<uint32_t> index = op.getIndex();
979  writer.append(OpCode::GetResults,
980  index.value_or(std::numeric_limits<uint32_t>::max()),
981  op.getInputOp());
982  if (isa<pdl::RangeType>(result.getType()))
983  writer.append(getRangeStorageIndex(result));
984  else
986  writer.append(result);
987 }
988 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
989  Value operations = op.getOperations();
990  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
991  writer.append(OpCode::GetUsers, operations, rangeIndex);
992  writer.appendPDLValue(op.getValue());
993 }
994 void Generator::generate(pdl_interp::GetValueTypeOp op,
995  ByteCodeWriter &writer) {
996  if (isa<pdl::RangeType>(op.getType())) {
997  Value result = op.getResult();
998  writer.append(OpCode::GetValueRangeTypes, result,
999  getRangeStorageIndex(result), op.getValue());
1000  } else {
1001  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1002  }
1003 }
1004 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1005  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1006 }
1007 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1008  ByteCodeField patternIndex = patterns.size();
1009  patterns.emplace_back(PDLByteCodePattern::create(
1010  op, configMap.lookup(op),
1011  rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1012  writer.append(OpCode::RecordMatch, patternIndex,
1013  SuccessorRange(op.getOperation()), op.getMatchedOps());
1014  writer.appendPDLValueList(op.getInputs());
1015 }
1016 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1017  writer.append(OpCode::ReplaceOp, op.getInputOp());
1018  writer.appendPDLValueList(op.getReplValues());
1019 }
1020 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1021  ByteCodeWriter &writer) {
1022  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1023  op.getCaseValuesAttr(), op.getSuccessors());
1024 }
1025 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1026  ByteCodeWriter &writer) {
1027  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1028  op.getCaseValuesAttr(), op.getSuccessors());
1029 }
1030 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1031  ByteCodeWriter &writer) {
1032  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
1033  return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1034  });
1035  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1036  op.getSuccessors());
1037 }
1038 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1039  ByteCodeWriter &writer) {
1040  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1041  op.getCaseValuesAttr(), op.getSuccessors());
1042 }
1043 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1044  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1045  op.getSuccessors());
1046 }
1047 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1048  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1049  op.getSuccessors());
1050 }
1051 
1052 //===----------------------------------------------------------------------===//
1053 // PDLByteCode
1054 //===----------------------------------------------------------------------===//
1055 
1056 PDLByteCode::PDLByteCode(
1057  ModuleOp module, SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1059  llvm::StringMap<PDLConstraintFunction> constraintFns,
1060  llvm::StringMap<PDLRewriteFunction> rewriteFns)
1061  : configs(std::move(configs)) {
1062  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1063  rewriterByteCode, patterns, maxValueMemoryIndex,
1064  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1065  maxLoopLevel, constraintFns, rewriteFns, configMap);
1066  generator.generate(module);
1067 
1068  // Initialize the external functions.
1069  for (auto &it : constraintFns)
1070  constraintFunctions.push_back(std::move(it.second));
1071  for (auto &it : rewriteFns)
1072  rewriteFunctions.push_back(std::move(it.second));
1073 }
1074 
1075 /// Initialize the given state such that it can be used to execute the current
1076 /// bytecode.
1077 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
1078  state.memory.resize(maxValueMemoryIndex, nullptr);
1079  state.opRangeMemory.resize(maxOpRangeCount);
1080  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1081  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1082  state.loopIndex.resize(maxLoopLevel, 0);
1083  state.currentPatternBenefits.reserve(patterns.size());
1084  for (const PDLByteCodePattern &pattern : patterns)
1085  state.currentPatternBenefits.push_back(pattern.getBenefit());
1086 }
1087 
1088 //===----------------------------------------------------------------------===//
1089 // ByteCode Execution
1090 //===----------------------------------------------------------------------===//
1091 
1092 namespace {
1093 /// This class is an instantiation of the PDLResultList that provides access to
1094 /// the returned results. This API is not on `PDLResultList` to avoid
1095 /// overexposing access to information specific solely to the ByteCode.
1096 class ByteCodeRewriteResultList : public PDLResultList {
1097 public:
1098  ByteCodeRewriteResultList(unsigned maxNumResults)
1099  : PDLResultList(maxNumResults) {}
1100 
1101  /// Return the list of PDL results.
1102  MutableArrayRef<PDLValue> getResults() { return results; }
1103 
1104  /// Return the type ranges allocated by this list.
1105  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1106  return allocatedTypeRanges;
1107  }
1108 
1109  /// Return the value ranges allocated by this list.
1110  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1111  return allocatedValueRanges;
1112  }
1113 };
1114 
1115 /// This class provides support for executing a bytecode stream.
1116 class ByteCodeExecutor {
1117 public:
1118  ByteCodeExecutor(
1119  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1120  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1121  MutableArrayRef<TypeRange> typeRangeMemory,
1122  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1123  MutableArrayRef<ValueRange> valueRangeMemory,
1124  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1125  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1127  ArrayRef<PatternBenefit> currentPatternBenefits,
1129  ArrayRef<PDLConstraintFunction> constraintFunctions,
1130  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1131  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1132  typeRangeMemory(typeRangeMemory),
1133  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1134  valueRangeMemory(valueRangeMemory),
1135  allocatedValueRangeMemory(allocatedValueRangeMemory),
1136  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1137  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1138  constraintFunctions(constraintFunctions),
1139  rewriteFunctions(rewriteFunctions) {}
1140 
1141  /// Start executing the code at the current bytecode index. `matches` is an
1142  /// optional field provided when this function is executed in a matching
1143  /// context.
1144  LogicalResult
1145  execute(PatternRewriter &rewriter,
1146  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1147  std::optional<Location> mainRewriteLoc = {});
1148 
1149 private:
1150  /// Internal implementation of executing each of the bytecode commands.
1151  void executeApplyConstraint(PatternRewriter &rewriter);
1152  LogicalResult executeApplyRewrite(PatternRewriter &rewriter);
1153  void executeAreEqual();
1154  void executeAreRangesEqual();
1155  void executeBranch();
1156  void executeCheckOperandCount();
1157  void executeCheckOperationName();
1158  void executeCheckResultCount();
1159  void executeCheckTypes();
1160  void executeContinue();
1161  void executeCreateConstantTypeRange();
1162  void executeCreateOperation(PatternRewriter &rewriter,
1163  Location mainRewriteLoc);
1164  template <typename T>
1165  void executeDynamicCreateRange(StringRef type);
1166  void executeEraseOp(PatternRewriter &rewriter);
1167  template <typename T, typename Range, PDLValue::Kind kind>
1168  void executeExtract();
1169  void executeFinalize();
1170  void executeForEach();
1171  void executeGetAttribute();
1172  void executeGetAttributeType();
1173  void executeGetDefiningOp();
1174  void executeGetOperand(unsigned index);
1175  void executeGetOperands();
1176  void executeGetResult(unsigned index);
1177  void executeGetResults();
1178  void executeGetUsers();
1179  void executeGetValueType();
1180  void executeGetValueRangeTypes();
1181  void executeIsNotNull();
1182  void executeRecordMatch(PatternRewriter &rewriter,
1184  void executeReplaceOp(PatternRewriter &rewriter);
1185  void executeSwitchAttribute();
1186  void executeSwitchOperandCount();
1187  void executeSwitchOperationName();
1188  void executeSwitchResultCount();
1189  void executeSwitchType();
1190  void executeSwitchTypes();
1191  void processNativeFunResults(ByteCodeRewriteResultList &results,
1192  unsigned numResults,
1193  LogicalResult &rewriteResult);
1194 
1195  /// Pushes a code iterator to the stack.
1196  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1197 
1198  /// Pops a code iterator from the stack, returning true on success.
1199  void popCodeIt() {
1200  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1201  curCodeIt = resumeCodeIt.pop_back_val();
1202  }
1203 
1204  /// Return the bytecode iterator at the start of the current op code.
1205  const ByteCodeField *getPrevCodeIt() const {
1206  LLVM_DEBUG({
1207  // Account for the op code and the Location stored inline.
1208  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1209  });
1210 
1211  // Account for the op code only.
1212  return curCodeIt - 1;
1213  }
1214 
1215  /// Read a value from the bytecode buffer, optionally skipping a certain
1216  /// number of prefix values. These methods always update the buffer to point
1217  /// to the next field after the read data.
1218  template <typename T = ByteCodeField>
1219  T read(size_t skipN = 0) {
1220  curCodeIt += skipN;
1221  return readImpl<T>();
1222  }
1223  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1224 
1225  /// Read a list of values from the bytecode buffer.
1226  template <typename ValueT, typename T>
1227  void readList(SmallVectorImpl<T> &list) {
1228  list.clear();
1229  for (unsigned i = 0, e = read(); i != e; ++i)
1230  list.push_back(read<ValueT>());
1231  }
1232 
1233  /// Read a list of values from the bytecode buffer. The values may be encoded
1234  /// either as a single element or a range of elements.
1235  void readList(SmallVectorImpl<Type> &list) {
1236  for (unsigned i = 0, e = read(); i != e; ++i) {
1237  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1238  list.push_back(read<Type>());
1239  } else {
1240  TypeRange *values = read<TypeRange *>();
1241  list.append(values->begin(), values->end());
1242  }
1243  }
1244  }
1245  void readList(SmallVectorImpl<Value> &list) {
1246  for (unsigned i = 0, e = read(); i != e; ++i) {
1247  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1248  list.push_back(read<Value>());
1249  } else {
1250  ValueRange *values = read<ValueRange *>();
1251  list.append(values->begin(), values->end());
1252  }
1253  }
1254  }
1255 
1256  /// Read a value stored inline as a pointer.
1257  template <typename T>
1258  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1259  readInline() {
1260  const void *pointer;
1261  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1262  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1263  return T::getFromOpaquePointer(pointer);
1264  }
1265 
1266  void skip(size_t skipN) { curCodeIt += skipN; }
1267 
1268  /// Jump to a specific successor based on a predicate value.
1269  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1270  /// Jump to a specific successor based on a destination index.
1271  void selectJump(size_t destIndex) {
1272  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1273  }
1274 
1275  /// Handle a switch operation with the provided value and cases.
1276  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1277  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1278  LLVM_DEBUG({
1279  llvm::dbgs() << " * Value: " << value << "\n"
1280  << " * Cases: ";
1281  llvm::interleaveComma(cases, llvm::dbgs());
1282  llvm::dbgs() << "\n";
1283  });
1284 
1285  // Check to see if the attribute value is within the case list. Jump to
1286  // the correct successor index based on the result.
1287  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1288  if (cmp(*it, value))
1289  return selectJump(size_t((it - cases.begin()) + 1));
1290  selectJump(size_t(0));
1291  }
1292 
1293  /// Store a pointer to memory.
1294  void storeToMemory(unsigned index, const void *value) {
1295  memory[index] = value;
1296  }
1297 
1298  /// Store a value to memory as an opaque pointer.
1299  template <typename T>
1300  std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1301  storeToMemory(unsigned index, T value) {
1302  memory[index] = value.getAsOpaquePointer();
1303  }
1304 
1305  /// Internal implementation of reading various data types from the bytecode
1306  /// stream.
1307  template <typename T>
1308  const void *readFromMemory() {
1309  size_t index = *curCodeIt++;
1310 
1311  // If this type is an SSA value, it can only be stored in non-const memory.
1312  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1313  Value>::value ||
1314  index < memory.size())
1315  return memory[index];
1316 
1317  // Otherwise, if this index is not inbounds it is uniqued.
1318  return uniquedMemory[index - memory.size()];
1319  }
1320  template <typename T>
1321  std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1322  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1323  }
1324  template <typename T>
1325  std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1326  T>
1327  readImpl() {
1328  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1329  }
1330  template <typename T>
1331  std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1332  switch (read<PDLValue::Kind>()) {
1333  case PDLValue::Kind::Attribute:
1334  return read<Attribute>();
1335  case PDLValue::Kind::Operation:
1336  return read<Operation *>();
1337  case PDLValue::Kind::Type:
1338  return read<Type>();
1339  case PDLValue::Kind::Value:
1340  return read<Value>();
1341  case PDLValue::Kind::TypeRange:
1342  return read<TypeRange *>();
1343  case PDLValue::Kind::ValueRange:
1344  return read<ValueRange *>();
1345  }
1346  llvm_unreachable("unhandled PDLValue::Kind");
1347  }
1348  template <typename T>
1349  std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1350  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1351  "unexpected ByteCode address size");
1352  ByteCodeAddr result;
1353  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1354  curCodeIt += 2;
1355  return result;
1356  }
1357  template <typename T>
1358  std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1359  return *curCodeIt++;
1360  }
1361  template <typename T>
1362  std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1363  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1364  }
1365 
1366  /// Assign the given range to the given memory index. This allocates a new
1367  /// range object if necessary.
1368  template <typename RangeT, typename T = llvm::detail::ValueOfRange<RangeT>>
1369  void assignRangeToMemory(RangeT &&range, unsigned memIndex,
1370  unsigned rangeIndex) {
1371  // Utility functor used to type-erase the assignment.
1372  auto assignRange = [&](auto &allocatedRangeMemory, auto &rangeMemory) {
1373  // If the input range is empty, we don't need to allocate anything.
1374  if (range.empty()) {
1375  rangeMemory[rangeIndex] = {};
1376  } else {
1377  // Allocate a buffer for this type range.
1378  llvm::OwningArrayRef<T> storage(llvm::size(range));
1379  llvm::copy(range, storage.begin());
1380 
1381  // Assign this to the range slot and use the range as the value for the
1382  // memory index.
1383  allocatedRangeMemory.emplace_back(std::move(storage));
1384  rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1385  }
1386  memory[memIndex] = &rangeMemory[rangeIndex];
1387  };
1388 
1389  // Dispatch based on the concrete range type.
1390  if constexpr (std::is_same_v<T, Type>) {
1391  return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1392  } else if constexpr (std::is_same_v<T, Value>) {
1393  return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1394  } else {
1395  llvm_unreachable("unhandled range type");
1396  }
1397  }
1398 
1399  /// The underlying bytecode buffer.
1400  const ByteCodeField *curCodeIt;
1401 
1402  /// The stack of bytecode positions at which to resume operation.
1404 
1405  /// The current execution memory.
1407  MutableArrayRef<OwningOpRange> opRangeMemory;
1408  MutableArrayRef<TypeRange> typeRangeMemory;
1409  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1410  MutableArrayRef<ValueRange> valueRangeMemory;
1411  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1412 
1413  /// The current loop indices.
1414  MutableArrayRef<unsigned> loopIndex;
1415 
1416  /// References to ByteCode data necessary for execution.
1417  ArrayRef<const void *> uniquedMemory;
1419  ArrayRef<PatternBenefit> currentPatternBenefits;
1421  ArrayRef<PDLConstraintFunction> constraintFunctions;
1422  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1423 };
1424 } // namespace
1425 
1426 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1427  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1428  ByteCodeField fun_idx = read();
1430  readList<PDLValue>(args);
1431 
1432  LLVM_DEBUG({
1433  llvm::dbgs() << " * Arguments: ";
1434  llvm::interleaveComma(args, llvm::dbgs());
1435  llvm::dbgs() << "\n";
1436  });
1437 
1438  ByteCodeField isNegated = read();
1439  LLVM_DEBUG({
1440  llvm::dbgs() << " * isNegated: " << isNegated << "\n";
1441  llvm::interleaveComma(args, llvm::dbgs());
1442  });
1443 
1444  ByteCodeField numResults = read();
1445  const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1446  ByteCodeRewriteResultList results(numResults);
1447  LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1448  [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults();
1449  LLVM_DEBUG({
1450  if (succeeded(rewriteResult)) {
1451  llvm::dbgs() << " * Constraint succeeded\n";
1452  llvm::dbgs() << " * Results: ";
1453  llvm::interleaveComma(constraintResults, llvm::dbgs());
1454  llvm::dbgs() << "\n";
1455  } else {
1456  llvm::dbgs() << " * Constraint failed\n";
1457  }
1458  });
1459  assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1460  "native PDL rewrite function succeeded but returned "
1461  "unexpected number of results");
1462  processNativeFunResults(results, numResults, rewriteResult);
1463 
1464  // Depending on the constraint jump to the proper destination.
1465  selectJump(isNegated != succeeded(rewriteResult));
1466 }
1467 
1468 LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1469  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1470  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1472  readList<PDLValue>(args);
1473 
1474  LLVM_DEBUG({
1475  llvm::dbgs() << " * Arguments: ";
1476  llvm::interleaveComma(args, llvm::dbgs());
1477  });
1478 
1479  // Execute the rewrite function.
1480  ByteCodeField numResults = read();
1481  ByteCodeRewriteResultList results(numResults);
1482  LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1483 
1484  assert(results.getResults().size() == numResults &&
1485  "native PDL rewrite function returned unexpected number of results");
1486 
1487  processNativeFunResults(results, numResults, rewriteResult);
1488 
1489  if (failed(rewriteResult)) {
1490  LLVM_DEBUG(llvm::dbgs() << " - Failed");
1491  return failure();
1492  }
1493  return success();
1494 }
1495 
1496 void ByteCodeExecutor::processNativeFunResults(
1497  ByteCodeRewriteResultList &results, unsigned numResults,
1498  LogicalResult &rewriteResult) {
1499  if (failed(rewriteResult)) {
1500  // Skip the according number of values on the buffer on failure and exit
1501  // early as there are no results to process.
1502  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1503  const PDLValue::Kind resultKind = read<PDLValue::Kind>();
1504  if (resultKind == PDLValue::Kind::TypeRange ||
1505  resultKind == PDLValue::Kind::ValueRange) {
1506  skip(2);
1507  } else {
1508  skip(1);
1509  }
1510  }
1511  return;
1512  }
1513 
1514  // Store the results in the bytecode memory
1515  for (unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1516  PDLValue::Kind resultKind = read<PDLValue::Kind>();
1517  (void)resultKind;
1518  PDLValue result = results.getResults()[resultIdx];
1519  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1520  assert(result.getKind() == resultKind &&
1521  "native PDL rewrite function returned an unexpected type of "
1522  "result");
1523  // If the result is a range, we need to copy it over to the bytecodes
1524  // range memory.
1525  if (std::optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1526  unsigned rangeIndex = read();
1527  typeRangeMemory[rangeIndex] = *typeRange;
1528  memory[read()] = &typeRangeMemory[rangeIndex];
1529  } else if (std::optional<ValueRange> valueRange =
1530  result.dyn_cast<ValueRange>()) {
1531  unsigned rangeIndex = read();
1532  valueRangeMemory[rangeIndex] = *valueRange;
1533  memory[read()] = &valueRangeMemory[rangeIndex];
1534  } else {
1535  memory[read()] = result.getAsOpaquePointer();
1536  }
1537  }
1538 
1539  // Copy over any underlying storage allocated for result ranges.
1540  for (auto &it : results.getAllocatedTypeRanges())
1541  allocatedTypeRangeMemory.push_back(std::move(it));
1542  for (auto &it : results.getAllocatedValueRanges())
1543  allocatedValueRangeMemory.push_back(std::move(it));
1544 }
1545 
1546 void ByteCodeExecutor::executeAreEqual() {
1547  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1548  const void *lhs = read<const void *>();
1549  const void *rhs = read<const void *>();
1550 
1551  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1552  selectJump(lhs == rhs);
1553 }
1554 
1555 void ByteCodeExecutor::executeAreRangesEqual() {
1556  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1557  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1558  const void *lhs = read<const void *>();
1559  const void *rhs = read<const void *>();
1560 
1561  switch (valueKind) {
1562  case PDLValue::Kind::TypeRange: {
1563  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1564  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1565  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1566  selectJump(*lhsRange == *rhsRange);
1567  break;
1568  }
1569  case PDLValue::Kind::ValueRange: {
1570  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1571  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1572  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1573  selectJump(*lhsRange == *rhsRange);
1574  break;
1575  }
1576  default:
1577  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1578  }
1579 }
1580 
1581 void ByteCodeExecutor::executeBranch() {
1582  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1583  curCodeIt = &code[read<ByteCodeAddr>()];
1584 }
1585 
1586 void ByteCodeExecutor::executeCheckOperandCount() {
1587  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1588  Operation *op = read<Operation *>();
1589  uint32_t expectedCount = read<uint32_t>();
1590  bool compareAtLeast = read();
1591 
1592  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1593  << " * Expected: " << expectedCount << "\n"
1594  << " * Comparator: "
1595  << (compareAtLeast ? ">=" : "==") << "\n");
1596  if (compareAtLeast)
1597  selectJump(op->getNumOperands() >= expectedCount);
1598  else
1599  selectJump(op->getNumOperands() == expectedCount);
1600 }
1601 
1602 void ByteCodeExecutor::executeCheckOperationName() {
1603  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1604  Operation *op = read<Operation *>();
1605  OperationName expectedName = read<OperationName>();
1606 
1607  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1608  << " * Expected: \"" << expectedName << "\"\n");
1609  selectJump(op->getName() == expectedName);
1610 }
1611 
1612 void ByteCodeExecutor::executeCheckResultCount() {
1613  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1614  Operation *op = read<Operation *>();
1615  uint32_t expectedCount = read<uint32_t>();
1616  bool compareAtLeast = read();
1617 
1618  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1619  << " * Expected: " << expectedCount << "\n"
1620  << " * Comparator: "
1621  << (compareAtLeast ? ">=" : "==") << "\n");
1622  if (compareAtLeast)
1623  selectJump(op->getNumResults() >= expectedCount);
1624  else
1625  selectJump(op->getNumResults() == expectedCount);
1626 }
1627 
1628 void ByteCodeExecutor::executeCheckTypes() {
1629  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1630  TypeRange *lhs = read<TypeRange *>();
1631  Attribute rhs = read<Attribute>();
1632  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1633 
1634  selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1635 }
1636 
1637 void ByteCodeExecutor::executeContinue() {
1638  ByteCodeField level = read();
1639  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1640  << " * Level: " << level << "\n");
1641  ++loopIndex[level];
1642  popCodeIt();
1643 }
1644 
1645 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1646  LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n");
1647  unsigned memIndex = read();
1648  unsigned rangeIndex = read();
1649  ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1650 
1651  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1652  assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1653  rangeIndex);
1654 }
1655 
1656 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1657  Location mainRewriteLoc) {
1658  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1659 
1660  unsigned memIndex = read();
1661  OperationState state(mainRewriteLoc, read<OperationName>());
1662  readList(state.operands);
1663  for (unsigned i = 0, e = read(); i != e; ++i) {
1664  StringAttr name = read<StringAttr>();
1665  if (Attribute attr = read<Attribute>())
1666  state.addAttribute(name, attr);
1667  }
1668 
1669  // Read in the result types. If the "size" is the sentinel value, this
1670  // indicates that the result types should be inferred.
1671  unsigned numResults = read();
1672  if (numResults == kInferTypesMarker) {
1673  InferTypeOpInterface::Concept *inferInterface =
1674  state.name.getInterface<InferTypeOpInterface>();
1675  assert(inferInterface &&
1676  "expected operation to provide InferTypeOpInterface");
1677 
1678  // TODO: Handle failure.
1679  if (failed(inferInterface->inferReturnTypes(
1680  state.getContext(), state.location, state.operands,
1681  state.attributes.getDictionary(state.getContext()),
1682  state.getRawProperties(), state.regions, state.types)))
1683  return;
1684  } else {
1685  // Otherwise, this is a fixed number of results.
1686  for (unsigned i = 0; i != numResults; ++i) {
1687  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1688  state.types.push_back(read<Type>());
1689  } else {
1690  TypeRange *resultTypes = read<TypeRange *>();
1691  state.types.append(resultTypes->begin(), resultTypes->end());
1692  }
1693  }
1694  }
1695 
1696  Operation *resultOp = rewriter.create(state);
1697  memory[memIndex] = resultOp;
1698 
1699  LLVM_DEBUG({
1700  llvm::dbgs() << " * Attributes: "
1701  << state.attributes.getDictionary(state.getContext())
1702  << "\n * Operands: ";
1703  llvm::interleaveComma(state.operands, llvm::dbgs());
1704  llvm::dbgs() << "\n * Result Types: ";
1705  llvm::interleaveComma(state.types, llvm::dbgs());
1706  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1707  });
1708 }
1709 
1710 template <typename T>
1711 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1712  LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n");
1713  unsigned memIndex = read();
1714  unsigned rangeIndex = read();
1715  SmallVector<T> values;
1716  readList(values);
1717 
1718  LLVM_DEBUG({
1719  llvm::dbgs() << "\n * " << type << "s: ";
1720  llvm::interleaveComma(values, llvm::dbgs());
1721  llvm::dbgs() << "\n";
1722  });
1723 
1724  assignRangeToMemory(values, memIndex, rangeIndex);
1725 }
1726 
1727 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1728  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1729  Operation *op = read<Operation *>();
1730 
1731  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1732  rewriter.eraseOp(op);
1733 }
1734 
1735 template <typename T, typename Range, PDLValue::Kind kind>
1736 void ByteCodeExecutor::executeExtract() {
1737  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1738  Range *range = read<Range *>();
1739  unsigned index = read<uint32_t>();
1740  unsigned memIndex = read();
1741 
1742  if (!range) {
1743  memory[memIndex] = nullptr;
1744  return;
1745  }
1746 
1747  T result = index < range->size() ? (*range)[index] : T();
1748  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1749  << " * Index: " << index << "\n"
1750  << " * Result: " << result << "\n");
1751  storeToMemory(memIndex, result);
1752 }
1753 
1754 void ByteCodeExecutor::executeFinalize() {
1755  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1756 }
1757 
1758 void ByteCodeExecutor::executeForEach() {
1759  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1760  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1761  unsigned rangeIndex = read();
1762  unsigned memIndex = read();
1763  const void *value = nullptr;
1764 
1765  switch (read<PDLValue::Kind>()) {
1766  case PDLValue::Kind::Operation: {
1767  unsigned &index = loopIndex[read()];
1768  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1769  assert(index <= array.size() && "iterated past the end");
1770  if (index < array.size()) {
1771  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1772  value = array[index];
1773  break;
1774  }
1775 
1776  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1777  index = 0;
1778  selectJump(size_t(0));
1779  return;
1780  }
1781  default:
1782  llvm_unreachable("unexpected `ForEach` value kind");
1783  }
1784 
1785  // Store the iterate value and the stack address.
1786  memory[memIndex] = value;
1787  pushCodeIt(prevCodeIt);
1788 
1789  // Skip over the successor (we will enter the body of the loop).
1790  read<ByteCodeAddr>();
1791 }
1792 
1793 void ByteCodeExecutor::executeGetAttribute() {
1794  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1795  unsigned memIndex = read();
1796  Operation *op = read<Operation *>();
1797  StringAttr attrName = read<StringAttr>();
1798  Attribute attr = op->getAttr(attrName);
1799 
1800  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1801  << " * Attribute: " << attrName << "\n"
1802  << " * Result: " << attr << "\n");
1803  memory[memIndex] = attr.getAsOpaquePointer();
1804 }
1805 
1806 void ByteCodeExecutor::executeGetAttributeType() {
1807  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1808  unsigned memIndex = read();
1809  Attribute attr = read<Attribute>();
1810  Type type;
1811  if (auto typedAttr = dyn_cast<TypedAttr>(attr))
1812  type = typedAttr.getType();
1813 
1814  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1815  << " * Result: " << type << "\n");
1816  memory[memIndex] = type.getAsOpaquePointer();
1817 }
1818 
1819 void ByteCodeExecutor::executeGetDefiningOp() {
1820  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1821  unsigned memIndex = read();
1822  Operation *op = nullptr;
1823  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1824  Value value = read<Value>();
1825  if (value)
1826  op = value.getDefiningOp();
1827  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1828  } else {
1829  ValueRange *values = read<ValueRange *>();
1830  if (values && !values->empty()) {
1831  op = values->front().getDefiningOp();
1832  }
1833  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1834  }
1835 
1836  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1837  memory[memIndex] = op;
1838 }
1839 
1840 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1841  Operation *op = read<Operation *>();
1842  unsigned memIndex = read();
1843  Value operand =
1844  index < op->getNumOperands() ? op->getOperand(index) : Value();
1845 
1846  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1847  << " * Index: " << index << "\n"
1848  << " * Result: " << operand << "\n");
1849  memory[memIndex] = operand.getAsOpaquePointer();
1850 }
1851 
1852 /// This function is the internal implementation of `GetResults` and
1853 /// `GetOperands` that provides support for extracting a value range from the
1854 /// given operation.
1855 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1856 static void *
1857 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1858  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1859  MutableArrayRef<ValueRange> valueRangeMemory) {
1860  // Check for the sentinel index that signals that all values should be
1861  // returned.
1862  if (index == std::numeric_limits<uint32_t>::max()) {
1863  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1864  // `values` is already the full value range.
1865 
1866  // Otherwise, check to see if this operation uses AttrSizedSegments.
1867  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1868  LLVM_DEBUG(llvm::dbgs()
1869  << " * Extracting values from `" << attrSizedSegments << "`\n");
1870 
1871  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1872  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1873  return nullptr;
1874 
1875  ArrayRef<int32_t> segments = segmentAttr;
1876  unsigned startIndex =
1877  std::accumulate(segments.begin(), segments.begin() + index, 0);
1878  values = values.slice(startIndex, *std::next(segments.begin(), index));
1879 
1880  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1881  << *std::next(segments.begin(), index) << "]\n");
1882 
1883  // Otherwise, assume this is the last operand group of the operation.
1884  // FIXME: We currently don't support operations with
1885  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1886  // have a way to detect it's presence.
1887  } else if (values.size() >= index) {
1888  LLVM_DEBUG(llvm::dbgs()
1889  << " * Treating values as trailing variadic range\n");
1890  values = values.drop_front(index);
1891 
1892  // If we couldn't detect a way to compute the values, bail out.
1893  } else {
1894  return nullptr;
1895  }
1896 
1897  // If the range index is valid, we are returning a range.
1898  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1899  valueRangeMemory[rangeIndex] = values;
1900  return &valueRangeMemory[rangeIndex];
1901  }
1902 
1903  // If a range index wasn't provided, the range is required to be non-variadic.
1904  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1905 }
1906 
1907 void ByteCodeExecutor::executeGetOperands() {
1908  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1909  unsigned index = read<uint32_t>();
1910  Operation *op = read<Operation *>();
1911  ByteCodeField rangeIndex = read();
1912 
1913  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1914  op->getOperands(), op, index, rangeIndex, "operandSegmentSizes",
1915  valueRangeMemory);
1916  if (!result)
1917  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1918  memory[read()] = result;
1919 }
1920 
1921 void ByteCodeExecutor::executeGetResult(unsigned index) {
1922  Operation *op = read<Operation *>();
1923  unsigned memIndex = read();
1924  OpResult result =
1925  index < op->getNumResults() ? op->getResult(index) : OpResult();
1926 
1927  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1928  << " * Index: " << index << "\n"
1929  << " * Result: " << result << "\n");
1930  memory[memIndex] = result.getAsOpaquePointer();
1931 }
1932 
1933 void ByteCodeExecutor::executeGetResults() {
1934  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1935  unsigned index = read<uint32_t>();
1936  Operation *op = read<Operation *>();
1937  ByteCodeField rangeIndex = read();
1938 
1939  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1940  op->getResults(), op, index, rangeIndex, "resultSegmentSizes",
1941  valueRangeMemory);
1942  if (!result)
1943  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1944  memory[read()] = result;
1945 }
1946 
1947 void ByteCodeExecutor::executeGetUsers() {
1948  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1949  unsigned memIndex = read();
1950  unsigned rangeIndex = read();
1951  OwningOpRange &range = opRangeMemory[rangeIndex];
1952  memory[memIndex] = &range;
1953 
1954  range = OwningOpRange();
1955  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1956  // Read the value.
1957  Value value = read<Value>();
1958  if (!value)
1959  return;
1960  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1961 
1962  // Extract the users of a single value.
1963  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1964  llvm::copy(value.getUsers(), range.begin());
1965  } else {
1966  // Read a range of values.
1967  ValueRange *values = read<ValueRange *>();
1968  if (!values)
1969  return;
1970  LLVM_DEBUG({
1971  llvm::dbgs() << " * Values (" << values->size() << "): ";
1972  llvm::interleaveComma(*values, llvm::dbgs());
1973  llvm::dbgs() << "\n";
1974  });
1975 
1976  // Extract all the users of a range of values.
1978  for (Value value : *values)
1979  users.append(value.user_begin(), value.user_end());
1980  range = OwningOpRange(users.size());
1981  llvm::copy(users, range.begin());
1982  }
1983 
1984  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1985 }
1986 
1987 void ByteCodeExecutor::executeGetValueType() {
1988  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1989  unsigned memIndex = read();
1990  Value value = read<Value>();
1991  Type type = value ? value.getType() : Type();
1992 
1993  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1994  << " * Result: " << type << "\n");
1995  memory[memIndex] = type.getAsOpaquePointer();
1996 }
1997 
1998 void ByteCodeExecutor::executeGetValueRangeTypes() {
1999  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
2000  unsigned memIndex = read();
2001  unsigned rangeIndex = read();
2002  ValueRange *values = read<ValueRange *>();
2003  if (!values) {
2004  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
2005  memory[memIndex] = nullptr;
2006  return;
2007  }
2008 
2009  LLVM_DEBUG({
2010  llvm::dbgs() << " * Values (" << values->size() << "): ";
2011  llvm::interleaveComma(*values, llvm::dbgs());
2012  llvm::dbgs() << "\n * Result: ";
2013  llvm::interleaveComma(values->getType(), llvm::dbgs());
2014  llvm::dbgs() << "\n";
2015  });
2016  typeRangeMemory[rangeIndex] = values->getType();
2017  memory[memIndex] = &typeRangeMemory[rangeIndex];
2018 }
2019 
2020 void ByteCodeExecutor::executeIsNotNull() {
2021  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
2022  const void *value = read<const void *>();
2023 
2024  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
2025  selectJump(value != nullptr);
2026 }
2027 
2028 void ByteCodeExecutor::executeRecordMatch(
2029  PatternRewriter &rewriter,
2031  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
2032  unsigned patternIndex = read();
2033  PatternBenefit benefit = currentPatternBenefits[patternIndex];
2034  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2035 
2036  // If the benefit of the pattern is impossible, skip the processing of the
2037  // rest of the pattern.
2038  if (benefit.isImpossibleToMatch()) {
2039  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
2040  curCodeIt = dest;
2041  return;
2042  }
2043 
2044  // Create a fused location containing the locations of each of the
2045  // operations used in the match. This will be used as the location for
2046  // created operations during the rewrite that don't already have an
2047  // explicit location set.
2048  unsigned numMatchLocs = read();
2049  SmallVector<Location, 4> matchLocs;
2050  matchLocs.reserve(numMatchLocs);
2051  for (unsigned i = 0; i != numMatchLocs; ++i)
2052  matchLocs.push_back(read<Operation *>()->getLoc());
2053  Location matchLoc = rewriter.getFusedLoc(matchLocs);
2054 
2055  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
2056  << " * Location: " << matchLoc << "\n");
2057  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2058  PDLByteCode::MatchResult &match = matches.back();
2059 
2060  // Record all of the inputs to the match. If any of the inputs are ranges, we
2061  // will also need to remap the range pointer to memory stored in the match
2062  // state.
2063  unsigned numInputs = read();
2064  match.values.reserve(numInputs);
2065  match.typeRangeValues.reserve(numInputs);
2066  match.valueRangeValues.reserve(numInputs);
2067  for (unsigned i = 0; i < numInputs; ++i) {
2068  switch (read<PDLValue::Kind>()) {
2069  case PDLValue::Kind::TypeRange:
2070  match.typeRangeValues.push_back(*read<TypeRange *>());
2071  match.values.push_back(&match.typeRangeValues.back());
2072  break;
2073  case PDLValue::Kind::ValueRange:
2074  match.valueRangeValues.push_back(*read<ValueRange *>());
2075  match.values.push_back(&match.valueRangeValues.back());
2076  break;
2077  default:
2078  match.values.push_back(read<const void *>());
2079  break;
2080  }
2081  }
2082  curCodeIt = dest;
2083 }
2084 
2085 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
2086  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
2087  Operation *op = read<Operation *>();
2089  readList(args);
2090 
2091  LLVM_DEBUG({
2092  llvm::dbgs() << " * Operation: " << *op << "\n"
2093  << " * Values: ";
2094  llvm::interleaveComma(args, llvm::dbgs());
2095  llvm::dbgs() << "\n";
2096  });
2097  rewriter.replaceOp(op, args);
2098 }
2099 
2100 void ByteCodeExecutor::executeSwitchAttribute() {
2101  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
2102  Attribute value = read<Attribute>();
2103  ArrayAttr cases = read<ArrayAttr>();
2104  handleSwitch(value, cases);
2105 }
2106 
2107 void ByteCodeExecutor::executeSwitchOperandCount() {
2108  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
2109  Operation *op = read<Operation *>();
2110  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2111 
2112  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2113  handleSwitch(op->getNumOperands(), cases);
2114 }
2115 
2116 void ByteCodeExecutor::executeSwitchOperationName() {
2117  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
2118  OperationName value = read<Operation *>()->getName();
2119  size_t caseCount = read();
2120 
2121  // The operation names are stored in-line, so to print them out for
2122  // debugging purposes we need to read the array before executing the
2123  // switch so that we can display all of the possible values.
2124  LLVM_DEBUG({
2125  const ByteCodeField *prevCodeIt = curCodeIt;
2126  llvm::dbgs() << " * Value: " << value << "\n"
2127  << " * Cases: ";
2128  llvm::interleaveComma(
2129  llvm::map_range(llvm::seq<size_t>(0, caseCount),
2130  [&](size_t) { return read<OperationName>(); }),
2131  llvm::dbgs());
2132  llvm::dbgs() << "\n";
2133  curCodeIt = prevCodeIt;
2134  });
2135 
2136  // Try to find the switch value within any of the cases.
2137  for (size_t i = 0; i != caseCount; ++i) {
2138  if (read<OperationName>() == value) {
2139  curCodeIt += (caseCount - i - 1);
2140  return selectJump(i + 1);
2141  }
2142  }
2143  selectJump(size_t(0));
2144 }
2145 
2146 void ByteCodeExecutor::executeSwitchResultCount() {
2147  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
2148  Operation *op = read<Operation *>();
2149  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2150 
2151  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
2152  handleSwitch(op->getNumResults(), cases);
2153 }
2154 
2155 void ByteCodeExecutor::executeSwitchType() {
2156  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2157  Type value = read<Type>();
2158  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2159  handleSwitch(value, cases);
2160 }
2161 
2162 void ByteCodeExecutor::executeSwitchTypes() {
2163  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2164  TypeRange *value = read<TypeRange *>();
2165  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2166  if (!value) {
2167  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2168  return selectJump(size_t(0));
2169  }
2170  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2171  return value == caseValue.getAsValueRange<TypeAttr>();
2172  });
2173 }
2174 
2175 LogicalResult
2176 ByteCodeExecutor::execute(PatternRewriter &rewriter,
2178  std::optional<Location> mainRewriteLoc) {
2179  while (true) {
2180  // Print the location of the operation being executed.
2181  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2182 
2183  OpCode opCode = static_cast<OpCode>(read());
2184  switch (opCode) {
2185  case ApplyConstraint:
2186  executeApplyConstraint(rewriter);
2187  break;
2188  case ApplyRewrite:
2189  if (failed(executeApplyRewrite(rewriter)))
2190  return failure();
2191  break;
2192  case AreEqual:
2193  executeAreEqual();
2194  break;
2195  case AreRangesEqual:
2196  executeAreRangesEqual();
2197  break;
2198  case Branch:
2199  executeBranch();
2200  break;
2201  case CheckOperandCount:
2202  executeCheckOperandCount();
2203  break;
2204  case CheckOperationName:
2205  executeCheckOperationName();
2206  break;
2207  case CheckResultCount:
2208  executeCheckResultCount();
2209  break;
2210  case CheckTypes:
2211  executeCheckTypes();
2212  break;
2213  case Continue:
2214  executeContinue();
2215  break;
2216  case CreateConstantTypeRange:
2217  executeCreateConstantTypeRange();
2218  break;
2219  case CreateOperation:
2220  executeCreateOperation(rewriter, *mainRewriteLoc);
2221  break;
2222  case CreateDynamicTypeRange:
2223  executeDynamicCreateRange<Type>("Type");
2224  break;
2225  case CreateDynamicValueRange:
2226  executeDynamicCreateRange<Value>("Value");
2227  break;
2228  case EraseOp:
2229  executeEraseOp(rewriter);
2230  break;
2231  case ExtractOp:
2232  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2233  break;
2234  case ExtractType:
2235  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2236  break;
2237  case ExtractValue:
2238  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2239  break;
2240  case Finalize:
2241  executeFinalize();
2242  LLVM_DEBUG(llvm::dbgs() << "\n");
2243  return success();
2244  case ForEach:
2245  executeForEach();
2246  break;
2247  case GetAttribute:
2248  executeGetAttribute();
2249  break;
2250  case GetAttributeType:
2251  executeGetAttributeType();
2252  break;
2253  case GetDefiningOp:
2254  executeGetDefiningOp();
2255  break;
2256  case GetOperand0:
2257  case GetOperand1:
2258  case GetOperand2:
2259  case GetOperand3: {
2260  unsigned index = opCode - GetOperand0;
2261  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2262  executeGetOperand(index);
2263  break;
2264  }
2265  case GetOperandN:
2266  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2267  executeGetOperand(read<uint32_t>());
2268  break;
2269  case GetOperands:
2270  executeGetOperands();
2271  break;
2272  case GetResult0:
2273  case GetResult1:
2274  case GetResult2:
2275  case GetResult3: {
2276  unsigned index = opCode - GetResult0;
2277  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2278  executeGetResult(index);
2279  break;
2280  }
2281  case GetResultN:
2282  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2283  executeGetResult(read<uint32_t>());
2284  break;
2285  case GetResults:
2286  executeGetResults();
2287  break;
2288  case GetUsers:
2289  executeGetUsers();
2290  break;
2291  case GetValueType:
2292  executeGetValueType();
2293  break;
2294  case GetValueRangeTypes:
2295  executeGetValueRangeTypes();
2296  break;
2297  case IsNotNull:
2298  executeIsNotNull();
2299  break;
2300  case RecordMatch:
2301  assert(matches &&
2302  "expected matches to be provided when executing the matcher");
2303  executeRecordMatch(rewriter, *matches);
2304  break;
2305  case ReplaceOp:
2306  executeReplaceOp(rewriter);
2307  break;
2308  case SwitchAttribute:
2309  executeSwitchAttribute();
2310  break;
2311  case SwitchOperandCount:
2312  executeSwitchOperandCount();
2313  break;
2314  case SwitchOperationName:
2315  executeSwitchOperationName();
2316  break;
2317  case SwitchResultCount:
2318  executeSwitchResultCount();
2319  break;
2320  case SwitchType:
2321  executeSwitchType();
2322  break;
2323  case SwitchTypes:
2324  executeSwitchTypes();
2325  break;
2326  }
2327  LLVM_DEBUG(llvm::dbgs() << "\n");
2328  }
2329 }
2330 
2331 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
2333  PDLByteCodeMutableState &state) const {
2334  // The first memory slot is always the root operation.
2335  state.memory[0] = op;
2336 
2337  // The matcher function always starts at code address 0.
2338  ByteCodeExecutor executor(
2339  matcherByteCode.data(), state.memory, state.opRangeMemory,
2340  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2341  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2342  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2343  constraintFunctions, rewriteFunctions);
2344  LogicalResult executeResult = executor.execute(rewriter, &matches);
2345  (void)executeResult;
2346  assert(succeeded(executeResult) && "unexpected matcher execution failure");
2347 
2348  // Order the found matches by benefit.
2349  llvm::stable_sort(matches,
2350  [](const MatchResult &lhs, const MatchResult &rhs) {
2351  return lhs.benefit > rhs.benefit;
2352  });
2353 }
2354 
2355 LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter,
2356  const MatchResult &match,
2357  PDLByteCodeMutableState &state) const {
2358  auto *configSet = match.pattern->getConfigSet();
2359  if (configSet)
2360  configSet->notifyRewriteBegin(rewriter);
2361 
2362  // The arguments of the rewrite function are stored at the start of the
2363  // memory buffer.
2364  llvm::copy(match.values, state.memory.begin());
2365 
2366  ByteCodeExecutor executor(
2367  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2368  state.opRangeMemory, state.typeRangeMemory,
2369  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2370  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2371  rewriterByteCode, state.currentPatternBenefits, patterns,
2372  constraintFunctions, rewriteFunctions);
2373  LogicalResult result =
2374  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2375 
2376  if (configSet)
2377  configSet->notifyRewriteEnd(rewriter);
2378 
2379  // If the rewrite failed, check if the pattern rewriter can recover. If it
2380  // can, we can signal to the pattern applicator to keep trying patterns. If it
2381  // doesn't, we need to bail. Bailing here should be fine, given that we have
2382  // no means to propagate such a failure to the user, and it also indicates a
2383  // bug in the user code (i.e. failable rewrites should not be used with
2384  // pattern rewriters that don't support it).
2385  if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) {
2386  LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting");
2387  llvm::report_fatal_error(
2388  "Native PDL Rewrite failed, but the pattern "
2389  "rewriter doesn't support recovery. Failable pattern rewrites should "
2390  "not be used with pattern rewriters that do not support them.");
2391  }
2392  return result;
2393 }
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1857
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:173
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1203::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
Definition: Attributes.h:73
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
Definition: Block.cpp:29
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:38
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:29
This class represents liveness information on block level.
Definition: Liveness.h:99
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This is a value defined by a result of an operation.
Definition: Value.h:447
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:550
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:534
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
unsigned getNumOperands()
Definition: Operation.h:346
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:677
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
bool isImpossibleToMatch() const
Definition: PatternMatch.h:44
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:749
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
Definition: PatternMatch.h:758
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:247
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
Definition: BlockSupport.h:73
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Types.h:164
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
Definition: Types.h:215
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_iterator user_begin() const
Definition: Value.h:216
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:233
user_iterator user_end() const
Definition: Value.h:217
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.h:236
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
Definition: ByteCode.h:235
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:249
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Definition: ByteCode.h:252
AttrTypeReplacer.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
Definition: Visitors.h:136
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult size