MLIR  16.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37  ByteCodeAddr rewriterAddr) {
38  SmallVector<StringRef, 8> generatedOps;
39  if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
40  generatedOps =
41  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43  PatternBenefit benefit = matchOp.getBenefit();
44  MLIRContext *ctx = matchOp.getContext();
45 
46  // Check to see if this is pattern matches a specific operation type.
47  if (Optional<StringRef> rootKind = matchOp.getRootKind())
48  return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49  generatedOps);
50  return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51  generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
62  PatternBenefit benefit) {
63  currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
70  allocatedTypeRangeMemory.clear();
71  allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
80  /// Apply an externally registered constraint.
81  ApplyConstraint,
82  /// Apply an externally registered rewrite.
83  ApplyRewrite,
84  /// Check if two generic values are equal.
85  AreEqual,
86  /// Check if two ranges are equal.
87  AreRangesEqual,
88  /// Unconditional branch.
89  Branch,
90  /// Compare the operand count of an operation with a constant.
91  CheckOperandCount,
92  /// Compare the name of an operation with a constant.
93  CheckOperationName,
94  /// Compare the result count of an operation with a constant.
95  CheckResultCount,
96  /// Compare a range of types to a constant range of types.
97  CheckTypes,
98  /// Continue to the next iteration of a loop.
99  Continue,
100  /// Create an operation.
101  CreateOperation,
102  /// Create a range of types.
103  CreateTypes,
104  /// Erase an operation.
105  EraseOp,
106  /// Extract the op from a range at the specified index.
107  ExtractOp,
108  /// Extract the type from a range at the specified index.
109  ExtractType,
110  /// Extract the value from a range at the specified index.
111  ExtractValue,
112  /// Terminate a matcher or rewrite sequence.
113  Finalize,
114  /// Iterate over a range of values.
115  ForEach,
116  /// Get a specific attribute of an operation.
117  GetAttribute,
118  /// Get the type of an attribute.
119  GetAttributeType,
120  /// Get the defining operation of a value.
121  GetDefiningOp,
122  /// Get a specific operand of an operation.
123  GetOperand0,
124  GetOperand1,
125  GetOperand2,
126  GetOperand3,
127  GetOperandN,
128  /// Get a specific operand group of an operation.
129  GetOperands,
130  /// Get a specific result of an operation.
131  GetResult0,
132  GetResult1,
133  GetResult2,
134  GetResult3,
135  GetResultN,
136  /// Get a specific result group of an operation.
137  GetResults,
138  /// Get the users of a value or a range of values.
139  GetUsers,
140  /// Get the type of a value.
141  GetValueType,
142  /// Get the types of a value range.
143  GetValueRangeTypes,
144  /// Check if a generic value is not null.
145  IsNotNull,
146  /// Record a successful pattern match.
147  RecordMatch,
148  /// Replace an operation.
149  ReplaceOp,
150  /// Compare an attribute with a set of constants.
151  SwitchAttribute,
152  /// Compare the operand count of an operation with a set of constants.
153  SwitchOperandCount,
154  /// Compare the name of an operation with a set of constants.
155  SwitchOperationName,
156  /// Compare the result count of an operation with a set of constants.
157  SwitchResultCount,
158  /// Compare a type with a set of constants.
159  SwitchType,
160  /// Compare a range of types with a set of constants.
161  SwitchTypes,
162 };
163 } // namespace
164 
165 /// A marker used to indicate if an operation should infer types.
168 
169 //===----------------------------------------------------------------------===//
170 // ByteCode Generation
171 //===----------------------------------------------------------------------===//
172 
173 //===----------------------------------------------------------------------===//
174 // Generator
175 
176 namespace {
177 struct ByteCodeLiveRange;
178 struct ByteCodeWriter;
179 
180 /// Check if the given class `T` can be converted to an opaque pointer.
181 template <typename T, typename... Args>
182 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
183 
184 /// This class represents the main generator for the pattern bytecode.
185 class Generator {
186 public:
187  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
188  SmallVectorImpl<ByteCodeField> &matcherByteCode,
189  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
191  ByteCodeField &maxValueMemoryIndex,
192  ByteCodeField &maxOpRangeMemoryIndex,
193  ByteCodeField &maxTypeRangeMemoryIndex,
194  ByteCodeField &maxValueRangeMemoryIndex,
195  ByteCodeField &maxLoopLevel,
196  llvm::StringMap<PDLConstraintFunction> &constraintFns,
197  llvm::StringMap<PDLRewriteFunction> &rewriteFns)
198  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
199  rewriterByteCode(rewriterByteCode), patterns(patterns),
200  maxValueMemoryIndex(maxValueMemoryIndex),
201  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
202  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
203  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
204  maxLoopLevel(maxLoopLevel) {
205  for (const auto &it : llvm::enumerate(constraintFns))
206  constraintToMemIndex.try_emplace(it.value().first(), it.index());
207  for (const auto &it : llvm::enumerate(rewriteFns))
208  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
209  }
210 
211  /// Generate the bytecode for the given PDL interpreter module.
212  void generate(ModuleOp module);
213 
214  /// Return the memory index to use for the given value.
215  ByteCodeField &getMemIndex(Value value) {
216  assert(valueToMemIndex.count(value) &&
217  "expected memory index to be assigned");
218  return valueToMemIndex[value];
219  }
220 
221  /// Return the range memory index used to store the given range value.
222  ByteCodeField &getRangeStorageIndex(Value value) {
223  assert(valueToRangeIndex.count(value) &&
224  "expected range index to be assigned");
225  return valueToRangeIndex[value];
226  }
227 
228  /// Return an index to use when referring to the given data that is uniqued in
229  /// the MLIR context.
230  template <typename T>
232  getMemIndex(T val) {
233  const void *opaqueVal = val.getAsOpaquePointer();
234 
235  // Get or insert a reference to this value.
236  auto it = uniquedDataToMemIndex.try_emplace(
237  opaqueVal, maxValueMemoryIndex + uniquedData.size());
238  if (it.second)
239  uniquedData.push_back(opaqueVal);
240  return it.first->second;
241  }
242 
243 private:
244  /// Allocate memory indices for the results of operations within the matcher
245  /// and rewriters.
246  void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
247  ModuleOp rewriterModule);
248 
249  /// Generate the bytecode for the given operation.
250  void generate(Region *region, ByteCodeWriter &writer);
251  void generate(Operation *op, ByteCodeWriter &writer);
252  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
253  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
254  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
255  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
256  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
257  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
258  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
259  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
260  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
261  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
285  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
286  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
287  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
288  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
289 
290  /// Mapping from value to its corresponding memory index.
291  DenseMap<Value, ByteCodeField> valueToMemIndex;
292 
293  /// Mapping from a range value to its corresponding range storage index.
294  DenseMap<Value, ByteCodeField> valueToRangeIndex;
295 
296  /// Mapping from the name of an externally registered rewrite to its index in
297  /// the bytecode registry.
298  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
299 
300  /// Mapping from the name of an externally registered constraint to its index
301  /// in the bytecode registry.
302  llvm::StringMap<ByteCodeField> constraintToMemIndex;
303 
304  /// Mapping from rewriter function name to the bytecode address of the
305  /// rewriter function in byte.
306  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
307 
308  /// Mapping from a uniqued storage object to its memory index within
309  /// `uniquedData`.
310  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
311 
312  /// The current level of the foreach loop.
313  ByteCodeField curLoopLevel = 0;
314 
315  /// The current MLIR context.
316  MLIRContext *ctx;
317 
318  /// Mapping from block to its address.
320 
321  /// Data of the ByteCode class to be populated.
322  std::vector<const void *> &uniquedData;
323  SmallVectorImpl<ByteCodeField> &matcherByteCode;
324  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
326  ByteCodeField &maxValueMemoryIndex;
327  ByteCodeField &maxOpRangeMemoryIndex;
328  ByteCodeField &maxTypeRangeMemoryIndex;
329  ByteCodeField &maxValueRangeMemoryIndex;
330  ByteCodeField &maxLoopLevel;
331 };
332 
333 /// This class provides utilities for writing a bytecode stream.
334 struct ByteCodeWriter {
335  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
336  : bytecode(bytecode), generator(generator) {}
337 
338  /// Append a field to the bytecode.
339  void append(ByteCodeField field) { bytecode.push_back(field); }
340  void append(OpCode opCode) { bytecode.push_back(opCode); }
341 
342  /// Append an address to the bytecode.
343  void append(ByteCodeAddr field) {
344  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
345  "unexpected ByteCode address size");
346 
347  ByteCodeField fieldParts[2];
348  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
349  bytecode.append({fieldParts[0], fieldParts[1]});
350  }
351 
352  /// Append a single successor to the bytecode, the exact address will need to
353  /// be resolved later.
354  void append(Block *successor) {
355  // Add back a reference to the successor so that the address can be resolved
356  // later.
357  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
358  append(ByteCodeAddr(0));
359  }
360 
361  /// Append a successor range to the bytecode, the exact address will need to
362  /// be resolved later.
363  void append(SuccessorRange successors) {
364  for (Block *successor : successors)
365  append(successor);
366  }
367 
368  /// Append a range of values that will be read as generic PDLValues.
369  void appendPDLValueList(OperandRange values) {
370  bytecode.push_back(values.size());
371  for (Value value : values)
372  appendPDLValue(value);
373  }
374 
375  /// Append a value as a PDLValue.
376  void appendPDLValue(Value value) {
377  appendPDLValueKind(value);
378  append(value);
379  }
380 
381  /// Append the PDLValue::Kind of the given value.
382  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
383 
384  /// Append the PDLValue::Kind of the given type.
385  void appendPDLValueKind(Type type) {
386  PDLValue::Kind kind =
388  .Case<pdl::AttributeType>(
389  [](Type) { return PDLValue::Kind::Attribute; })
390  .Case<pdl::OperationType>(
391  [](Type) { return PDLValue::Kind::Operation; })
392  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
393  if (rangeTy.getElementType().isa<pdl::TypeType>())
396  })
397  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
398  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
399  bytecode.push_back(static_cast<ByteCodeField>(kind));
400  }
401 
402  /// Append a value that will be stored in a memory slot and not inline within
403  /// the bytecode.
404  template <typename T>
407  append(T value) {
408  bytecode.push_back(generator.getMemIndex(value));
409  }
410 
411  /// Append a range of values.
412  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
414  append(T range) {
415  bytecode.push_back(llvm::size(range));
416  for (auto it : range)
417  append(it);
418  }
419 
420  /// Append a variadic number of fields to the bytecode.
421  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
422  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
423  append(field);
424  append(field2, fields...);
425  }
426 
427  /// Appends a value as a pointer, stored inline within the bytecode.
428  template <typename T>
430  appendInline(T value) {
431  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
432  const void *pointer = value.getAsOpaquePointer();
433  ByteCodeField fieldParts[numParts];
434  std::memcpy(fieldParts, &pointer, sizeof(const void *));
435  bytecode.append(fieldParts, fieldParts + numParts);
436  }
437 
438  /// Successor references in the bytecode that have yet to be resolved.
439  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
440 
441  /// The underlying bytecode buffer.
443 
444  /// The main generator producing PDL.
445  Generator &generator;
446 };
447 
448 /// This class represents a live range of PDL Interpreter values, containing
449 /// information about when values are live within a match/rewrite.
450 struct ByteCodeLiveRange {
451  using Set = llvm::IntervalMap<uint64_t, char, 16>;
452  using Allocator = Set::Allocator;
453 
454  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
455 
456  /// Union this live range with the one provided.
457  void unionWith(const ByteCodeLiveRange &rhs) {
458  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
459  ++it)
460  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
461  }
462 
463  /// Returns true if this range overlaps with the one provided.
464  bool overlaps(const ByteCodeLiveRange &rhs) const {
465  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
466  .valid();
467  }
468 
469  /// A map representing the ranges of the match/rewrite that a value is live in
470  /// the interpreter.
471  ///
472  /// We use std::unique_ptr here, because IntervalMap does not provide a
473  /// correct copy or move constructor. We can eliminate the pointer once
474  /// https://reviews.llvm.org/D113240 lands.
475  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
476 
477  /// The operation range storage index for this range.
478  Optional<unsigned> opRangeIndex;
479 
480  /// The type range storage index for this range.
481  Optional<unsigned> typeRangeIndex;
482 
483  /// The value range storage index for this range.
484  Optional<unsigned> valueRangeIndex;
485 };
486 } // namespace
487 
488 void Generator::generate(ModuleOp module) {
489  auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
490  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
491  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
492  pdl_interp::PDLInterpDialect::getRewriterModuleName());
493  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
494 
495  // Allocate memory indices for the results of operations within the matcher
496  // and rewriters.
497  allocateMemoryIndices(matcherFunc, rewriterModule);
498 
499  // Generate code for the rewriter functions.
500  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
501  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
502  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
503  for (Operation &op : rewriterFunc.getOps())
504  generate(&op, rewriterByteCodeWriter);
505  }
506  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
507  "unexpected branches in rewriter function");
508 
509  // Generate code for the matcher function.
510  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
511  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
512 
513  // Resolve successor references in the matcher.
514  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
515  ByteCodeAddr addr = blockToAddr[it.first];
516  for (unsigned offsetToFix : it.second)
517  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
518  }
519 }
520 
521 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
522  ModuleOp rewriterModule) {
523  // Rewriters use simplistic allocation scheme that simply assigns an index to
524  // each result.
525  for (auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
526  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
527  auto processRewriterValue = [&](Value val) {
528  valueToMemIndex.try_emplace(val, index++);
529  if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
530  Type elementTy = rangeType.getElementType();
531  if (elementTy.isa<pdl::TypeType>())
532  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
533  else if (elementTy.isa<pdl::ValueType>())
534  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
535  }
536  };
537 
538  for (BlockArgument arg : rewriterFunc.getArguments())
539  processRewriterValue(arg);
540  rewriterFunc.getBody().walk([&](Operation *op) {
541  for (Value result : op->getResults())
542  processRewriterValue(result);
543  });
544  if (index > maxValueMemoryIndex)
545  maxValueMemoryIndex = index;
546  if (typeRangeIndex > maxTypeRangeMemoryIndex)
547  maxTypeRangeMemoryIndex = typeRangeIndex;
548  if (valueRangeIndex > maxValueRangeMemoryIndex)
549  maxValueRangeMemoryIndex = valueRangeIndex;
550  }
551 
552  // The matcher function uses a more sophisticated numbering that tries to
553  // minimize the number of memory indices assigned. This is done by determining
554  // a live range of the values within the matcher, then the allocation is just
555  // finding the minimal number of overlapping live ranges. This is essentially
556  // a simplified form of register allocation where we don't necessarily have a
557  // limited number of registers, but we still want to minimize the number used.
558  DenseMap<Operation *, unsigned> opToFirstIndex;
559  DenseMap<Operation *, unsigned> opToLastIndex;
560 
561  // A custom walk that marks the first and the last index of each operation.
562  // The entry marks the beginning of the liveness range for this operation,
563  // followed by nested operations, followed by the end of the liveness range.
564  unsigned index = 0;
565  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
566  opToFirstIndex.try_emplace(op, index++);
567  for (Region &region : op->getRegions())
568  for (Block &block : region.getBlocks())
569  for (Operation &nested : block)
570  walk(&nested);
571  opToLastIndex.try_emplace(op, index++);
572  };
573  walk(matcherFunc);
574 
575  // Liveness info for each of the defs within the matcher.
576  ByteCodeLiveRange::Allocator allocator;
577  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
578 
579  // Assign the root operation being matched to slot 0.
580  BlockArgument rootOpArg = matcherFunc.getArgument(0);
581  valueToMemIndex[rootOpArg] = 0;
582 
583  // Walk each of the blocks, computing the def interval that the value is used.
584  Liveness matcherLiveness(matcherFunc);
585  matcherFunc->walk([&](Block *block) {
586  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
587  assert(info && "expected liveness info for block");
588  auto processValue = [&](Value value, Operation *firstUseOrDef) {
589  // We don't need to process the root op argument, this value is always
590  // assigned to the first memory slot.
591  if (value == rootOpArg)
592  return;
593 
594  // Set indices for the range of this block that the value is used.
595  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
596  defRangeIt->second.liveness->insert(
597  opToFirstIndex[firstUseOrDef],
598  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
599  /*dummyValue*/ 0);
600 
601  // Check to see if this value is a range type.
602  if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
603  Type eleType = rangeTy.getElementType();
604  if (eleType.isa<pdl::OperationType>())
605  defRangeIt->second.opRangeIndex = 0;
606  else if (eleType.isa<pdl::TypeType>())
607  defRangeIt->second.typeRangeIndex = 0;
608  else if (eleType.isa<pdl::ValueType>())
609  defRangeIt->second.valueRangeIndex = 0;
610  }
611  };
612 
613  // Process the live-ins of this block.
614  for (Value liveIn : info->in()) {
615  // Only process the value if it has been defined in the current region.
616  // Other values that span across pdl_interp.foreach will be added higher
617  // up. This ensures that the we keep them alive for the entire duration
618  // of the loop.
619  if (liveIn.getParentRegion() == block->getParent())
620  processValue(liveIn, &block->front());
621  }
622 
623  // Process the block arguments for the entry block (those are not live-in).
624  if (block->isEntryBlock()) {
625  for (Value argument : block->getArguments())
626  processValue(argument, &block->front());
627  }
628 
629  // Process any new defs within this block.
630  for (Operation &op : *block)
631  for (Value result : op.getResults())
632  processValue(result, &op);
633  });
634 
635  // Greedily allocate memory slots using the computed def live ranges.
636  std::vector<ByteCodeLiveRange> allocatedIndices;
637 
638  // The number of memory indices currently allocated (and its next value).
639  // Recall that the root gets allocated memory index 0.
640  ByteCodeField numIndices = 1;
641 
642  // The number of memory ranges of various types (and their next values).
643  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
644 
645  for (auto &defIt : valueDefRanges) {
646  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
647  ByteCodeLiveRange &defRange = defIt.second;
648 
649  // Try to allocate to an existing index.
650  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
651  ByteCodeLiveRange &existingRange = existingIndexIt.value();
652  if (!defRange.overlaps(existingRange)) {
653  existingRange.unionWith(defRange);
654  memIndex = existingIndexIt.index() + 1;
655 
656  if (defRange.opRangeIndex) {
657  if (!existingRange.opRangeIndex)
658  existingRange.opRangeIndex = numOpRanges++;
659  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
660  } else if (defRange.typeRangeIndex) {
661  if (!existingRange.typeRangeIndex)
662  existingRange.typeRangeIndex = numTypeRanges++;
663  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
664  } else if (defRange.valueRangeIndex) {
665  if (!existingRange.valueRangeIndex)
666  existingRange.valueRangeIndex = numValueRanges++;
667  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
668  }
669  break;
670  }
671  }
672 
673  // If no existing index could be used, add a new one.
674  if (memIndex == 0) {
675  allocatedIndices.emplace_back(allocator);
676  ByteCodeLiveRange &newRange = allocatedIndices.back();
677  newRange.unionWith(defRange);
678 
679  // Allocate an index for op/type/value ranges.
680  if (defRange.opRangeIndex) {
681  newRange.opRangeIndex = numOpRanges;
682  valueToRangeIndex[defIt.first] = numOpRanges++;
683  } else if (defRange.typeRangeIndex) {
684  newRange.typeRangeIndex = numTypeRanges;
685  valueToRangeIndex[defIt.first] = numTypeRanges++;
686  } else if (defRange.valueRangeIndex) {
687  newRange.valueRangeIndex = numValueRanges;
688  valueToRangeIndex[defIt.first] = numValueRanges++;
689  }
690 
691  memIndex = allocatedIndices.size();
692  ++numIndices;
693  }
694  }
695 
696  // Print the index usage and ensure that we did not run out of index space.
697  LLVM_DEBUG({
698  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
699  << "(down from initial " << valueDefRanges.size() << ").\n";
700  });
701  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
702  "Ran out of memory for allocated indices");
703 
704  // Update the max number of indices.
705  if (numIndices > maxValueMemoryIndex)
706  maxValueMemoryIndex = numIndices;
707  if (numOpRanges > maxOpRangeMemoryIndex)
708  maxOpRangeMemoryIndex = numOpRanges;
709  if (numTypeRanges > maxTypeRangeMemoryIndex)
710  maxTypeRangeMemoryIndex = numTypeRanges;
711  if (numValueRanges > maxValueRangeMemoryIndex)
712  maxValueRangeMemoryIndex = numValueRanges;
713 }
714 
715 void Generator::generate(Region *region, ByteCodeWriter &writer) {
716  llvm::ReversePostOrderTraversal<Region *> rpot(region);
717  for (Block *block : rpot) {
718  // Keep track of where this block begins within the matcher function.
719  blockToAddr.try_emplace(block, matcherByteCode.size());
720  for (Operation &op : *block)
721  generate(&op, writer);
722  }
723 }
724 
725 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
726  LLVM_DEBUG({
727  // The following list must contain all the operations that do not
728  // produce any bytecode.
729  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
730  writer.appendInline(op->getLoc());
731  });
733  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
734  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
735  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
736  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
737  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
738  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
739  pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
740  pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
741  pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
742  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
743  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
744  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
745  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
746  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
747  pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
748  pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
749  pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
750  pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
751  pdl_interp::SwitchResultCountOp>(
752  [&](auto interpOp) { this->generate(interpOp, writer); })
753  .Default([](Operation *) {
754  llvm_unreachable("unknown `pdl_interp` operation");
755  });
756 }
757 
758 void Generator::generate(pdl_interp::ApplyConstraintOp op,
759  ByteCodeWriter &writer) {
760  assert(constraintToMemIndex.count(op.getName()) &&
761  "expected index for constraint function");
762  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
763  writer.appendPDLValueList(op.getArgs());
764  writer.append(op.getSuccessors());
765 }
766 void Generator::generate(pdl_interp::ApplyRewriteOp op,
767  ByteCodeWriter &writer) {
768  assert(externalRewriterToMemIndex.count(op.getName()) &&
769  "expected index for rewrite function");
770  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
771  writer.appendPDLValueList(op.getArgs());
772 
773  ResultRange results = op.getResults();
774  writer.append(ByteCodeField(results.size()));
775  for (Value result : results) {
776  // In debug mode we also record the expected kind of the result, so that we
777  // can provide extra verification of the native rewrite function.
778 #ifndef NDEBUG
779  writer.appendPDLValueKind(result);
780 #endif
781 
782  // Range results also need to append the range storage index.
783  if (result.getType().isa<pdl::RangeType>())
784  writer.append(getRangeStorageIndex(result));
785  writer.append(result);
786  }
787 }
788 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
789  Value lhs = op.getLhs();
790  if (lhs.getType().isa<pdl::RangeType>()) {
791  writer.append(OpCode::AreRangesEqual);
792  writer.appendPDLValueKind(lhs);
793  writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
794  return;
795  }
796 
797  writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
798 }
799 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
800  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
801 }
802 void Generator::generate(pdl_interp::CheckAttributeOp op,
803  ByteCodeWriter &writer) {
804  writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
805  op.getSuccessors());
806 }
807 void Generator::generate(pdl_interp::CheckOperandCountOp op,
808  ByteCodeWriter &writer) {
809  writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
810  static_cast<ByteCodeField>(op.getCompareAtLeast()),
811  op.getSuccessors());
812 }
813 void Generator::generate(pdl_interp::CheckOperationNameOp op,
814  ByteCodeWriter &writer) {
815  writer.append(OpCode::CheckOperationName, op.getInputOp(),
816  OperationName(op.getName(), ctx), op.getSuccessors());
817 }
818 void Generator::generate(pdl_interp::CheckResultCountOp op,
819  ByteCodeWriter &writer) {
820  writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
821  static_cast<ByteCodeField>(op.getCompareAtLeast()),
822  op.getSuccessors());
823 }
824 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
825  writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
826  op.getSuccessors());
827 }
828 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
829  writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
830  op.getSuccessors());
831 }
832 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
833  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
834  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
835 }
836 void Generator::generate(pdl_interp::CreateAttributeOp op,
837  ByteCodeWriter &writer) {
838  // Simply repoint the memory index of the result to the constant.
839  getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
840 }
841 void Generator::generate(pdl_interp::CreateOperationOp op,
842  ByteCodeWriter &writer) {
843  writer.append(OpCode::CreateOperation, op.getResultOp(),
844  OperationName(op.getName(), ctx));
845  writer.appendPDLValueList(op.getInputOperands());
846 
847  // Add the attributes.
848  OperandRange attributes = op.getInputAttributes();
849  writer.append(static_cast<ByteCodeField>(attributes.size()));
850  for (auto it : llvm::zip(op.getInputAttributeNames(), attributes))
851  writer.append(std::get<0>(it), std::get<1>(it));
852 
853  // Add the result types. If the operation has inferred results, we use a
854  // marker "size" value. Otherwise, we add the list of explicit result types.
855  if (op.getInferredResultTypes())
856  writer.append(kInferTypesMarker);
857  else
858  writer.appendPDLValueList(op.getInputResultTypes());
859 }
860 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
861  // Simply repoint the memory index of the result to the constant.
862  getMemIndex(op.getResult()) = getMemIndex(op.getValue());
863 }
864 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
865  writer.append(OpCode::CreateTypes, op.getResult(),
866  getRangeStorageIndex(op.getResult()), op.getValue());
867 }
868 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
869  writer.append(OpCode::EraseOp, op.getInputOp());
870 }
871 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
872  OpCode opCode =
873  TypeSwitch<Type, OpCode>(op.getResult().getType())
874  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
875  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
876  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
877  .Default([](Type) -> OpCode {
878  llvm_unreachable("unsupported element type");
879  });
880  writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
881 }
882 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
883  writer.append(OpCode::Finalize);
884 }
885 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
886  BlockArgument arg = op.getLoopVariable();
887  writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
888  writer.appendPDLValueKind(arg.getType());
889  writer.append(curLoopLevel, op.getSuccessor());
890  ++curLoopLevel;
891  if (curLoopLevel > maxLoopLevel)
892  maxLoopLevel = curLoopLevel;
893  generate(&op.getRegion(), writer);
894  --curLoopLevel;
895 }
896 void Generator::generate(pdl_interp::GetAttributeOp op,
897  ByteCodeWriter &writer) {
898  writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
899  op.getNameAttr());
900 }
901 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
902  ByteCodeWriter &writer) {
903  writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
904 }
905 void Generator::generate(pdl_interp::GetDefiningOpOp op,
906  ByteCodeWriter &writer) {
907  writer.append(OpCode::GetDefiningOp, op.getInputOp());
908  writer.appendPDLValue(op.getValue());
909 }
910 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
911  uint32_t index = op.getIndex();
912  if (index < 4)
913  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
914  else
915  writer.append(OpCode::GetOperandN, index);
916  writer.append(op.getInputOp(), op.getValue());
917 }
918 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
919  Value result = op.getValue();
920  Optional<uint32_t> index = op.getIndex();
921  writer.append(OpCode::GetOperands,
922  index.value_or(std::numeric_limits<uint32_t>::max()),
923  op.getInputOp());
924  if (result.getType().isa<pdl::RangeType>())
925  writer.append(getRangeStorageIndex(result));
926  else
928  writer.append(result);
929 }
930 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
931  uint32_t index = op.getIndex();
932  if (index < 4)
933  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
934  else
935  writer.append(OpCode::GetResultN, index);
936  writer.append(op.getInputOp(), op.getValue());
937 }
938 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
939  Value result = op.getValue();
940  Optional<uint32_t> index = op.getIndex();
941  writer.append(OpCode::GetResults,
942  index.value_or(std::numeric_limits<uint32_t>::max()),
943  op.getInputOp());
944  if (result.getType().isa<pdl::RangeType>())
945  writer.append(getRangeStorageIndex(result));
946  else
948  writer.append(result);
949 }
950 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
951  Value operations = op.getOperations();
952  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
953  writer.append(OpCode::GetUsers, operations, rangeIndex);
954  writer.appendPDLValue(op.getValue());
955 }
956 void Generator::generate(pdl_interp::GetValueTypeOp op,
957  ByteCodeWriter &writer) {
958  if (op.getType().isa<pdl::RangeType>()) {
959  Value result = op.getResult();
960  writer.append(OpCode::GetValueRangeTypes, result,
961  getRangeStorageIndex(result), op.getValue());
962  } else {
963  writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
964  }
965 }
966 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
967  writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
968 }
969 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
970  ByteCodeField patternIndex = patterns.size();
971  patterns.emplace_back(PDLByteCodePattern::create(
972  op, rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
973  writer.append(OpCode::RecordMatch, patternIndex,
974  SuccessorRange(op.getOperation()), op.getMatchedOps());
975  writer.appendPDLValueList(op.getInputs());
976 }
977 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
978  writer.append(OpCode::ReplaceOp, op.getInputOp());
979  writer.appendPDLValueList(op.getReplValues());
980 }
981 void Generator::generate(pdl_interp::SwitchAttributeOp op,
982  ByteCodeWriter &writer) {
983  writer.append(OpCode::SwitchAttribute, op.getAttribute(),
984  op.getCaseValuesAttr(), op.getSuccessors());
985 }
986 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
987  ByteCodeWriter &writer) {
988  writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
989  op.getCaseValuesAttr(), op.getSuccessors());
990 }
991 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
992  ByteCodeWriter &writer) {
993  auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](Attribute attr) {
994  return OperationName(attr.cast<StringAttr>().getValue(), ctx);
995  });
996  writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
997  op.getSuccessors());
998 }
999 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1000  ByteCodeWriter &writer) {
1001  writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1002  op.getCaseValuesAttr(), op.getSuccessors());
1003 }
1004 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1005  writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1006  op.getSuccessors());
1007 }
1008 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1009  writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1010  op.getSuccessors());
1011 }
1012 
1013 //===----------------------------------------------------------------------===//
1014 // PDLByteCode
1015 //===----------------------------------------------------------------------===//
1016 
1018  llvm::StringMap<PDLConstraintFunction> constraintFns,
1019  llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1020  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1021  rewriterByteCode, patterns, maxValueMemoryIndex,
1022  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1023  maxLoopLevel, constraintFns, rewriteFns);
1024  generator.generate(module);
1025 
1026  // Initialize the external functions.
1027  for (auto &it : constraintFns)
1028  constraintFunctions.push_back(std::move(it.second));
1029  for (auto &it : rewriteFns)
1030  rewriteFunctions.push_back(std::move(it.second));
1031 }
1032 
1033 /// Initialize the given state such that it can be used to execute the current
1034 /// bytecode.
1036  state.memory.resize(maxValueMemoryIndex, nullptr);
1037  state.opRangeMemory.resize(maxOpRangeCount);
1038  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1039  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1040  state.loopIndex.resize(maxLoopLevel, 0);
1041  state.currentPatternBenefits.reserve(patterns.size());
1042  for (const PDLByteCodePattern &pattern : patterns)
1043  state.currentPatternBenefits.push_back(pattern.getBenefit());
1044 }
1045 
1046 //===----------------------------------------------------------------------===//
1047 // ByteCode Execution
1048 
1049 namespace {
1050 /// This class provides support for executing a bytecode stream.
1051 class ByteCodeExecutor {
1052 public:
1053  ByteCodeExecutor(
1054  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1055  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1056  MutableArrayRef<TypeRange> typeRangeMemory,
1057  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1058  MutableArrayRef<ValueRange> valueRangeMemory,
1059  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1060  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1062  ArrayRef<PatternBenefit> currentPatternBenefits,
1064  ArrayRef<PDLConstraintFunction> constraintFunctions,
1065  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1066  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1067  typeRangeMemory(typeRangeMemory),
1068  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1069  valueRangeMemory(valueRangeMemory),
1070  allocatedValueRangeMemory(allocatedValueRangeMemory),
1071  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1072  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1073  constraintFunctions(constraintFunctions),
1074  rewriteFunctions(rewriteFunctions) {}
1075 
1076  /// Start executing the code at the current bytecode index. `matches` is an
1077  /// optional field provided when this function is executed in a matching
1078  /// context.
1079  void execute(PatternRewriter &rewriter,
1080  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1081  Optional<Location> mainRewriteLoc = {});
1082 
1083 private:
1084  /// Internal implementation of executing each of the bytecode commands.
1085  void executeApplyConstraint(PatternRewriter &rewriter);
1086  void executeApplyRewrite(PatternRewriter &rewriter);
1087  void executeAreEqual();
1088  void executeAreRangesEqual();
1089  void executeBranch();
1090  void executeCheckOperandCount();
1091  void executeCheckOperationName();
1092  void executeCheckResultCount();
1093  void executeCheckTypes();
1094  void executeContinue();
1095  void executeCreateOperation(PatternRewriter &rewriter,
1096  Location mainRewriteLoc);
1097  void executeCreateTypes();
1098  void executeEraseOp(PatternRewriter &rewriter);
1099  template <typename T, typename Range, PDLValue::Kind kind>
1100  void executeExtract();
1101  void executeFinalize();
1102  void executeForEach();
1103  void executeGetAttribute();
1104  void executeGetAttributeType();
1105  void executeGetDefiningOp();
1106  void executeGetOperand(unsigned index);
1107  void executeGetOperands();
1108  void executeGetResult(unsigned index);
1109  void executeGetResults();
1110  void executeGetUsers();
1111  void executeGetValueType();
1112  void executeGetValueRangeTypes();
1113  void executeIsNotNull();
1114  void executeRecordMatch(PatternRewriter &rewriter,
1116  void executeReplaceOp(PatternRewriter &rewriter);
1117  void executeSwitchAttribute();
1118  void executeSwitchOperandCount();
1119  void executeSwitchOperationName();
1120  void executeSwitchResultCount();
1121  void executeSwitchType();
1122  void executeSwitchTypes();
1123 
1124  /// Pushes a code iterator to the stack.
1125  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1126 
1127  /// Pops a code iterator from the stack, returning true on success.
1128  void popCodeIt() {
1129  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1130  curCodeIt = resumeCodeIt.back();
1131  resumeCodeIt.pop_back();
1132  }
1133 
1134  /// Return the bytecode iterator at the start of the current op code.
1135  const ByteCodeField *getPrevCodeIt() const {
1136  LLVM_DEBUG({
1137  // Account for the op code and the Location stored inline.
1138  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1139  });
1140 
1141  // Account for the op code only.
1142  return curCodeIt - 1;
1143  }
1144 
1145  /// Read a value from the bytecode buffer, optionally skipping a certain
1146  /// number of prefix values. These methods always update the buffer to point
1147  /// to the next field after the read data.
1148  template <typename T = ByteCodeField>
1149  T read(size_t skipN = 0) {
1150  curCodeIt += skipN;
1151  return readImpl<T>();
1152  }
1153  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1154 
1155  /// Read a list of values from the bytecode buffer.
1156  template <typename ValueT, typename T>
1157  void readList(SmallVectorImpl<T> &list) {
1158  list.clear();
1159  for (unsigned i = 0, e = read(); i != e; ++i)
1160  list.push_back(read<ValueT>());
1161  }
1162 
1163  /// Read a list of values from the bytecode buffer. The values may be encoded
1164  /// as either Value or ValueRange elements.
1165  void readValueList(SmallVectorImpl<Value> &list) {
1166  for (unsigned i = 0, e = read(); i != e; ++i) {
1167  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1168  list.push_back(read<Value>());
1169  } else {
1170  ValueRange *values = read<ValueRange *>();
1171  list.append(values->begin(), values->end());
1172  }
1173  }
1174  }
1175 
1176  /// Read a value stored inline as a pointer.
1177  template <typename T>
1179  readInline() {
1180  const void *pointer;
1181  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1182  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1183  return T::getFromOpaquePointer(pointer);
1184  }
1185 
1186  /// Jump to a specific successor based on a predicate value.
1187  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1188  /// Jump to a specific successor based on a destination index.
1189  void selectJump(size_t destIndex) {
1190  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1191  }
1192 
1193  /// Handle a switch operation with the provided value and cases.
1194  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1195  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1196  LLVM_DEBUG({
1197  llvm::dbgs() << " * Value: " << value << "\n"
1198  << " * Cases: ";
1199  llvm::interleaveComma(cases, llvm::dbgs());
1200  llvm::dbgs() << "\n";
1201  });
1202 
1203  // Check to see if the attribute value is within the case list. Jump to
1204  // the correct successor index based on the result.
1205  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1206  if (cmp(*it, value))
1207  return selectJump(size_t((it - cases.begin()) + 1));
1208  selectJump(size_t(0));
1209  }
1210 
1211  /// Store a pointer to memory.
1212  void storeToMemory(unsigned index, const void *value) {
1213  memory[index] = value;
1214  }
1215 
1216  /// Store a value to memory as an opaque pointer.
1217  template <typename T>
1219  storeToMemory(unsigned index, T value) {
1220  memory[index] = value.getAsOpaquePointer();
1221  }
1222 
1223  /// Internal implementation of reading various data types from the bytecode
1224  /// stream.
1225  template <typename T>
1226  const void *readFromMemory() {
1227  size_t index = *curCodeIt++;
1228 
1229  // If this type is an SSA value, it can only be stored in non-const memory.
1230  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1231  Value>::value ||
1232  index < memory.size())
1233  return memory[index];
1234 
1235  // Otherwise, if this index is not inbounds it is uniqued.
1236  return uniquedMemory[index - memory.size()];
1237  }
1238  template <typename T>
1240  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1241  }
1242  template <typename T>
1244  T>
1245  readImpl() {
1246  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1247  }
1248  template <typename T>
1250  switch (read<PDLValue::Kind>()) {
1252  return read<Attribute>();
1254  return read<Operation *>();
1255  case PDLValue::Kind::Type:
1256  return read<Type>();
1257  case PDLValue::Kind::Value:
1258  return read<Value>();
1260  return read<TypeRange *>();
1262  return read<ValueRange *>();
1263  }
1264  llvm_unreachable("unhandled PDLValue::Kind");
1265  }
1266  template <typename T>
1268  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1269  "unexpected ByteCode address size");
1270  ByteCodeAddr result;
1271  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1272  curCodeIt += 2;
1273  return result;
1274  }
1275  template <typename T>
1277  return *curCodeIt++;
1278  }
1279  template <typename T>
1281  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1282  }
1283 
1284  /// The underlying bytecode buffer.
1285  const ByteCodeField *curCodeIt;
1286 
1287  /// The stack of bytecode positions at which to resume operation.
1289 
1290  /// The current execution memory.
1292  MutableArrayRef<OwningOpRange> opRangeMemory;
1293  MutableArrayRef<TypeRange> typeRangeMemory;
1294  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1295  MutableArrayRef<ValueRange> valueRangeMemory;
1296  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1297 
1298  /// The current loop indices.
1299  MutableArrayRef<unsigned> loopIndex;
1300 
1301  /// References to ByteCode data necessary for execution.
1302  ArrayRef<const void *> uniquedMemory;
1304  ArrayRef<PatternBenefit> currentPatternBenefits;
1306  ArrayRef<PDLConstraintFunction> constraintFunctions;
1307  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1308 };
1309 
1310 /// This class is an instantiation of the PDLResultList that provides access to
1311 /// the returned results. This API is not on `PDLResultList` to avoid
1312 /// overexposing access to information specific solely to the ByteCode.
1313 class ByteCodeRewriteResultList : public PDLResultList {
1314 public:
1315  ByteCodeRewriteResultList(unsigned maxNumResults)
1316  : PDLResultList(maxNumResults) {}
1317 
1318  /// Return the list of PDL results.
1319  MutableArrayRef<PDLValue> getResults() { return results; }
1320 
1321  /// Return the type ranges allocated by this list.
1322  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1323  return allocatedTypeRanges;
1324  }
1325 
1326  /// Return the value ranges allocated by this list.
1327  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1328  return allocatedValueRanges;
1329  }
1330 };
1331 } // namespace
1332 
1333 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1334  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1335  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1337  readList<PDLValue>(args);
1338 
1339  LLVM_DEBUG({
1340  llvm::dbgs() << " * Arguments: ";
1341  llvm::interleaveComma(args, llvm::dbgs());
1342  });
1343 
1344  // Invoke the constraint and jump to the proper destination.
1345  selectJump(succeeded(constraintFn(rewriter, args)));
1346 }
1347 
1348 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1349  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1350  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1352  readList<PDLValue>(args);
1353 
1354  LLVM_DEBUG({
1355  llvm::dbgs() << " * Arguments: ";
1356  llvm::interleaveComma(args, llvm::dbgs());
1357  });
1358 
1359  // Execute the rewrite function.
1360  ByteCodeField numResults = read();
1361  ByteCodeRewriteResultList results(numResults);
1362  rewriteFn(rewriter, results, args);
1363 
1364  assert(results.getResults().size() == numResults &&
1365  "native PDL rewrite function returned unexpected number of results");
1366 
1367  // Store the results in the bytecode memory.
1368  for (PDLValue &result : results.getResults()) {
1369  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1370 
1371 // In debug mode we also verify the expected kind of the result.
1372 #ifndef NDEBUG
1373  assert(result.getKind() == read<PDLValue::Kind>() &&
1374  "native PDL rewrite function returned an unexpected type of result");
1375 #endif
1376 
1377  // If the result is a range, we need to copy it over to the bytecodes
1378  // range memory.
1379  if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1380  unsigned rangeIndex = read();
1381  typeRangeMemory[rangeIndex] = *typeRange;
1382  memory[read()] = &typeRangeMemory[rangeIndex];
1383  } else if (Optional<ValueRange> valueRange =
1384  result.dyn_cast<ValueRange>()) {
1385  unsigned rangeIndex = read();
1386  valueRangeMemory[rangeIndex] = *valueRange;
1387  memory[read()] = &valueRangeMemory[rangeIndex];
1388  } else {
1389  memory[read()] = result.getAsOpaquePointer();
1390  }
1391  }
1392 
1393  // Copy over any underlying storage allocated for result ranges.
1394  for (auto &it : results.getAllocatedTypeRanges())
1395  allocatedTypeRangeMemory.push_back(std::move(it));
1396  for (auto &it : results.getAllocatedValueRanges())
1397  allocatedValueRangeMemory.push_back(std::move(it));
1398 }
1399 
1400 void ByteCodeExecutor::executeAreEqual() {
1401  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1402  const void *lhs = read<const void *>();
1403  const void *rhs = read<const void *>();
1404 
1405  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1406  selectJump(lhs == rhs);
1407 }
1408 
1409 void ByteCodeExecutor::executeAreRangesEqual() {
1410  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1411  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1412  const void *lhs = read<const void *>();
1413  const void *rhs = read<const void *>();
1414 
1415  switch (valueKind) {
1417  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1418  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1419  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1420  selectJump(*lhsRange == *rhsRange);
1421  break;
1422  }
1424  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1425  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1426  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1427  selectJump(*lhsRange == *rhsRange);
1428  break;
1429  }
1430  default:
1431  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1432  }
1433 }
1434 
1435 void ByteCodeExecutor::executeBranch() {
1436  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1437  curCodeIt = &code[read<ByteCodeAddr>()];
1438 }
1439 
1440 void ByteCodeExecutor::executeCheckOperandCount() {
1441  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1442  Operation *op = read<Operation *>();
1443  uint32_t expectedCount = read<uint32_t>();
1444  bool compareAtLeast = read();
1445 
1446  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1447  << " * Expected: " << expectedCount << "\n"
1448  << " * Comparator: "
1449  << (compareAtLeast ? ">=" : "==") << "\n");
1450  if (compareAtLeast)
1451  selectJump(op->getNumOperands() >= expectedCount);
1452  else
1453  selectJump(op->getNumOperands() == expectedCount);
1454 }
1455 
1456 void ByteCodeExecutor::executeCheckOperationName() {
1457  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1458  Operation *op = read<Operation *>();
1459  OperationName expectedName = read<OperationName>();
1460 
1461  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1462  << " * Expected: \"" << expectedName << "\"\n");
1463  selectJump(op->getName() == expectedName);
1464 }
1465 
1466 void ByteCodeExecutor::executeCheckResultCount() {
1467  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1468  Operation *op = read<Operation *>();
1469  uint32_t expectedCount = read<uint32_t>();
1470  bool compareAtLeast = read();
1471 
1472  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1473  << " * Expected: " << expectedCount << "\n"
1474  << " * Comparator: "
1475  << (compareAtLeast ? ">=" : "==") << "\n");
1476  if (compareAtLeast)
1477  selectJump(op->getNumResults() >= expectedCount);
1478  else
1479  selectJump(op->getNumResults() == expectedCount);
1480 }
1481 
1482 void ByteCodeExecutor::executeCheckTypes() {
1483  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1484  TypeRange *lhs = read<TypeRange *>();
1485  Attribute rhs = read<Attribute>();
1486  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1487 
1488  selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1489 }
1490 
1491 void ByteCodeExecutor::executeContinue() {
1492  ByteCodeField level = read();
1493  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1494  << " * Level: " << level << "\n");
1495  ++loopIndex[level];
1496  popCodeIt();
1497 }
1498 
1499 void ByteCodeExecutor::executeCreateTypes() {
1500  LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1501  unsigned memIndex = read();
1502  unsigned rangeIndex = read();
1503  ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1504 
1505  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1506 
1507  // Allocate a buffer for this type range.
1508  llvm::OwningArrayRef<Type> storage(typesAttr.size());
1509  llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1510  allocatedTypeRangeMemory.emplace_back(std::move(storage));
1511 
1512  // Assign this to the range slot and use the range as the value for the
1513  // memory index.
1514  typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1515  memory[memIndex] = &typeRangeMemory[rangeIndex];
1516 }
1517 
1518 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1519  Location mainRewriteLoc) {
1520  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1521 
1522  unsigned memIndex = read();
1523  OperationState state(mainRewriteLoc, read<OperationName>());
1524  readValueList(state.operands);
1525  for (unsigned i = 0, e = read(); i != e; ++i) {
1526  StringAttr name = read<StringAttr>();
1527  if (Attribute attr = read<Attribute>())
1528  state.addAttribute(name, attr);
1529  }
1530 
1531  // Read in the result types. If the "size" is the sentinel value, this
1532  // indicates that the result types should be inferred.
1533  unsigned numResults = read();
1534  if (numResults == kInferTypesMarker) {
1535  InferTypeOpInterface::Concept *inferInterface =
1536  state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1537  assert(inferInterface &&
1538  "expected operation to provide InferTypeOpInterface");
1539 
1540  // TODO: Handle failure.
1541  if (failed(inferInterface->inferReturnTypes(
1542  state.getContext(), state.location, state.operands,
1543  state.attributes.getDictionary(state.getContext()), state.regions,
1544  state.types)))
1545  return;
1546  } else {
1547  // Otherwise, this is a fixed number of results.
1548  for (unsigned i = 0; i != numResults; ++i) {
1549  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1550  state.types.push_back(read<Type>());
1551  } else {
1552  TypeRange *resultTypes = read<TypeRange *>();
1553  state.types.append(resultTypes->begin(), resultTypes->end());
1554  }
1555  }
1556  }
1557 
1558  Operation *resultOp = rewriter.create(state);
1559  memory[memIndex] = resultOp;
1560 
1561  LLVM_DEBUG({
1562  llvm::dbgs() << " * Attributes: "
1563  << state.attributes.getDictionary(state.getContext())
1564  << "\n * Operands: ";
1565  llvm::interleaveComma(state.operands, llvm::dbgs());
1566  llvm::dbgs() << "\n * Result Types: ";
1567  llvm::interleaveComma(state.types, llvm::dbgs());
1568  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1569  });
1570 }
1571 
1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1573  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1574  Operation *op = read<Operation *>();
1575 
1576  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1577  rewriter.eraseOp(op);
1578 }
1579 
1580 template <typename T, typename Range, PDLValue::Kind kind>
1581 void ByteCodeExecutor::executeExtract() {
1582  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1583  Range *range = read<Range *>();
1584  unsigned index = read<uint32_t>();
1585  unsigned memIndex = read();
1586 
1587  if (!range) {
1588  memory[memIndex] = nullptr;
1589  return;
1590  }
1591 
1592  T result = index < range->size() ? (*range)[index] : T();
1593  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1594  << " * Index: " << index << "\n"
1595  << " * Result: " << result << "\n");
1596  storeToMemory(memIndex, result);
1597 }
1598 
1599 void ByteCodeExecutor::executeFinalize() {
1600  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1601 }
1602 
1603 void ByteCodeExecutor::executeForEach() {
1604  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1605  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1606  unsigned rangeIndex = read();
1607  unsigned memIndex = read();
1608  const void *value = nullptr;
1609 
1610  switch (read<PDLValue::Kind>()) {
1612  unsigned &index = loopIndex[read()];
1613  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1614  assert(index <= array.size() && "iterated past the end");
1615  if (index < array.size()) {
1616  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1617  value = array[index];
1618  break;
1619  }
1620 
1621  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1622  index = 0;
1623  selectJump(size_t(0));
1624  return;
1625  }
1626  default:
1627  llvm_unreachable("unexpected `ForEach` value kind");
1628  }
1629 
1630  // Store the iterate value and the stack address.
1631  memory[memIndex] = value;
1632  pushCodeIt(prevCodeIt);
1633 
1634  // Skip over the successor (we will enter the body of the loop).
1635  read<ByteCodeAddr>();
1636 }
1637 
1638 void ByteCodeExecutor::executeGetAttribute() {
1639  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1640  unsigned memIndex = read();
1641  Operation *op = read<Operation *>();
1642  StringAttr attrName = read<StringAttr>();
1643  Attribute attr = op->getAttr(attrName);
1644 
1645  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1646  << " * Attribute: " << attrName << "\n"
1647  << " * Result: " << attr << "\n");
1648  memory[memIndex] = attr.getAsOpaquePointer();
1649 }
1650 
1651 void ByteCodeExecutor::executeGetAttributeType() {
1652  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1653  unsigned memIndex = read();
1654  Attribute attr = read<Attribute>();
1655  Type type;
1656  if (auto typedAttr = attr.dyn_cast<TypedAttr>())
1657  type = typedAttr.getType();
1658 
1659  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1660  << " * Result: " << type << "\n");
1661  memory[memIndex] = type.getAsOpaquePointer();
1662 }
1663 
1664 void ByteCodeExecutor::executeGetDefiningOp() {
1665  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1666  unsigned memIndex = read();
1667  Operation *op = nullptr;
1668  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1669  Value value = read<Value>();
1670  if (value)
1671  op = value.getDefiningOp();
1672  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1673  } else {
1674  ValueRange *values = read<ValueRange *>();
1675  if (values && !values->empty()) {
1676  op = values->front().getDefiningOp();
1677  }
1678  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1679  }
1680 
1681  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1682  memory[memIndex] = op;
1683 }
1684 
1685 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1686  Operation *op = read<Operation *>();
1687  unsigned memIndex = read();
1688  Value operand =
1689  index < op->getNumOperands() ? op->getOperand(index) : Value();
1690 
1691  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1692  << " * Index: " << index << "\n"
1693  << " * Result: " << operand << "\n");
1694  memory[memIndex] = operand.getAsOpaquePointer();
1695 }
1696 
1697 /// This function is the internal implementation of `GetResults` and
1698 /// `GetOperands` that provides support for extracting a value range from the
1699 /// given operation.
1700 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1701 static void *
1702 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1703  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1704  MutableArrayRef<ValueRange> valueRangeMemory) {
1705  // Check for the sentinel index that signals that all values should be
1706  // returned.
1707  if (index == std::numeric_limits<uint32_t>::max()) {
1708  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1709  // `values` is already the full value range.
1710 
1711  // Otherwise, check to see if this operation uses AttrSizedSegments.
1712  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1713  LLVM_DEBUG(llvm::dbgs()
1714  << " * Extracting values from `" << attrSizedSegments << "`\n");
1715 
1716  auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments);
1717  if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1718  return nullptr;
1719 
1720  ArrayRef<int32_t> segments = segmentAttr;
1721  unsigned startIndex =
1722  std::accumulate(segments.begin(), segments.begin() + index, 0);
1723  values = values.slice(startIndex, *std::next(segments.begin(), index));
1724 
1725  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1726  << *std::next(segments.begin(), index) << "]\n");
1727 
1728  // Otherwise, assume this is the last operand group of the operation.
1729  // FIXME: We currently don't support operations with
1730  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1731  // have a way to detect it's presence.
1732  } else if (values.size() >= index) {
1733  LLVM_DEBUG(llvm::dbgs()
1734  << " * Treating values as trailing variadic range\n");
1735  values = values.drop_front(index);
1736 
1737  // If we couldn't detect a way to compute the values, bail out.
1738  } else {
1739  return nullptr;
1740  }
1741 
1742  // If the range index is valid, we are returning a range.
1743  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1744  valueRangeMemory[rangeIndex] = values;
1745  return &valueRangeMemory[rangeIndex];
1746  }
1747 
1748  // If a range index wasn't provided, the range is required to be non-variadic.
1749  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1750 }
1751 
1752 void ByteCodeExecutor::executeGetOperands() {
1753  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1754  unsigned index = read<uint32_t>();
1755  Operation *op = read<Operation *>();
1756  ByteCodeField rangeIndex = read();
1757 
1758  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1759  op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1760  valueRangeMemory);
1761  if (!result)
1762  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1763  memory[read()] = result;
1764 }
1765 
1766 void ByteCodeExecutor::executeGetResult(unsigned index) {
1767  Operation *op = read<Operation *>();
1768  unsigned memIndex = read();
1769  OpResult result =
1770  index < op->getNumResults() ? op->getResult(index) : OpResult();
1771 
1772  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1773  << " * Index: " << index << "\n"
1774  << " * Result: " << result << "\n");
1775  memory[memIndex] = result.getAsOpaquePointer();
1776 }
1777 
1778 void ByteCodeExecutor::executeGetResults() {
1779  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1780  unsigned index = read<uint32_t>();
1781  Operation *op = read<Operation *>();
1782  ByteCodeField rangeIndex = read();
1783 
1784  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1785  op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1786  valueRangeMemory);
1787  if (!result)
1788  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1789  memory[read()] = result;
1790 }
1791 
1792 void ByteCodeExecutor::executeGetUsers() {
1793  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1794  unsigned memIndex = read();
1795  unsigned rangeIndex = read();
1796  OwningOpRange &range = opRangeMemory[rangeIndex];
1797  memory[memIndex] = &range;
1798 
1799  range = OwningOpRange();
1800  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1801  // Read the value.
1802  Value value = read<Value>();
1803  if (!value)
1804  return;
1805  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1806 
1807  // Extract the users of a single value.
1808  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1809  llvm::copy(value.getUsers(), range.begin());
1810  } else {
1811  // Read a range of values.
1812  ValueRange *values = read<ValueRange *>();
1813  if (!values)
1814  return;
1815  LLVM_DEBUG({
1816  llvm::dbgs() << " * Values (" << values->size() << "): ";
1817  llvm::interleaveComma(*values, llvm::dbgs());
1818  llvm::dbgs() << "\n";
1819  });
1820 
1821  // Extract all the users of a range of values.
1823  for (Value value : *values)
1824  users.append(value.user_begin(), value.user_end());
1825  range = OwningOpRange(users.size());
1826  llvm::copy(users, range.begin());
1827  }
1828 
1829  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1830 }
1831 
1832 void ByteCodeExecutor::executeGetValueType() {
1833  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1834  unsigned memIndex = read();
1835  Value value = read<Value>();
1836  Type type = value ? value.getType() : Type();
1837 
1838  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1839  << " * Result: " << type << "\n");
1840  memory[memIndex] = type.getAsOpaquePointer();
1841 }
1842 
1843 void ByteCodeExecutor::executeGetValueRangeTypes() {
1844  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1845  unsigned memIndex = read();
1846  unsigned rangeIndex = read();
1847  ValueRange *values = read<ValueRange *>();
1848  if (!values) {
1849  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1850  memory[memIndex] = nullptr;
1851  return;
1852  }
1853 
1854  LLVM_DEBUG({
1855  llvm::dbgs() << " * Values (" << values->size() << "): ";
1856  llvm::interleaveComma(*values, llvm::dbgs());
1857  llvm::dbgs() << "\n * Result: ";
1858  llvm::interleaveComma(values->getType(), llvm::dbgs());
1859  llvm::dbgs() << "\n";
1860  });
1861  typeRangeMemory[rangeIndex] = values->getType();
1862  memory[memIndex] = &typeRangeMemory[rangeIndex];
1863 }
1864 
1865 void ByteCodeExecutor::executeIsNotNull() {
1866  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1867  const void *value = read<const void *>();
1868 
1869  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1870  selectJump(value != nullptr);
1871 }
1872 
1873 void ByteCodeExecutor::executeRecordMatch(
1874  PatternRewriter &rewriter,
1876  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1877  unsigned patternIndex = read();
1878  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1879  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1880 
1881  // If the benefit of the pattern is impossible, skip the processing of the
1882  // rest of the pattern.
1883  if (benefit.isImpossibleToMatch()) {
1884  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1885  curCodeIt = dest;
1886  return;
1887  }
1888 
1889  // Create a fused location containing the locations of each of the
1890  // operations used in the match. This will be used as the location for
1891  // created operations during the rewrite that don't already have an
1892  // explicit location set.
1893  unsigned numMatchLocs = read();
1894  SmallVector<Location, 4> matchLocs;
1895  matchLocs.reserve(numMatchLocs);
1896  for (unsigned i = 0; i != numMatchLocs; ++i)
1897  matchLocs.push_back(read<Operation *>()->getLoc());
1898  Location matchLoc = rewriter.getFusedLoc(matchLocs);
1899 
1900  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1901  << " * Location: " << matchLoc << "\n");
1902  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1903  PDLByteCode::MatchResult &match = matches.back();
1904 
1905  // Record all of the inputs to the match. If any of the inputs are ranges, we
1906  // will also need to remap the range pointer to memory stored in the match
1907  // state.
1908  unsigned numInputs = read();
1909  match.values.reserve(numInputs);
1910  match.typeRangeValues.reserve(numInputs);
1911  match.valueRangeValues.reserve(numInputs);
1912  for (unsigned i = 0; i < numInputs; ++i) {
1913  switch (read<PDLValue::Kind>()) {
1915  match.typeRangeValues.push_back(*read<TypeRange *>());
1916  match.values.push_back(&match.typeRangeValues.back());
1917  break;
1919  match.valueRangeValues.push_back(*read<ValueRange *>());
1920  match.values.push_back(&match.valueRangeValues.back());
1921  break;
1922  default:
1923  match.values.push_back(read<const void *>());
1924  break;
1925  }
1926  }
1927  curCodeIt = dest;
1928 }
1929 
1930 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1931  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1932  Operation *op = read<Operation *>();
1934  readValueList(args);
1935 
1936  LLVM_DEBUG({
1937  llvm::dbgs() << " * Operation: " << *op << "\n"
1938  << " * Values: ";
1939  llvm::interleaveComma(args, llvm::dbgs());
1940  llvm::dbgs() << "\n";
1941  });
1942  rewriter.replaceOp(op, args);
1943 }
1944 
1945 void ByteCodeExecutor::executeSwitchAttribute() {
1946  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1947  Attribute value = read<Attribute>();
1948  ArrayAttr cases = read<ArrayAttr>();
1949  handleSwitch(value, cases);
1950 }
1951 
1952 void ByteCodeExecutor::executeSwitchOperandCount() {
1953  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1954  Operation *op = read<Operation *>();
1955  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1956 
1957  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1958  handleSwitch(op->getNumOperands(), cases);
1959 }
1960 
1961 void ByteCodeExecutor::executeSwitchOperationName() {
1962  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1963  OperationName value = read<Operation *>()->getName();
1964  size_t caseCount = read();
1965 
1966  // The operation names are stored in-line, so to print them out for
1967  // debugging purposes we need to read the array before executing the
1968  // switch so that we can display all of the possible values.
1969  LLVM_DEBUG({
1970  const ByteCodeField *prevCodeIt = curCodeIt;
1971  llvm::dbgs() << " * Value: " << value << "\n"
1972  << " * Cases: ";
1973  llvm::interleaveComma(
1974  llvm::map_range(llvm::seq<size_t>(0, caseCount),
1975  [&](size_t) { return read<OperationName>(); }),
1976  llvm::dbgs());
1977  llvm::dbgs() << "\n";
1978  curCodeIt = prevCodeIt;
1979  });
1980 
1981  // Try to find the switch value within any of the cases.
1982  for (size_t i = 0; i != caseCount; ++i) {
1983  if (read<OperationName>() == value) {
1984  curCodeIt += (caseCount - i - 1);
1985  return selectJump(i + 1);
1986  }
1987  }
1988  selectJump(size_t(0));
1989 }
1990 
1991 void ByteCodeExecutor::executeSwitchResultCount() {
1992  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1993  Operation *op = read<Operation *>();
1994  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1995 
1996  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1997  handleSwitch(op->getNumResults(), cases);
1998 }
1999 
2000 void ByteCodeExecutor::executeSwitchType() {
2001  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2002  Type value = read<Type>();
2003  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2004  handleSwitch(value, cases);
2005 }
2006 
2007 void ByteCodeExecutor::executeSwitchTypes() {
2008  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2009  TypeRange *value = read<TypeRange *>();
2010  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2011  if (!value) {
2012  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2013  return selectJump(size_t(0));
2014  }
2015  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2016  return value == caseValue.getAsValueRange<TypeAttr>();
2017  });
2018 }
2019 
2020 void ByteCodeExecutor::execute(
2021  PatternRewriter &rewriter,
2023  Optional<Location> mainRewriteLoc) {
2024  while (true) {
2025  // Print the location of the operation being executed.
2026  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2027 
2028  OpCode opCode = static_cast<OpCode>(read());
2029  switch (opCode) {
2030  case ApplyConstraint:
2031  executeApplyConstraint(rewriter);
2032  break;
2033  case ApplyRewrite:
2034  executeApplyRewrite(rewriter);
2035  break;
2036  case AreEqual:
2037  executeAreEqual();
2038  break;
2039  case AreRangesEqual:
2040  executeAreRangesEqual();
2041  break;
2042  case Branch:
2043  executeBranch();
2044  break;
2045  case CheckOperandCount:
2046  executeCheckOperandCount();
2047  break;
2048  case CheckOperationName:
2049  executeCheckOperationName();
2050  break;
2051  case CheckResultCount:
2052  executeCheckResultCount();
2053  break;
2054  case CheckTypes:
2055  executeCheckTypes();
2056  break;
2057  case Continue:
2058  executeContinue();
2059  break;
2060  case CreateOperation:
2061  executeCreateOperation(rewriter, *mainRewriteLoc);
2062  break;
2063  case CreateTypes:
2064  executeCreateTypes();
2065  break;
2066  case EraseOp:
2067  executeEraseOp(rewriter);
2068  break;
2069  case ExtractOp:
2070  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2071  break;
2072  case ExtractType:
2073  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2074  break;
2075  case ExtractValue:
2076  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2077  break;
2078  case Finalize:
2079  executeFinalize();
2080  LLVM_DEBUG(llvm::dbgs() << "\n");
2081  return;
2082  case ForEach:
2083  executeForEach();
2084  break;
2085  case GetAttribute:
2086  executeGetAttribute();
2087  break;
2088  case GetAttributeType:
2089  executeGetAttributeType();
2090  break;
2091  case GetDefiningOp:
2092  executeGetDefiningOp();
2093  break;
2094  case GetOperand0:
2095  case GetOperand1:
2096  case GetOperand2:
2097  case GetOperand3: {
2098  unsigned index = opCode - GetOperand0;
2099  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2100  executeGetOperand(index);
2101  break;
2102  }
2103  case GetOperandN:
2104  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2105  executeGetOperand(read<uint32_t>());
2106  break;
2107  case GetOperands:
2108  executeGetOperands();
2109  break;
2110  case GetResult0:
2111  case GetResult1:
2112  case GetResult2:
2113  case GetResult3: {
2114  unsigned index = opCode - GetResult0;
2115  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2116  executeGetResult(index);
2117  break;
2118  }
2119  case GetResultN:
2120  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2121  executeGetResult(read<uint32_t>());
2122  break;
2123  case GetResults:
2124  executeGetResults();
2125  break;
2126  case GetUsers:
2127  executeGetUsers();
2128  break;
2129  case GetValueType:
2130  executeGetValueType();
2131  break;
2132  case GetValueRangeTypes:
2133  executeGetValueRangeTypes();
2134  break;
2135  case IsNotNull:
2136  executeIsNotNull();
2137  break;
2138  case RecordMatch:
2139  assert(matches &&
2140  "expected matches to be provided when executing the matcher");
2141  executeRecordMatch(rewriter, *matches);
2142  break;
2143  case ReplaceOp:
2144  executeReplaceOp(rewriter);
2145  break;
2146  case SwitchAttribute:
2147  executeSwitchAttribute();
2148  break;
2149  case SwitchOperandCount:
2150  executeSwitchOperandCount();
2151  break;
2152  case SwitchOperationName:
2153  executeSwitchOperationName();
2154  break;
2155  case SwitchResultCount:
2156  executeSwitchResultCount();
2157  break;
2158  case SwitchType:
2159  executeSwitchType();
2160  break;
2161  case SwitchTypes:
2162  executeSwitchTypes();
2163  break;
2164  }
2165  LLVM_DEBUG(llvm::dbgs() << "\n");
2166  }
2167 }
2168 
2169 /// Run the pattern matcher on the given root operation, collecting the matched
2170 /// patterns in `matches`.
2173  PDLByteCodeMutableState &state) const {
2174  // The first memory slot is always the root operation.
2175  state.memory[0] = op;
2176 
2177  // The matcher function always starts at code address 0.
2178  ByteCodeExecutor executor(
2179  matcherByteCode.data(), state.memory, state.opRangeMemory,
2180  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2181  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2182  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2183  constraintFunctions, rewriteFunctions);
2184  executor.execute(rewriter, &matches);
2185 
2186  // Order the found matches by benefit.
2187  std::stable_sort(matches.begin(), matches.end(),
2188  [](const MatchResult &lhs, const MatchResult &rhs) {
2189  return lhs.benefit > rhs.benefit;
2190  });
2191 }
2192 
2193 /// Run the rewriter of the given pattern on the root operation `op`.
2194 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2195  PDLByteCodeMutableState &state) const {
2196  // The arguments of the rewrite function are stored at the start of the
2197  // memory buffer.
2198  llvm::copy(match.values, state.memory.begin());
2199 
2200  ByteCodeExecutor executor(
2201  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2202  state.opRangeMemory, state.typeRangeMemory,
2203  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2204  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2205  rewriterByteCode, state.currentPatternBenefits, patterns,
2206  constraintFunctions, rewriteFunctions);
2207  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2208 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1702
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition: Operation.h:480
This is a value defined by a result of an operation.
Definition: Value.h:425
Block represents an ordered list of Operations.
Definition: Block.h:29
This class represents liveness information on block level.
Definition: Liveness.h:99
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Definition: ByteCode.h:124
Value getOperand(unsigned idx)
Definition: Operation.h:267
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:375
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
OpCode
Definition: ByteCode.cpp:79
unsigned getNumOperands()
Definition: Operation.h:263
llvm::OwningArrayRef< Operation * > OwningOpRange
Definition: ByteCode.h:31
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
This class implements the result iterators for the Operation class.
Definition: ValueRange.h:228
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.
Definition: ByteCode.h:136
Operation & front()
Definition: Block.h:144
PDLByteCode(ModuleOp module, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
Definition: ByteCode.cpp:1017
user_range getUsers() const
Definition: Value.h:213
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
static constexpr const bool value
SmallVector< Value, 4 > operands
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:227
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
std::function< void(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:812
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:806
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
All of the data pertaining to a specific pattern within the bytecode.
Definition: ByteCode.h:38
uint16_t ByteCodeField
Use generic bytecode types.
Definition: ByteCode.h:29
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, ByteCodeAddr rewriterAddr)
Definition: ByteCode.cpp:36
OpFoldResult size
void rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Definition: ByteCode.cpp:2194
SmallVector< ValueRange, 0 > valueRangeValues
Definition: ByteCode.h:139
U dyn_cast() const
Definition: Types.h:270
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:614
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
U dyn_cast() const
Definition: Value.h:100
user_iterator user_begin() const
Definition: Value.h:211
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
Definition: ByteCode.h:138
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:324
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:32
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, None otherwise.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
This represents an operation in an abstracted form, suitable for use with the builder APIs...
This class acts as a special tag that makes the desire to match "any" operation type explicit...
Definition: PatternMatch.h:157
BlockArgListType getArguments()
Definition: Block.h:76
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class represents an argument of a Block.
Definition: Value.h:300
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:617
This class implements the successor iterators for Block.
Definition: BlockSupport.h:71
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
user_iterator user_end() const
Definition: Value.h:212
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
MLIRContext * getContext() const
Get the context held by this operation state.
NamedAttrList attributes
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:375
Type getType() const
Return the type of this value.
Definition: Value.h:118
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.cpp:61
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
Definition: ByteCode.cpp:69
U dyn_cast() const
Definition: Attributes.h:127
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
type_range getType() const
Definition: ValueRange.cpp:46
Location location
The location of operations to be replaced.
Definition: ByteCode.h:134
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:726
const PDLByteCodePattern * pattern
The originating pattern that was matched.
Definition: ByteCode.h:143
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:40
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:321
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
This class contains the mutable state of a bytecode instance.
Definition: ByteCode.h:63
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches...
Definition: ByteCode.cpp:2171
uint32_t ByteCodeAddr
Definition: ByteCode.h:30
PatternBenefit benefit
The current benefit of the pattern that was matched.
Definition: ByteCode.h:145
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:50
bool isa() const
Definition: Types.h:254
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
result_range getResults()
Definition: Operation.h:332
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:371
ByteCodeAddr getRewriterAddr() const
Return the bytecode address of the rewriter for this pattern.
Definition: ByteCode.h:44
static const mlir::GenInfo * generator
static void processValue(Value value, LiveMap &liveMap)
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition: ByteCode.cpp:1035
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
Definition: ByteCode.cpp:166
SmallVector< Type, 4 > types
Types of the results of this operation.