19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/Format.h"
24 #include "llvm/Support/FormatVariadic.h"
28 #define DEBUG_TYPE "pdl-bytecode"
45 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
47 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
50 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
54 benefit, ctx, generatedOps);
66 currentPatternBenefits[patternIndex] = benefit;
73 allocatedTypeRangeMemory.clear();
74 allocatedValueRangeMemory.clear();
104 CreateConstantTypeRange,
108 CreateDynamicTypeRange,
110 CreateDynamicValueRange,
184 struct ByteCodeLiveRange;
185 struct ByteCodeWriter;
188 template <
typename T,
typename... Args>
189 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
194 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
203 llvm::StringMap<PDLConstraintFunction> &constraintFns,
204 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
206 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
207 rewriterByteCode(rewriterByteCode), patterns(patterns),
208 maxValueMemoryIndex(maxValueMemoryIndex),
209 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
210 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
211 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
212 maxLoopLevel(maxLoopLevel), configMap(configMap) {
214 constraintToMemIndex.try_emplace(it.value().first(), it.index());
216 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
220 void generate(ModuleOp module);
224 assert(valueToMemIndex.count(value) &&
225 "expected memory index to be assigned");
226 return valueToMemIndex[value];
231 assert(valueToRangeIndex.count(value) &&
232 "expected range index to be assigned");
233 return valueToRangeIndex[value];
238 template <
typename T>
239 std::enable_if_t<!std::is_convertible<T, Value>::value,
ByteCodeField &>
241 const void *opaqueVal = val.getAsOpaquePointer();
244 auto it = uniquedDataToMemIndex.try_emplace(
245 opaqueVal, maxValueMemoryIndex + uniquedData.size());
247 uniquedData.push_back(opaqueVal);
248 return it.first->second;
254 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
255 ModuleOp rewriterModule);
258 void generate(
Region *region, ByteCodeWriter &writer);
259 void generate(
Operation *op, ByteCodeWriter &writer);
260 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
261 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
262 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
263 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
307 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
311 llvm::StringMap<ByteCodeField> constraintToMemIndex;
315 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
331 std::vector<const void *> &uniquedData;
346 struct ByteCodeWriter {
351 void append(
ByteCodeField field) { bytecode.push_back(field); }
352 void append(OpCode opCode) { bytecode.push_back(opCode); }
357 "unexpected ByteCode address size");
361 bytecode.append({fieldParts[0], fieldParts[1]});
366 void append(
Block *successor) {
369 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
376 for (
Block *successor : successors)
382 bytecode.push_back(values.size());
383 for (
Value value : values)
384 appendPDLValue(value);
388 void appendPDLValue(
Value value) {
389 appendPDLValueKind(value);
394 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
397 void appendPDLValueKind(
Type type) {
400 .Case<pdl::AttributeType>(
402 .Case<pdl::OperationType>(
404 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
405 if (isa<pdl::TypeType>(rangeTy.getElementType()))
416 template <
typename T>
417 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
418 std::is_pointer<T>::value>
420 bytecode.push_back(
generator.getMemIndex(value));
424 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
425 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
427 bytecode.push_back(llvm::size(range));
428 for (
auto it : range)
433 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
434 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
436 append(field2, fields...);
440 template <
typename T>
441 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
442 appendInline(T value) {
443 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(
ByteCodeField);
444 const void *pointer = value.getAsOpaquePointer();
446 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
447 bytecode.append(fieldParts, fieldParts + numParts);
462 struct ByteCodeLiveRange {
463 using Set = llvm::IntervalMap<uint64_t, char, 16>;
464 using Allocator = Set::Allocator;
466 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
469 void unionWith(
const ByteCodeLiveRange &rhs) {
470 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
472 liveness->insert(it.start(), it.stop(), 0);
476 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
477 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
487 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
490 std::optional<unsigned> opRangeIndex;
493 std::optional<unsigned> typeRangeIndex;
496 std::optional<unsigned> valueRangeIndex;
500 void Generator::generate(ModuleOp module) {
501 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
502 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
503 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
504 pdl_interp::PDLInterpDialect::getRewriterModuleName());
505 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
509 allocateMemoryIndices(matcherFunc, rewriterModule);
512 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
513 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
514 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
515 for (
Operation &op : rewriterFunc.getOps())
516 generate(&op, rewriterByteCodeWriter);
518 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
519 "unexpected branches in rewriter function");
522 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
523 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
526 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
528 for (
unsigned offsetToFix : it.second)
529 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(
ByteCodeAddr));
533 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
534 ModuleOp rewriterModule) {
537 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
538 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
539 auto processRewriterValue = [&](
Value val) {
540 valueToMemIndex.try_emplace(val, index++);
541 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
542 Type elementTy = rangeType.getElementType();
543 if (isa<pdl::TypeType>(elementTy))
544 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
545 else if (isa<pdl::ValueType>(elementTy))
546 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
551 processRewriterValue(arg);
554 processRewriterValue(result);
556 if (index > maxValueMemoryIndex)
557 maxValueMemoryIndex = index;
558 if (typeRangeIndex > maxTypeRangeMemoryIndex)
559 maxTypeRangeMemoryIndex = typeRangeIndex;
560 if (valueRangeIndex > maxValueRangeMemoryIndex)
561 maxValueRangeMemoryIndex = valueRangeIndex;
578 opToFirstIndex.try_emplace(op, index++);
580 for (
Block &block : region.getBlocks())
583 opToLastIndex.try_emplace(op, index++);
588 ByteCodeLiveRange::Allocator allocator;
593 valueToMemIndex[rootOpArg] = 0;
596 Liveness matcherLiveness(matcherFunc);
597 matcherFunc->walk([&](
Block *block) {
599 assert(info &&
"expected liveness info for block");
603 if (value == rootOpArg)
607 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
608 defRangeIt->second.liveness->insert(
609 opToFirstIndex[firstUseOrDef],
614 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
615 Type eleType = rangeTy.getElementType();
616 if (isa<pdl::OperationType>(eleType))
617 defRangeIt->second.opRangeIndex = 0;
618 else if (isa<pdl::TypeType>(eleType))
619 defRangeIt->second.typeRangeIndex = 0;
620 else if (isa<pdl::ValueType>(eleType))
621 defRangeIt->second.valueRangeIndex = 0;
626 for (
Value liveIn : info->
in()) {
631 if (liveIn.getParentRegion() == block->
getParent())
648 std::vector<ByteCodeLiveRange> allocatedIndices;
655 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
657 for (
auto &defIt : valueDefRanges) {
659 ByteCodeLiveRange &defRange = defIt.second;
662 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
663 ByteCodeLiveRange &existingRange = existingIndexIt.value();
664 if (!defRange.overlaps(existingRange)) {
665 existingRange.unionWith(defRange);
666 memIndex = existingIndexIt.index() + 1;
668 if (defRange.opRangeIndex) {
669 if (!existingRange.opRangeIndex)
670 existingRange.opRangeIndex = numOpRanges++;
671 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
672 }
else if (defRange.typeRangeIndex) {
673 if (!existingRange.typeRangeIndex)
674 existingRange.typeRangeIndex = numTypeRanges++;
675 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
676 }
else if (defRange.valueRangeIndex) {
677 if (!existingRange.valueRangeIndex)
678 existingRange.valueRangeIndex = numValueRanges++;
679 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
687 allocatedIndices.emplace_back(allocator);
688 ByteCodeLiveRange &newRange = allocatedIndices.back();
689 newRange.unionWith(defRange);
692 if (defRange.opRangeIndex) {
693 newRange.opRangeIndex = numOpRanges;
694 valueToRangeIndex[defIt.first] = numOpRanges++;
695 }
else if (defRange.typeRangeIndex) {
696 newRange.typeRangeIndex = numTypeRanges;
697 valueToRangeIndex[defIt.first] = numTypeRanges++;
698 }
else if (defRange.valueRangeIndex) {
699 newRange.valueRangeIndex = numValueRanges;
700 valueToRangeIndex[defIt.first] = numValueRanges++;
703 memIndex = allocatedIndices.size();
710 llvm::dbgs() <<
"Allocated " << allocatedIndices.size() <<
" indices "
711 <<
"(down from initial " << valueDefRanges.size() <<
").\n";
714 "Ran out of memory for allocated indices");
717 if (numIndices > maxValueMemoryIndex)
718 maxValueMemoryIndex = numIndices;
719 if (numOpRanges > maxOpRangeMemoryIndex)
720 maxOpRangeMemoryIndex = numOpRanges;
721 if (numTypeRanges > maxTypeRangeMemoryIndex)
722 maxTypeRangeMemoryIndex = numTypeRanges;
723 if (numValueRanges > maxValueRangeMemoryIndex)
724 maxValueRangeMemoryIndex = numValueRanges;
727 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
728 llvm::ReversePostOrderTraversal<Region *> rpot(region);
729 for (
Block *block : rpot) {
731 blockToAddr.try_emplace(block, matcherByteCode.size());
733 generate(&op, writer);
737 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
741 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
742 writer.appendInline(op->
getLoc());
745 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
746 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
747 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
748 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
749 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
750 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
751 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
752 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
753 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
754 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
755 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
756 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
757 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
758 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
759 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
760 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
761 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
762 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
763 pdl_interp::SwitchResultCountOp>(
764 [&](
auto interpOp) { this->generate(interpOp, writer); })
766 llvm_unreachable(
"unknown `pdl_interp` operation");
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
772 assert(constraintToMemIndex.count(op.
getName()) &&
773 "expected index for constraint function");
774 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.
getName()]);
775 writer.appendPDLValueList(op.getArgs());
779 void Generator::generate(pdl_interp::ApplyRewriteOp op,
780 ByteCodeWriter &writer) {
781 assert(externalRewriterToMemIndex.count(op.
getName()) &&
782 "expected index for rewrite function");
783 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.
getName()]);
784 writer.appendPDLValueList(op.getArgs());
788 for (
Value result : results) {
792 writer.appendPDLValueKind(result);
796 if (isa<pdl::RangeType>(result.getType()))
797 writer.append(getRangeStorageIndex(result));
798 writer.append(result);
801 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
802 Value lhs = op.getLhs();
803 if (isa<pdl::RangeType>(lhs.
getType())) {
804 writer.append(OpCode::AreRangesEqual);
805 writer.appendPDLValueKind(lhs);
810 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.
getSuccessors());
812 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
815 void Generator::generate(pdl_interp::CheckAttributeOp op,
816 ByteCodeWriter &writer) {
817 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
820 void Generator::generate(pdl_interp::CheckOperandCountOp op,
821 ByteCodeWriter &writer) {
822 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
826 void Generator::generate(pdl_interp::CheckOperationNameOp op,
827 ByteCodeWriter &writer) {
828 writer.append(OpCode::CheckOperationName, op.getInputOp(),
831 void Generator::generate(pdl_interp::CheckResultCountOp op,
832 ByteCodeWriter &writer) {
833 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
837 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
838 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
841 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
842 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
845 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
846 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
847 writer.append(OpCode::Continue,
ByteCodeField(curLoopLevel - 1));
849 void Generator::generate(pdl_interp::CreateAttributeOp op,
850 ByteCodeWriter &writer) {
852 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
854 void Generator::generate(pdl_interp::CreateOperationOp op,
855 ByteCodeWriter &writer) {
856 writer.append(OpCode::CreateOperation, op.getResultOp(),
858 writer.appendPDLValueList(op.getInputOperands());
862 writer.append(
static_cast<ByteCodeField>(attributes.size()));
863 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
864 writer.append(std::get<0>(it), std::get<1>(it));
868 if (op.getInferredResultTypes())
871 writer.appendPDLValueList(op.getInputResultTypes());
873 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
877 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
878 .Case([&](pdl::ValueType) {
879 writer.append(OpCode::CreateDynamicValueRange);
885 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
887 getMemIndex(op.
getResult()) = getMemIndex(op.getValue());
889 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
890 writer.append(OpCode::CreateConstantTypeRange, op.
getResult(),
891 getRangeStorageIndex(op.
getResult()), op.getValue());
893 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
894 writer.append(OpCode::EraseOp, op.getInputOp());
896 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
899 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
900 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
901 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
902 .Default([](
Type) -> OpCode {
903 llvm_unreachable(
"unsupported element type");
905 writer.append(opCode, op.getRange(), op.getIndex(), op.
getResult());
907 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
908 writer.append(OpCode::Finalize);
910 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
912 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
913 writer.appendPDLValueKind(arg.
getType());
916 if (curLoopLevel > maxLoopLevel)
917 maxLoopLevel = curLoopLevel;
921 void Generator::generate(pdl_interp::GetAttributeOp op,
922 ByteCodeWriter &writer) {
923 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
926 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
927 ByteCodeWriter &writer) {
928 writer.append(OpCode::GetAttributeType, op.
getResult(), op.getValue());
930 void Generator::generate(pdl_interp::GetDefiningOpOp op,
931 ByteCodeWriter &writer) {
932 writer.append(OpCode::GetDefiningOp, op.getInputOp());
933 writer.appendPDLValue(op.getValue());
935 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
936 uint32_t index = op.getIndex();
938 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
940 writer.append(OpCode::GetOperandN, index);
941 writer.append(op.getInputOp(), op.getValue());
943 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
944 Value result = op.getValue();
945 std::optional<uint32_t> index = op.getIndex();
946 writer.append(OpCode::GetOperands,
949 if (isa<pdl::RangeType>(result.
getType()))
950 writer.append(getRangeStorageIndex(result));
953 writer.append(result);
955 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
956 uint32_t index = op.getIndex();
958 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
960 writer.append(OpCode::GetResultN, index);
961 writer.append(op.getInputOp(), op.getValue());
963 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
964 Value result = op.getValue();
965 std::optional<uint32_t> index = op.getIndex();
966 writer.append(OpCode::GetResults,
969 if (isa<pdl::RangeType>(result.
getType()))
970 writer.append(getRangeStorageIndex(result));
973 writer.append(result);
975 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
976 Value operations = op.getOperations();
978 writer.append(OpCode::GetUsers, operations, rangeIndex);
979 writer.appendPDLValue(op.getValue());
981 void Generator::generate(pdl_interp::GetValueTypeOp op,
982 ByteCodeWriter &writer) {
983 if (isa<pdl::RangeType>(op.getType())) {
985 writer.append(OpCode::GetValueRangeTypes, result,
986 getRangeStorageIndex(result), op.getValue());
988 writer.append(OpCode::GetValueType, op.
getResult(), op.getValue());
991 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
992 writer.append(OpCode::IsNotNull, op.getValue(), op.
getSuccessors());
994 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
997 op, configMap.lookup(op),
998 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
999 writer.append(OpCode::RecordMatch, patternIndex,
1001 writer.appendPDLValueList(op.getInputs());
1003 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1004 writer.append(OpCode::ReplaceOp, op.getInputOp());
1005 writer.appendPDLValueList(op.getReplValues());
1007 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1008 ByteCodeWriter &writer) {
1009 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1012 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1013 ByteCodeWriter &writer) {
1014 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1017 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1018 ByteCodeWriter &writer) {
1019 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1020 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1022 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1025 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1026 ByteCodeWriter &writer) {
1027 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1030 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1031 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1034 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1035 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1044 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1046 llvm::StringMap<PDLConstraintFunction> constraintFns,
1047 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1048 : configs(std::move(configs)) {
1049 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1050 rewriterByteCode, patterns, maxValueMemoryIndex,
1051 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1052 maxLoopLevel, constraintFns, rewriteFns, configMap);
1056 for (
auto &it : constraintFns)
1057 constraintFunctions.push_back(std::move(it.second));
1058 for (
auto &it : rewriteFns)
1059 rewriteFunctions.push_back(std::move(it.second));
1065 state.memory.resize(maxValueMemoryIndex,
nullptr);
1066 state.opRangeMemory.resize(maxOpRangeCount);
1067 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1068 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1069 state.loopIndex.resize(maxLoopLevel, 0);
1070 state.currentPatternBenefits.reserve(patterns.size());
1072 state.currentPatternBenefits.push_back(pattern.getBenefit());
1080 class ByteCodeExecutor {
1086 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1088 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1095 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1096 typeRangeMemory(typeRangeMemory),
1097 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1098 valueRangeMemory(valueRangeMemory),
1099 allocatedValueRangeMemory(allocatedValueRangeMemory),
1100 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1101 currentPatternBenefits(currentPatternBenefits), patterns(patterns),
1102 constraintFunctions(constraintFunctions),
1103 rewriteFunctions(rewriteFunctions) {}
1111 std::optional<Location> mainRewriteLoc = {});
1117 void executeAreEqual();
1118 void executeAreRangesEqual();
1119 void executeBranch();
1120 void executeCheckOperandCount();
1121 void executeCheckOperationName();
1122 void executeCheckResultCount();
1123 void executeCheckTypes();
1124 void executeContinue();
1125 void executeCreateConstantTypeRange();
1128 template <
typename T>
1129 void executeDynamicCreateRange(StringRef type);
1131 template <
typename T,
typename Range, PDLValue::Kind kind>
1132 void executeExtract();
1133 void executeFinalize();
1134 void executeForEach();
1135 void executeGetAttribute();
1136 void executeGetAttributeType();
1137 void executeGetDefiningOp();
1138 void executeGetOperand(
unsigned index);
1139 void executeGetOperands();
1140 void executeGetResult(
unsigned index);
1141 void executeGetResults();
1142 void executeGetUsers();
1143 void executeGetValueType();
1144 void executeGetValueRangeTypes();
1145 void executeIsNotNull();
1149 void executeSwitchAttribute();
1150 void executeSwitchOperandCount();
1151 void executeSwitchOperationName();
1152 void executeSwitchResultCount();
1153 void executeSwitchType();
1154 void executeSwitchTypes();
1157 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1161 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1162 curCodeIt = resumeCodeIt.back();
1163 resumeCodeIt.pop_back();
1170 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(
ByteCodeField);
1174 return curCodeIt - 1;
1180 template <
typename T = ByteCodeField>
1181 T read(
size_t skipN = 0) {
1183 return readImpl<T>();
1185 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1188 template <
typename ValueT,
typename T>
1191 for (
unsigned i = 0, e = read(); i != e; ++i)
1192 list.push_back(read<ValueT>());
1198 for (
unsigned i = 0, e = read(); i != e; ++i) {
1200 list.push_back(read<Type>());
1202 TypeRange *values = read<TypeRange *>();
1203 list.append(values->begin(), values->end());
1208 for (
unsigned i = 0, e = read(); i != e; ++i) {
1210 list.push_back(read<Value>());
1213 list.append(values->begin(), values->end());
1219 template <
typename T>
1220 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1222 const void *pointer;
1223 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1225 return T::getFromOpaquePointer(pointer);
1229 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1231 void selectJump(
size_t destIndex) {
1232 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1236 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1237 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1239 llvm::dbgs() <<
" * Value: " << value <<
"\n"
1241 llvm::interleaveComma(cases, llvm::dbgs());
1242 llvm::dbgs() <<
"\n";
1247 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1248 if (cmp(*it, value))
1249 return selectJump(
size_t((it - cases.begin()) + 1));
1250 selectJump(
size_t(0));
1254 void storeToMemory(
unsigned index,
const void *value) {
1255 memory[index] = value;
1259 template <
typename T>
1260 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1261 storeToMemory(
unsigned index, T value) {
1262 memory[index] = value.getAsOpaquePointer();
1267 template <
typename T>
1268 const void *readFromMemory() {
1269 size_t index = *curCodeIt++;
1274 index < memory.size())
1275 return memory[index];
1278 return uniquedMemory[index - memory.size()];
1280 template <
typename T>
1281 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1282 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1284 template <
typename T>
1285 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1288 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1290 template <
typename T>
1291 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1292 switch (read<PDLValue::Kind>()) {
1294 return read<Attribute>();
1296 return read<Operation *>();
1298 return read<Type>();
1300 return read<Value>();
1302 return read<TypeRange *>();
1304 return read<ValueRange *>();
1306 llvm_unreachable(
"unhandled PDLValue::Kind");
1308 template <
typename T>
1309 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1311 "unexpected ByteCode address size");
1317 template <
typename T>
1318 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1319 return *curCodeIt++;
1321 template <
typename T>
1322 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1328 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1329 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1330 unsigned rangeIndex) {
1332 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1334 if (range.empty()) {
1335 rangeMemory[rangeIndex] = {};
1338 llvm::OwningArrayRef<T> storage(llvm::size(range));
1343 allocatedRangeMemory.emplace_back(std::move(storage));
1344 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1346 memory[memIndex] = &rangeMemory[rangeIndex];
1350 if constexpr (std::is_same_v<T, Type>) {
1351 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1352 }
else if constexpr (std::is_same_v<T, Value>) {
1353 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1355 llvm_unreachable(
"unhandled range type");
1369 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1371 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1390 ByteCodeRewriteResultList(
unsigned maxNumResults)
1398 return allocatedTypeRanges;
1403 return allocatedValueRanges;
1408 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1409 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyConstraint:\n");
1412 readList<PDLValue>(args);
1415 llvm::dbgs() <<
" * Arguments: ";
1416 llvm::interleaveComma(args, llvm::dbgs());
1417 llvm::dbgs() <<
"\n";
1422 llvm::dbgs() <<
" * isNegated: " << isNegated <<
"\n";
1423 llvm::interleaveComma(args, llvm::dbgs());
1426 selectJump(isNegated !=
succeeded(constraintFn(rewriter, args)));
1430 LLVM_DEBUG(llvm::dbgs() <<
"Executing ApplyRewrite:\n");
1433 readList<PDLValue>(args);
1436 llvm::dbgs() <<
" * Arguments: ";
1437 llvm::interleaveComma(args, llvm::dbgs());
1442 ByteCodeRewriteResultList results(numResults);
1443 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1445 assert(results.getResults().size() == numResults &&
1446 "native PDL rewrite function returned unexpected number of results");
1449 for (
PDLValue &result : results.getResults()) {
1450 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << result <<
"\n");
1454 assert(result.getKind() == read<PDLValue::Kind>() &&
1455 "native PDL rewrite function returned an unexpected type of result");
1460 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1461 unsigned rangeIndex = read();
1462 typeRangeMemory[rangeIndex] = *typeRange;
1463 memory[read()] = &typeRangeMemory[rangeIndex];
1464 }
else if (std::optional<ValueRange> valueRange =
1466 unsigned rangeIndex = read();
1467 valueRangeMemory[rangeIndex] = *valueRange;
1468 memory[read()] = &valueRangeMemory[rangeIndex];
1470 memory[read()] = result.getAsOpaquePointer();
1475 for (
auto &it : results.getAllocatedTypeRanges())
1476 allocatedTypeRangeMemory.push_back(std::move(it));
1477 for (
auto &it : results.getAllocatedValueRanges())
1478 allocatedValueRangeMemory.push_back(std::move(it));
1481 if (
failed(rewriteResult)) {
1482 LLVM_DEBUG(llvm::dbgs() <<
" - Failed");
1488 void ByteCodeExecutor::executeAreEqual() {
1489 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1490 const void *lhs = read<const void *>();
1491 const void *rhs = read<const void *>();
1493 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n");
1494 selectJump(lhs == rhs);
1497 void ByteCodeExecutor::executeAreRangesEqual() {
1498 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreRangesEqual:\n");
1500 const void *lhs = read<const void *>();
1501 const void *rhs = read<const void *>();
1503 switch (valueKind) {
1507 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1508 selectJump(*lhsRange == *rhsRange);
1512 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1513 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1514 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1515 selectJump(*lhsRange == *rhsRange);
1519 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1523 void ByteCodeExecutor::executeBranch() {
1524 LLVM_DEBUG(llvm::dbgs() <<
"Executing Branch\n");
1525 curCodeIt = &code[read<ByteCodeAddr>()];
1528 void ByteCodeExecutor::executeCheckOperandCount() {
1529 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperandCount:\n");
1531 uint32_t expectedCount = read<uint32_t>();
1532 bool compareAtLeast = read();
1534 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumOperands() <<
"\n"
1535 <<
" * Expected: " << expectedCount <<
"\n"
1536 <<
" * Comparator: "
1537 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1544 void ByteCodeExecutor::executeCheckOperationName() {
1545 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckOperationName:\n");
1549 LLVM_DEBUG(llvm::dbgs() <<
" * Found: \"" << op->
getName() <<
"\"\n"
1550 <<
" * Expected: \"" << expectedName <<
"\"\n");
1551 selectJump(op->
getName() == expectedName);
1554 void ByteCodeExecutor::executeCheckResultCount() {
1555 LLVM_DEBUG(llvm::dbgs() <<
"Executing CheckResultCount:\n");
1557 uint32_t expectedCount = read<uint32_t>();
1558 bool compareAtLeast = read();
1560 LLVM_DEBUG(llvm::dbgs() <<
" * Found: " << op->
getNumResults() <<
"\n"
1561 <<
" * Expected: " << expectedCount <<
"\n"
1562 <<
" * Comparator: "
1563 << (compareAtLeast ?
">=" :
"==") <<
"\n");
1570 void ByteCodeExecutor::executeCheckTypes() {
1571 LLVM_DEBUG(llvm::dbgs() <<
"Executing AreEqual:\n");
1574 LLVM_DEBUG(llvm::dbgs() <<
" * " << lhs <<
" == " << rhs <<
"\n\n");
1576 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1579 void ByteCodeExecutor::executeContinue() {
1581 LLVM_DEBUG(llvm::dbgs() <<
"Executing Continue\n"
1582 <<
" * Level: " << level <<
"\n");
1587 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1588 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateConstantTypeRange:\n");
1589 unsigned memIndex = read();
1590 unsigned rangeIndex = read();
1591 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1593 LLVM_DEBUG(llvm::dbgs() <<
" * Types: " << typesAttr <<
"\n\n");
1594 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1598 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1600 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateOperation:\n");
1602 unsigned memIndex = read();
1604 readList(state.operands);
1605 for (
unsigned i = 0, e = read(); i != e; ++i) {
1606 StringAttr name = read<StringAttr>();
1608 state.addAttribute(name, attr);
1613 unsigned numResults = read();
1615 InferTypeOpInterface::Concept *inferInterface =
1616 state.name.getInterface<InferTypeOpInterface>();
1617 assert(inferInterface &&
1618 "expected operation to provide InferTypeOpInterface");
1621 if (
failed(inferInterface->inferReturnTypes(
1622 state.getContext(), state.location, state.operands,
1623 state.attributes.getDictionary(state.getContext()),
1624 state.getRawProperties(), state.regions, state.types)))
1628 for (
unsigned i = 0; i != numResults; ++i) {
1630 state.types.push_back(read<Type>());
1632 TypeRange *resultTypes = read<TypeRange *>();
1633 state.types.append(resultTypes->begin(), resultTypes->end());
1639 memory[memIndex] = resultOp;
1642 llvm::dbgs() <<
" * Attributes: "
1643 << state.attributes.getDictionary(state.getContext())
1644 <<
"\n * Operands: ";
1645 llvm::interleaveComma(state.operands, llvm::dbgs());
1646 llvm::dbgs() <<
"\n * Result Types: ";
1647 llvm::interleaveComma(state.types, llvm::dbgs());
1648 llvm::dbgs() <<
"\n * Result: " << *resultOp <<
"\n";
1652 template <
typename T>
1653 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1654 LLVM_DEBUG(llvm::dbgs() <<
"Executing CreateDynamic" << type <<
"Range:\n");
1655 unsigned memIndex = read();
1656 unsigned rangeIndex = read();
1661 llvm::dbgs() <<
"\n * " << type <<
"s: ";
1662 llvm::interleaveComma(values, llvm::dbgs());
1663 llvm::dbgs() <<
"\n";
1666 assignRangeToMemory(values, memIndex, rangeIndex);
1670 LLVM_DEBUG(llvm::dbgs() <<
"Executing EraseOp:\n");
1673 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
1677 template <
typename T,
typename Range, PDLValue::Kind kind>
1678 void ByteCodeExecutor::executeExtract() {
1679 LLVM_DEBUG(llvm::dbgs() <<
"Executing Extract" << kind <<
":\n");
1680 Range *range = read<Range *>();
1681 unsigned index = read<uint32_t>();
1682 unsigned memIndex = read();
1685 memory[memIndex] =
nullptr;
1689 T result = index < range->
size() ? (*range)[index] : T();
1690 LLVM_DEBUG(llvm::dbgs() <<
" * " << kind <<
"s(" << range->
size() <<
")\n"
1691 <<
" * Index: " << index <<
"\n"
1692 <<
" * Result: " << result <<
"\n");
1693 storeToMemory(memIndex, result);
1696 void ByteCodeExecutor::executeFinalize() {
1697 LLVM_DEBUG(llvm::dbgs() <<
"Executing Finalize\n");
1700 void ByteCodeExecutor::executeForEach() {
1701 LLVM_DEBUG(llvm::dbgs() <<
"Executing ForEach:\n");
1703 unsigned rangeIndex = read();
1704 unsigned memIndex = read();
1705 const void *value =
nullptr;
1707 switch (read<PDLValue::Kind>()) {
1709 unsigned &index = loopIndex[read()];
1711 assert(index <= array.size() &&
"iterated past the end");
1712 if (index < array.size()) {
1713 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << array[index] <<
"\n");
1714 value = array[index];
1718 LLVM_DEBUG(llvm::dbgs() <<
" * Done\n");
1720 selectJump(
size_t(0));
1724 llvm_unreachable(
"unexpected `ForEach` value kind");
1728 memory[memIndex] = value;
1729 pushCodeIt(prevCodeIt);
1732 read<ByteCodeAddr>();
1735 void ByteCodeExecutor::executeGetAttribute() {
1736 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttribute:\n");
1737 unsigned memIndex = read();
1739 StringAttr attrName = read<StringAttr>();
1742 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1743 <<
" * Attribute: " << attrName <<
"\n"
1744 <<
" * Result: " << attr <<
"\n");
1748 void ByteCodeExecutor::executeGetAttributeType() {
1749 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetAttributeType:\n");
1750 unsigned memIndex = read();
1753 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1754 type = typedAttr.getType();
1756 LLVM_DEBUG(llvm::dbgs() <<
" * Attribute: " << attr <<
"\n"
1757 <<
" * Result: " << type <<
"\n");
1761 void ByteCodeExecutor::executeGetDefiningOp() {
1762 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetDefiningOp:\n");
1763 unsigned memIndex = read();
1766 Value value = read<Value>();
1769 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1772 if (values && !values->empty()) {
1773 op = values->front().getDefiningOp();
1775 LLVM_DEBUG(llvm::dbgs() <<
" * Values: " << values <<
"\n");
1778 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << op <<
"\n");
1779 memory[memIndex] = op;
1782 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1784 unsigned memIndex = read();
1788 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1789 <<
" * Index: " << index <<
"\n"
1790 <<
" * Result: " << operand <<
"\n");
1797 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1805 LLVM_DEBUG(llvm::dbgs() <<
" * Getting all values\n");
1809 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1810 LLVM_DEBUG(llvm::dbgs()
1811 <<
" * Extracting values from `" << attrSizedSegments <<
"`\n");
1814 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1818 unsigned startIndex =
1819 std::accumulate(segments.begin(), segments.begin() + index, 0);
1820 values = values.slice(startIndex, *std::next(segments.begin(), index));
1822 LLVM_DEBUG(llvm::dbgs() <<
" * Extracting range[" << startIndex <<
", "
1823 << *std::next(segments.begin(), index) <<
"]\n");
1829 }
else if (values.size() >= index) {
1830 LLVM_DEBUG(llvm::dbgs()
1831 <<
" * Treating values as trailing variadic range\n");
1832 values = values.drop_front(index);
1841 valueRangeMemory[rangeIndex] = values;
1842 return &valueRangeMemory[rangeIndex];
1846 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1849 void ByteCodeExecutor::executeGetOperands() {
1850 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperands:\n");
1851 unsigned index = read<uint32_t>();
1855 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1856 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1859 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid operand range\n");
1860 memory[read()] = result;
1863 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1865 unsigned memIndex = read();
1869 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
1870 <<
" * Index: " << index <<
"\n"
1871 <<
" * Result: " << result <<
"\n");
1875 void ByteCodeExecutor::executeGetResults() {
1876 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResults:\n");
1877 unsigned index = read<uint32_t>();
1881 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1882 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1885 LLVM_DEBUG(llvm::dbgs() <<
" * Invalid result range\n");
1886 memory[read()] = result;
1889 void ByteCodeExecutor::executeGetUsers() {
1890 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetUsers:\n");
1891 unsigned memIndex = read();
1892 unsigned rangeIndex = read();
1894 memory[memIndex] = ⦥
1899 Value value = read<Value>();
1902 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1913 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1914 llvm::interleaveComma(*values, llvm::dbgs());
1915 llvm::dbgs() <<
"\n";
1920 for (
Value value : *values)
1926 LLVM_DEBUG(llvm::dbgs() <<
" * Result: " << range.size() <<
" operations\n");
1929 void ByteCodeExecutor::executeGetValueType() {
1930 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueType:\n");
1931 unsigned memIndex = read();
1932 Value value = read<Value>();
1935 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n"
1936 <<
" * Result: " << type <<
"\n");
1940 void ByteCodeExecutor::executeGetValueRangeTypes() {
1941 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetValueRangeTypes:\n");
1942 unsigned memIndex = read();
1943 unsigned rangeIndex = read();
1946 LLVM_DEBUG(llvm::dbgs() <<
" * Values: <NULL>\n\n");
1947 memory[memIndex] =
nullptr;
1952 llvm::dbgs() <<
" * Values (" << values->size() <<
"): ";
1953 llvm::interleaveComma(*values, llvm::dbgs());
1954 llvm::dbgs() <<
"\n * Result: ";
1955 llvm::interleaveComma(values->
getType(), llvm::dbgs());
1956 llvm::dbgs() <<
"\n";
1958 typeRangeMemory[rangeIndex] = values->
getType();
1959 memory[memIndex] = &typeRangeMemory[rangeIndex];
1962 void ByteCodeExecutor::executeIsNotNull() {
1963 LLVM_DEBUG(llvm::dbgs() <<
"Executing IsNotNull:\n");
1964 const void *value = read<const void *>();
1966 LLVM_DEBUG(llvm::dbgs() <<
" * Value: " << value <<
"\n");
1967 selectJump(value !=
nullptr);
1970 void ByteCodeExecutor::executeRecordMatch(
1973 LLVM_DEBUG(llvm::dbgs() <<
"Executing RecordMatch:\n");
1974 unsigned patternIndex = read();
1981 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: Impossible To Match\n");
1990 unsigned numMatchLocs = read();
1992 matchLocs.reserve(numMatchLocs);
1993 for (
unsigned i = 0; i != numMatchLocs; ++i)
1994 matchLocs.push_back(read<Operation *>()->getLoc());
1997 LLVM_DEBUG(llvm::dbgs() <<
" * Benefit: " << benefit.
getBenefit() <<
"\n"
1998 <<
" * Location: " << matchLoc <<
"\n");
1999 matches.emplace_back(matchLoc, patterns[patternIndex], benefit);
2005 unsigned numInputs = read();
2006 match.
values.reserve(numInputs);
2009 for (
unsigned i = 0; i < numInputs; ++i) {
2010 switch (read<PDLValue::Kind>()) {
2020 match.
values.push_back(read<const void *>());
2028 LLVM_DEBUG(llvm::dbgs() <<
"Executing ReplaceOp:\n");
2034 llvm::dbgs() <<
" * Operation: " << *op <<
"\n"
2036 llvm::interleaveComma(args, llvm::dbgs());
2037 llvm::dbgs() <<
"\n";
2042 void ByteCodeExecutor::executeSwitchAttribute() {
2043 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchAttribute:\n");
2045 ArrayAttr cases = read<ArrayAttr>();
2046 handleSwitch(value, cases);
2049 void ByteCodeExecutor::executeSwitchOperandCount() {
2050 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperandCount:\n");
2052 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2054 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2058 void ByteCodeExecutor::executeSwitchOperationName() {
2059 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchOperationName:\n");
2061 size_t caseCount = read();
2068 llvm::dbgs() <<
" * Value: " << value <<
"\n"
2070 llvm::interleaveComma(
2071 llvm::map_range(llvm::seq<size_t>(0, caseCount),
2072 [&](
size_t) {
return read<OperationName>(); }),
2074 llvm::dbgs() <<
"\n";
2075 curCodeIt = prevCodeIt;
2079 for (
size_t i = 0; i != caseCount; ++i) {
2080 if (read<OperationName>() == value) {
2081 curCodeIt += (caseCount - i - 1);
2082 return selectJump(i + 1);
2085 selectJump(
size_t(0));
2088 void ByteCodeExecutor::executeSwitchResultCount() {
2089 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchResultCount:\n");
2091 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2093 LLVM_DEBUG(llvm::dbgs() <<
" * Operation: " << *op <<
"\n");
2097 void ByteCodeExecutor::executeSwitchType() {
2098 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchType:\n");
2099 Type value = read<Type>();
2100 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2101 handleSwitch(value, cases);
2104 void ByteCodeExecutor::executeSwitchTypes() {
2105 LLVM_DEBUG(llvm::dbgs() <<
"Executing SwitchTypes:\n");
2107 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2109 LLVM_DEBUG(llvm::dbgs() <<
"Types: <NULL>\n");
2110 return selectJump(
size_t(0));
2112 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2113 return value == caseValue.getAsValueRange<TypeAttr>();
2120 std::optional<Location> mainRewriteLoc) {
2123 LLVM_DEBUG(llvm::dbgs() << readInline<Location>() <<
"\n");
2125 OpCode opCode =
static_cast<OpCode
>(read());
2127 case ApplyConstraint:
2128 executeApplyConstraint(rewriter);
2131 if (
failed(executeApplyRewrite(rewriter)))
2137 case AreRangesEqual:
2138 executeAreRangesEqual();
2143 case CheckOperandCount:
2144 executeCheckOperandCount();
2146 case CheckOperationName:
2147 executeCheckOperationName();
2149 case CheckResultCount:
2150 executeCheckResultCount();
2153 executeCheckTypes();
2158 case CreateConstantTypeRange:
2159 executeCreateConstantTypeRange();
2161 case CreateOperation:
2162 executeCreateOperation(rewriter, *mainRewriteLoc);
2164 case CreateDynamicTypeRange:
2165 executeDynamicCreateRange<Type>(
"Type");
2167 case CreateDynamicValueRange:
2168 executeDynamicCreateRange<Value>(
"Value");
2171 executeEraseOp(rewriter);
2174 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2177 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2180 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2184 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2190 executeGetAttribute();
2192 case GetAttributeType:
2193 executeGetAttributeType();
2196 executeGetDefiningOp();
2202 unsigned index = opCode - GetOperand0;
2203 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperand" << index <<
":\n");
2204 executeGetOperand(index);
2208 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetOperandN:\n");
2209 executeGetOperand(read<uint32_t>());
2212 executeGetOperands();
2218 unsigned index = opCode - GetResult0;
2219 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResult" << index <<
":\n");
2220 executeGetResult(index);
2224 LLVM_DEBUG(llvm::dbgs() <<
"Executing GetResultN:\n");
2225 executeGetResult(read<uint32_t>());
2228 executeGetResults();
2234 executeGetValueType();
2236 case GetValueRangeTypes:
2237 executeGetValueRangeTypes();
2244 "expected matches to be provided when executing the matcher");
2245 executeRecordMatch(rewriter, *matches);
2248 executeReplaceOp(rewriter);
2250 case SwitchAttribute:
2251 executeSwitchAttribute();
2253 case SwitchOperandCount:
2254 executeSwitchOperandCount();
2256 case SwitchOperationName:
2257 executeSwitchOperationName();
2259 case SwitchResultCount:
2260 executeSwitchResultCount();
2263 executeSwitchType();
2266 executeSwitchTypes();
2269 LLVM_DEBUG(llvm::dbgs() <<
"\n");
2277 state.memory[0] = op;
2280 ByteCodeExecutor executor(
2281 matcherByteCode.data(), state.memory, state.opRangeMemory,
2282 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2283 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2284 uniquedData, matcherByteCode, state.currentPatternBenefits, patterns,
2285 constraintFunctions, rewriteFunctions);
2286 LogicalResult executeResult = executor.execute(rewriter, &matches);
2287 (void)executeResult;
2288 assert(
succeeded(executeResult) &&
"unexpected matcher execution failure");
2291 std::stable_sort(matches.begin(), matches.end(),
2293 return lhs.benefit > rhs.benefit;
2300 auto *configSet =
match.pattern->getConfigSet();
2302 configSet->notifyRewriteBegin(rewriter);
2308 ByteCodeExecutor executor(
2309 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2310 state.opRangeMemory, state.typeRangeMemory,
2311 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2312 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2313 rewriterByteCode, state.currentPatternBenefits, patterns,
2314 constraintFunctions, rewriteFunctions);
2316 executor.execute(rewriter,
nullptr,
match.location);
2319 configSet->notifyRewriteEnd(rewriter);
2328 LLVM_DEBUG(llvm::dbgs() <<
" and rollback is not supported - aborting");
2329 llvm::report_fatal_error(
2330 "Native PDL Rewrite failed, but the pattern "
2331 "rewriter doesn't support recovery. Failable pattern rewrites should "
2332 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
Block * getSuccessor(unsigned index)
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
SuccessorRange getSuccessors()
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class contains a set of configurations for a specific pattern.
The class represents a list of PDL results, returned by a native rewrite method.
Storage type of byte-code interpreter values.
Kind
The underlying kind of a PDL value.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
This class contains the mutable state of a bytecode instance.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a match/rewrite has been completed.
All of the data pertaining to a specific pattern within the bytecode.
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp, PDLPatternConfigSet *configSet, ByteCodeAddr rewriterAddr)
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
Run the pattern matcher on the given root operation, collecting the matched patterns in matches.
PDLByteCode(ModuleOp module, SmallVector< std::unique_ptr< PDLPatternConfigSet >> configs, const DenseMap< Operation *, PDLPatternConfigSet * > &configMap, llvm::StringMap< PDLConstraintFunction > constraintFns, llvm::StringMap< PDLRewriteFunction > rewriteFns)
Create a ByteCode instance from the given module containing operations in the PDL interpreter dialect...
void initializeMutableState(PDLByteCodeMutableState &state) const
Initialize the given state such that it can be used to execute the current bytecode.
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
Run the rewriter of the given pattern that was previously matched in match.
Detect if any of the given parameter types has a sub-element handler.
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
uint16_t ByteCodeField
Use generic bytecode types.
llvm::OwningArrayRef< Operation * > OwningOpRange
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
std::function< LogicalResult(PatternRewriter &, ArrayRef< PDLValue >)> PDLConstraintFunction
A generic PDL pattern constraint function.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
std::function< LogicalResult(PatternRewriter &, PDLResultList &, ArrayRef< PDLValue >)> PDLRewriteFunction
A native PDL rewrite function.
This class represents an efficient way to signal success or failure.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
This class acts as a special tag that makes the desire to match "any" operation type explicit.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
Each successful match returns a MatchResult, which contains information necessary to execute the rewr...
SmallVector< TypeRange, 0 > typeRangeValues
Memory used for the range input values.
SmallVector< ValueRange, 0 > valueRangeValues
SmallVector< const void * > values
Memory values defined in the matcher that are passed to the rewriter.