MLIR  14.0.0git
ByteCode.cpp
Go to the documentation of this file.
1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
17 #include "mlir/IR/BuiltinOps.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
25 #include <numeric>
26 
27 #define DEBUG_TYPE "pdl-bytecode"
28 
29 using namespace mlir;
30 using namespace mlir::detail;
31 
32 //===----------------------------------------------------------------------===//
33 // PDLByteCodePattern
34 //===----------------------------------------------------------------------===//
35 
36 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
37  ByteCodeAddr rewriterAddr) {
38  SmallVector<StringRef, 8> generatedOps;
39  if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
40  generatedOps =
41  llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
42 
43  PatternBenefit benefit = matchOp.benefit();
44  MLIRContext *ctx = matchOp.getContext();
45 
46  // Check to see if this is pattern matches a specific operation type.
47  if (Optional<StringRef> rootKind = matchOp.rootKind())
48  return PDLByteCodePattern(rewriterAddr, *rootKind, benefit, ctx,
49  generatedOps);
50  return PDLByteCodePattern(rewriterAddr, MatchAnyOpTypeTag(), benefit, ctx,
51  generatedOps);
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // PDLByteCodeMutableState
56 //===----------------------------------------------------------------------===//
57 
58 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
59 /// to the position of the pattern within the range returned by
60 /// `PDLByteCode::getPatterns`.
62  PatternBenefit benefit) {
63  currentPatternBenefits[patternIndex] = benefit;
64 }
65 
66 /// Cleanup any allocated state after a full match/rewrite has been completed.
67 /// This method should be called irregardless of whether the match+rewrite was a
68 /// success or not.
70  allocatedTypeRangeMemory.clear();
71  allocatedValueRangeMemory.clear();
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // Bytecode OpCodes
76 //===----------------------------------------------------------------------===//
77 
78 namespace {
80  /// Apply an externally registered constraint.
81  ApplyConstraint,
82  /// Apply an externally registered rewrite.
83  ApplyRewrite,
84  /// Check if two generic values are equal.
85  AreEqual,
86  /// Check if two ranges are equal.
87  AreRangesEqual,
88  /// Unconditional branch.
89  Branch,
90  /// Compare the operand count of an operation with a constant.
91  CheckOperandCount,
92  /// Compare the name of an operation with a constant.
93  CheckOperationName,
94  /// Compare the result count of an operation with a constant.
95  CheckResultCount,
96  /// Compare a range of types to a constant range of types.
97  CheckTypes,
98  /// Continue to the next iteration of a loop.
99  Continue,
100  /// Create an operation.
101  CreateOperation,
102  /// Create a range of types.
103  CreateTypes,
104  /// Erase an operation.
105  EraseOp,
106  /// Extract the op from a range at the specified index.
107  ExtractOp,
108  /// Extract the type from a range at the specified index.
109  ExtractType,
110  /// Extract the value from a range at the specified index.
111  ExtractValue,
112  /// Terminate a matcher or rewrite sequence.
113  Finalize,
114  /// Iterate over a range of values.
115  ForEach,
116  /// Get a specific attribute of an operation.
117  GetAttribute,
118  /// Get the type of an attribute.
119  GetAttributeType,
120  /// Get the defining operation of a value.
121  GetDefiningOp,
122  /// Get a specific operand of an operation.
123  GetOperand0,
124  GetOperand1,
125  GetOperand2,
126  GetOperand3,
127  GetOperandN,
128  /// Get a specific operand group of an operation.
129  GetOperands,
130  /// Get a specific result of an operation.
131  GetResult0,
132  GetResult1,
133  GetResult2,
134  GetResult3,
135  GetResultN,
136  /// Get a specific result group of an operation.
137  GetResults,
138  /// Get the users of a value or a range of values.
139  GetUsers,
140  /// Get the type of a value.
141  GetValueType,
142  /// Get the types of a value range.
143  GetValueRangeTypes,
144  /// Check if a generic value is not null.
145  IsNotNull,
146  /// Record a successful pattern match.
147  RecordMatch,
148  /// Replace an operation.
149  ReplaceOp,
150  /// Compare an attribute with a set of constants.
151  SwitchAttribute,
152  /// Compare the operand count of an operation with a set of constants.
153  SwitchOperandCount,
154  /// Compare the name of an operation with a set of constants.
155  SwitchOperationName,
156  /// Compare the result count of an operation with a set of constants.
157  SwitchResultCount,
158  /// Compare a type with a set of constants.
159  SwitchType,
160  /// Compare a range of types with a set of constants.
161  SwitchTypes,
162 };
163 } // namespace
164 
165 //===----------------------------------------------------------------------===//
166 // ByteCode Generation
167 //===----------------------------------------------------------------------===//
168 
169 //===----------------------------------------------------------------------===//
170 // Generator
171 
172 namespace {
173 struct ByteCodeLiveRange;
174 struct ByteCodeWriter;
175 
176 /// Check if the given class `T` can be converted to an opaque pointer.
177 template <typename T, typename... Args>
178 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
179 
180 /// This class represents the main generator for the pattern bytecode.
181 class Generator {
182 public:
183  Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
184  SmallVectorImpl<ByteCodeField> &matcherByteCode,
185  SmallVectorImpl<ByteCodeField> &rewriterByteCode,
187  ByteCodeField &maxValueMemoryIndex,
188  ByteCodeField &maxOpRangeMemoryIndex,
189  ByteCodeField &maxTypeRangeMemoryIndex,
190  ByteCodeField &maxValueRangeMemoryIndex,
191  ByteCodeField &maxLoopLevel,
192  llvm::StringMap<PDLConstraintFunction> &constraintFns,
193  llvm::StringMap<PDLRewriteFunction> &rewriteFns)
194  : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
195  rewriterByteCode(rewriterByteCode), patterns(patterns),
196  maxValueMemoryIndex(maxValueMemoryIndex),
197  maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
198  maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
199  maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
200  maxLoopLevel(maxLoopLevel) {
201  for (const auto &it : llvm::enumerate(constraintFns))
202  constraintToMemIndex.try_emplace(it.value().first(), it.index());
203  for (const auto &it : llvm::enumerate(rewriteFns))
204  externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
205  }
206 
207  /// Generate the bytecode for the given PDL interpreter module.
208  void generate(ModuleOp module);
209 
210  /// Return the memory index to use for the given value.
211  ByteCodeField &getMemIndex(Value value) {
212  assert(valueToMemIndex.count(value) &&
213  "expected memory index to be assigned");
214  return valueToMemIndex[value];
215  }
216 
217  /// Return the range memory index used to store the given range value.
218  ByteCodeField &getRangeStorageIndex(Value value) {
219  assert(valueToRangeIndex.count(value) &&
220  "expected range index to be assigned");
221  return valueToRangeIndex[value];
222  }
223 
224  /// Return an index to use when referring to the given data that is uniqued in
225  /// the MLIR context.
226  template <typename T>
228  getMemIndex(T val) {
229  const void *opaqueVal = val.getAsOpaquePointer();
230 
231  // Get or insert a reference to this value.
232  auto it = uniquedDataToMemIndex.try_emplace(
233  opaqueVal, maxValueMemoryIndex + uniquedData.size());
234  if (it.second)
235  uniquedData.push_back(opaqueVal);
236  return it.first->second;
237  }
238 
239 private:
240  /// Allocate memory indices for the results of operations within the matcher
241  /// and rewriters.
242  void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
243 
244  /// Generate the bytecode for the given operation.
245  void generate(Region *region, ByteCodeWriter &writer);
246  void generate(Operation *op, ByteCodeWriter &writer);
247  void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
248  void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
249  void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
250  void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
251  void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
252  void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
253  void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
254  void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
255  void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
256  void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
257  void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
258  void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
259  void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
260  void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
261  void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
262  void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
263  void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
264  void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
265  void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
266  void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
267  void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
268  void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
269  void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
270  void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
271  void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
272  void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
273  void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
274  void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
275  void generate(pdl_interp::InferredTypesOp op, ByteCodeWriter &writer);
276  void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
277  void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
278  void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
279  void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
280  void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
281  void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
282  void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
283  void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
284  void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
285 
286  /// Mapping from value to its corresponding memory index.
287  DenseMap<Value, ByteCodeField> valueToMemIndex;
288 
289  /// Mapping from a range value to its corresponding range storage index.
290  DenseMap<Value, ByteCodeField> valueToRangeIndex;
291 
292  /// Mapping from the name of an externally registered rewrite to its index in
293  /// the bytecode registry.
294  llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
295 
296  /// Mapping from the name of an externally registered constraint to its index
297  /// in the bytecode registry.
298  llvm::StringMap<ByteCodeField> constraintToMemIndex;
299 
300  /// Mapping from rewriter function name to the bytecode address of the
301  /// rewriter function in byte.
302  llvm::StringMap<ByteCodeAddr> rewriterToAddr;
303 
304  /// Mapping from a uniqued storage object to its memory index within
305  /// `uniquedData`.
306  DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
307 
308  /// The current level of the foreach loop.
309  ByteCodeField curLoopLevel = 0;
310 
311  /// The current MLIR context.
312  MLIRContext *ctx;
313 
314  /// Mapping from block to its address.
316 
317  /// Data of the ByteCode class to be populated.
318  std::vector<const void *> &uniquedData;
319  SmallVectorImpl<ByteCodeField> &matcherByteCode;
320  SmallVectorImpl<ByteCodeField> &rewriterByteCode;
322  ByteCodeField &maxValueMemoryIndex;
323  ByteCodeField &maxOpRangeMemoryIndex;
324  ByteCodeField &maxTypeRangeMemoryIndex;
325  ByteCodeField &maxValueRangeMemoryIndex;
326  ByteCodeField &maxLoopLevel;
327 };
328 
329 /// This class provides utilities for writing a bytecode stream.
330 struct ByteCodeWriter {
331  ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
332  : bytecode(bytecode), generator(generator) {}
333 
334  /// Append a field to the bytecode.
335  void append(ByteCodeField field) { bytecode.push_back(field); }
336  void append(OpCode opCode) { bytecode.push_back(opCode); }
337 
338  /// Append an address to the bytecode.
339  void append(ByteCodeAddr field) {
340  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
341  "unexpected ByteCode address size");
342 
343  ByteCodeField fieldParts[2];
344  std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
345  bytecode.append({fieldParts[0], fieldParts[1]});
346  }
347 
348  /// Append a single successor to the bytecode, the exact address will need to
349  /// be resolved later.
350  void append(Block *successor) {
351  // Add back a reference to the successor so that the address can be resolved
352  // later.
353  unresolvedSuccessorRefs[successor].push_back(bytecode.size());
354  append(ByteCodeAddr(0));
355  }
356 
357  /// Append a successor range to the bytecode, the exact address will need to
358  /// be resolved later.
359  void append(SuccessorRange successors) {
360  for (Block *successor : successors)
361  append(successor);
362  }
363 
364  /// Append a range of values that will be read as generic PDLValues.
365  void appendPDLValueList(OperandRange values) {
366  bytecode.push_back(values.size());
367  for (Value value : values)
368  appendPDLValue(value);
369  }
370 
371  /// Append a value as a PDLValue.
372  void appendPDLValue(Value value) {
373  appendPDLValueKind(value);
374  append(value);
375  }
376 
377  /// Append the PDLValue::Kind of the given value.
378  void appendPDLValueKind(Value value) { appendPDLValueKind(value.getType()); }
379 
380  /// Append the PDLValue::Kind of the given type.
381  void appendPDLValueKind(Type type) {
382  PDLValue::Kind kind =
384  .Case<pdl::AttributeType>(
385  [](Type) { return PDLValue::Kind::Attribute; })
386  .Case<pdl::OperationType>(
387  [](Type) { return PDLValue::Kind::Operation; })
388  .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
389  if (rangeTy.getElementType().isa<pdl::TypeType>())
392  })
393  .Case<pdl::TypeType>([](Type) { return PDLValue::Kind::Type; })
394  .Case<pdl::ValueType>([](Type) { return PDLValue::Kind::Value; });
395  bytecode.push_back(static_cast<ByteCodeField>(kind));
396  }
397 
398  /// Append a value that will be stored in a memory slot and not inline within
399  /// the bytecode.
400  template <typename T>
403  append(T value) {
404  bytecode.push_back(generator.getMemIndex(value));
405  }
406 
407  /// Append a range of values.
408  template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
410  append(T range) {
411  bytecode.push_back(llvm::size(range));
412  for (auto it : range)
413  append(it);
414  }
415 
416  /// Append a variadic number of fields to the bytecode.
417  template <typename FieldTy, typename Field2Ty, typename... FieldTys>
418  void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
419  append(field);
420  append(field2, fields...);
421  }
422 
423  /// Appends a value as a pointer, stored inline within the bytecode.
424  template <typename T>
426  appendInline(T value) {
427  constexpr size_t numParts = sizeof(const void *) / sizeof(ByteCodeField);
428  const void *pointer = value.getAsOpaquePointer();
429  ByteCodeField fieldParts[numParts];
430  std::memcpy(fieldParts, &pointer, sizeof(const void *));
431  bytecode.append(fieldParts, fieldParts + numParts);
432  }
433 
434  /// Successor references in the bytecode that have yet to be resolved.
435  DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
436 
437  /// The underlying bytecode buffer.
439 
440  /// The main generator producing PDL.
441  Generator &generator;
442 };
443 
444 /// This class represents a live range of PDL Interpreter values, containing
445 /// information about when values are live within a match/rewrite.
446 struct ByteCodeLiveRange {
447  using Set = llvm::IntervalMap<uint64_t, char, 16>;
448  using Allocator = Set::Allocator;
449 
450  ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
451 
452  /// Union this live range with the one provided.
453  void unionWith(const ByteCodeLiveRange &rhs) {
454  for (auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
455  ++it)
456  liveness->insert(it.start(), it.stop(), /*dummyValue*/ 0);
457  }
458 
459  /// Returns true if this range overlaps with the one provided.
460  bool overlaps(const ByteCodeLiveRange &rhs) const {
461  return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
462  .valid();
463  }
464 
465  /// A map representing the ranges of the match/rewrite that a value is live in
466  /// the interpreter.
467  ///
468  /// We use std::unique_ptr here, because IntervalMap does not provide a
469  /// correct copy or move constructor. We can eliminate the pointer once
470  /// https://reviews.llvm.org/D113240 lands.
471  std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
472 
473  /// The operation range storage index for this range.
474  Optional<unsigned> opRangeIndex;
475 
476  /// The type range storage index for this range.
477  Optional<unsigned> typeRangeIndex;
478 
479  /// The value range storage index for this range.
480  Optional<unsigned> valueRangeIndex;
481 };
482 } // namespace
483 
484 void Generator::generate(ModuleOp module) {
485  FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
486  pdl_interp::PDLInterpDialect::getMatcherFunctionName());
487  ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
488  pdl_interp::PDLInterpDialect::getRewriterModuleName());
489  assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
490 
491  // Allocate memory indices for the results of operations within the matcher
492  // and rewriters.
493  allocateMemoryIndices(matcherFunc, rewriterModule);
494 
495  // Generate code for the rewriter functions.
496  ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
497  for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
498  rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
499  for (Operation &op : rewriterFunc.getOps())
500  generate(&op, rewriterByteCodeWriter);
501  }
502  assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
503  "unexpected branches in rewriter function");
504 
505  // Generate code for the matcher function.
506  ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
507  generate(&matcherFunc.getBody(), matcherByteCodeWriter);
508 
509  // Resolve successor references in the matcher.
510  for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
511  ByteCodeAddr addr = blockToAddr[it.first];
512  for (unsigned offsetToFix : it.second)
513  std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
514  }
515 }
516 
517 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
518  ModuleOp rewriterModule) {
519  // Rewriters use simplistic allocation scheme that simply assigns an index to
520  // each result.
521  for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
522  ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
523  auto processRewriterValue = [&](Value val) {
524  valueToMemIndex.try_emplace(val, index++);
525  if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
526  Type elementTy = rangeType.getElementType();
527  if (elementTy.isa<pdl::TypeType>())
528  valueToRangeIndex.try_emplace(val, typeRangeIndex++);
529  else if (elementTy.isa<pdl::ValueType>())
530  valueToRangeIndex.try_emplace(val, valueRangeIndex++);
531  }
532  };
533 
534  for (BlockArgument arg : rewriterFunc.getArguments())
535  processRewriterValue(arg);
536  rewriterFunc.getBody().walk([&](Operation *op) {
537  for (Value result : op->getResults())
538  processRewriterValue(result);
539  });
540  if (index > maxValueMemoryIndex)
541  maxValueMemoryIndex = index;
542  if (typeRangeIndex > maxTypeRangeMemoryIndex)
543  maxTypeRangeMemoryIndex = typeRangeIndex;
544  if (valueRangeIndex > maxValueRangeMemoryIndex)
545  maxValueRangeMemoryIndex = valueRangeIndex;
546  }
547 
548  // The matcher function uses a more sophisticated numbering that tries to
549  // minimize the number of memory indices assigned. This is done by determining
550  // a live range of the values within the matcher, then the allocation is just
551  // finding the minimal number of overlapping live ranges. This is essentially
552  // a simplified form of register allocation where we don't necessarily have a
553  // limited number of registers, but we still want to minimize the number used.
554  DenseMap<Operation *, unsigned> opToFirstIndex;
555  DenseMap<Operation *, unsigned> opToLastIndex;
556 
557  // A custom walk that marks the first and the last index of each operation.
558  // The entry marks the beginning of the liveness range for this operation,
559  // followed by nested operations, followed by the end of the liveness range.
560  unsigned index = 0;
561  llvm::unique_function<void(Operation *)> walk = [&](Operation *op) {
562  opToFirstIndex.try_emplace(op, index++);
563  for (Region &region : op->getRegions())
564  for (Block &block : region.getBlocks())
565  for (Operation &nested : block)
566  walk(&nested);
567  opToLastIndex.try_emplace(op, index++);
568  };
569  walk(matcherFunc);
570 
571  // Liveness info for each of the defs within the matcher.
572  ByteCodeLiveRange::Allocator allocator;
573  DenseMap<Value, ByteCodeLiveRange> valueDefRanges;
574 
575  // Assign the root operation being matched to slot 0.
576  BlockArgument rootOpArg = matcherFunc.getArgument(0);
577  valueToMemIndex[rootOpArg] = 0;
578 
579  // Walk each of the blocks, computing the def interval that the value is used.
580  Liveness matcherLiveness(matcherFunc);
581  matcherFunc->walk([&](Block *block) {
582  const LivenessBlockInfo *info = matcherLiveness.getLiveness(block);
583  assert(info && "expected liveness info for block");
584  auto processValue = [&](Value value, Operation *firstUseOrDef) {
585  // We don't need to process the root op argument, this value is always
586  // assigned to the first memory slot.
587  if (value == rootOpArg)
588  return;
589 
590  // Set indices for the range of this block that the value is used.
591  auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
592  defRangeIt->second.liveness->insert(
593  opToFirstIndex[firstUseOrDef],
594  opToLastIndex[info->getEndOperation(value, firstUseOrDef)],
595  /*dummyValue*/ 0);
596 
597  // Check to see if this value is a range type.
598  if (auto rangeTy = value.getType().dyn_cast<pdl::RangeType>()) {
599  Type eleType = rangeTy.getElementType();
600  if (eleType.isa<pdl::OperationType>())
601  defRangeIt->second.opRangeIndex = 0;
602  else if (eleType.isa<pdl::TypeType>())
603  defRangeIt->second.typeRangeIndex = 0;
604  else if (eleType.isa<pdl::ValueType>())
605  defRangeIt->second.valueRangeIndex = 0;
606  }
607  };
608 
609  // Process the live-ins of this block.
610  for (Value liveIn : info->in()) {
611  // Only process the value if it has been defined in the current region.
612  // Other values that span across pdl_interp.foreach will be added higher
613  // up. This ensures that the we keep them alive for the entire duration
614  // of the loop.
615  if (liveIn.getParentRegion() == block->getParent())
616  processValue(liveIn, &block->front());
617  }
618 
619  // Process the block arguments for the entry block (those are not live-in).
620  if (block->isEntryBlock()) {
621  for (Value argument : block->getArguments())
622  processValue(argument, &block->front());
623  }
624 
625  // Process any new defs within this block.
626  for (Operation &op : *block)
627  for (Value result : op.getResults())
628  processValue(result, &op);
629  });
630 
631  // Greedily allocate memory slots using the computed def live ranges.
632  std::vector<ByteCodeLiveRange> allocatedIndices;
633 
634  // The number of memory indices currently allocated (and its next value).
635  // Recall that the root gets allocated memory index 0.
636  ByteCodeField numIndices = 1;
637 
638  // The number of memory ranges of various types (and their next values).
639  ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
640 
641  for (auto &defIt : valueDefRanges) {
642  ByteCodeField &memIndex = valueToMemIndex[defIt.first];
643  ByteCodeLiveRange &defRange = defIt.second;
644 
645  // Try to allocate to an existing index.
646  for (const auto &existingIndexIt : llvm::enumerate(allocatedIndices)) {
647  ByteCodeLiveRange &existingRange = existingIndexIt.value();
648  if (!defRange.overlaps(existingRange)) {
649  existingRange.unionWith(defRange);
650  memIndex = existingIndexIt.index() + 1;
651 
652  if (defRange.opRangeIndex) {
653  if (!existingRange.opRangeIndex)
654  existingRange.opRangeIndex = numOpRanges++;
655  valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
656  } else if (defRange.typeRangeIndex) {
657  if (!existingRange.typeRangeIndex)
658  existingRange.typeRangeIndex = numTypeRanges++;
659  valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
660  } else if (defRange.valueRangeIndex) {
661  if (!existingRange.valueRangeIndex)
662  existingRange.valueRangeIndex = numValueRanges++;
663  valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
664  }
665  break;
666  }
667  }
668 
669  // If no existing index could be used, add a new one.
670  if (memIndex == 0) {
671  allocatedIndices.emplace_back(allocator);
672  ByteCodeLiveRange &newRange = allocatedIndices.back();
673  newRange.unionWith(defRange);
674 
675  // Allocate an index for op/type/value ranges.
676  if (defRange.opRangeIndex) {
677  newRange.opRangeIndex = numOpRanges;
678  valueToRangeIndex[defIt.first] = numOpRanges++;
679  } else if (defRange.typeRangeIndex) {
680  newRange.typeRangeIndex = numTypeRanges;
681  valueToRangeIndex[defIt.first] = numTypeRanges++;
682  } else if (defRange.valueRangeIndex) {
683  newRange.valueRangeIndex = numValueRanges;
684  valueToRangeIndex[defIt.first] = numValueRanges++;
685  }
686 
687  memIndex = allocatedIndices.size();
688  ++numIndices;
689  }
690  }
691 
692  // Print the index usage and ensure that we did not run out of index space.
693  LLVM_DEBUG({
694  llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices "
695  << "(down from initial " << valueDefRanges.size() << ").\n";
696  });
697  assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() &&
698  "Ran out of memory for allocated indices");
699 
700  // Update the max number of indices.
701  if (numIndices > maxValueMemoryIndex)
702  maxValueMemoryIndex = numIndices;
703  if (numOpRanges > maxOpRangeMemoryIndex)
704  maxOpRangeMemoryIndex = numOpRanges;
705  if (numTypeRanges > maxTypeRangeMemoryIndex)
706  maxTypeRangeMemoryIndex = numTypeRanges;
707  if (numValueRanges > maxValueRangeMemoryIndex)
708  maxValueRangeMemoryIndex = numValueRanges;
709 }
710 
711 void Generator::generate(Region *region, ByteCodeWriter &writer) {
712  llvm::ReversePostOrderTraversal<Region *> rpot(region);
713  for (Block *block : rpot) {
714  // Keep track of where this block begins within the matcher function.
715  blockToAddr.try_emplace(block, matcherByteCode.size());
716  for (Operation &op : *block)
717  generate(&op, writer);
718  }
719 }
720 
721 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
722  LLVM_DEBUG({
723  // The following list must contain all the operations that do not
724  // produce any bytecode.
725  if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp,
726  pdl_interp::InferredTypesOp>(op))
727  writer.appendInline(op->getLoc());
728  });
730  .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
731  pdl_interp::AreEqualOp, pdl_interp::BranchOp,
732  pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
733  pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
734  pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
735  pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
736  pdl_interp::CreateOperationOp, pdl_interp::CreateTypeOp,
737  pdl_interp::CreateTypesOp, pdl_interp::EraseOp,
738  pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
739  pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
740  pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
741  pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
742  pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
743  pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
744  pdl_interp::InferredTypesOp, pdl_interp::IsNotNullOp,
745  pdl_interp::RecordMatchOp, pdl_interp::ReplaceOp,
746  pdl_interp::SwitchAttributeOp, pdl_interp::SwitchTypeOp,
747  pdl_interp::SwitchTypesOp, pdl_interp::SwitchOperandCountOp,
748  pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
749  [&](auto interpOp) { this->generate(interpOp, writer); })
750  .Default([](Operation *) {
751  llvm_unreachable("unknown `pdl_interp` operation");
752  });
753 }
754 
755 void Generator::generate(pdl_interp::ApplyConstraintOp op,
756  ByteCodeWriter &writer) {
757  assert(constraintToMemIndex.count(op.name()) &&
758  "expected index for constraint function");
759  writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
760  op.constParamsAttr());
761  writer.appendPDLValueList(op.args());
762  writer.append(op.getSuccessors());
763 }
764 void Generator::generate(pdl_interp::ApplyRewriteOp op,
765  ByteCodeWriter &writer) {
766  assert(externalRewriterToMemIndex.count(op.name()) &&
767  "expected index for rewrite function");
768  writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
769  op.constParamsAttr());
770  writer.appendPDLValueList(op.args());
771 
772  ResultRange results = op.results();
773  writer.append(ByteCodeField(results.size()));
774  for (Value result : results) {
775  // In debug mode we also record the expected kind of the result, so that we
776  // can provide extra verification of the native rewrite function.
777 #ifndef NDEBUG
778  writer.appendPDLValueKind(result);
779 #endif
780 
781  // Range results also need to append the range storage index.
782  if (result.getType().isa<pdl::RangeType>())
783  writer.append(getRangeStorageIndex(result));
784  writer.append(result);
785  }
786 }
787 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
788  Value lhs = op.lhs();
789  if (lhs.getType().isa<pdl::RangeType>()) {
790  writer.append(OpCode::AreRangesEqual);
791  writer.appendPDLValueKind(lhs);
792  writer.append(op.lhs(), op.rhs(), op.getSuccessors());
793  return;
794  }
795 
796  writer.append(OpCode::AreEqual, lhs, op.rhs(), op.getSuccessors());
797 }
798 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
799  writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
800 }
801 void Generator::generate(pdl_interp::CheckAttributeOp op,
802  ByteCodeWriter &writer) {
803  writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
804  op.getSuccessors());
805 }
806 void Generator::generate(pdl_interp::CheckOperandCountOp op,
807  ByteCodeWriter &writer) {
808  writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
809  static_cast<ByteCodeField>(op.compareAtLeast()),
810  op.getSuccessors());
811 }
812 void Generator::generate(pdl_interp::CheckOperationNameOp op,
813  ByteCodeWriter &writer) {
814  writer.append(OpCode::CheckOperationName, op.operation(),
815  OperationName(op.name(), ctx), op.getSuccessors());
816 }
817 void Generator::generate(pdl_interp::CheckResultCountOp op,
818  ByteCodeWriter &writer) {
819  writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
820  static_cast<ByteCodeField>(op.compareAtLeast()),
821  op.getSuccessors());
822 }
823 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
824  writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
825 }
826 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
827  writer.append(OpCode::CheckTypes, op.value(), op.types(), op.getSuccessors());
828 }
829 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
830  assert(curLoopLevel > 0 && "encountered pdl_interp.continue at top level");
831  writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
832 }
833 void Generator::generate(pdl_interp::CreateAttributeOp op,
834  ByteCodeWriter &writer) {
835  // Simply repoint the memory index of the result to the constant.
836  getMemIndex(op.attribute()) = getMemIndex(op.value());
837 }
838 void Generator::generate(pdl_interp::CreateOperationOp op,
839  ByteCodeWriter &writer) {
840  writer.append(OpCode::CreateOperation, op.operation(),
841  OperationName(op.name(), ctx));
842  writer.appendPDLValueList(op.operands());
843 
844  // Add the attributes.
845  OperandRange attributes = op.attributes();
846  writer.append(static_cast<ByteCodeField>(attributes.size()));
847  for (auto it : llvm::zip(op.attributeNames(), op.attributes()))
848  writer.append(std::get<0>(it), std::get<1>(it));
849  writer.appendPDLValueList(op.types());
850 }
851 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
852  // Simply repoint the memory index of the result to the constant.
853  getMemIndex(op.result()) = getMemIndex(op.value());
854 }
855 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
856  writer.append(OpCode::CreateTypes, op.result(),
857  getRangeStorageIndex(op.result()), op.value());
858 }
859 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
860  writer.append(OpCode::EraseOp, op.operation());
861 }
862 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
863  OpCode opCode =
864  TypeSwitch<Type, OpCode>(op.result().getType())
865  .Case([](pdl::OperationType) { return OpCode::ExtractOp; })
866  .Case([](pdl::ValueType) { return OpCode::ExtractValue; })
867  .Case([](pdl::TypeType) { return OpCode::ExtractType; })
868  .Default([](Type) -> OpCode {
869  llvm_unreachable("unsupported element type");
870  });
871  writer.append(opCode, op.range(), op.index(), op.result());
872 }
873 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
874  writer.append(OpCode::Finalize);
875 }
876 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
877  BlockArgument arg = op.getLoopVariable();
878  writer.append(OpCode::ForEach, getRangeStorageIndex(op.values()), arg);
879  writer.appendPDLValueKind(arg.getType());
880  writer.append(curLoopLevel, op.successor());
881  ++curLoopLevel;
882  if (curLoopLevel > maxLoopLevel)
883  maxLoopLevel = curLoopLevel;
884  generate(&op.region(), writer);
885  --curLoopLevel;
886 }
887 void Generator::generate(pdl_interp::GetAttributeOp op,
888  ByteCodeWriter &writer) {
889  writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
890  op.nameAttr());
891 }
892 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
893  ByteCodeWriter &writer) {
894  writer.append(OpCode::GetAttributeType, op.result(), op.value());
895 }
896 void Generator::generate(pdl_interp::GetDefiningOpOp op,
897  ByteCodeWriter &writer) {
898  writer.append(OpCode::GetDefiningOp, op.operation());
899  writer.appendPDLValue(op.value());
900 }
901 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
902  uint32_t index = op.index();
903  if (index < 4)
904  writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
905  else
906  writer.append(OpCode::GetOperandN, index);
907  writer.append(op.operation(), op.value());
908 }
909 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
910  Value result = op.value();
911  Optional<uint32_t> index = op.index();
912  writer.append(OpCode::GetOperands,
913  index.getValueOr(std::numeric_limits<uint32_t>::max()),
914  op.operation());
915  if (result.getType().isa<pdl::RangeType>())
916  writer.append(getRangeStorageIndex(result));
917  else
919  writer.append(result);
920 }
921 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
922  uint32_t index = op.index();
923  if (index < 4)
924  writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
925  else
926  writer.append(OpCode::GetResultN, index);
927  writer.append(op.operation(), op.value());
928 }
929 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
930  Value result = op.value();
931  Optional<uint32_t> index = op.index();
932  writer.append(OpCode::GetResults,
933  index.getValueOr(std::numeric_limits<uint32_t>::max()),
934  op.operation());
935  if (result.getType().isa<pdl::RangeType>())
936  writer.append(getRangeStorageIndex(result));
937  else
939  writer.append(result);
940 }
941 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
942  Value operations = op.operations();
943  ByteCodeField rangeIndex = getRangeStorageIndex(operations);
944  writer.append(OpCode::GetUsers, operations, rangeIndex);
945  writer.appendPDLValue(op.value());
946 }
947 void Generator::generate(pdl_interp::GetValueTypeOp op,
948  ByteCodeWriter &writer) {
949  if (op.getType().isa<pdl::RangeType>()) {
950  Value result = op.result();
951  writer.append(OpCode::GetValueRangeTypes, result,
952  getRangeStorageIndex(result), op.value());
953  } else {
954  writer.append(OpCode::GetValueType, op.result(), op.value());
955  }
956 }
957 
958 void Generator::generate(pdl_interp::InferredTypesOp op,
959  ByteCodeWriter &writer) {
960  // InferType maps to a null type as a marker for inferring result types.
961  getMemIndex(op.type()) = getMemIndex(Type());
962 }
963 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
964  writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
965 }
966 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
967  ByteCodeField patternIndex = patterns.size();
968  patterns.emplace_back(PDLByteCodePattern::create(
969  op, rewriterToAddr[op.rewriter().getLeafReference().getValue()]));
970  writer.append(OpCode::RecordMatch, patternIndex,
971  SuccessorRange(op.getOperation()), op.matchedOps());
972  writer.appendPDLValueList(op.inputs());
973 }
974 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
975  writer.append(OpCode::ReplaceOp, op.operation());
976  writer.appendPDLValueList(op.replValues());
977 }
978 void Generator::generate(pdl_interp::SwitchAttributeOp op,
979  ByteCodeWriter &writer) {
980  writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
981  op.getSuccessors());
982 }
983 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
984  ByteCodeWriter &writer) {
985  writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
986  op.getSuccessors());
987 }
988 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
989  ByteCodeWriter &writer) {
990  auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
991  return OperationName(attr.cast<StringAttr>().getValue(), ctx);
992  });
993  writer.append(OpCode::SwitchOperationName, op.operation(), cases,
994  op.getSuccessors());
995 }
996 void Generator::generate(pdl_interp::SwitchResultCountOp op,
997  ByteCodeWriter &writer) {
998  writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
999  op.getSuccessors());
1000 }
1001 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1002  writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
1003  op.getSuccessors());
1004 }
1005 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1006  writer.append(OpCode::SwitchTypes, op.value(), op.caseValuesAttr(),
1007  op.getSuccessors());
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // PDLByteCode
1012 //===----------------------------------------------------------------------===//
1013 
1015  llvm::StringMap<PDLConstraintFunction> constraintFns,
1016  llvm::StringMap<PDLRewriteFunction> rewriteFns) {
1017  Generator generator(module.getContext(), uniquedData, matcherByteCode,
1018  rewriterByteCode, patterns, maxValueMemoryIndex,
1019  maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1020  maxLoopLevel, constraintFns, rewriteFns);
1021  generator.generate(module);
1022 
1023  // Initialize the external functions.
1024  for (auto &it : constraintFns)
1025  constraintFunctions.push_back(std::move(it.second));
1026  for (auto &it : rewriteFns)
1027  rewriteFunctions.push_back(std::move(it.second));
1028 }
1029 
1030 /// Initialize the given state such that it can be used to execute the current
1031 /// bytecode.
1033  state.memory.resize(maxValueMemoryIndex, nullptr);
1034  state.opRangeMemory.resize(maxOpRangeCount);
1035  state.typeRangeMemory.resize(maxTypeRangeCount, TypeRange());
1036  state.valueRangeMemory.resize(maxValueRangeCount, ValueRange());
1037  state.loopIndex.resize(maxLoopLevel, 0);
1038  state.currentPatternBenefits.reserve(patterns.size());
1039  for (const PDLByteCodePattern &pattern : patterns)
1040  state.currentPatternBenefits.push_back(pattern.getBenefit());
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // ByteCode Execution
1045 
1046 namespace {
1047 /// This class provides support for executing a bytecode stream.
1048 class ByteCodeExecutor {
1049 public:
1050  ByteCodeExecutor(
1051  const ByteCodeField *curCodeIt, MutableArrayRef<const void *> memory,
1052  MutableArrayRef<llvm::OwningArrayRef<Operation *>> opRangeMemory,
1053  MutableArrayRef<TypeRange> typeRangeMemory,
1054  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1055  MutableArrayRef<ValueRange> valueRangeMemory,
1056  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1057  MutableArrayRef<unsigned> loopIndex, ArrayRef<const void *> uniquedMemory,
1059  ArrayRef<PatternBenefit> currentPatternBenefits,
1061  ArrayRef<PDLConstraintFunction> constraintFunctions,
1062  ArrayRef<PDLRewriteFunction> rewriteFunctions)
1063  : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1064  typeRangeMemory(typeRangeMemory),
1065  allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1066  valueRangeMemory(valueRangeMemory),
1067  allocatedValueRangeMemory(allocatedValueRangeMemory),
1068  loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1069  currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1070  constraintFunctions(constraintFunctions),
1071  rewriteFunctions(rewriteFunctions) {}
1072 
1073  /// Start executing the code at the current bytecode index. `matches` is an
1074  /// optional field provided when this function is executed in a matching
1075  /// context.
1076  void execute(PatternRewriter &rewriter,
1077  SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
1078  Optional<Location> mainRewriteLoc = {});
1079 
1080 private:
1081  /// Internal implementation of executing each of the bytecode commands.
1082  void executeApplyConstraint(PatternRewriter &rewriter);
1083  void executeApplyRewrite(PatternRewriter &rewriter);
1084  void executeAreEqual();
1085  void executeAreRangesEqual();
1086  void executeBranch();
1087  void executeCheckOperandCount();
1088  void executeCheckOperationName();
1089  void executeCheckResultCount();
1090  void executeCheckTypes();
1091  void executeContinue();
1092  void executeCreateOperation(PatternRewriter &rewriter,
1093  Location mainRewriteLoc);
1094  void executeCreateTypes();
1095  void executeEraseOp(PatternRewriter &rewriter);
1096  template <typename T, typename Range, PDLValue::Kind kind>
1097  void executeExtract();
1098  void executeFinalize();
1099  void executeForEach();
1100  void executeGetAttribute();
1101  void executeGetAttributeType();
1102  void executeGetDefiningOp();
1103  void executeGetOperand(unsigned index);
1104  void executeGetOperands();
1105  void executeGetResult(unsigned index);
1106  void executeGetResults();
1107  void executeGetUsers();
1108  void executeGetValueType();
1109  void executeGetValueRangeTypes();
1110  void executeIsNotNull();
1111  void executeRecordMatch(PatternRewriter &rewriter,
1113  void executeReplaceOp(PatternRewriter &rewriter);
1114  void executeSwitchAttribute();
1115  void executeSwitchOperandCount();
1116  void executeSwitchOperationName();
1117  void executeSwitchResultCount();
1118  void executeSwitchType();
1119  void executeSwitchTypes();
1120 
1121  /// Pushes a code iterator to the stack.
1122  void pushCodeIt(const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1123 
1124  /// Pops a code iterator from the stack, returning true on success.
1125  void popCodeIt() {
1126  assert(!resumeCodeIt.empty() && "attempt to pop code off empty stack");
1127  curCodeIt = resumeCodeIt.back();
1128  resumeCodeIt.pop_back();
1129  }
1130 
1131  /// Return the bytecode iterator at the start of the current op code.
1132  const ByteCodeField *getPrevCodeIt() const {
1133  LLVM_DEBUG({
1134  // Account for the op code and the Location stored inline.
1135  return curCodeIt - 1 - sizeof(const void *) / sizeof(ByteCodeField);
1136  });
1137 
1138  // Account for the op code only.
1139  return curCodeIt - 1;
1140  }
1141 
1142  /// Read a value from the bytecode buffer, optionally skipping a certain
1143  /// number of prefix values. These methods always update the buffer to point
1144  /// to the next field after the read data.
1145  template <typename T = ByteCodeField>
1146  T read(size_t skipN = 0) {
1147  curCodeIt += skipN;
1148  return readImpl<T>();
1149  }
1150  ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
1151 
1152  /// Read a list of values from the bytecode buffer.
1153  template <typename ValueT, typename T>
1154  void readList(SmallVectorImpl<T> &list) {
1155  list.clear();
1156  for (unsigned i = 0, e = read(); i != e; ++i)
1157  list.push_back(read<ValueT>());
1158  }
1159 
1160  /// Read a list of values from the bytecode buffer. The values may be encoded
1161  /// as either Value or ValueRange elements.
1162  void readValueList(SmallVectorImpl<Value> &list) {
1163  for (unsigned i = 0, e = read(); i != e; ++i) {
1164  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1165  list.push_back(read<Value>());
1166  } else {
1167  ValueRange *values = read<ValueRange *>();
1168  list.append(values->begin(), values->end());
1169  }
1170  }
1171  }
1172 
1173  /// Read a value stored inline as a pointer.
1174  template <typename T>
1176  readInline() {
1177  const void *pointer;
1178  std::memcpy(&pointer, curCodeIt, sizeof(const void *));
1179  curCodeIt += sizeof(const void *) / sizeof(ByteCodeField);
1180  return T::getFromOpaquePointer(pointer);
1181  }
1182 
1183  /// Jump to a specific successor based on a predicate value.
1184  void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
1185  /// Jump to a specific successor based on a destination index.
1186  void selectJump(size_t destIndex) {
1187  curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1188  }
1189 
1190  /// Handle a switch operation with the provided value and cases.
1191  template <typename T, typename RangeT, typename Comparator = std::equal_to<T>>
1192  void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) {
1193  LLVM_DEBUG({
1194  llvm::dbgs() << " * Value: " << value << "\n"
1195  << " * Cases: ";
1196  llvm::interleaveComma(cases, llvm::dbgs());
1197  llvm::dbgs() << "\n";
1198  });
1199 
1200  // Check to see if the attribute value is within the case list. Jump to
1201  // the correct successor index based on the result.
1202  for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
1203  if (cmp(*it, value))
1204  return selectJump(size_t((it - cases.begin()) + 1));
1205  selectJump(size_t(0));
1206  }
1207 
1208  /// Store a pointer to memory.
1209  void storeToMemory(unsigned index, const void *value) {
1210  memory[index] = value;
1211  }
1212 
1213  /// Store a value to memory as an opaque pointer.
1214  template <typename T>
1216  storeToMemory(unsigned index, T value) {
1217  memory[index] = value.getAsOpaquePointer();
1218  }
1219 
1220  /// Internal implementation of reading various data types from the bytecode
1221  /// stream.
1222  template <typename T>
1223  const void *readFromMemory() {
1224  size_t index = *curCodeIt++;
1225 
1226  // If this type is an SSA value, it can only be stored in non-const memory.
1227  if (llvm::is_one_of<T, Operation *, TypeRange *, ValueRange *,
1228  Value>::value ||
1229  index < memory.size())
1230  return memory[index];
1231 
1232  // Otherwise, if this index is not inbounds it is uniqued.
1233  return uniquedMemory[index - memory.size()];
1234  }
1235  template <typename T>
1237  return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
1238  }
1239  template <typename T>
1241  T>
1242  readImpl() {
1243  return T(T::getFromOpaquePointer(readFromMemory<T>()));
1244  }
1245  template <typename T>
1247  switch (read<PDLValue::Kind>()) {
1249  return read<Attribute>();
1251  return read<Operation *>();
1252  case PDLValue::Kind::Type:
1253  return read<Type>();
1254  case PDLValue::Kind::Value:
1255  return read<Value>();
1257  return read<TypeRange *>();
1259  return read<ValueRange *>();
1260  }
1261  llvm_unreachable("unhandled PDLValue::Kind");
1262  }
1263  template <typename T>
1265  static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
1266  "unexpected ByteCode address size");
1267  ByteCodeAddr result;
1268  std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
1269  curCodeIt += 2;
1270  return result;
1271  }
1272  template <typename T>
1274  return *curCodeIt++;
1275  }
1276  template <typename T>
1278  return static_cast<PDLValue::Kind>(readImpl<ByteCodeField>());
1279  }
1280 
1281  /// The underlying bytecode buffer.
1282  const ByteCodeField *curCodeIt;
1283 
1284  /// The stack of bytecode positions at which to resume operation.
1286 
1287  /// The current execution memory.
1289  MutableArrayRef<OwningOpRange> opRangeMemory;
1290  MutableArrayRef<TypeRange> typeRangeMemory;
1291  std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1292  MutableArrayRef<ValueRange> valueRangeMemory;
1293  std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1294 
1295  /// The current loop indices.
1296  MutableArrayRef<unsigned> loopIndex;
1297 
1298  /// References to ByteCode data necessary for execution.
1299  ArrayRef<const void *> uniquedMemory;
1301  ArrayRef<PatternBenefit> currentPatternBenefits;
1303  ArrayRef<PDLConstraintFunction> constraintFunctions;
1304  ArrayRef<PDLRewriteFunction> rewriteFunctions;
1305 };
1306 
1307 /// This class is an instantiation of the PDLResultList that provides access to
1308 /// the returned results. This API is not on `PDLResultList` to avoid
1309 /// overexposing access to information specific solely to the ByteCode.
1310 class ByteCodeRewriteResultList : public PDLResultList {
1311 public:
1312  ByteCodeRewriteResultList(unsigned maxNumResults)
1313  : PDLResultList(maxNumResults) {}
1314 
1315  /// Return the list of PDL results.
1316  MutableArrayRef<PDLValue> getResults() { return results; }
1317 
1318  /// Return the type ranges allocated by this list.
1319  MutableArrayRef<llvm::OwningArrayRef<Type>> getAllocatedTypeRanges() {
1320  return allocatedTypeRanges;
1321  }
1322 
1323  /// Return the value ranges allocated by this list.
1324  MutableArrayRef<llvm::OwningArrayRef<Value>> getAllocatedValueRanges() {
1325  return allocatedValueRanges;
1326  }
1327 };
1328 } // namespace
1329 
1330 void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) {
1331  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
1332  const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
1333  ArrayAttr constParams = read<ArrayAttr>();
1335  readList<PDLValue>(args);
1336 
1337  LLVM_DEBUG({
1338  llvm::dbgs() << " * Arguments: ";
1339  llvm::interleaveComma(args, llvm::dbgs());
1340  llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
1341  });
1342 
1343  // Invoke the constraint and jump to the proper destination.
1344  selectJump(succeeded(constraintFn(args, constParams, rewriter)));
1345 }
1346 
1347 void ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) {
1348  LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
1349  const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1350  ArrayAttr constParams = read<ArrayAttr>();
1352  readList<PDLValue>(args);
1353 
1354  LLVM_DEBUG({
1355  llvm::dbgs() << " * Arguments: ";
1356  llvm::interleaveComma(args, llvm::dbgs());
1357  llvm::dbgs() << "\n * Parameters: " << constParams << "\n";
1358  });
1359 
1360  // Execute the rewrite function.
1361  ByteCodeField numResults = read();
1362  ByteCodeRewriteResultList results(numResults);
1363  rewriteFn(args, constParams, rewriter, results);
1364 
1365  assert(results.getResults().size() == numResults &&
1366  "native PDL rewrite function returned unexpected number of results");
1367 
1368  // Store the results in the bytecode memory.
1369  for (PDLValue &result : results.getResults()) {
1370  LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n");
1371 
1372 // In debug mode we also verify the expected kind of the result.
1373 #ifndef NDEBUG
1374  assert(result.getKind() == read<PDLValue::Kind>() &&
1375  "native PDL rewrite function returned an unexpected type of result");
1376 #endif
1377 
1378  // If the result is a range, we need to copy it over to the bytecodes
1379  // range memory.
1380  if (Optional<TypeRange> typeRange = result.dyn_cast<TypeRange>()) {
1381  unsigned rangeIndex = read();
1382  typeRangeMemory[rangeIndex] = *typeRange;
1383  memory[read()] = &typeRangeMemory[rangeIndex];
1384  } else if (Optional<ValueRange> valueRange =
1385  result.dyn_cast<ValueRange>()) {
1386  unsigned rangeIndex = read();
1387  valueRangeMemory[rangeIndex] = *valueRange;
1388  memory[read()] = &valueRangeMemory[rangeIndex];
1389  } else {
1390  memory[read()] = result.getAsOpaquePointer();
1391  }
1392  }
1393 
1394  // Copy over any underlying storage allocated for result ranges.
1395  for (auto &it : results.getAllocatedTypeRanges())
1396  allocatedTypeRangeMemory.push_back(std::move(it));
1397  for (auto &it : results.getAllocatedValueRanges())
1398  allocatedValueRangeMemory.push_back(std::move(it));
1399 }
1400 
1401 void ByteCodeExecutor::executeAreEqual() {
1402  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1403  const void *lhs = read<const void *>();
1404  const void *rhs = read<const void *>();
1405 
1406  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n");
1407  selectJump(lhs == rhs);
1408 }
1409 
1410 void ByteCodeExecutor::executeAreRangesEqual() {
1411  LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n");
1412  PDLValue::Kind valueKind = read<PDLValue::Kind>();
1413  const void *lhs = read<const void *>();
1414  const void *rhs = read<const void *>();
1415 
1416  switch (valueKind) {
1418  const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs);
1419  const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs);
1420  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1421  selectJump(*lhsRange == *rhsRange);
1422  break;
1423  }
1425  const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs);
1426  const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs);
1427  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1428  selectJump(*lhsRange == *rhsRange);
1429  break;
1430  }
1431  default:
1432  llvm_unreachable("unexpected `AreRangesEqual` value kind");
1433  }
1434 }
1435 
1436 void ByteCodeExecutor::executeBranch() {
1437  LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n");
1438  curCodeIt = &code[read<ByteCodeAddr>()];
1439 }
1440 
1441 void ByteCodeExecutor::executeCheckOperandCount() {
1442  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
1443  Operation *op = read<Operation *>();
1444  uint32_t expectedCount = read<uint32_t>();
1445  bool compareAtLeast = read();
1446 
1447  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n"
1448  << " * Expected: " << expectedCount << "\n"
1449  << " * Comparator: "
1450  << (compareAtLeast ? ">=" : "==") << "\n");
1451  if (compareAtLeast)
1452  selectJump(op->getNumOperands() >= expectedCount);
1453  else
1454  selectJump(op->getNumOperands() == expectedCount);
1455 }
1456 
1457 void ByteCodeExecutor::executeCheckOperationName() {
1458  LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
1459  Operation *op = read<Operation *>();
1460  OperationName expectedName = read<OperationName>();
1461 
1462  LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n"
1463  << " * Expected: \"" << expectedName << "\"\n");
1464  selectJump(op->getName() == expectedName);
1465 }
1466 
1467 void ByteCodeExecutor::executeCheckResultCount() {
1468  LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
1469  Operation *op = read<Operation *>();
1470  uint32_t expectedCount = read<uint32_t>();
1471  bool compareAtLeast = read();
1472 
1473  LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n"
1474  << " * Expected: " << expectedCount << "\n"
1475  << " * Comparator: "
1476  << (compareAtLeast ? ">=" : "==") << "\n");
1477  if (compareAtLeast)
1478  selectJump(op->getNumResults() >= expectedCount);
1479  else
1480  selectJump(op->getNumResults() == expectedCount);
1481 }
1482 
1483 void ByteCodeExecutor::executeCheckTypes() {
1484  LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
1485  TypeRange *lhs = read<TypeRange *>();
1486  Attribute rhs = read<Attribute>();
1487  LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n");
1488 
1489  selectJump(*lhs == rhs.cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1490 }
1491 
1492 void ByteCodeExecutor::executeContinue() {
1493  ByteCodeField level = read();
1494  LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n"
1495  << " * Level: " << level << "\n");
1496  ++loopIndex[level];
1497  popCodeIt();
1498 }
1499 
1500 void ByteCodeExecutor::executeCreateTypes() {
1501  LLVM_DEBUG(llvm::dbgs() << "Executing CreateTypes:\n");
1502  unsigned memIndex = read();
1503  unsigned rangeIndex = read();
1504  ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1505 
1506  LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n");
1507 
1508  // Allocate a buffer for this type range.
1509  llvm::OwningArrayRef<Type> storage(typesAttr.size());
1510  llvm::copy(typesAttr.getAsValueRange<TypeAttr>(), storage.begin());
1511  allocatedTypeRangeMemory.emplace_back(std::move(storage));
1512 
1513  // Assign this to the range slot and use the range as the value for the
1514  // memory index.
1515  typeRangeMemory[rangeIndex] = allocatedTypeRangeMemory.back();
1516  memory[memIndex] = &typeRangeMemory[rangeIndex];
1517 }
1518 
1519 void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter,
1520  Location mainRewriteLoc) {
1521  LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
1522 
1523  unsigned memIndex = read();
1524  OperationState state(mainRewriteLoc, read<OperationName>());
1525  readValueList(state.operands);
1526  for (unsigned i = 0, e = read(); i != e; ++i) {
1527  StringAttr name = read<StringAttr>();
1528  if (Attribute attr = read<Attribute>())
1529  state.addAttribute(name, attr);
1530  }
1531 
1532  for (unsigned i = 0, e = read(); i != e; ++i) {
1533  if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1534  state.types.push_back(read<Type>());
1535  continue;
1536  }
1537 
1538  // If we find a null range, this signals that the types are infered.
1539  if (TypeRange *resultTypes = read<TypeRange *>()) {
1540  state.types.append(resultTypes->begin(), resultTypes->end());
1541  continue;
1542  }
1543 
1544  // Handle the case where the operation has inferred types.
1545  InferTypeOpInterface::Concept *concept =
1546  state.name.getRegisteredInfo()->getInterface<InferTypeOpInterface>();
1547 
1548  // TODO: Handle failure.
1549  state.types.clear();
1550  if (failed(concept->inferReturnTypes(
1551  state.getContext(), state.location, state.operands,
1552  state.attributes.getDictionary(state.getContext()), state.regions,
1553  state.types)))
1554  return;
1555  break;
1556  }
1557 
1558  Operation *resultOp = rewriter.createOperation(state);
1559  memory[memIndex] = resultOp;
1560 
1561  LLVM_DEBUG({
1562  llvm::dbgs() << " * Attributes: "
1563  << state.attributes.getDictionary(state.getContext())
1564  << "\n * Operands: ";
1565  llvm::interleaveComma(state.operands, llvm::dbgs());
1566  llvm::dbgs() << "\n * Result Types: ";
1567  llvm::interleaveComma(state.types, llvm::dbgs());
1568  llvm::dbgs() << "\n * Result: " << *resultOp << "\n";
1569  });
1570 }
1571 
1572 void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) {
1573  LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1574  Operation *op = read<Operation *>();
1575 
1576  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1577  rewriter.eraseOp(op);
1578 }
1579 
1580 template <typename T, typename Range, PDLValue::Kind kind>
1581 void ByteCodeExecutor::executeExtract() {
1582  LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n");
1583  Range *range = read<Range *>();
1584  unsigned index = read<uint32_t>();
1585  unsigned memIndex = read();
1586 
1587  if (!range) {
1588  memory[memIndex] = nullptr;
1589  return;
1590  }
1591 
1592  T result = index < range->size() ? (*range)[index] : T();
1593  LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n"
1594  << " * Index: " << index << "\n"
1595  << " * Result: " << result << "\n");
1596  storeToMemory(memIndex, result);
1597 }
1598 
1599 void ByteCodeExecutor::executeFinalize() {
1600  LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n");
1601 }
1602 
1603 void ByteCodeExecutor::executeForEach() {
1604  LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n");
1605  const ByteCodeField *prevCodeIt = getPrevCodeIt();
1606  unsigned rangeIndex = read();
1607  unsigned memIndex = read();
1608  const void *value = nullptr;
1609 
1610  switch (read<PDLValue::Kind>()) {
1612  unsigned &index = loopIndex[read()];
1613  ArrayRef<Operation *> array = opRangeMemory[rangeIndex];
1614  assert(index <= array.size() && "iterated past the end");
1615  if (index < array.size()) {
1616  LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n");
1617  value = array[index];
1618  break;
1619  }
1620 
1621  LLVM_DEBUG(llvm::dbgs() << " * Done\n");
1622  index = 0;
1623  selectJump(size_t(0));
1624  return;
1625  }
1626  default:
1627  llvm_unreachable("unexpected `ForEach` value kind");
1628  }
1629 
1630  // Store the iterate value and the stack address.
1631  memory[memIndex] = value;
1632  pushCodeIt(prevCodeIt);
1633 
1634  // Skip over the successor (we will enter the body of the loop).
1635  read<ByteCodeAddr>();
1636 }
1637 
1638 void ByteCodeExecutor::executeGetAttribute() {
1639  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1640  unsigned memIndex = read();
1641  Operation *op = read<Operation *>();
1642  StringAttr attrName = read<StringAttr>();
1643  Attribute attr = op->getAttr(attrName);
1644 
1645  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1646  << " * Attribute: " << attrName << "\n"
1647  << " * Result: " << attr << "\n");
1648  memory[memIndex] = attr.getAsOpaquePointer();
1649 }
1650 
1651 void ByteCodeExecutor::executeGetAttributeType() {
1652  LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1653  unsigned memIndex = read();
1654  Attribute attr = read<Attribute>();
1655  Type type = attr ? attr.getType() : Type();
1656 
1657  LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n"
1658  << " * Result: " << type << "\n");
1659  memory[memIndex] = type.getAsOpaquePointer();
1660 }
1661 
1662 void ByteCodeExecutor::executeGetDefiningOp() {
1663  LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1664  unsigned memIndex = read();
1665  Operation *op = nullptr;
1666  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1667  Value value = read<Value>();
1668  if (value)
1669  op = value.getDefiningOp();
1670  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1671  } else {
1672  ValueRange *values = read<ValueRange *>();
1673  if (values && !values->empty()) {
1674  op = values->front().getDefiningOp();
1675  }
1676  LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n");
1677  }
1678 
1679  LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n");
1680  memory[memIndex] = op;
1681 }
1682 
1683 void ByteCodeExecutor::executeGetOperand(unsigned index) {
1684  Operation *op = read<Operation *>();
1685  unsigned memIndex = read();
1686  Value operand =
1687  index < op->getNumOperands() ? op->getOperand(index) : Value();
1688 
1689  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1690  << " * Index: " << index << "\n"
1691  << " * Result: " << operand << "\n");
1692  memory[memIndex] = operand.getAsOpaquePointer();
1693 }
1694 
1695 /// This function is the internal implementation of `GetResults` and
1696 /// `GetOperands` that provides support for extracting a value range from the
1697 /// given operation.
1698 template <template <typename> class AttrSizedSegmentsT, typename RangeT>
1699 static void *
1700 executeGetOperandsResults(RangeT values, Operation *op, unsigned index,
1701  ByteCodeField rangeIndex, StringRef attrSizedSegments,
1702  MutableArrayRef<ValueRange> valueRangeMemory) {
1703  // Check for the sentinel index that signals that all values should be
1704  // returned.
1705  if (index == std::numeric_limits<uint32_t>::max()) {
1706  LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n");
1707  // `values` is already the full value range.
1708 
1709  // Otherwise, check to see if this operation uses AttrSizedSegments.
1710  } else if (op->hasTrait<AttrSizedSegmentsT>()) {
1711  LLVM_DEBUG(llvm::dbgs()
1712  << " * Extracting values from `" << attrSizedSegments << "`\n");
1713 
1714  auto segmentAttr = op->getAttrOfType<DenseElementsAttr>(attrSizedSegments);
1715  if (!segmentAttr || segmentAttr.getNumElements() <= index)
1716  return nullptr;
1717 
1718  auto segments = segmentAttr.getValues<int32_t>();
1719  unsigned startIndex =
1720  std::accumulate(segments.begin(), segments.begin() + index, 0);
1721  values = values.slice(startIndex, *std::next(segments.begin(), index));
1722 
1723  LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", "
1724  << *std::next(segments.begin(), index) << "]\n");
1725 
1726  // Otherwise, assume this is the last operand group of the operation.
1727  // FIXME: We currently don't support operations with
1728  // SameVariadicOperandSize/SameVariadicResultSize here given that we don't
1729  // have a way to detect it's presence.
1730  } else if (values.size() >= index) {
1731  LLVM_DEBUG(llvm::dbgs()
1732  << " * Treating values as trailing variadic range\n");
1733  values = values.drop_front(index);
1734 
1735  // If we couldn't detect a way to compute the values, bail out.
1736  } else {
1737  return nullptr;
1738  }
1739 
1740  // If the range index is valid, we are returning a range.
1741  if (rangeIndex != std::numeric_limits<ByteCodeField>::max()) {
1742  valueRangeMemory[rangeIndex] = values;
1743  return &valueRangeMemory[rangeIndex];
1744  }
1745 
1746  // If a range index wasn't provided, the range is required to be non-variadic.
1747  return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1748 }
1749 
1750 void ByteCodeExecutor::executeGetOperands() {
1751  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n");
1752  unsigned index = read<uint32_t>();
1753  Operation *op = read<Operation *>();
1754  ByteCodeField rangeIndex = read();
1755 
1756  void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1757  op->getOperands(), op, index, rangeIndex, "operand_segment_sizes",
1758  valueRangeMemory);
1759  if (!result)
1760  LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n");
1761  memory[read()] = result;
1762 }
1763 
1764 void ByteCodeExecutor::executeGetResult(unsigned index) {
1765  Operation *op = read<Operation *>();
1766  unsigned memIndex = read();
1767  OpResult result =
1768  index < op->getNumResults() ? op->getResult(index) : OpResult();
1769 
1770  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"
1771  << " * Index: " << index << "\n"
1772  << " * Result: " << result << "\n");
1773  memory[memIndex] = result.getAsOpaquePointer();
1774 }
1775 
1776 void ByteCodeExecutor::executeGetResults() {
1777  LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n");
1778  unsigned index = read<uint32_t>();
1779  Operation *op = read<Operation *>();
1780  ByteCodeField rangeIndex = read();
1781 
1782  void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1783  op->getResults(), op, index, rangeIndex, "result_segment_sizes",
1784  valueRangeMemory);
1785  if (!result)
1786  LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n");
1787  memory[read()] = result;
1788 }
1789 
1790 void ByteCodeExecutor::executeGetUsers() {
1791  LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n");
1792  unsigned memIndex = read();
1793  unsigned rangeIndex = read();
1794  OwningOpRange &range = opRangeMemory[rangeIndex];
1795  memory[memIndex] = &range;
1796 
1797  range = OwningOpRange();
1798  if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1799  // Read the value.
1800  Value value = read<Value>();
1801  if (!value)
1802  return;
1803  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1804 
1805  // Extract the users of a single value.
1806  range = OwningOpRange(std::distance(value.user_begin(), value.user_end()));
1807  llvm::copy(value.getUsers(), range.begin());
1808  } else {
1809  // Read a range of values.
1810  ValueRange *values = read<ValueRange *>();
1811  if (!values)
1812  return;
1813  LLVM_DEBUG({
1814  llvm::dbgs() << " * Values (" << values->size() << "): ";
1815  llvm::interleaveComma(*values, llvm::dbgs());
1816  llvm::dbgs() << "\n";
1817  });
1818 
1819  // Extract all the users of a range of values.
1821  for (Value value : *values)
1822  users.append(value.user_begin(), value.user_end());
1823  range = OwningOpRange(users.size());
1824  llvm::copy(users, range.begin());
1825  }
1826 
1827  LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n");
1828 }
1829 
1830 void ByteCodeExecutor::executeGetValueType() {
1831  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1832  unsigned memIndex = read();
1833  Value value = read<Value>();
1834  Type type = value ? value.getType() : Type();
1835 
1836  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"
1837  << " * Result: " << type << "\n");
1838  memory[memIndex] = type.getAsOpaquePointer();
1839 }
1840 
1841 void ByteCodeExecutor::executeGetValueRangeTypes() {
1842  LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n");
1843  unsigned memIndex = read();
1844  unsigned rangeIndex = read();
1845  ValueRange *values = read<ValueRange *>();
1846  if (!values) {
1847  LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n");
1848  memory[memIndex] = nullptr;
1849  return;
1850  }
1851 
1852  LLVM_DEBUG({
1853  llvm::dbgs() << " * Values (" << values->size() << "): ";
1854  llvm::interleaveComma(*values, llvm::dbgs());
1855  llvm::dbgs() << "\n * Result: ";
1856  llvm::interleaveComma(values->getType(), llvm::dbgs());
1857  llvm::dbgs() << "\n";
1858  });
1859  typeRangeMemory[rangeIndex] = values->getType();
1860  memory[memIndex] = &typeRangeMemory[rangeIndex];
1861 }
1862 
1863 void ByteCodeExecutor::executeIsNotNull() {
1864  LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1865  const void *value = read<const void *>();
1866 
1867  LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n");
1868  selectJump(value != nullptr);
1869 }
1870 
1871 void ByteCodeExecutor::executeRecordMatch(
1872  PatternRewriter &rewriter,
1874  LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1875  unsigned patternIndex = read();
1876  PatternBenefit benefit = currentPatternBenefits[patternIndex];
1877  const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1878 
1879  // If the benefit of the pattern is impossible, skip the processing of the
1880  // rest of the pattern.
1881  if (benefit.isImpossibleToMatch()) {
1882  LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n");
1883  curCodeIt = dest;
1884  return;
1885  }
1886 
1887  // Create a fused location containing the locations of each of the
1888  // operations used in the match. This will be used as the location for
1889  // created operations during the rewrite that don't already have an
1890  // explicit location set.
1891  unsigned numMatchLocs = read();
1892  SmallVector<Location, 4> matchLocs;
1893  matchLocs.reserve(numMatchLocs);
1894  for (unsigned i = 0; i != numMatchLocs; ++i)
1895  matchLocs.push_back(read<Operation *>()->getLoc());
1896  Location matchLoc = rewriter.getFusedLoc(matchLocs);
1897 
1898  LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n"
1899  << " * Location: " << matchLoc << "\n");
1900  matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1901  PDLByteCode::MatchResult &match = matches.back();
1902 
1903  // Record all of the inputs to the match. If any of the inputs are ranges, we
1904  // will also need to remap the range pointer to memory stored in the match
1905  // state.
1906  unsigned numInputs = read();
1907  match.values.reserve(numInputs);
1908  match.typeRangeValues.reserve(numInputs);
1909  match.valueRangeValues.reserve(numInputs);
1910  for (unsigned i = 0; i < numInputs; ++i) {
1911  switch (read<PDLValue::Kind>()) {
1913  match.typeRangeValues.push_back(*read<TypeRange *>());
1914  match.values.push_back(&match.typeRangeValues.back());
1915  break;
1917  match.valueRangeValues.push_back(*read<ValueRange *>());
1918  match.values.push_back(&match.valueRangeValues.back());
1919  break;
1920  default:
1921  match.values.push_back(read<const void *>());
1922  break;
1923  }
1924  }
1925  curCodeIt = dest;
1926 }
1927 
1928 void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) {
1929  LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1930  Operation *op = read<Operation *>();
1932  readValueList(args);
1933 
1934  LLVM_DEBUG({
1935  llvm::dbgs() << " * Operation: " << *op << "\n"
1936  << " * Values: ";
1937  llvm::interleaveComma(args, llvm::dbgs());
1938  llvm::dbgs() << "\n";
1939  });
1940  rewriter.replaceOp(op, args);
1941 }
1942 
1943 void ByteCodeExecutor::executeSwitchAttribute() {
1944  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1945  Attribute value = read<Attribute>();
1946  ArrayAttr cases = read<ArrayAttr>();
1947  handleSwitch(value, cases);
1948 }
1949 
1950 void ByteCodeExecutor::executeSwitchOperandCount() {
1951  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1952  Operation *op = read<Operation *>();
1953  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1954 
1955  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1956  handleSwitch(op->getNumOperands(), cases);
1957 }
1958 
1959 void ByteCodeExecutor::executeSwitchOperationName() {
1960  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1961  OperationName value = read<Operation *>()->getName();
1962  size_t caseCount = read();
1963 
1964  // The operation names are stored in-line, so to print them out for
1965  // debugging purposes we need to read the array before executing the
1966  // switch so that we can display all of the possible values.
1967  LLVM_DEBUG({
1968  const ByteCodeField *prevCodeIt = curCodeIt;
1969  llvm::dbgs() << " * Value: " << value << "\n"
1970  << " * Cases: ";
1971  llvm::interleaveComma(
1972  llvm::map_range(llvm::seq<size_t>(0, caseCount),
1973  [&](size_t) { return read<OperationName>(); }),
1974  llvm::dbgs());
1975  llvm::dbgs() << "\n";
1976  curCodeIt = prevCodeIt;
1977  });
1978 
1979  // Try to find the switch value within any of the cases.
1980  for (size_t i = 0; i != caseCount; ++i) {
1981  if (read<OperationName>() == value) {
1982  curCodeIt += (caseCount - i - 1);
1983  return selectJump(i + 1);
1984  }
1985  }
1986  selectJump(size_t(0));
1987 }
1988 
1989 void ByteCodeExecutor::executeSwitchResultCount() {
1990  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1991  Operation *op = read<Operation *>();
1992  auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1993 
1994  LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n");
1995  handleSwitch(op->getNumResults(), cases);
1996 }
1997 
1998 void ByteCodeExecutor::executeSwitchType() {
1999  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
2000  Type value = read<Type>();
2001  auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2002  handleSwitch(value, cases);
2003 }
2004 
2005 void ByteCodeExecutor::executeSwitchTypes() {
2006  LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n");
2007  TypeRange *value = read<TypeRange *>();
2008  auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2009  if (!value) {
2010  LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n");
2011  return selectJump(size_t(0));
2012  }
2013  handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) {
2014  return value == caseValue.getAsValueRange<TypeAttr>();
2015  });
2016 }
2017 
2018 void ByteCodeExecutor::execute(
2019  PatternRewriter &rewriter,
2021  Optional<Location> mainRewriteLoc) {
2022  while (true) {
2023  // Print the location of the operation being executed.
2024  LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n");
2025 
2026  OpCode opCode = static_cast<OpCode>(read());
2027  switch (opCode) {
2028  case ApplyConstraint:
2029  executeApplyConstraint(rewriter);
2030  break;
2031  case ApplyRewrite:
2032  executeApplyRewrite(rewriter);
2033  break;
2034  case AreEqual:
2035  executeAreEqual();
2036  break;
2037  case AreRangesEqual:
2038  executeAreRangesEqual();
2039  break;
2040  case Branch:
2041  executeBranch();
2042  break;
2043  case CheckOperandCount:
2044  executeCheckOperandCount();
2045  break;
2046  case CheckOperationName:
2047  executeCheckOperationName();
2048  break;
2049  case CheckResultCount:
2050  executeCheckResultCount();
2051  break;
2052  case CheckTypes:
2053  executeCheckTypes();
2054  break;
2055  case Continue:
2056  executeContinue();
2057  break;
2058  case CreateOperation:
2059  executeCreateOperation(rewriter, *mainRewriteLoc);
2060  break;
2061  case CreateTypes:
2062  executeCreateTypes();
2063  break;
2064  case EraseOp:
2065  executeEraseOp(rewriter);
2066  break;
2067  case ExtractOp:
2068  executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2069  break;
2070  case ExtractType:
2071  executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2072  break;
2073  case ExtractValue:
2074  executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2075  break;
2076  case Finalize:
2077  executeFinalize();
2078  LLVM_DEBUG(llvm::dbgs() << "\n");
2079  return;
2080  case ForEach:
2081  executeForEach();
2082  break;
2083  case GetAttribute:
2084  executeGetAttribute();
2085  break;
2086  case GetAttributeType:
2087  executeGetAttributeType();
2088  break;
2089  case GetDefiningOp:
2090  executeGetDefiningOp();
2091  break;
2092  case GetOperand0:
2093  case GetOperand1:
2094  case GetOperand2:
2095  case GetOperand3: {
2096  unsigned index = opCode - GetOperand0;
2097  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n");
2098  executeGetOperand(index);
2099  break;
2100  }
2101  case GetOperandN:
2102  LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n");
2103  executeGetOperand(read<uint32_t>());
2104  break;
2105  case GetOperands:
2106  executeGetOperands();
2107  break;
2108  case GetResult0:
2109  case GetResult1:
2110  case GetResult2:
2111  case GetResult3: {
2112  unsigned index = opCode - GetResult0;
2113  LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n");
2114  executeGetResult(index);
2115  break;
2116  }
2117  case GetResultN:
2118  LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n");
2119  executeGetResult(read<uint32_t>());
2120  break;
2121  case GetResults:
2122  executeGetResults();
2123  break;
2124  case GetUsers:
2125  executeGetUsers();
2126  break;
2127  case GetValueType:
2128  executeGetValueType();
2129  break;
2130  case GetValueRangeTypes:
2131  executeGetValueRangeTypes();
2132  break;
2133  case IsNotNull:
2134  executeIsNotNull();
2135  break;
2136  case RecordMatch:
2137  assert(matches &&
2138  "expected matches to be provided when executing the matcher");
2139  executeRecordMatch(rewriter, *matches);
2140  break;
2141  case ReplaceOp:
2142  executeReplaceOp(rewriter);
2143  break;
2144  case SwitchAttribute:
2145  executeSwitchAttribute();
2146  break;
2147  case SwitchOperandCount:
2148  executeSwitchOperandCount();
2149  break;
2150  case SwitchOperationName:
2151  executeSwitchOperationName();
2152  break;
2153  case SwitchResultCount:
2154  executeSwitchResultCount();
2155  break;
2156  case SwitchType:
2157  executeSwitchType();
2158  break;
2159  case SwitchTypes:
2160  executeSwitchTypes();
2161  break;
2162  }
2163  LLVM_DEBUG(llvm::dbgs() << "\n");
2164  }
2165 }
2166 
2167 /// Run the pattern matcher on the given root operation, collecting the matched
2168 /// patterns in `matches`.
2171  PDLByteCodeMutableState &state) const {
2172  // The first memory slot is always the root operation.
2173  state.memory[0] = op;
2174 
2175  // The matcher function always starts at code address 0.
2176  ByteCodeExecutor executor(
2177  matcherByteCode.data(), state.memory, state.opRangeMemory,
2178  state.typeRangeMemory, state.allocatedTypeRangeMemory,
2179  state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2180  uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2181  constraintFunctions, rewriteFunctions);
2182  executor.execute(rewriter, &matches);
2183 
2184  // Order the found matches by benefit.
2185  std::stable_sort(matches.begin(), matches.end(),
2186  [](const MatchResult &lhs, const MatchResult &rhs) {
2187  return lhs.benefit > rhs.benefit;
2188  });
2189 }
2190 
2191 /// Run the rewriter of the given pattern on the root operation `op`.
2192 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
2193  PDLByteCodeMutableState &state) const {
2194  // The arguments of the rewrite function are stored at the start of the
2195  // memory buffer.
2196  llvm::copy(match.values, state.memory.begin());
2197 
2198  ByteCodeExecutor executor(
2199  &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
2200  state.opRangeMemory, state.typeRangeMemory,
2201  state.allocatedTypeRangeMemory, state.valueRangeMemory,
2202  state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2203  rewriterByteCode, state.currentPatternBenefits, patterns,
2204  constraintFunctions, rewriteFunctions);
2205  executor.execute(rewriter, /*matches=*/nullptr, match.location);
2206 }
Include the generated interface declarations.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
Definition: Region.h:26
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
Definition: ByteCode.cpp:1700
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:881
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This is a value defined by a result of an operation.
Definition: Value.h:423
Block represents an ordered list of Operations.
Definition: Block.h:29
This class represents liveness information on block level.
Definition: Liveness.h:99
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
Definition: ByteCode.h:124
Value getOperand(unsigned idx)
Definition: Operation.h:219
AttrClass getAttrOfType(StringAttr name)
Definition: Operation.h:327
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
OpCode
Definition: ByteCode.cpp:79
unsigned getNumOperands()
Definition: Operation.h:215
llvm::OwningArrayRef< Operation * > OwningOpRange
Definition: ByteCode.h:31
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
This class implements the result iterators for the Operation class.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.
Definition: ByteCode.h:136
Operation & front()
Definition: Block.h:144
PDLByteCode(ModuleOp module, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
Definition: ByteCode.cpp:1014
user_range getUsers() const
Definition: Value.h:212
Region * getParent() const
Provide a &#39;getParent&#39; method for ilist_node_with_parent methods.
Definition: Block.cpp:26
static constexpr const bool value
SmallVector< Value, 4 > operands
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
Definition: Value.h:225
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:48
Auxiliary range data structure to unpack the offset, size and stride operands into a list of triples...
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
All of the data pertaining to a specific pattern within the bytecode.
Definition: ByteCode.h:38
uint16_t ByteCodeField
Use generic bytecode types.
Definition: ByteCode.h:29
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
An attribute that represents a reference to a dense vector or tensor object.
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, ByteCodeAddr rewriterAddr)
Definition: ByteCode.cpp:36
iterator_range_impl< ElementIterator< T > > getValues() const
void rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Definition: ByteCode.cpp:2192
SmallVector< ValueRange, 0 > valueRangeValues
Definition: ByteCode.h:139
U dyn_cast() const
Definition: Types.h:244
Storage type of byte-code interpreter values.
Definition: PatternMatch.h:398
Attributes are known-constant values of operations.
Definition: Attributes.h:24
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:32
user_iterator user_begin() const
Definition: Value.h:210
Operation * createOperation(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:470
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
std::function< void(ArrayRef< PDLValue >, ArrayAttr, PatternRewriter &, PDLResultList &)> PDLRewriteFunction
A native PDL rewrite function.
Definition: PatternMatch.h:598
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
Definition: ByteCode.h:138
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
Definition: Builders.cpp:28
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:38
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Definition: Block.cpp:35
Optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, None otherwise.
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
This represents an operation in an abstracted form, suitable for use with the builder APIs...
This class acts as a special tag that makes the desire to match "any" operation type explicit...
Definition: PatternMatch.h:157
BlockArgListType getArguments()
Definition: Block.h:76
Represents an analysis for computing liveness information from a given top-level operation.
Definition: Liveness.h:47
This class represents an argument of a Block.
Definition: Value.h:298
auto getType() const
const LivenessBlockInfo * getLiveness(Block *block) const
Gets liveness info (if any) for the block.
Definition: Liveness.cpp:224
Kind
The underlying kind of a PDL value.
Definition: PatternMatch.h:401
This class implements the successor iterators for Block.
Definition: BlockSupport.h:72
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
user_iterator user_end() const
Definition: Value.h:211
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
MLIRContext * getContext() const
Get the context held by this operation state.
NamedAttrList attributes
Type getType() const
Return the type of this attribute.
Definition: Attributes.h:64
SuccessorRange getSuccessors()
Definition: Operation.h:446
bool isImpossibleToMatch() const
Definition: PatternMatch.h:42
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Definition: Liveness.cpp:362
Type getType() const
Return the type of this value.
Definition: Value.h:117
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
Definition: ByteCode.cpp:61
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
Definition: ByteCode.cpp:69
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation...
Definition: Visitors.cpp:24
Location location
The location of operations to be replaced.
Definition: ByteCode.h:134
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
The class represents a list of PDL results, returned by a native rewrite method.
Definition: PatternMatch.h:510
const PDLByteCodePattern * pattern
The originating pattern that was matched.
Definition: ByteCode.h:143
This class implements the operand iterators for the Operation class.
std::function< LogicalResult(ArrayRef< PDLValue >, ArrayAttr, PatternRewriter &)> PDLConstraintFunction
A generic PDL pattern constraint function.
Definition: PatternMatch.h:591
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:273
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
This class contains the mutable state of a bytecode instance.
Definition: ByteCode.h:63
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches...
Definition: ByteCode.cpp:2169
uint32_t ByteCodeAddr
Definition: ByteCode.h:30
PatternBenefit benefit
The current benefit of the pattern that was matched.
Definition: ByteCode.h:145
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:57
bool isa() const
Definition: Types.h:234
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Definition: Liveness.h:110
result_range getResults()
Definition: Operation.h:284
This class provides an abstraction over the different types of ranges over Values.
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Definition: Operation.h:323
ByteCodeAddr getRewriterAddr() const
Return the bytecode address of the rewriter for this pattern.
Definition: ByteCode.h:44
static void processValue(Value value, LiveMap &liveMap)
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
Definition: ByteCode.cpp:1032
static Value max(ImplicitLocOpBuilder &builder, Value a, Value b)
SmallVector< Type, 4 > types
Types of the results of this operation.