19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "pdl-bytecode"
38 PDLPatternConfigSet *configSet,
39 ByteCodeAddr rewriterAddr) {
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
54 benefit, ctx, generatedOps);
66 currentPatternBenefits[patternIndex] = benefit;
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
82 enum OpCode : ByteCodeField {
104 CreateConstantTypeRange,
108 CreateDynamicTypeRange,
110 CreateDynamicValueRange,
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
188 template <
typename T,
typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
194 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
198 ByteCodeField &maxValueMemoryIndex,
199 ByteCodeField &maxOpRangeMemoryIndex,
200 ByteCodeField &maxTypeRangeMemoryIndex,
201 ByteCodeField &maxValueRangeMemoryIndex,
202 ByteCodeField &maxLoopLevel,
203 llvm::StringMap<PDLConstraintFunction> &constraintFns,
204 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
208 maxValueMemoryIndex(maxValueMemoryIndex),
209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212 maxLoopLevel(maxLoopLevel), configMap(configMap) {
214 constraintToMemIndex.try_emplace(it.value().first(), it.index());
216 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220 void generate(ModuleOp module);
223 ByteCodeField &getMemIndex(
Value value) {
224 assert(valueToMemIndex.count(value) &&
225 "expected memory index to be assigned");
226 return valueToMemIndex[value];
230 ByteCodeField &getRangeStorageIndex(
Value value) {
231 assert(valueToRangeIndex.count(value) &&
232 "expected range index to be assigned");
233 return valueToRangeIndex[value];
238 template <
typename T>
239 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
241 const void *opaqueVal = val.getAsOpaquePointer();
244 auto it = uniquedDataToMemIndex.try_emplace(
245 opaqueVal, maxValueMemoryIndex + uniquedData.size());
247 uniquedData.push_back(opaqueVal);
248 return it.first->second;
254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255 ModuleOp rewriterModule);
258 void generate(
Region *region, ByteCodeWriter &writer);
259 void generate(
Operation *op, ByteCodeWriter &writer);
260 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
307 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311 llvm::StringMap<ByteCodeField> constraintToMemIndex;
315 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
322 ByteCodeField curLoopLevel = 0;
331 std::vector<const void *> &uniquedData;
335 ByteCodeField &maxValueMemoryIndex;
336 ByteCodeField &maxOpRangeMemoryIndex;
337 ByteCodeField &maxTypeRangeMemoryIndex;
338 ByteCodeField &maxValueRangeMemoryIndex;
339 ByteCodeField &maxLoopLevel;
346 struct ByteCodeWriter {
351 void append(ByteCodeField field) { bytecode.push_back(field); }
352 void append(OpCode opCode) { bytecode.push_back(opCode); }
355 void append(ByteCodeAddr field) {
356 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
357 "unexpected ByteCode address size");
359 ByteCodeField fieldParts[2];
360 std::memcpy(fieldParts, &field,
sizeof(ByteCodeAddr));
361 bytecode.append({fieldParts[0], fieldParts[1]});
366 void append(
Block *successor) {
369 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
370 append(ByteCodeAddr(0));
376 for (
Block *successor : successors)
382 bytecode.push_back(values.size());
383 for (
Value value : values)
384 appendPDLValue(value);
388 void appendPDLValue(
Value value) {
389 appendPDLValueKind(value);
394 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
397 void appendPDLValueKind(
Type type) {
400 .Case<pdl::AttributeType>(
401 [](
Type) {
return PDLValue::Kind::Attribute; })
402 .Case<pdl::OperationType>(
403 [](
Type) {
return PDLValue::Kind::Operation; })
404 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405 if (isa<pdl::TypeType>(rangeTy.getElementType()))
406 return PDLValue::Kind::TypeRange;
407 return PDLValue::Kind::ValueRange;
409 .Case<pdl::TypeType>([](
Type) {
return PDLValue::Kind::Type; })
410 .Case<pdl::ValueType>([](
Type) {
return PDLValue::Kind::Value; });
411 bytecode.push_back(
static_cast<ByteCodeField
>(kind));
416 template <
typename T>
417 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418 std::is_pointer<T>::value>
420 bytecode.push_back(
generator.getMemIndex(value));
424 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
425 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
427 bytecode.push_back(llvm::size(range));
428 for (
auto it : range)
433 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
434 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
436 append(field2, fields...);
440 template <
typename T>
441 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442 appendInline(T value) {
443 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(ByteCodeField);
444 const void *pointer = value.getAsOpaquePointer();
445 ByteCodeField fieldParts[numParts];
446 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
447 bytecode.append(fieldParts, fieldParts + numParts);
462 struct ByteCodeLiveRange {
463 using Set = llvm::IntervalMap<uint64_t, char, 16>;
464 using Allocator = Set::Allocator;
466 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
469 void unionWith(
const ByteCodeLiveRange &rhs) {
470 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
472 liveness->insert(it.start(), it.stop(), 0);
476 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
477 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
487 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
490 std::optional<unsigned> opRangeIndex;
493 std::optional<unsigned> typeRangeIndex;
496 std::optional<unsigned> valueRangeIndex;
500 void Generator::generate(ModuleOp module) {
501 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504 pdl_interp::PDLInterpDialect::getRewriterModuleName());
505 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
509 allocateMemoryIndices(matcherFunc, rewriterModule);
512 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
513 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515 for (
Operation &op : rewriterFunc.getOps())
516 generate(&op, rewriterByteCodeWriter);
518 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519 "unexpected branches in rewriter function");
522 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
523 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
526 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
527 ByteCodeAddr addr = blockToAddr[it.first];
528 for (
unsigned offsetToFix : it.second)
529 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(ByteCodeAddr));
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534 ModuleOp rewriterModule) {
537 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539 auto processRewriterValue = [&](
Value val) {
540 valueToMemIndex.try_emplace(val, index++);
541 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542 Type elementTy = rangeType.getElementType();
543 if (isa<pdl::TypeType>(elementTy))
544 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545 else if (isa<pdl::ValueType>(elementTy))
546 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
551 processRewriterValue(arg);
554 processRewriterValue(result);
556 if (index > maxValueMemoryIndex)
557 maxValueMemoryIndex = index;
558 if (typeRangeIndex > maxTypeRangeMemoryIndex)
559 maxTypeRangeMemoryIndex = typeRangeIndex;
560 if (valueRangeIndex > maxValueRangeMemoryIndex)
561 maxValueRangeMemoryIndex = valueRangeIndex;
578 opToFirstIndex.try_emplace(op, index++);
580 for (
Block &block : region.getBlocks())
583 opToLastIndex.try_emplace(op, index++);
588 ByteCodeLiveRange::Allocator allocator;
593 valueToMemIndex[rootOpArg] = 0;
596 Liveness matcherLiveness(matcherFunc);
597 matcherFunc->walk([&](
Block *block) {
599 assert(info &&
"expected liveness info for block");
603 if (value == rootOpArg)
607 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608 defRangeIt->second.liveness->insert(
609 opToFirstIndex[firstUseOrDef],
614 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
615 Type eleType = rangeTy.getElementType();
616 if (isa<pdl::OperationType>(eleType))
617 defRangeIt->second.opRangeIndex = 0;
618 else if (isa<pdl::TypeType>(eleType))
619 defRangeIt->second.typeRangeIndex = 0;
620 else if (isa<pdl::ValueType>(eleType))
621 defRangeIt->second.valueRangeIndex = 0;
626 for (
Value liveIn : info->
in()) {
631 if (liveIn.getParentRegion() == block->
getParent())
648 std::vector<ByteCodeLiveRange> allocatedIndices;
652 ByteCodeField numIndices = 1;
655 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
657 for (
auto &defIt : valueDefRanges) {
658 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
659 ByteCodeLiveRange &defRange = defIt.second;
662 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
663 ByteCodeLiveRange &existingRange = existingIndexIt.value();
664 if (!defRange.overlaps(existingRange)) {
665 existingRange.unionWith(defRange);
666 memIndex = existingIndexIt.index() + 1;
668 if (defRange.opRangeIndex) {
669 if (!existingRange.opRangeIndex)
670 existingRange.opRangeIndex = numOpRanges++;
671 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672 }
else if (defRange.typeRangeIndex) {
673 if (!existingRange.typeRangeIndex)
674 existingRange.typeRangeIndex = numTypeRanges++;
675 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676 }
else if (defRange.valueRangeIndex) {
677 if (!existingRange.valueRangeIndex)
678 existingRange.valueRangeIndex = numValueRanges++;
679 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
687 allocatedIndices.emplace_back(allocator);
688 ByteCodeLiveRange &newRange = allocatedIndices.back();
689 newRange.unionWith(defRange);
692 if (defRange.opRangeIndex) {
693 newRange.opRangeIndex = numOpRanges;
694 valueToRangeIndex[defIt.first] = numOpRanges++;
695 }
else if (defRange.typeRangeIndex) {
696 newRange.typeRangeIndex = numTypeRanges;
697 valueToRangeIndex[defIt.first] = numTypeRanges++;
698 }
else if (defRange.valueRangeIndex) {
699 newRange.valueRangeIndex = numValueRanges;
700 valueToRangeIndex[defIt.first] = numValueRanges++;
703 memIndex = allocatedIndices.size();
710 llvm::dbgs() <<
"Allocated " << allocatedIndices.size() <<
" indices "
711 <<
"(down from initial " << valueDefRanges.size() <<
").\n";
714 "Ran out of memory for allocated indices");
717 if (numIndices > maxValueMemoryIndex)
718 maxValueMemoryIndex = numIndices;
719 if (numOpRanges > maxOpRangeMemoryIndex)
720 maxOpRangeMemoryIndex = numOpRanges;
721 if (numTypeRanges > maxTypeRangeMemoryIndex)
722 maxTypeRangeMemoryIndex = numTypeRanges;
723 if (numValueRanges > maxValueRangeMemoryIndex)
724 maxValueRangeMemoryIndex = numValueRanges;
727 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
728 llvm::ReversePostOrderTraversal<Region *> rpot(region);
729 for (
Block *block : rpot) {
731 blockToAddr.try_emplace(block, matcherByteCode.size());
733 generate(&op, writer);
737 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
741 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742 writer.appendInline(op->
getLoc());
745 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763 pdl_interp::SwitchResultCountOp>(
764 [&](
auto interpOp) { this->generate(interpOp, writer); })
766 llvm_unreachable(
"unknown `pdl_interp` operation");
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(op.getArgs());
777 writer.append(ByteCodeField(op.getIsNegated()));
779 writer.append(ByteCodeField(results.size()));
780 for (
Value result : results) {
784 writer.appendPDLValueKind(result);
787 if (isa<pdl::RangeType>(result.getType()))
788 writer.append(getRangeStorageIndex(result));
789 writer.append(result);
791 writer.append(op.getSuccessors());
793 void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(op.getArgs());
801 writer.append(ByteCodeField(results.size()));
802 for (
Value result : results) {
805 writer.appendPDLValueKind(result);
808 if (isa<pdl::RangeType>(result.getType()))
809 writer.append(getRangeStorageIndex(result));
810 writer.append(result);
813 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value lhs = op.getLhs();
815 if (isa<pdl::RangeType>(lhs.
getType())) {
816 writer.append(OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
822 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
824 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
827 void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
832 void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
838 void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
843 void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
849 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
853 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
857 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
859 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
861 void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
866 void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
870 writer.appendPDLValueList(op.getInputOperands());
874 writer.append(
static_cast<ByteCodeField
>(attributes.size()));
875 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
880 if (op.getInferredResultTypes())
883 writer.appendPDLValueList(op.getInputResultTypes());
885 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
889 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890 .Case([&](pdl::ValueType) {
891 writer.append(OpCode::CreateDynamicValueRange);
894 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895 writer.appendPDLValueList(op->getOperands());
897 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
901 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(op.getResult()), op.getValue());
905 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
908 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
911 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
912 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
913 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
914 .Default([](
Type) -> OpCode {
915 llvm_unreachable(
"unsupported element type");
917 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
919 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
920 writer.append(OpCode::Finalize);
922 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
924 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
925 writer.appendPDLValueKind(arg.
getType());
926 writer.append(curLoopLevel, op.getSuccessor());
928 if (curLoopLevel > maxLoopLevel)
929 maxLoopLevel = curLoopLevel;
930 generate(&op.getRegion(), writer);
933 void Generator::generate(pdl_interp::GetAttributeOp op,
934 ByteCodeWriter &writer) {
935 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
938 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
939 ByteCodeWriter &writer) {
940 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
942 void Generator::generate(pdl_interp::GetDefiningOpOp op,
943 ByteCodeWriter &writer) {
944 writer.append(OpCode::GetDefiningOp, op.getInputOp());
945 writer.appendPDLValue(op.getValue());
947 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
948 uint32_t index = op.getIndex();
950 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
952 writer.append(OpCode::GetOperandN, index);
953 writer.append(op.getInputOp(), op.getValue());
955 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
956 Value result = op.getValue();
957 std::optional<uint32_t> index = op.getIndex();
958 writer.append(OpCode::GetOperands,
961 if (isa<pdl::RangeType>(result.
getType()))
962 writer.append(getRangeStorageIndex(result));
965 writer.append(result);
967 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
968 uint32_t index = op.getIndex();
970 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
972 writer.append(OpCode::GetResultN, index);
973 writer.append(op.getInputOp(), op.getValue());
975 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
976 Value result = op.getValue();
977 std::optional<uint32_t> index = op.getIndex();
978 writer.append(OpCode::GetResults,
981 if (isa<pdl::RangeType>(result.
getType()))
982 writer.append(getRangeStorageIndex(result));
985 writer.append(result);
987 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
988 Value operations = op.getOperations();
989 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
990 writer.append(OpCode::GetUsers, operations, rangeIndex);
991 writer.appendPDLValue(op.getValue());
993 void Generator::generate(pdl_interp::GetValueTypeOp op,
994 ByteCodeWriter &writer) {
995 if (isa<pdl::RangeType>(op.getType())) {
996 Value result = op.getResult();
997 writer.append(OpCode::GetValueRangeTypes, result,
998 getRangeStorageIndex(result), op.getValue());
1000 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1003 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1004 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1006 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1007 ByteCodeField patternIndex =
patterns.size();
1008 patterns.emplace_back(PDLByteCodePattern::create(
1009 op, configMap.lookup(op),
1010 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1011 writer.append(OpCode::RecordMatch, patternIndex,
1013 writer.appendPDLValueList(op.getInputs());
1015 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1016 writer.append(OpCode::ReplaceOp, op.getInputOp());
1017 writer.appendPDLValueList(op.getReplValues());
1019 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1020 ByteCodeWriter &writer) {
1021 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1022 op.getCaseValuesAttr(), op.getSuccessors());
1024 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1025 ByteCodeWriter &writer) {
1026 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1027 op.getCaseValuesAttr(), op.getSuccessors());
1029 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1030 ByteCodeWriter &writer) {
1031 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1032 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1034 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1035 op.getSuccessors());
1037 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1038 ByteCodeWriter &writer) {
1039 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1040 op.getCaseValuesAttr(), op.getSuccessors());
1042 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1043 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1044 op.getSuccessors());
1046 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1047 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1048 op.getSuccessors());
1055 PDLByteCode::PDLByteCode(
1056 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1058 llvm::StringMap<PDLConstraintFunction> constraintFns,
1059 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1060 : configs(std::move(configs)) {
1061 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1062 rewriterByteCode,
patterns, maxValueMemoryIndex,
1063 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1064 maxLoopLevel, constraintFns, rewriteFns, configMap);
1068 for (
auto &it : constraintFns)
1069 constraintFunctions.push_back(std::move(it.second));
1070 for (
auto &it : rewriteFns)
1071 rewriteFunctions.push_back(std::move(it.second));
1077 state.memory.resize(maxValueMemoryIndex,
nullptr);
1078 state.opRangeMemory.resize(maxOpRangeCount);
1079 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1080 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1081 state.loopIndex.resize(maxLoopLevel, 0);
1082 state.currentPatternBenefits.reserve(
patterns.size());
1084 state.currentPatternBenefits.push_back(pattern.getBenefit());
1094 class ByteCodeRewriteResultList :
public PDLResultList {
1096 ByteCodeRewriteResultList(
unsigned maxNumResults)
1097 : PDLResultList(maxNumResults) {}
1104 return allocatedTypeRanges;
1109 return allocatedValueRanges;
1114 class ByteCodeExecutor {
1120 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1122 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1129 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1130 typeRangeMemory(typeRangeMemory),
1131 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1132 valueRangeMemory(valueRangeMemory),
1133 allocatedValueRangeMemory(allocatedValueRangeMemory),
1134 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1136 constraintFunctions(constraintFunctions),
1137 rewriteFunctions(rewriteFunctions) {}
1145 std::optional<Location> mainRewriteLoc = {});
1151 void executeAreEqual();
1152 void executeAreRangesEqual();
1153 void executeBranch();
1154 void executeCheckOperandCount();
1155 void executeCheckOperationName();
1156 void executeCheckResultCount();
1157 void executeCheckTypes();
1158 void executeContinue();
1159 void executeCreateConstantTypeRange();
1162 template <
typename T>
1163 void executeDynamicCreateRange(StringRef type);
1165 template <
typename T,
typename Range, PDLValue::Kind kind>
1166 void executeExtract();
1167 void executeFinalize();
1168 void executeForEach();
1169 void executeGetAttribute();
1170 void executeGetAttributeType();
1171 void executeGetDefiningOp();
1172 void executeGetOperand(
unsigned index);
1173 void executeGetOperands();
1174 void executeGetResult(
unsigned index);
1175 void executeGetResults();
1176 void executeGetUsers();
1177 void executeGetValueType();
1178 void executeGetValueRangeTypes();
1179 void executeIsNotNull();
1183 void executeSwitchAttribute();
1184 void executeSwitchOperandCount();
1185 void executeSwitchOperationName();
1186 void executeSwitchResultCount();
1187 void executeSwitchType();
1188 void executeSwitchTypes();
1189 void processNativeFunResults(ByteCodeRewriteResultList &results,
1190 unsigned numResults,
1191 LogicalResult &rewriteResult);
1194 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1198 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1199 curCodeIt = resumeCodeIt.back();
1200 resumeCodeIt.pop_back();
1204 const ByteCodeField *getPrevCodeIt()
const {
1207 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(ByteCodeField);
1211 return curCodeIt - 1;
1217 template <
typename T = ByteCodeField>
1218 T read(
size_t skipN = 0) {
1220 return readImpl<T>();
1222 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1225 template <
typename ValueT,
typename T>
1228 for (
unsigned i = 0, e = read(); i != e; ++i)
1229 list.push_back(read<ValueT>());
1235 for (
unsigned i = 0, e = read(); i != e; ++i) {
1236 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1237 list.push_back(read<Type>());
1239 TypeRange *values = read<TypeRange *>();
1240 list.append(values->begin(), values->end());
1245 for (
unsigned i = 0, e = read(); i != e; ++i) {
1246 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1247 list.push_back(read<Value>());
1250 list.append(values->begin(), values->end());
1256 template <
typename T>
1257 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1259 const void *pointer;
1260 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1261 curCodeIt +=
sizeof(
const void *) /
sizeof(ByteCodeField);
1262 return T::getFromOpaquePointer(pointer);
1265 void skip(
size_t skipN) { curCodeIt += skipN; }
1268 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1270 void selectJump(
size_t destIndex) {
1271 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1275 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1276 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1278 llvm::dbgs() <<
" * Value: " << value <<
"\n"
1280 llvm::interleaveComma(cases, llvm::dbgs());
1281 llvm::dbgs() <<
"\n";
1286 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1287 if (cmp(*it, value))
1288 return selectJump(
size_t((it - cases.begin()) + 1));
1289 selectJump(
size_t(0));
1293 void storeToMemory(
unsigned index,
const void *value) {
1294 memory[index] = value;
1298 template <
typename T>
1299 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1300 storeToMemory(
unsigned index, T value) {
1301 memory[index] = value.getAsOpaquePointer();
1306 template <
typename T>
1307 const void *readFromMemory() {
1308 size_t index = *curCodeIt++;
1313 index < memory.size())
1314 return memory[index];
1317 return uniquedMemory[index - memory.size()];
1319 template <
typename T>
1320 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1321 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1323 template <
typename T>
1324 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1327 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1329 template <
typename T>
1330 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1331 switch (read<PDLValue::Kind>()) {
1332 case PDLValue::Kind::Attribute:
1333 return read<Attribute>();
1334 case PDLValue::Kind::Operation:
1335 return read<Operation *>();
1336 case PDLValue::Kind::Type:
1337 return read<Type>();
1338 case PDLValue::Kind::Value:
1339 return read<Value>();
1340 case PDLValue::Kind::TypeRange:
1341 return read<TypeRange *>();
1342 case PDLValue::Kind::ValueRange:
1343 return read<ValueRange *>();
1345 llvm_unreachable(
"unhandled PDLValue::Kind");
1347 template <
typename T>
1348 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1349 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
1350 "unexpected ByteCode address size");
1351 ByteCodeAddr result;
1352 std::memcpy(&result, curCodeIt,
sizeof(ByteCodeAddr));
1356 template <
typename T>
1357 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1358 return *curCodeIt++;
1360 template <
typename T>
1361 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1367 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1368 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1369 unsigned rangeIndex) {
1371 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1373 if (range.empty()) {
1374 rangeMemory[rangeIndex] = {};
1377 llvm::OwningArrayRef<T> storage(llvm::size(range));
1382 allocatedRangeMemory.emplace_back(std::move(storage));
1383 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1385 memory[memIndex] = &rangeMemory[rangeIndex];
1389 if constexpr (std::is_same_v<T, Type>) {
1390 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1391 }
else if constexpr (std::is_same_v<T, Value>) {
1392 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1394 llvm_unreachable(
"unhandled range type");
1399 const ByteCodeField *curCodeIt;
1408 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1410 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1425 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1426 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyConstraint:\n");
1427 ByteCodeField fun_idx = read();
1429 readList<PDLValue>(args);
1432 llvm::dbgs() <<
" * Arguments: ";
1433 llvm::interleaveComma(args, llvm::dbgs());
1434 llvm::dbgs() <<
"\n";
1437 ByteCodeField isNegated = read();
1439 llvm::dbgs() <<
" * isNegated: " << isNegated <<
"\n";
1440 llvm::interleaveComma(args, llvm::dbgs());
1443 ByteCodeField numResults = read();
1444 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1445 ByteCodeRewriteResultList results(numResults);
1446 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1449 if (succeeded(rewriteResult)) {
1450 llvm::dbgs() <<
" * Constraint succeeded\n";
1451 llvm::dbgs() <<
" * Results: ";
1452 llvm::interleaveComma(constraintResults, llvm::dbgs());
1453 llvm::dbgs() <<
"\n";
1455 llvm::dbgs() <<
" * Constraint failed\n";
1458 assert((failed(rewriteResult) || constraintResults.size() == numResults) &&
1459 "native PDL rewrite function succeeded but returned "
1460 "unexpected number of results");
1461 processNativeFunResults(results, numResults, rewriteResult);
1464 selectJump(isNegated != succeeded(rewriteResult));
1467 LogicalResult ByteCodeExecutor::executeApplyRewrite(
PatternRewriter &rewriter) {
1468 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyRewrite:\n");
1469 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1471 readList<PDLValue>(args);
1474 llvm::dbgs() <<
" * Arguments: ";
1475 llvm::interleaveComma(args, llvm::dbgs());
1479 ByteCodeField numResults = read();
1480 ByteCodeRewriteResultList results(numResults);
1481 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1483 assert(results.getResults().size() == numResults &&
1484 "native PDL rewrite function returned unexpected number of results");
1486 processNativeFunResults(results, numResults, rewriteResult);
1488 if (failed(rewriteResult)) {
1489 LLVM_DEBUG(llvm::dbgs() <<
" - Failed");
1495 void ByteCodeExecutor::processNativeFunResults(
1496 ByteCodeRewriteResultList &results,
unsigned numResults,
1497 LogicalResult &rewriteResult) {
1500 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1505 if (failed(rewriteResult)) {
1506 if (resultKind == PDLValue::Kind::TypeRange ||
1507 resultKind == PDLValue::Kind::ValueRange) {
1514 PDLValue result = results.getResults()[resultIdx];
1515 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << result <<
"\n");
1516 assert(result.getKind() == resultKind &&
1517 "native PDL rewrite function returned an unexpected type of "
1521 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1522 unsigned rangeIndex = read();
1523 typeRangeMemory[rangeIndex] = *typeRange;
1524 memory[read()] = &typeRangeMemory[rangeIndex];
1525 }
else if (std::optional<ValueRange> valueRange =
1527 unsigned rangeIndex = read();
1528 valueRangeMemory[rangeIndex] = *valueRange;
1529 memory[read()] = &valueRangeMemory[rangeIndex];
1531 memory[read()] = result.getAsOpaquePointer();
1536 for (
auto &it : results.getAllocatedTypeRanges())
1537 allocatedTypeRangeMemory.push_back(std::move(it));
1538 for (
auto &it : results.getAllocatedValueRanges())
1539 allocatedValueRangeMemory.push_back(std::move(it));
1542 void ByteCodeExecutor::executeAreEqual() {
1543 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1544 const void *lhs = read<const void *>();
1545 const void *rhs = read<const void *>();
1547 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n");
1548 selectJump(lhs == rhs);
1551 void ByteCodeExecutor::executeAreRangesEqual() {
1552 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreRangesEqual:\n");
1554 const void *lhs = read<const void *>();
1555 const void *rhs = read<const void *>();
1557 switch (valueKind) {
1558 case PDLValue::Kind::TypeRange: {
1561 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1562 selectJump(*lhsRange == *rhsRange);
1565 case PDLValue::Kind::ValueRange: {
1566 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1567 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1568 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1569 selectJump(*lhsRange == *rhsRange);
1573 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1577 void ByteCodeExecutor::executeBranch() {
1578 LLVM_DEBUG(llvm::dbgs() <<
"Executing Branch\n");
1579 curCodeIt = &code[read<ByteCodeAddr>()];
1582 void ByteCodeExecutor::executeCheckOperandCount() {
1583 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperandCount:\n");
1585 uint32_t expectedCount = read<uint32_t>();
1586 bool compareAtLeast = read();
1588 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumOperands() <<
"\n"
1589 <<
" * Expected: " << expectedCount <<
"\n"
1590 <<
" * Comparator: "
1591 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1598 void ByteCodeExecutor::executeCheckOperationName() {
1599 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperationName:\n");
1603 LLVM_DEBUG(llvm::dbgs() <<
" * Found: \"" << op->
getName() <<
"\"\n"
1604 <<
" * Expected: \"" << expectedName <<
"\"\n");
1605 selectJump(op->
getName() == expectedName);
1608 void ByteCodeExecutor::executeCheckResultCount() {
1609 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckResultCount:\n");
1611 uint32_t expectedCount = read<uint32_t>();
1612 bool compareAtLeast = read();
1614 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumResults() <<
"\n"
1615 <<
" * Expected: " << expectedCount <<
"\n"
1616 <<
" * Comparator: "
1617 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1624 void ByteCodeExecutor::executeCheckTypes() {
1625 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1628 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1630 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1633 void ByteCodeExecutor::executeContinue() {
1634 ByteCodeField level = read();
1635 LLVM_DEBUG(llvm::dbgs() <<
"Executing Continue\n"
1636 <<
" * Level: " << level <<
"\n");
1641 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1642 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateConstantTypeRange:\n");
1643 unsigned memIndex = read();
1644 unsigned rangeIndex = read();
1645 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1647 LLVM_DEBUG(llvm::dbgs() <<
" * Types: " << typesAttr <<
"\n\n");
1648 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1652 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1654 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateOperation:\n");
1656 unsigned memIndex = read();
1658 readList(state.operands);
1659 for (
unsigned i = 0, e = read(); i != e; ++i) {
1660 StringAttr name = read<StringAttr>();
1662 state.addAttribute(name, attr);
1667 unsigned numResults = read();
1669 InferTypeOpInterface::Concept *inferInterface =
1670 state.name.getInterface<InferTypeOpInterface>();
1671 assert(inferInterface &&
1672 "expected operation to provide InferTypeOpInterface");
1675 if (failed(inferInterface->inferReturnTypes(
1676 state.getContext(), state.location, state.operands,
1677 state.attributes.getDictionary(state.getContext()),
1678 state.getRawProperties(), state.regions, state.types)))
1682 for (
unsigned i = 0; i != numResults; ++i) {
1683 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1684 state.types.push_back(read<Type>());
1686 TypeRange *resultTypes = read<TypeRange *>();
1687 state.types.append(resultTypes->begin(), resultTypes->end());
1693 memory[memIndex] = resultOp;
1696 llvm::dbgs() <<
" * Attributes: "
1697 << state.attributes.getDictionary(state.getContext())
1698 <<
"\n * Operands: ";
1699 llvm::interleaveComma(state.operands, llvm::dbgs());
1700 llvm::dbgs() <<
"\n * Result Types: ";
1701 llvm::interleaveComma(state.types, llvm::dbgs());
1702 llvm::dbgs() <<
"\n * Result: " << *resultOp <<
"\n";
1706 template <
typename T>
1707 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1708 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateDynamic" << type <<
"Range:\n");
1709 unsigned memIndex = read();
1710 unsigned rangeIndex = read();
1715 llvm::dbgs() <<
"\n * " << type <<
"s: ";
1716 llvm::interleaveComma(values, llvm::dbgs());
1717 llvm::dbgs() <<
"\n";
1720 assignRangeToMemory(values, memIndex, rangeIndex);
1724 LLVM_DEBUG(llvm::dbgs() <<
"Executing EraseOp:\n");
1727 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1731 template <
typename T,
typename Range, PDLValue::Kind kind>
1732 void ByteCodeExecutor::executeExtract() {
1733 LLVM_DEBUG(llvm::dbgs() <<
"Executing Extract" << kind <<
":\n");
1734 Range *range = read<Range *>();
1735 unsigned index = read<uint32_t>();
1736 unsigned memIndex = read();
1739 memory[memIndex] =
nullptr;
1743 T result = index < range->
size() ? (*range)[index] : T();
1744 LLVM_DEBUG(llvm::dbgs() <<
" * " << kind <<
"s(" << range->
size() <<
")\n"
1745 <<
" * Index: " << index <<
"\n"
1746 <<
" * Result: " << result <<
"\n");
1747 storeToMemory(memIndex, result);
1750 void ByteCodeExecutor::executeFinalize() {
1751 LLVM_DEBUG(llvm::dbgs() <<
"Executing Finalize\n");
1754 void ByteCodeExecutor::executeForEach() {
1755 LLVM_DEBUG(llvm::dbgs() <<
"Executing ForEach:\n");
1756 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1757 unsigned rangeIndex = read();
1758 unsigned memIndex = read();
1759 const void *value =
nullptr;
1761 switch (read<PDLValue::Kind>()) {
1762 case PDLValue::Kind::Operation: {
1763 unsigned &index = loopIndex[read()];
1765 assert(index <= array.size() &&
"iterated past the end");
1766 if (index < array.size()) {
1767 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << array[index] <<
"\n");
1768 value = array[index];
1772 LLVM_DEBUG(llvm::dbgs() <<
" * Done\n");
1774 selectJump(
size_t(0));
1778 llvm_unreachable(
"unexpected `ForEach` value kind");
1782 memory[memIndex] = value;
1783 pushCodeIt(prevCodeIt);
1786 read<ByteCodeAddr>();
1789 void ByteCodeExecutor::executeGetAttribute() {
1790 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttribute:\n");
1791 unsigned memIndex = read();
1793 StringAttr attrName = read<StringAttr>();
1796 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1797 <<
" * Attribute: " << attrName <<
"\n"
1798 <<
" * Result: " << attr <<
"\n");
1802 void ByteCodeExecutor::executeGetAttributeType() {
1803 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttributeType:\n");
1804 unsigned memIndex = read();
1807 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1808 type = typedAttr.getType();
1810 LLVM_DEBUG(llvm::dbgs() <<
" * Attribute: " << attr <<
"\n"
1811 <<
" * Result: " << type <<
"\n");
1815 void ByteCodeExecutor::executeGetDefiningOp() {
1816 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetDefiningOp:\n");
1817 unsigned memIndex = read();
1819 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1820 Value value = read<Value>();
1823 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1826 if (values && !values->empty()) {
1827 op = values->front().getDefiningOp();
1829 LLVM_DEBUG(llvm::dbgs() <<
" * Values: " << values <<
"\n");
1832 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << op <<
"\n");
1833 memory[memIndex] = op;
1836 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1838 unsigned memIndex = read();
1842 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1843 <<
" * Index: " << index <<
"\n"
1844 <<
" * Result: " << operand <<
"\n");
1851 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1854 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1859 LLVM_DEBUG(llvm::dbgs() <<
" * Getting all values\n");
1863 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1864 LLVM_DEBUG(llvm::dbgs()
1865 <<
" * Extracting values from `" << attrSizedSegments <<
"`\n");
1868 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1872 unsigned startIndex =
1873 std::accumulate(segments.begin(), segments.begin() + index, 0);
1874 values = values.slice(startIndex, *std::next(segments.begin(), index));
1876 LLVM_DEBUG(llvm::dbgs() <<
" * Extracting range[" << startIndex <<
", "
1877 << *std::next(segments.begin(), index) <<
"]\n");
1883 }
else if (values.size() >= index) {
1884 LLVM_DEBUG(llvm::dbgs()
1885 <<
" * Treating values as trailing variadic range\n");
1886 values = values.drop_front(index);
1895 valueRangeMemory[rangeIndex] = values;
1896 return &valueRangeMemory[rangeIndex];
1900 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1903 void ByteCodeExecutor::executeGetOperands() {
1904 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperands:\n");
1905 unsigned index = read<uint32_t>();
1907 ByteCodeField rangeIndex = read();
1909 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1910 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1913 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid operand range\n");
1914 memory[read()] = result;
1917 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1919 unsigned memIndex = read();
1923 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1924 <<
" * Index: " << index <<
"\n"
1925 <<
" * Result: " << result <<
"\n");
1929 void ByteCodeExecutor::executeGetResults() {
1930 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResults:\n");
1931 unsigned index = read<uint32_t>();
1933 ByteCodeField rangeIndex = read();
1935 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1936 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1939 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid result range\n");
1940 memory[read()] = result;
1943 void ByteCodeExecutor::executeGetUsers() {
1944 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetUsers:\n");
1945 unsigned memIndex = read();
1946 unsigned rangeIndex = read();
1947 OwningOpRange &range = opRangeMemory[rangeIndex];
1948 memory[memIndex] = ⦥
1950 range = OwningOpRange();
1951 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1953 Value value = read<Value>();
1956 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1967 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1968 llvm::interleaveComma(*values, llvm::dbgs());
1969 llvm::dbgs() <<
"\n";
1974 for (
Value value : *values)
1976 range = OwningOpRange(users.size());
1980 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << range.size() <<
" operations\n");
1983 void ByteCodeExecutor::executeGetValueType() {
1984 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueType:\n");
1985 unsigned memIndex = read();
1986 Value value = read<Value>();
1989 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n"
1990 <<
" * Result: " << type <<
"\n");
1994 void ByteCodeExecutor::executeGetValueRangeTypes() {
1995 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueRangeTypes:\n");
1996 unsigned memIndex = read();
1997 unsigned rangeIndex = read();
2000 LLVM_DEBUG(llvm::dbgs() <<
" * Values: <NULL>\n\n");
2001 memory[memIndex] =
nullptr;
2006 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
2007 llvm::interleaveComma(*values, llvm::dbgs());
2008 llvm::dbgs() <<
"\n * Result: ";
2009 llvm::interleaveComma(values->
getType(), llvm::dbgs());
2010 llvm::dbgs() <<
"\n";
2012 typeRangeMemory[rangeIndex] = values->
getType();
2013 memory[memIndex] = &typeRangeMemory[rangeIndex];
2016 void ByteCodeExecutor::executeIsNotNull() {
2017 LLVM_DEBUG(llvm::dbgs() <<
"Executing IsNotNull:\n");
2018 const void *value = read<const void *>();
2020 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
2021 selectJump(value !=
nullptr);
2024 void ByteCodeExecutor::executeRecordMatch(
2027 LLVM_DEBUG(llvm::dbgs() <<
"Executing RecordMatch:\n");
2028 unsigned patternIndex = read();
2030 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
2035 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: Impossible To Match\n");
2044 unsigned numMatchLocs = read();
2046 matchLocs.reserve(numMatchLocs);
2047 for (
unsigned i = 0; i != numMatchLocs; ++i)
2048 matchLocs.push_back(read<Operation *>()->getLoc());
2051 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: " << benefit.
getBenefit() <<
"\n"
2052 <<
" * Location: " << matchLoc <<
"\n");
2053 matches.emplace_back(matchLoc,
patterns[patternIndex], benefit);
2059 unsigned numInputs = read();
2060 match.values.reserve(numInputs);
2061 match.typeRangeValues.reserve(numInputs);
2062 match.valueRangeValues.reserve(numInputs);
2063 for (
unsigned i = 0; i < numInputs; ++i) {
2064 switch (read<PDLValue::Kind>()) {
2065 case PDLValue::Kind::TypeRange:
2066 match.typeRangeValues.push_back(*read<TypeRange *>());
2067 match.values.push_back(&match.typeRangeValues.back());
2069 case PDLValue::Kind::ValueRange:
2070 match.valueRangeValues.push_back(*read<ValueRange *>());
2071 match.values.push_back(&match.valueRangeValues.back());
2074 match.values.push_back(read<const void *>());
2082 LLVM_DEBUG(llvm::dbgs() <<
"Executing ReplaceOp:\n");
2088 llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
2090 llvm::interleaveComma(args, llvm::dbgs());
2091 llvm::dbgs() <<
"\n";
2096 void ByteCodeExecutor::executeSwitchAttribute() {
2097 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchAttribute:\n");
2099 ArrayAttr cases = read<ArrayAttr>();
2100 handleSwitch(value, cases);
2103 void ByteCodeExecutor::executeSwitchOperandCount() {
2104 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperandCount:\n");
2106 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2108 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2112 void ByteCodeExecutor::executeSwitchOperationName() {
2113 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperationName:\n");
2115 size_t caseCount = read();
2121 const ByteCodeField *prevCodeIt = curCodeIt;
2122 llvm::dbgs() <<
" * Value: " << value <<
"\n"
2124 llvm::interleaveComma(
2125 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2126 [&](
size_t) {
return read<OperationName>(); }),
2128 llvm::dbgs() <<
"\n";
2129 curCodeIt = prevCodeIt;
2133 for (
size_t i = 0; i != caseCount; ++i) {
2134 if (read<OperationName>() == value) {
2135 curCodeIt += (caseCount - i - 1);
2136 return selectJump(i + 1);
2139 selectJump(
size_t(0));
2142 void ByteCodeExecutor::executeSwitchResultCount() {
2143 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchResultCount:\n");
2145 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2147 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2151 void ByteCodeExecutor::executeSwitchType() {
2152 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchType:\n");
2153 Type value = read<Type>();
2154 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2155 handleSwitch(value, cases);
2158 void ByteCodeExecutor::executeSwitchTypes() {
2159 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchTypes:\n");
2161 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2163 LLVM_DEBUG(llvm::dbgs() <<
"Types: <NULL>\n");
2164 return selectJump(
size_t(0));
2166 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2167 return value == caseValue.getAsValueRange<TypeAttr>();
2174 std::optional<Location> mainRewriteLoc) {
2177 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() <<
"\n");
2179 OpCode opCode =
static_cast<OpCode
>(read());
2181 case ApplyConstraint:
2182 executeApplyConstraint(rewriter);
2185 if (failed(executeApplyRewrite(rewriter)))
2191 case AreRangesEqual:
2192 executeAreRangesEqual();
2197 case CheckOperandCount:
2198 executeCheckOperandCount();
2200 case CheckOperationName:
2201 executeCheckOperationName();
2203 case CheckResultCount:
2204 executeCheckResultCount();
2207 executeCheckTypes();
2212 case CreateConstantTypeRange:
2213 executeCreateConstantTypeRange();
2215 case CreateOperation:
2216 executeCreateOperation(rewriter, *mainRewriteLoc);
2218 case CreateDynamicTypeRange:
2219 executeDynamicCreateRange<Type>(
"Type");
2221 case CreateDynamicValueRange:
2222 executeDynamicCreateRange<Value>(
"Value");
2225 executeEraseOp(rewriter);
2228 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2231 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2234 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2238 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2244 executeGetAttribute();
2246 case GetAttributeType:
2247 executeGetAttributeType();
2250 executeGetDefiningOp();
2256 unsigned index = opCode - GetOperand0;
2257 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperand" << index <<
":\n");
2258 executeGetOperand(index);
2262 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperandN:\n");
2263 executeGetOperand(read<uint32_t>());
2266 executeGetOperands();
2272 unsigned index = opCode - GetResult0;
2273 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResult" << index <<
":\n");
2274 executeGetResult(index);
2278 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResultN:\n");
2279 executeGetResult(read<uint32_t>());
2282 executeGetResults();
2288 executeGetValueType();
2290 case GetValueRangeTypes:
2291 executeGetValueRangeTypes();
2298 "expected matches to be provided when executing the matcher");
2299 executeRecordMatch(rewriter, *matches);
2302 executeReplaceOp(rewriter);
2304 case SwitchAttribute:
2305 executeSwitchAttribute();
2307 case SwitchOperandCount:
2308 executeSwitchOperandCount();
2310 case SwitchOperationName:
2311 executeSwitchOperationName();
2313 case SwitchResultCount:
2314 executeSwitchResultCount();
2317 executeSwitchType();
2320 executeSwitchTypes();
2323 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2331 state.memory[0] = op;
2334 ByteCodeExecutor executor(
2335 matcherByteCode.data(), state.memory, state.opRangeMemory,
2336 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2337 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2338 uniquedData, matcherByteCode, state.currentPatternBenefits,
patterns,
2339 constraintFunctions, rewriteFunctions);
2340 LogicalResult executeResult = executor.execute(rewriter, &matches);
2341 (void)executeResult;
2342 assert(succeeded(executeResult) &&
"unexpected matcher execution failure");
2345 std::stable_sort(matches.begin(), matches.end(),
2346 [](
const MatchResult &lhs,
const MatchResult &rhs) {
2347 return lhs.benefit > rhs.benefit;
2352 const MatchResult &match,
2354 auto *configSet =
match.pattern->getConfigSet();
2356 configSet->notifyRewriteBegin(rewriter);
2362 ByteCodeExecutor executor(
2363 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2364 state.opRangeMemory, state.typeRangeMemory,
2365 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2366 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2367 rewriterByteCode, state.currentPatternBenefits,
patterns,
2368 constraintFunctions, rewriteFunctions);
2369 LogicalResult result =
2370 executor.execute(rewriter,
nullptr,
match.location);
2373 configSet->notifyRewriteEnd(rewriter);
2382 LLVM_DEBUG(llvm::dbgs() <<
" and rollback is not supported - aborting");
2383 llvm::report_fatal_error(
2384 "Native PDL Rewrite failed, but the pattern "
2385 "rewriter doesn't support recovery. Failable pattern rewrites should "
2386 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...