19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "pdl-bytecode"
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
54 benefit, ctx, generatedOps);
66 currentPatternBenefits[patternIndex] = benefit;
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
104 CreateConstantTypeRange,
108 CreateDynamicTypeRange,
110 CreateDynamicValueRange,
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
188 template <
typename T,
typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
194 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
203 llvm::StringMap<PDLConstraintFunction> &constraintFns,
204 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207 rewriterByteCode(rewriterByteCode), patterns(patterns),
208 maxValueMemoryIndex(maxValueMemoryIndex),
209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212 maxLoopLevel(maxLoopLevel), configMap(configMap) {
214 constraintToMemIndex.try_emplace(it.value().first(), it.index());
216 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220 void generate(ModuleOp module);
224 assert(valueToMemIndex.count(value) &&
225 "expected memory index to be assigned");
226 return valueToMemIndex[value];
231 assert(valueToRangeIndex.count(value) &&
232 "expected range index to be assigned");
233 return valueToRangeIndex[value];
238 template <
typename T>
239 std::enable_if_t<!std::is_convertible<T, Value>::value,
ByteCodeField &>
241 const void *opaqueVal = val.getAsOpaquePointer();
244 auto it = uniquedDataToMemIndex.try_emplace(
245 opaqueVal, maxValueMemoryIndex + uniquedData.size());
247 uniquedData.push_back(opaqueVal);
248 return it.first->second;
254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255 ModuleOp rewriterModule);
258 void generate(
Region *region, ByteCodeWriter &writer);
259 void generate(
Operation *op, ByteCodeWriter &writer);
260 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
307 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311 llvm::StringMap<ByteCodeField> constraintToMemIndex;
315 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
331 std::vector<const void *> &uniquedData;
346 struct ByteCodeWriter {
351 void append(
ByteCodeField field) { bytecode.push_back(field); }
352 void append(OpCode opCode) { bytecode.push_back(opCode); }
357 "unexpected ByteCode address size");
361 bytecode.append({fieldParts[0], fieldParts[1]});
366 void append(
Block *successor) {
369 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
376 for (
Block *successor : successors)
382 bytecode.push_back(values.size());
383 for (
Value value : values)
384 appendPDLValue(value);
388 void appendPDLValue(
Value value) {
389 appendPDLValueKind(value);
394 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
397 void appendPDLValueKind(
Type type) {
400 .Case<pdl::AttributeType>(
402 .Case<pdl::OperationType>(
404 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405 if (rangeTy.getElementType().isa<pdl::TypeType>())
416 template <
typename T>
417 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418 std::is_pointer<T>::value>
420 bytecode.push_back(
generator.getMemIndex(value));
424 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
425 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
427 bytecode.push_back(llvm::size(range));
428 for (
auto it : range)
433 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
434 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
436 append(field2, fields...);
440 template <
typename T>
441 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442 appendInline(T value) {
443 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(
ByteCodeField);
444 const void *pointer = value.getAsOpaquePointer();
446 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
447 bytecode.append(fieldParts, fieldParts + numParts);
462 struct ByteCodeLiveRange {
463 using Set = llvm::IntervalMap<uint64_t, char, 16>;
464 using Allocator = Set::Allocator;
466 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
469 void unionWith(
const ByteCodeLiveRange &rhs) {
470 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
472 liveness->insert(it.start(), it.stop(), 0);
476 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
477 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
487 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
490 std::optional<unsigned> opRangeIndex;
493 std::optional<unsigned> typeRangeIndex;
496 std::optional<unsigned> valueRangeIndex;
500 void Generator::generate(ModuleOp module) {
501 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504 pdl_interp::PDLInterpDialect::getRewriterModuleName());
505 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
509 allocateMemoryIndices(matcherFunc, rewriterModule);
512 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
513 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515 for (
Operation &op : rewriterFunc.getOps())
516 generate(&op, rewriterByteCodeWriter);
518 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519 "unexpected branches in rewriter function");
522 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
523 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
526 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528 for (
unsigned offsetToFix : it.second)
529 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(
ByteCodeAddr));
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534 ModuleOp rewriterModule) {
537 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539 auto processRewriterValue = [&](
Value val) {
540 valueToMemIndex.try_emplace(val, index++);
541 if (pdl::RangeType rangeType = val.getType().dyn_cast<pdl::RangeType>()) {
542 Type elementTy = rangeType.getElementType();
543 if (elementTy.
isa<pdl::TypeType>())
544 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545 else if (elementTy.
isa<pdl::ValueType>())
546 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
551 processRewriterValue(arg);
554 processRewriterValue(result);
556 if (index > maxValueMemoryIndex)
557 maxValueMemoryIndex = index;
558 if (typeRangeIndex > maxTypeRangeMemoryIndex)
559 maxTypeRangeMemoryIndex = typeRangeIndex;
560 if (valueRangeIndex > maxValueRangeMemoryIndex)
561 maxValueRangeMemoryIndex = valueRangeIndex;
578 opToFirstIndex.try_emplace(op, index++);
580 for (
Block &block : region.getBlocks())
583 opToLastIndex.try_emplace(op, index++);
588 ByteCodeLiveRange::Allocator allocator;
593 valueToMemIndex[rootOpArg] = 0;
596 Liveness matcherLiveness(matcherFunc);
597 matcherFunc->walk([&](
Block *block) {
599 assert(info &&
"expected liveness info for block");
603 if (value == rootOpArg)
607 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608 defRangeIt->second.liveness->insert(
609 opToFirstIndex[firstUseOrDef],
615 Type eleType = rangeTy.getElementType();
616 if (eleType.
isa<pdl::OperationType>())
617 defRangeIt->second.opRangeIndex = 0;
618 else if (eleType.
isa<pdl::TypeType>())
619 defRangeIt->second.typeRangeIndex = 0;
620 else if (eleType.
isa<pdl::ValueType>())
621 defRangeIt->second.valueRangeIndex = 0;
626 for (
Value liveIn : info->
in()) {
631 if (liveIn.getParentRegion() == block->
getParent())
648 std::vector<ByteCodeLiveRange> allocatedIndices;
655 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
657 for (
auto &defIt : valueDefRanges) {
659 ByteCodeLiveRange &defRange = defIt.second;
662 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
663 ByteCodeLiveRange &existingRange = existingIndexIt.value();
664 if (!defRange.overlaps(existingRange)) {
665 existingRange.unionWith(defRange);
666 memIndex = existingIndexIt.index() + 1;
668 if (defRange.opRangeIndex) {
669 if (!existingRange.opRangeIndex)
670 existingRange.opRangeIndex = numOpRanges++;
671 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672 }
else if (defRange.typeRangeIndex) {
673 if (!existingRange.typeRangeIndex)
674 existingRange.typeRangeIndex = numTypeRanges++;
675 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676 }
else if (defRange.valueRangeIndex) {
677 if (!existingRange.valueRangeIndex)
678 existingRange.valueRangeIndex = numValueRanges++;
679 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
687 allocatedIndices.emplace_back(allocator);
688 ByteCodeLiveRange &newRange = allocatedIndices.back();
689 newRange.unionWith(defRange);
692 if (defRange.opRangeIndex) {
693 newRange.opRangeIndex = numOpRanges;
694 valueToRangeIndex[defIt.first] = numOpRanges++;
695 }
else if (defRange.typeRangeIndex) {
696 newRange.typeRangeIndex = numTypeRanges;
697 valueToRangeIndex[defIt.first] = numTypeRanges++;
698 }
else if (defRange.valueRangeIndex) {
699 newRange.valueRangeIndex = numValueRanges;
700 valueToRangeIndex[defIt.first] = numValueRanges++;
703 memIndex = allocatedIndices.size();
710 llvm::dbgs() <<
"Allocated " << allocatedIndices.size() <<
" indices "
711 <<
"(down from initial " << valueDefRanges.size() <<
").\n";
714 "Ran out of memory for allocated indices");
717 if (numIndices > maxValueMemoryIndex)
718 maxValueMemoryIndex = numIndices;
719 if (numOpRanges > maxOpRangeMemoryIndex)
720 maxOpRangeMemoryIndex = numOpRanges;
721 if (numTypeRanges > maxTypeRangeMemoryIndex)
722 maxTypeRangeMemoryIndex = numTypeRanges;
723 if (numValueRanges > maxValueRangeMemoryIndex)
724 maxValueRangeMemoryIndex = numValueRanges;
727 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
728 llvm::ReversePostOrderTraversal<Region *> rpot(region);
729 for (
Block *block : rpot) {
731 blockToAddr.try_emplace(block, matcherByteCode.size());
733 generate(&op, writer);
737 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
741 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742 writer.appendInline(op->
getLoc());
745 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763 pdl_interp::SwitchResultCountOp>(
764 [&](
auto interpOp) { this->generate(interpOp, writer); })
766 llvm_unreachable(
"unknown `pdl_interp` operation");
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 assert(constraintToMemIndex.count(op.getName()) &&
773 "expected index for constraint function");
774 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
775 writer.appendPDLValueList(op.getArgs());
776 writer.append(op.getSuccessors());
778 void Generator::generate(pdl_interp::ApplyRewriteOp op,
779 ByteCodeWriter &writer) {
780 assert(externalRewriterToMemIndex.count(op.getName()) &&
781 "expected index for rewrite function");
782 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
783 writer.appendPDLValueList(op.getArgs());
787 for (
Value result : results) {
791 writer.appendPDLValueKind(result);
795 if (result.getType().isa<pdl::RangeType>())
796 writer.append(getRangeStorageIndex(result));
797 writer.append(result);
800 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
801 Value lhs = op.getLhs();
803 writer.append(OpCode::AreRangesEqual);
804 writer.appendPDLValueKind(lhs);
805 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
809 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
811 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
814 void Generator::generate(pdl_interp::CheckAttributeOp op,
815 ByteCodeWriter &writer) {
816 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
819 void Generator::generate(pdl_interp::CheckOperandCountOp op,
820 ByteCodeWriter &writer) {
821 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
825 void Generator::generate(pdl_interp::CheckOperationNameOp op,
826 ByteCodeWriter &writer) {
827 writer.append(OpCode::CheckOperationName, op.getInputOp(),
830 void Generator::generate(pdl_interp::CheckResultCountOp op,
831 ByteCodeWriter &writer) {
832 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
836 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
837 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
840 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
841 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
844 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
845 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
846 writer.append(OpCode::Continue,
ByteCodeField(curLoopLevel - 1));
848 void Generator::generate(pdl_interp::CreateAttributeOp op,
849 ByteCodeWriter &writer) {
851 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
853 void Generator::generate(pdl_interp::CreateOperationOp op,
854 ByteCodeWriter &writer) {
855 writer.append(OpCode::CreateOperation, op.getResultOp(),
857 writer.appendPDLValueList(op.getInputOperands());
861 writer.append(
static_cast<ByteCodeField>(attributes.size()));
862 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
863 writer.append(std::get<0>(it), std::get<1>(it));
867 if (op.getInferredResultTypes())
870 writer.appendPDLValueList(op.getInputResultTypes());
872 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
876 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
877 .Case([&](pdl::ValueType) {
878 writer.append(OpCode::CreateDynamicValueRange);
881 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
882 writer.appendPDLValueList(op->getOperands());
884 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
886 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
888 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
889 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
890 getRangeStorageIndex(op.getResult()), op.getValue());
892 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
893 writer.append(OpCode::EraseOp, op.getInputOp());
895 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
898 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
899 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
900 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
901 .Default([](
Type) -> OpCode {
902 llvm_unreachable(
"unsupported element type");
904 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
906 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
907 writer.append(OpCode::Finalize);
909 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
911 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
912 writer.appendPDLValueKind(arg.
getType());
913 writer.append(curLoopLevel, op.getSuccessor());
915 if (curLoopLevel > maxLoopLevel)
916 maxLoopLevel = curLoopLevel;
917 generate(&op.getRegion(), writer);
920 void Generator::generate(pdl_interp::GetAttributeOp op,
921 ByteCodeWriter &writer) {
922 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
925 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
926 ByteCodeWriter &writer) {
927 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
929 void Generator::generate(pdl_interp::GetDefiningOpOp op,
930 ByteCodeWriter &writer) {
931 writer.append(OpCode::GetDefiningOp, op.getInputOp());
932 writer.appendPDLValue(op.getValue());
934 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
935 uint32_t index = op.getIndex();
937 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
939 writer.append(OpCode::GetOperandN, index);
940 writer.append(op.getInputOp(), op.getValue());
942 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
943 Value result = op.getValue();
944 std::optional<uint32_t> index = op.getIndex();
945 writer.append(OpCode::GetOperands,
949 writer.append(getRangeStorageIndex(result));
952 writer.append(result);
954 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
955 uint32_t index = op.getIndex();
957 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
959 writer.append(OpCode::GetResultN, index);
960 writer.append(op.getInputOp(), op.getValue());
962 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
963 Value result = op.getValue();
964 std::optional<uint32_t> index = op.getIndex();
965 writer.append(OpCode::GetResults,
969 writer.append(getRangeStorageIndex(result));
972 writer.append(result);
974 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
975 Value operations = op.getOperations();
977 writer.append(OpCode::GetUsers, operations, rangeIndex);
978 writer.appendPDLValue(op.getValue());
980 void Generator::generate(pdl_interp::GetValueTypeOp op,
981 ByteCodeWriter &writer) {
982 if (op.getType().isa<pdl::RangeType>()) {
983 Value result = op.getResult();
984 writer.append(OpCode::GetValueRangeTypes, result,
985 getRangeStorageIndex(result), op.getValue());
987 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
990 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
991 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
993 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
996 op, configMap.lookup(op),
997 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
998 writer.append(OpCode::RecordMatch, patternIndex,
1000 writer.appendPDLValueList(op.getInputs());
1002 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1003 writer.append(OpCode::ReplaceOp, op.getInputOp());
1004 writer.appendPDLValueList(op.getReplValues());
1006 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1007 ByteCodeWriter &writer) {
1008 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1009 op.getCaseValuesAttr(), op.getSuccessors());
1011 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1012 ByteCodeWriter &writer) {
1013 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1014 op.getCaseValuesAttr(), op.getSuccessors());
1016 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1017 ByteCodeWriter &writer) {
1018 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1019 return OperationName(attr.cast<StringAttr>().getValue(), ctx);
1021 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1022 op.getSuccessors());
1024 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1025 ByteCodeWriter &writer) {
1026 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1027 op.getCaseValuesAttr(), op.getSuccessors());
1029 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1030 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1031 op.getSuccessors());
1033 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1034 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1035 op.getSuccessors());
1043 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1045 llvm::StringMap<PDLConstraintFunction> constraintFns,
1046 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1047 : configs(std::move(configs)) {
1048 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1049 rewriterByteCode, patterns, maxValueMemoryIndex,
1050 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1051 maxLoopLevel, constraintFns, rewriteFns, configMap);
1055 for (
auto &it : constraintFns)
1056 constraintFunctions.push_back(std::move(it.second));
1057 for (
auto &it : rewriteFns)
1058 rewriteFunctions.push_back(std::move(it.second));
1064 state.memory.resize(maxValueMemoryIndex,
nullptr);
1065 state.opRangeMemory.resize(maxOpRangeCount);
1066 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1067 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1068 state.loopIndex.resize(maxLoopLevel, 0);
1069 state.currentPatternBenefits.reserve(patterns.size());
1071 state.currentPatternBenefits.push_back(pattern.getBenefit());
1079 class ByteCodeExecutor {
1085 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1087 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1094 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1095 typeRangeMemory(typeRangeMemory),
1096 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1097 valueRangeMemory(valueRangeMemory),
1098 allocatedValueRangeMemory(allocatedValueRangeMemory),
1099 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1100 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1101 constraintFunctions(constraintFunctions),
1102 rewriteFunctions(rewriteFunctions) {}
1110 std::optional<Location> mainRewriteLoc = {});
1116 void executeAreEqual();
1117 void executeAreRangesEqual();
1118 void executeBranch();
1119 void executeCheckOperandCount();
1120 void executeCheckOperationName();
1121 void executeCheckResultCount();
1122 void executeCheckTypes();
1123 void executeContinue();
1124 void executeCreateConstantTypeRange();
1127 template <
typename T>
1128 void executeDynamicCreateRange(StringRef type);
1130 template <
typename T,
typename Range, PDLValue::Kind kind>
1131 void executeExtract();
1132 void executeFinalize();
1133 void executeForEach();
1134 void executeGetAttribute();
1135 void executeGetAttributeType();
1136 void executeGetDefiningOp();
1137 void executeGetOperand(
unsigned index);
1138 void executeGetOperands();
1139 void executeGetResult(
unsigned index);
1140 void executeGetResults();
1141 void executeGetUsers();
1142 void executeGetValueType();
1143 void executeGetValueRangeTypes();
1144 void executeIsNotNull();
1148 void executeSwitchAttribute();
1149 void executeSwitchOperandCount();
1150 void executeSwitchOperationName();
1151 void executeSwitchResultCount();
1152 void executeSwitchType();
1153 void executeSwitchTypes();
1156 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1160 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1161 curCodeIt = resumeCodeIt.back();
1162 resumeCodeIt.pop_back();
1169 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(
ByteCodeField);
1173 return curCodeIt - 1;
1179 template <
typename T = ByteCodeField>
1180 T read(
size_t skipN = 0) {
1182 return readImpl<T>();
1184 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1187 template <
typename ValueT,
typename T>
1190 for (
unsigned i = 0, e = read(); i != e; ++i)
1191 list.push_back(read<ValueT>());
1197 for (
unsigned i = 0, e = read(); i != e; ++i) {
1199 list.push_back(read<Type>());
1201 TypeRange *values = read<TypeRange *>();
1202 list.append(values->begin(), values->end());
1207 for (
unsigned i = 0, e = read(); i != e; ++i) {
1209 list.push_back(read<Value>());
1212 list.append(values->begin(), values->end());
1218 template <
typename T>
1219 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1221 const void *pointer;
1222 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1224 return T::getFromOpaquePointer(pointer);
1228 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1230 void selectJump(
size_t destIndex) {
1231 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1235 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1236 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1238 llvm::dbgs() <<
" * Value: " << value <<
"\n"
1240 llvm::interleaveComma(cases, llvm::dbgs());
1241 llvm::dbgs() <<
"\n";
1246 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1247 if (cmp(*it, value))
1248 return selectJump(
size_t((it - cases.begin()) + 1));
1249 selectJump(
size_t(0));
1253 void storeToMemory(
unsigned index,
const void *value) {
1254 memory[index] = value;
1258 template <
typename T>
1259 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1260 storeToMemory(
unsigned index, T value) {
1261 memory[index] = value.getAsOpaquePointer();
1266 template <
typename T>
1267 const void *readFromMemory() {
1268 size_t index = *curCodeIt++;
1273 index < memory.size())
1274 return memory[index];
1277 return uniquedMemory[index - memory.size()];
1279 template <
typename T>
1280 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1281 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1283 template <
typename T>
1284 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1287 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1289 template <
typename T>
1290 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1291 switch (read<PDLValue::Kind>()) {
1293 return read<Attribute>();
1295 return read<Operation *>();
1297 return read<Type>();
1299 return read<Value>();
1301 return read<TypeRange *>();
1303 return read<ValueRange *>();
1305 llvm_unreachable(
"unhandled PDLValue::Kind");
1307 template <
typename T>
1308 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1310 "unexpected ByteCode address size");
1316 template <
typename T>
1317 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1318 return *curCodeIt++;
1320 template <
typename T>
1321 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1327 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1328 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1329 unsigned rangeIndex) {
1331 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1333 if (range.empty()) {
1334 rangeMemory[rangeIndex] = {};
1337 llvm::OwningArrayRef<T> storage(llvm::size(range));
1342 allocatedRangeMemory.emplace_back(std::move(storage));
1343 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1345 memory[memIndex] = &rangeMemory[rangeIndex];
1349 if constexpr (std::is_same_v<T, Type>) {
1350 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1351 }
else if constexpr (std::is_same_v<T, Value>) {
1352 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1354 llvm_unreachable(
"unhandled range type");
1368 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1370 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1389 ByteCodeRewriteResultList(
unsigned maxNumResults)
1397 return allocatedTypeRanges;
1402 return allocatedValueRanges;
1407 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1408 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyConstraint:\n");
1411 readList<PDLValue>(args);
1414 llvm::dbgs() <<
" * Arguments: ";
1415 llvm::interleaveComma(args, llvm::dbgs());
1419 selectJump(
succeeded(constraintFn(rewriter, args)));
1423 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyRewrite:\n");
1426 readList<PDLValue>(args);
1429 llvm::dbgs() <<
" * Arguments: ";
1430 llvm::interleaveComma(args, llvm::dbgs());
1435 ByteCodeRewriteResultList results(numResults);
1436 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1438 assert(results.getResults().size() == numResults &&
1439 "native PDL rewrite function returned unexpected number of results");
1442 for (
PDLValue &result : results.getResults()) {
1443 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << result <<
"\n");
1447 assert(result.getKind() == read<PDLValue::Kind>() &&
1448 "native PDL rewrite function returned an unexpected type of result");
1453 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1454 unsigned rangeIndex = read();
1455 typeRangeMemory[rangeIndex] = *typeRange;
1456 memory[read()] = &typeRangeMemory[rangeIndex];
1457 }
else if (std::optional<ValueRange> valueRange =
1459 unsigned rangeIndex = read();
1460 valueRangeMemory[rangeIndex] = *valueRange;
1461 memory[read()] = &valueRangeMemory[rangeIndex];
1463 memory[read()] = result.getAsOpaquePointer();
1468 for (
auto &it : results.getAllocatedTypeRanges())
1469 allocatedTypeRangeMemory.push_back(std::move(it));
1470 for (
auto &it : results.getAllocatedValueRanges())
1471 allocatedValueRangeMemory.push_back(std::move(it));
1474 if (
failed(rewriteResult)) {
1475 LLVM_DEBUG(llvm::dbgs() <<
" - Failed");
1481 void ByteCodeExecutor::executeAreEqual() {
1482 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1483 const void *lhs = read<const void *>();
1484 const void *rhs = read<const void *>();
1486 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n");
1487 selectJump(lhs == rhs);
1490 void ByteCodeExecutor::executeAreRangesEqual() {
1491 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreRangesEqual:\n");
1493 const void *lhs = read<const void *>();
1494 const void *rhs = read<const void *>();
1496 switch (valueKind) {
1500 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1501 selectJump(*lhsRange == *rhsRange);
1505 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1506 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1507 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1508 selectJump(*lhsRange == *rhsRange);
1512 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1516 void ByteCodeExecutor::executeBranch() {
1517 LLVM_DEBUG(llvm::dbgs() <<
"Executing Branch\n");
1518 curCodeIt = &code[read<ByteCodeAddr>()];
1521 void ByteCodeExecutor::executeCheckOperandCount() {
1522 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperandCount:\n");
1524 uint32_t expectedCount = read<uint32_t>();
1525 bool compareAtLeast = read();
1527 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumOperands() <<
"\n"
1528 <<
" * Expected: " << expectedCount <<
"\n"
1529 <<
" * Comparator: "
1530 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1537 void ByteCodeExecutor::executeCheckOperationName() {
1538 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperationName:\n");
1542 LLVM_DEBUG(llvm::dbgs() <<
" * Found: \"" << op->
getName() <<
"\"\n"
1543 <<
" * Expected: \"" << expectedName <<
"\"\n");
1544 selectJump(op->
getName() == expectedName);
1547 void ByteCodeExecutor::executeCheckResultCount() {
1548 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckResultCount:\n");
1550 uint32_t expectedCount = read<uint32_t>();
1551 bool compareAtLeast = read();
1553 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumResults() <<
"\n"
1554 <<
" * Expected: " << expectedCount <<
"\n"
1555 <<
" * Comparator: "
1556 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1563 void ByteCodeExecutor::executeCheckTypes() {
1564 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1567 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1569 selectJump(*lhs == rhs.
cast<ArrayAttr>().getAsValueRange<TypeAttr>());
1572 void ByteCodeExecutor::executeContinue() {
1574 LLVM_DEBUG(llvm::dbgs() <<
"Executing Continue\n"
1575 <<
" * Level: " << level <<
"\n");
1580 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1581 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateConstantTypeRange:\n");
1582 unsigned memIndex = read();
1583 unsigned rangeIndex = read();
1584 ArrayAttr typesAttr = read<Attribute>().cast<ArrayAttr>();
1586 LLVM_DEBUG(llvm::dbgs() <<
" * Types: " << typesAttr <<
"\n\n");
1587 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1591 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1593 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateOperation:\n");
1595 unsigned memIndex = read();
1597 readList(state.operands);
1598 for (
unsigned i = 0, e = read(); i != e; ++i) {
1599 StringAttr name = read<StringAttr>();
1601 state.addAttribute(name, attr);
1606 unsigned numResults = read();
1608 InferTypeOpInterface::Concept *inferInterface =
1609 state.name.getInterface<InferTypeOpInterface>();
1610 assert(inferInterface &&
1611 "expected operation to provide InferTypeOpInterface");
1614 if (
failed(inferInterface->inferReturnTypes(
1615 state.getContext(), state.location, state.operands,
1616 state.attributes.getDictionary(state.getContext()), state.regions,
1621 for (
unsigned i = 0; i != numResults; ++i) {
1623 state.types.push_back(read<Type>());
1625 TypeRange *resultTypes = read<TypeRange *>();
1626 state.types.append(resultTypes->begin(), resultTypes->end());
1632 memory[memIndex] = resultOp;
1635 llvm::dbgs() <<
" * Attributes: "
1636 << state.attributes.getDictionary(state.getContext())
1637 <<
"\n * Operands: ";
1638 llvm::interleaveComma(state.operands, llvm::dbgs());
1639 llvm::dbgs() <<
"\n * Result Types: ";
1640 llvm::interleaveComma(state.types, llvm::dbgs());
1641 llvm::dbgs() <<
"\n * Result: " << *resultOp <<
"\n";
1645 template <
typename T>
1646 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1647 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateDynamic" << type <<
"Range:\n");
1648 unsigned memIndex = read();
1649 unsigned rangeIndex = read();
1654 llvm::dbgs() <<
"\n * " << type <<
"s: ";
1655 llvm::interleaveComma(values, llvm::dbgs());
1656 llvm::dbgs() <<
"\n";
1659 assignRangeToMemory(values, memIndex, rangeIndex);
1663 LLVM_DEBUG(llvm::dbgs() <<
"Executing EraseOp:\n");
1666 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1670 template <
typename T,
typename Range, PDLValue::Kind kind>
1671 void ByteCodeExecutor::executeExtract() {
1672 LLVM_DEBUG(llvm::dbgs() <<
"Executing Extract" << kind <<
":\n");
1673 Range *range = read<Range *>();
1674 unsigned index = read<uint32_t>();
1675 unsigned memIndex = read();
1678 memory[memIndex] =
nullptr;
1682 T result = index < range->
size() ? (*range)[index] : T();
1683 LLVM_DEBUG(llvm::dbgs() <<
" * " << kind <<
"s(" << range->
size() <<
")\n"
1684 <<
" * Index: " << index <<
"\n"
1685 <<
" * Result: " << result <<
"\n");
1686 storeToMemory(memIndex, result);
1689 void ByteCodeExecutor::executeFinalize() {
1690 LLVM_DEBUG(llvm::dbgs() <<
"Executing Finalize\n");
1693 void ByteCodeExecutor::executeForEach() {
1694 LLVM_DEBUG(llvm::dbgs() <<
"Executing ForEach:\n");
1696 unsigned rangeIndex = read();
1697 unsigned memIndex = read();
1698 const void *value =
nullptr;
1700 switch (read<PDLValue::Kind>()) {
1702 unsigned &index = loopIndex[read()];
1704 assert(index <= array.size() &&
"iterated past the end");
1705 if (index < array.size()) {
1706 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << array[index] <<
"\n");
1707 value = array[index];
1711 LLVM_DEBUG(llvm::dbgs() <<
" * Done\n");
1713 selectJump(
size_t(0));
1717 llvm_unreachable(
"unexpected `ForEach` value kind");
1721 memory[memIndex] = value;
1722 pushCodeIt(prevCodeIt);
1725 read<ByteCodeAddr>();
1728 void ByteCodeExecutor::executeGetAttribute() {
1729 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttribute:\n");
1730 unsigned memIndex = read();
1732 StringAttr attrName = read<StringAttr>();
1735 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1736 <<
" * Attribute: " << attrName <<
"\n"
1737 <<
" * Result: " << attr <<
"\n");
1741 void ByteCodeExecutor::executeGetAttributeType() {
1742 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttributeType:\n");
1743 unsigned memIndex = read();
1746 if (
auto typedAttr = attr.
dyn_cast<TypedAttr>())
1747 type = typedAttr.getType();
1749 LLVM_DEBUG(llvm::dbgs() <<
" * Attribute: " << attr <<
"\n"
1750 <<
" * Result: " << type <<
"\n");
1754 void ByteCodeExecutor::executeGetDefiningOp() {
1755 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetDefiningOp:\n");
1756 unsigned memIndex = read();
1759 Value value = read<Value>();
1762 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1765 if (values && !values->empty()) {
1766 op = values->front().getDefiningOp();
1768 LLVM_DEBUG(llvm::dbgs() <<
" * Values: " << values <<
"\n");
1771 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << op <<
"\n");
1772 memory[memIndex] = op;
1775 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1777 unsigned memIndex = read();
1781 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1782 <<
" * Index: " << index <<
"\n"
1783 <<
" * Result: " << operand <<
"\n");
1790 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1798 LLVM_DEBUG(llvm::dbgs() <<
" * Getting all values\n");
1802 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1803 LLVM_DEBUG(llvm::dbgs()
1804 <<
" * Extracting values from `" << attrSizedSegments <<
"`\n");
1807 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1811 unsigned startIndex =
1812 std::accumulate(segments.begin(), segments.begin() + index, 0);
1813 values = values.slice(startIndex, *std::next(segments.begin(), index));
1815 LLVM_DEBUG(llvm::dbgs() <<
" * Extracting range[" << startIndex <<
", "
1816 << *std::next(segments.begin(), index) <<
"]\n");
1822 }
else if (values.size() >= index) {
1823 LLVM_DEBUG(llvm::dbgs()
1824 <<
" * Treating values as trailing variadic range\n");
1825 values = values.drop_front(index);
1834 valueRangeMemory[rangeIndex] = values;
1835 return &valueRangeMemory[rangeIndex];
1839 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1842 void ByteCodeExecutor::executeGetOperands() {
1843 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperands:\n");
1844 unsigned index = read<uint32_t>();
1848 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1849 op->
getOperands(), op, index, rangeIndex,
"operand_segment_sizes",
1852 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid operand range\n");
1853 memory[read()] = result;
1856 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1858 unsigned memIndex = read();
1862 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1863 <<
" * Index: " << index <<
"\n"
1864 <<
" * Result: " << result <<
"\n");
1868 void ByteCodeExecutor::executeGetResults() {
1869 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResults:\n");
1870 unsigned index = read<uint32_t>();
1874 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1875 op->
getResults(), op, index, rangeIndex,
"result_segment_sizes",
1878 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid result range\n");
1879 memory[read()] = result;
1882 void ByteCodeExecutor::executeGetUsers() {
1883 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetUsers:\n");
1884 unsigned memIndex = read();
1885 unsigned rangeIndex = read();
1887 memory[memIndex] = ⦥
1892 Value value = read<Value>();
1895 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1906 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1907 llvm::interleaveComma(*values, llvm::dbgs());
1908 llvm::dbgs() <<
"\n";
1913 for (
Value value : *values)
1919 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << range.size() <<
" operations\n");
1922 void ByteCodeExecutor::executeGetValueType() {
1923 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueType:\n");
1924 unsigned memIndex = read();
1925 Value value = read<Value>();
1928 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n"
1929 <<
" * Result: " << type <<
"\n");
1933 void ByteCodeExecutor::executeGetValueRangeTypes() {
1934 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueRangeTypes:\n");
1935 unsigned memIndex = read();
1936 unsigned rangeIndex = read();
1939 LLVM_DEBUG(llvm::dbgs() <<
" * Values: <NULL>\n\n");
1940 memory[memIndex] =
nullptr;
1945 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1946 llvm::interleaveComma(*values, llvm::dbgs());
1947 llvm::dbgs() <<
"\n * Result: ";
1948 llvm::interleaveComma(values->
getType(), llvm::dbgs());
1949 llvm::dbgs() <<
"\n";
1951 typeRangeMemory[rangeIndex] = values->
getType();
1952 memory[memIndex] = &typeRangeMemory[rangeIndex];
1955 void ByteCodeExecutor::executeIsNotNull() {
1956 LLVM_DEBUG(llvm::dbgs() <<
"Executing IsNotNull:\n");
1957 const void *value = read<const void *>();
1959 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1960 selectJump(value !=
nullptr);
1963 void ByteCodeExecutor::executeRecordMatch(
1966 LLVM_DEBUG(llvm::dbgs() <<
"Executing RecordMatch:\n");
1967 unsigned patternIndex = read();
1974 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: Impossible To Match\n");
1983 unsigned numMatchLocs = read();
1985 matchLocs.reserve(numMatchLocs);
1986 for (
unsigned i = 0; i != numMatchLocs; ++i)
1987 matchLocs.push_back(read<Operation *>()->getLoc());
1990 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: " << benefit.
getBenefit() <<
"\n"
1991 <<
" * Location: " << matchLoc <<
"\n");
1992 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
1998 unsigned numInputs = read();
1999 match.
values.reserve(numInputs);
2002 for (
unsigned i = 0; i < numInputs; ++i) {
2003 switch (read<PDLValue::Kind>()) {
2013 match.
values.push_back(read<const void *>());
2021 LLVM_DEBUG(llvm::dbgs() <<
"Executing ReplaceOp:\n");
2027 llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
2029 llvm::interleaveComma(args, llvm::dbgs());
2030 llvm::dbgs() <<
"\n";
2035 void ByteCodeExecutor::executeSwitchAttribute() {
2036 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchAttribute:\n");
2038 ArrayAttr cases = read<ArrayAttr>();
2039 handleSwitch(value, cases);
2042 void ByteCodeExecutor::executeSwitchOperandCount() {
2043 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperandCount:\n");
2045 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2047 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2051 void ByteCodeExecutor::executeSwitchOperationName() {
2052 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperationName:\n");
2054 size_t caseCount = read();
2061 llvm::dbgs() <<
" * Value: " << value <<
"\n"
2063 llvm::interleaveComma(
2064 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2065 [&](
size_t) {
return read<OperationName>(); }),
2067 llvm::dbgs() <<
"\n";
2068 curCodeIt = prevCodeIt;
2072 for (
size_t i = 0; i != caseCount; ++i) {
2073 if (read<OperationName>() == value) {
2074 curCodeIt += (caseCount - i - 1);
2075 return selectJump(i + 1);
2078 selectJump(
size_t(0));
2081 void ByteCodeExecutor::executeSwitchResultCount() {
2082 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchResultCount:\n");
2084 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2086 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2090 void ByteCodeExecutor::executeSwitchType() {
2091 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchType:\n");
2092 Type value = read<Type>();
2093 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2094 handleSwitch(value, cases);
2097 void ByteCodeExecutor::executeSwitchTypes() {
2098 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchTypes:\n");
2100 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2102 LLVM_DEBUG(llvm::dbgs() <<
"Types: <NULL>\n");
2103 return selectJump(
size_t(0));
2105 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2106 return value == caseValue.getAsValueRange<TypeAttr>();
2113 std::optional<Location> mainRewriteLoc) {
2116 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() <<
"\n");
2118 OpCode opCode =
static_cast<OpCode
>(read());
2120 case ApplyConstraint:
2121 executeApplyConstraint(rewriter);
2124 if (
failed(executeApplyRewrite(rewriter)))
2130 case AreRangesEqual:
2131 executeAreRangesEqual();
2136 case CheckOperandCount:
2137 executeCheckOperandCount();
2139 case CheckOperationName:
2140 executeCheckOperationName();
2142 case CheckResultCount:
2143 executeCheckResultCount();
2146 executeCheckTypes();
2151 case CreateConstantTypeRange:
2152 executeCreateConstantTypeRange();
2154 case CreateOperation:
2155 executeCreateOperation(rewriter, *mainRewriteLoc);
2157 case CreateDynamicTypeRange:
2158 executeDynamicCreateRange<Type>(
"Type");
2160 case CreateDynamicValueRange:
2161 executeDynamicCreateRange<Value>(
"Value");
2164 executeEraseOp(rewriter);
2167 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2170 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2173 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2177 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2183 executeGetAttribute();
2185 case GetAttributeType:
2186 executeGetAttributeType();
2189 executeGetDefiningOp();
2195 unsigned index = opCode - GetOperand0;
2196 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperand" << index <<
":\n");
2197 executeGetOperand(index);
2201 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperandN:\n");
2202 executeGetOperand(read<uint32_t>());
2205 executeGetOperands();
2211 unsigned index = opCode - GetResult0;
2212 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResult" << index <<
":\n");
2213 executeGetResult(index);
2217 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResultN:\n");
2218 executeGetResult(read<uint32_t>());
2221 executeGetResults();
2227 executeGetValueType();
2229 case GetValueRangeTypes:
2230 executeGetValueRangeTypes();
2237 "expected matches to be provided when executing the matcher");
2238 executeRecordMatch(rewriter, *matches);
2241 executeReplaceOp(rewriter);
2243 case SwitchAttribute:
2244 executeSwitchAttribute();
2246 case SwitchOperandCount:
2247 executeSwitchOperandCount();
2249 case SwitchOperationName:
2250 executeSwitchOperationName();
2252 case SwitchResultCount:
2253 executeSwitchResultCount();
2256 executeSwitchType();
2259 executeSwitchTypes();
2262 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2270 state.memory[0] = op;
2273 ByteCodeExecutor executor(
2274 matcherByteCode.data(), state.memory, state.opRangeMemory,
2275 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2276 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2277 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2278 constraintFunctions, rewriteFunctions);
2279 LogicalResult executeResult = executor.execute(rewriter, &matches);
2280 (void)executeResult;
2281 assert(
succeeded(executeResult) &&
"unexpected matcher execution failure");
2284 std::stable_sort(matches.begin(), matches.end(),
2286 return lhs.benefit > rhs.benefit;
2293 auto *configSet =
match.pattern->getConfigSet();
2295 configSet->notifyRewriteBegin(rewriter);
2301 ByteCodeExecutor executor(
2302 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2303 state.opRangeMemory, state.typeRangeMemory,
2304 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2305 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2306 rewriterByteCode, state.currentPatternBenefits, patterns,
2307 constraintFunctions, rewriteFunctions);
2309 executor.execute(rewriter,
nullptr,
match.location);
2312 configSet->notifyRewriteEnd(rewriter);
2321 LLVM_DEBUG(llvm::dbgs() <<
" and rollback is not supported - aborting");
2322 llvm::report_fatal_error(
2323 "Native PDL Rewrite failed, but the pattern "
2324 "rewriter doesn't support recovery. Failable pattern rewrites should "
2325 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a set of configurations for a specific pattern.
The class represents a list of PDL results, returned by a native rewrite method.
Storage type of byte-code interpreter values.
Kind
The underlying kind of a PDL value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class contains the mutable state of a bytecode instance.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
All of the data pertaining to a specific pattern within the bytecode.
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr)
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches.
PDLByteCode(ModuleOp module, SmallVector< std::unique_ptr< PDLPatternConfigSet >> configs, const DenseMap< Operation *, PDLPatternConfigSet * > &configMap, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Detect if any of the given parameter types has a sub-element handler.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
uint16_t ByteCodeField
Use generic bytecode types.
llvm::OwningArrayRef< Operation * > OwningOpRange
This header declares functions that assit transformations in the MemRef dialect.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
SmallVector< ValueRange, 0 > valueRangeValues
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.