19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Support/DebugLog.h"
24 #include "llvm/Support/Format.h"
25 #include "llvm/Support/FormatVariadic.h"
26 #include "llvm/Support/InterleavedRange.h"
30 #define DEBUG_TYPE "pdl-bytecode"
40 PDLPatternConfigSet *configSet,
41 ByteCodeAddr rewriterAddr) {
47 if (ArrayAttr generatedOpsAttr = matchOp.getGeneratedOpsAttr())
49 llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
52 if (std::optional<StringRef> rootKind = matchOp.getRootKind())
56 benefit, ctx, generatedOps);
68 currentPatternBenefits[patternIndex] = benefit;
75 allocatedTypeRangeMemory.clear();
76 allocatedValueRangeMemory.clear();
84 enum OpCode : ByteCodeField {
106 CreateConstantTypeRange,
110 CreateDynamicTypeRange,
112 CreateDynamicValueRange,
187 struct ByteCodeLiveRange;
188 struct ByteCodeWriter;
191 template <
typename T,
typename... Args>
192 using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
197 Generator(
MLIRContext *ctx, std::vector<const void *> &uniquedData,
201 ByteCodeField &maxValueMemoryIndex,
202 ByteCodeField &maxOpRangeMemoryIndex,
203 ByteCodeField &maxTypeRangeMemoryIndex,
204 ByteCodeField &maxValueRangeMemoryIndex,
205 ByteCodeField &maxLoopLevel,
206 llvm::StringMap<PDLConstraintFunction> &constraintFns,
207 llvm::StringMap<PDLRewriteFunction> &rewriteFns,
209 : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
211 maxValueMemoryIndex(maxValueMemoryIndex),
212 maxOpRangeMemoryIndex(maxOpRangeMemoryIndex),
213 maxTypeRangeMemoryIndex(maxTypeRangeMemoryIndex),
214 maxValueRangeMemoryIndex(maxValueRangeMemoryIndex),
215 maxLoopLevel(maxLoopLevel), configMap(configMap) {
217 constraintToMemIndex.try_emplace(it.value().first(), it.index());
219 externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
223 void generate(ModuleOp module);
226 ByteCodeField &getMemIndex(
Value value) {
227 assert(valueToMemIndex.count(value) &&
228 "expected memory index to be assigned");
229 return valueToMemIndex[value];
233 ByteCodeField &getRangeStorageIndex(
Value value) {
234 assert(valueToRangeIndex.count(value) &&
235 "expected range index to be assigned");
236 return valueToRangeIndex[value];
241 template <
typename T>
242 std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
244 const void *opaqueVal = val.getAsOpaquePointer();
247 auto it = uniquedDataToMemIndex.try_emplace(
248 opaqueVal, maxValueMemoryIndex + uniquedData.size());
250 uniquedData.push_back(opaqueVal);
251 return it.first->second;
257 void allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
258 ModuleOp rewriterModule);
261 void generate(
Region *region, ByteCodeWriter &writer);
262 void generate(
Operation *op, ByteCodeWriter &writer);
263 void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
264 void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
265 void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
266 void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
267 void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
268 void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
269 void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
270 void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
271 void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
272 void generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer);
273 void generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer);
274 void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
275 void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
276 void generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer);
277 void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
278 void generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer);
279 void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
280 void generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer);
281 void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
282 void generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer);
283 void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
284 void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
285 void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
286 void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
287 void generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer);
288 void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
289 void generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer);
290 void generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer);
291 void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
292 void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
293 void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
294 void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
295 void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
296 void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
297 void generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer);
298 void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
299 void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
300 void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
310 llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
314 llvm::StringMap<ByteCodeField> constraintToMemIndex;
318 llvm::StringMap<ByteCodeAddr> rewriterToAddr;
325 ByteCodeField curLoopLevel = 0;
334 std::vector<const void *> &uniquedData;
338 ByteCodeField &maxValueMemoryIndex;
339 ByteCodeField &maxOpRangeMemoryIndex;
340 ByteCodeField &maxTypeRangeMemoryIndex;
341 ByteCodeField &maxValueRangeMemoryIndex;
342 ByteCodeField &maxLoopLevel;
349 struct ByteCodeWriter {
354 void append(ByteCodeField field) { bytecode.push_back(field); }
355 void append(OpCode opCode) { bytecode.push_back(opCode); }
358 void append(ByteCodeAddr field) {
359 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
360 "unexpected ByteCode address size");
362 ByteCodeField fieldParts[2];
363 std::memcpy(fieldParts, &field,
sizeof(ByteCodeAddr));
364 bytecode.append({fieldParts[0], fieldParts[1]});
369 void append(
Block *successor) {
372 unresolvedSuccessorRefs[successor].push_back(bytecode.size());
373 append(ByteCodeAddr(0));
379 for (
Block *successor : successors)
385 bytecode.push_back(values.size());
386 for (
Value value : values)
387 appendPDLValue(value);
391 void appendPDLValue(
Value value) {
392 appendPDLValueKind(value);
397 void appendPDLValueKind(
Value value) { appendPDLValueKind(value.
getType()); }
400 void appendPDLValueKind(
Type type) {
403 .Case<pdl::AttributeType>(
404 [](
Type) {
return PDLValue::Kind::Attribute; })
405 .Case<pdl::OperationType>(
406 [](
Type) {
return PDLValue::Kind::Operation; })
407 .Case<pdl::RangeType>([](pdl::RangeType rangeTy) {
408 if (isa<pdl::TypeType>(rangeTy.getElementType()))
409 return PDLValue::Kind::TypeRange;
410 return PDLValue::Kind::ValueRange;
412 .Case<pdl::TypeType>([](
Type) {
return PDLValue::Kind::Type; })
413 .Case<pdl::ValueType>([](
Type) {
return PDLValue::Kind::Value; });
414 bytecode.push_back(
static_cast<ByteCodeField
>(
kind));
419 template <
typename T>
420 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
421 std::is_pointer<T>::value>
423 bytecode.push_back(
generator.getMemIndex(value));
427 template <
typename T,
typename IteratorT = llvm::detail::IterOfRange<T>>
428 std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
430 bytecode.push_back(llvm::size(range));
431 for (
auto it : range)
436 template <
typename FieldTy,
typename Field2Ty,
typename... FieldTys>
437 void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
439 append(field2, fields...);
443 template <
typename T>
444 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
445 appendInline(T value) {
446 constexpr
size_t numParts =
sizeof(
const void *) /
sizeof(ByteCodeField);
447 const void *pointer = value.getAsOpaquePointer();
448 ByteCodeField fieldParts[numParts];
449 std::memcpy(fieldParts, &pointer,
sizeof(
const void *));
450 bytecode.append(fieldParts, fieldParts + numParts);
465 struct ByteCodeLiveRange {
466 using Set = llvm::IntervalMap<uint64_t, char, 16>;
467 using Allocator = Set::Allocator;
469 ByteCodeLiveRange(Allocator &alloc) : liveness(new Set(alloc)) {}
472 void unionWith(
const ByteCodeLiveRange &rhs) {
473 for (
auto it = rhs.liveness->begin(), e = rhs.liveness->end(); it != e;
475 liveness->insert(it.start(), it.stop(), 0);
479 bool overlaps(
const ByteCodeLiveRange &rhs)
const {
480 return llvm::IntervalMapOverlaps<Set, Set>(*liveness, *rhs.liveness)
490 std::unique_ptr<llvm::IntervalMap<uint64_t, char, 16>> liveness;
493 std::optional<unsigned> opRangeIndex;
496 std::optional<unsigned> typeRangeIndex;
499 std::optional<unsigned> valueRangeIndex;
503 void Generator::generate(ModuleOp module) {
504 auto matcherFunc = module.lookupSymbol<pdl_interp::FuncOp>(
505 pdl_interp::PDLInterpDialect::getMatcherFunctionName());
506 ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
507 pdl_interp::PDLInterpDialect::getRewriterModuleName());
508 assert(matcherFunc && rewriterModule &&
"invalid PDL Interpreter module");
512 allocateMemoryIndices(matcherFunc, rewriterModule);
515 ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *
this);
516 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
517 rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
518 for (
Operation &op : rewriterFunc.getOps())
519 generate(&op, rewriterByteCodeWriter);
521 assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
522 "unexpected branches in rewriter function");
525 ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *
this);
526 generate(&matcherFunc.getBody(), matcherByteCodeWriter);
529 for (
auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
530 ByteCodeAddr addr = blockToAddr[it.first];
531 for (
unsigned offsetToFix : it.second)
532 std::memcpy(&matcherByteCode[offsetToFix], &addr,
sizeof(ByteCodeAddr));
536 void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc,
537 ModuleOp rewriterModule) {
540 for (
auto rewriterFunc : rewriterModule.getOps<pdl_interp::FuncOp>()) {
541 ByteCodeField index = 0, typeRangeIndex = 0, valueRangeIndex = 0;
542 auto processRewriterValue = [&](
Value val) {
543 valueToMemIndex.try_emplace(val, index++);
544 if (pdl::RangeType rangeType = dyn_cast<pdl::RangeType>(val.getType())) {
545 Type elementTy = rangeType.getElementType();
546 if (isa<pdl::TypeType>(elementTy))
547 valueToRangeIndex.try_emplace(val, typeRangeIndex++);
548 else if (isa<pdl::ValueType>(elementTy))
549 valueToRangeIndex.try_emplace(val, valueRangeIndex++);
554 processRewriterValue(arg);
557 processRewriterValue(result);
559 if (index > maxValueMemoryIndex)
560 maxValueMemoryIndex = index;
561 if (typeRangeIndex > maxTypeRangeMemoryIndex)
562 maxTypeRangeMemoryIndex = typeRangeIndex;
563 if (valueRangeIndex > maxValueRangeMemoryIndex)
564 maxValueRangeMemoryIndex = valueRangeIndex;
581 opToFirstIndex.try_emplace(op, index++);
583 for (
Block &block : region.getBlocks())
586 opToLastIndex.try_emplace(op, index++);
591 ByteCodeLiveRange::Allocator allocator;
596 valueToMemIndex[rootOpArg] = 0;
599 Liveness matcherLiveness(matcherFunc);
600 matcherFunc->walk([&](
Block *block) {
602 assert(info &&
"expected liveness info for block");
606 if (value == rootOpArg)
610 auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
611 defRangeIt->second.liveness->insert(
612 opToFirstIndex[firstUseOrDef],
617 if (
auto rangeTy = dyn_cast<pdl::RangeType>(value.
getType())) {
618 Type eleType = rangeTy.getElementType();
619 if (isa<pdl::OperationType>(eleType))
620 defRangeIt->second.opRangeIndex = 0;
621 else if (isa<pdl::TypeType>(eleType))
622 defRangeIt->second.typeRangeIndex = 0;
623 else if (isa<pdl::ValueType>(eleType))
624 defRangeIt->second.valueRangeIndex = 0;
629 for (
Value liveIn : info->
in()) {
634 if (liveIn.getParentRegion() == block->
getParent())
651 std::vector<ByteCodeLiveRange> allocatedIndices;
655 ByteCodeField numIndices = 1;
658 ByteCodeField numOpRanges = 0, numTypeRanges = 0, numValueRanges = 0;
660 for (
auto &defIt : valueDefRanges) {
661 ByteCodeField &memIndex = valueToMemIndex[defIt.first];
662 ByteCodeLiveRange &defRange = defIt.second;
665 for (
const auto &existingIndexIt :
llvm::enumerate(allocatedIndices)) {
666 ByteCodeLiveRange &existingRange = existingIndexIt.value();
667 if (!defRange.overlaps(existingRange)) {
668 existingRange.unionWith(defRange);
669 memIndex = existingIndexIt.index() + 1;
671 if (defRange.opRangeIndex) {
672 if (!existingRange.opRangeIndex)
673 existingRange.opRangeIndex = numOpRanges++;
674 valueToRangeIndex[defIt.first] = *existingRange.opRangeIndex;
675 }
else if (defRange.typeRangeIndex) {
676 if (!existingRange.typeRangeIndex)
677 existingRange.typeRangeIndex = numTypeRanges++;
678 valueToRangeIndex[defIt.first] = *existingRange.typeRangeIndex;
679 }
else if (defRange.valueRangeIndex) {
680 if (!existingRange.valueRangeIndex)
681 existingRange.valueRangeIndex = numValueRanges++;
682 valueToRangeIndex[defIt.first] = *existingRange.valueRangeIndex;
690 allocatedIndices.emplace_back(allocator);
691 ByteCodeLiveRange &newRange = allocatedIndices.back();
692 newRange.unionWith(defRange);
695 if (defRange.opRangeIndex) {
696 newRange.opRangeIndex = numOpRanges;
697 valueToRangeIndex[defIt.first] = numOpRanges++;
698 }
else if (defRange.typeRangeIndex) {
699 newRange.typeRangeIndex = numTypeRanges;
700 valueToRangeIndex[defIt.first] = numTypeRanges++;
701 }
else if (defRange.valueRangeIndex) {
702 newRange.valueRangeIndex = numValueRanges;
703 valueToRangeIndex[defIt.first] = numValueRanges++;
706 memIndex = allocatedIndices.size();
712 LDBG() <<
"Allocated " << allocatedIndices.size() <<
" indices "
713 <<
"(down from initial " << valueDefRanges.size() <<
").";
715 "Ran out of memory for allocated indices");
718 if (numIndices > maxValueMemoryIndex)
719 maxValueMemoryIndex = numIndices;
720 if (numOpRanges > maxOpRangeMemoryIndex)
721 maxOpRangeMemoryIndex = numOpRanges;
722 if (numTypeRanges > maxTypeRangeMemoryIndex)
723 maxTypeRangeMemoryIndex = numTypeRanges;
724 if (numValueRanges > maxValueRangeMemoryIndex)
725 maxValueRangeMemoryIndex = numValueRanges;
728 void Generator::generate(
Region *region, ByteCodeWriter &writer) {
729 llvm::ReversePostOrderTraversal<Region *> rpot(region);
730 for (
Block *block : rpot) {
732 blockToAddr.try_emplace(block, matcherByteCode.size());
734 generate(&op, writer);
738 void Generator::generate(
Operation *op, ByteCodeWriter &writer) {
739 LDBG() <<
"Generating bytecode for operation: " << op->
getName();
743 if (!isa<pdl_interp::CreateAttributeOp, pdl_interp::CreateTypeOp>(op))
744 writer.appendInline(op->
getLoc());
747 .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
748 pdl_interp::AreEqualOp, pdl_interp::BranchOp,
749 pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
750 pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
751 pdl_interp::CheckTypeOp, pdl_interp::CheckTypesOp,
752 pdl_interp::ContinueOp, pdl_interp::CreateAttributeOp,
753 pdl_interp::CreateOperationOp, pdl_interp::CreateRangeOp,
754 pdl_interp::CreateTypeOp, pdl_interp::CreateTypesOp,
755 pdl_interp::EraseOp, pdl_interp::ExtractOp, pdl_interp::FinalizeOp,
756 pdl_interp::ForEachOp, pdl_interp::GetAttributeOp,
757 pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
758 pdl_interp::GetOperandOp, pdl_interp::GetOperandsOp,
759 pdl_interp::GetResultOp, pdl_interp::GetResultsOp,
760 pdl_interp::GetUsersOp, pdl_interp::GetValueTypeOp,
761 pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
762 pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
763 pdl_interp::SwitchTypeOp, pdl_interp::SwitchTypesOp,
764 pdl_interp::SwitchOperandCountOp, pdl_interp::SwitchOperationNameOp,
765 pdl_interp::SwitchResultCountOp>(
766 [&](
auto interpOp) { this->generate(interpOp, writer); })
767 .DefaultUnreachable(
"unknown `pdl_interp` operation");
770 void Generator::generate(pdl_interp::ApplyConstraintOp op,
771 ByteCodeWriter &writer) {
775 writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.getName()]);
776 writer.appendPDLValueList(op.getArgs());
777 writer.append(ByteCodeField(op.getIsNegated()));
779 writer.append(ByteCodeField(results.size()));
780 for (
Value result : results) {
784 writer.appendPDLValueKind(result);
787 if (isa<pdl::RangeType>(result.getType()))
788 writer.append(getRangeStorageIndex(result));
789 writer.append(result);
791 writer.append(op.getSuccessors());
793 void Generator::generate(pdl_interp::ApplyRewriteOp op,
794 ByteCodeWriter &writer) {
795 assert(externalRewriterToMemIndex.count(op.getName()) &&
796 "expected index for rewrite function");
797 writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.getName()]);
798 writer.appendPDLValueList(op.getArgs());
801 writer.append(ByteCodeField(results.size()));
802 for (
Value result : results) {
805 writer.appendPDLValueKind(result);
808 if (isa<pdl::RangeType>(result.getType()))
809 writer.append(getRangeStorageIndex(result));
810 writer.append(result);
813 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
814 Value lhs = op.getLhs();
815 if (isa<pdl::RangeType>(lhs.
getType())) {
816 writer.append(OpCode::AreRangesEqual);
817 writer.appendPDLValueKind(lhs);
818 writer.append(op.getLhs(), op.getRhs(), op.getSuccessors());
822 writer.append(OpCode::AreEqual, lhs, op.getRhs(), op.getSuccessors());
824 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
827 void Generator::generate(pdl_interp::CheckAttributeOp op,
828 ByteCodeWriter &writer) {
829 writer.append(OpCode::AreEqual, op.getAttribute(), op.getConstantValue(),
832 void Generator::generate(pdl_interp::CheckOperandCountOp op,
833 ByteCodeWriter &writer) {
834 writer.append(OpCode::CheckOperandCount, op.getInputOp(), op.getCount(),
835 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
838 void Generator::generate(pdl_interp::CheckOperationNameOp op,
839 ByteCodeWriter &writer) {
840 writer.append(OpCode::CheckOperationName, op.getInputOp(),
843 void Generator::generate(pdl_interp::CheckResultCountOp op,
844 ByteCodeWriter &writer) {
845 writer.append(OpCode::CheckResultCount, op.getInputOp(), op.getCount(),
846 static_cast<ByteCodeField
>(op.getCompareAtLeast()),
849 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
850 writer.append(OpCode::AreEqual, op.getValue(), op.getType(),
853 void Generator::generate(pdl_interp::CheckTypesOp op, ByteCodeWriter &writer) {
854 writer.append(OpCode::CheckTypes, op.getValue(), op.getTypes(),
857 void Generator::generate(pdl_interp::ContinueOp op, ByteCodeWriter &writer) {
858 assert(curLoopLevel > 0 &&
"encountered pdl_interp.continue at top level");
859 writer.append(OpCode::Continue, ByteCodeField(curLoopLevel - 1));
861 void Generator::generate(pdl_interp::CreateAttributeOp op,
862 ByteCodeWriter &writer) {
864 getMemIndex(op.getAttribute()) = getMemIndex(op.getValue());
866 void Generator::generate(pdl_interp::CreateOperationOp op,
867 ByteCodeWriter &writer) {
868 writer.append(OpCode::CreateOperation, op.getResultOp(),
870 writer.appendPDLValueList(op.getInputOperands());
874 writer.append(
static_cast<ByteCodeField
>(attributes.size()));
875 for (
auto it : llvm::zip(op.getInputAttributeNames(), attributes))
876 writer.append(std::get<0>(it), std::get<1>(it));
880 if (op.getInferredResultTypes())
883 writer.appendPDLValueList(op.getInputResultTypes());
885 void Generator::generate(pdl_interp::CreateRangeOp op, ByteCodeWriter &writer) {
889 [&](pdl::TypeType) { writer.append(OpCode::CreateDynamicTypeRange); })
890 .Case([&](pdl::ValueType) {
891 writer.append(OpCode::CreateDynamicValueRange);
894 writer.append(op.getResult(), getRangeStorageIndex(op.getResult()));
895 writer.appendPDLValueList(op->getOperands());
897 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
899 getMemIndex(op.getResult()) = getMemIndex(op.getValue());
901 void Generator::generate(pdl_interp::CreateTypesOp op, ByteCodeWriter &writer) {
902 writer.append(OpCode::CreateConstantTypeRange, op.getResult(),
903 getRangeStorageIndex(op.getResult()), op.getValue());
905 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
906 writer.append(OpCode::EraseOp, op.getInputOp());
908 void Generator::generate(pdl_interp::ExtractOp op, ByteCodeWriter &writer) {
911 .Case([](pdl::OperationType) {
return OpCode::ExtractOp; })
912 .Case([](pdl::ValueType) {
return OpCode::ExtractValue; })
913 .Case([](pdl::TypeType) {
return OpCode::ExtractType; })
914 .DefaultUnreachable(
"unsupported element type");
915 writer.append(opCode, op.getRange(), op.getIndex(), op.getResult());
917 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
918 writer.append(OpCode::Finalize);
920 void Generator::generate(pdl_interp::ForEachOp op, ByteCodeWriter &writer) {
922 writer.append(OpCode::ForEach, getRangeStorageIndex(op.getValues()), arg);
923 writer.appendPDLValueKind(arg.
getType());
924 writer.append(curLoopLevel, op.getSuccessor());
926 if (curLoopLevel > maxLoopLevel)
927 maxLoopLevel = curLoopLevel;
928 generate(&op.getRegion(), writer);
931 void Generator::generate(pdl_interp::GetAttributeOp op,
932 ByteCodeWriter &writer) {
933 writer.append(OpCode::GetAttribute, op.getAttribute(), op.getInputOp(),
936 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
937 ByteCodeWriter &writer) {
938 writer.append(OpCode::GetAttributeType, op.getResult(), op.getValue());
940 void Generator::generate(pdl_interp::GetDefiningOpOp op,
941 ByteCodeWriter &writer) {
942 writer.append(OpCode::GetDefiningOp, op.getInputOp());
943 writer.appendPDLValue(op.getValue());
945 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
946 uint32_t index = op.getIndex();
948 writer.append(
static_cast<OpCode
>(OpCode::GetOperand0 + index));
950 writer.append(OpCode::GetOperandN, index);
951 writer.append(op.getInputOp(), op.getValue());
953 void Generator::generate(pdl_interp::GetOperandsOp op, ByteCodeWriter &writer) {
954 Value result = op.getValue();
955 std::optional<uint32_t> index = op.getIndex();
956 writer.append(OpCode::GetOperands,
959 if (isa<pdl::RangeType>(result.
getType()))
960 writer.append(getRangeStorageIndex(result));
963 writer.append(result);
965 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
966 uint32_t index = op.getIndex();
968 writer.append(
static_cast<OpCode
>(OpCode::GetResult0 + index));
970 writer.append(OpCode::GetResultN, index);
971 writer.append(op.getInputOp(), op.getValue());
973 void Generator::generate(pdl_interp::GetResultsOp op, ByteCodeWriter &writer) {
974 Value result = op.getValue();
975 std::optional<uint32_t> index = op.getIndex();
976 writer.append(OpCode::GetResults,
979 if (isa<pdl::RangeType>(result.
getType()))
980 writer.append(getRangeStorageIndex(result));
983 writer.append(result);
985 void Generator::generate(pdl_interp::GetUsersOp op, ByteCodeWriter &writer) {
986 Value operations = op.getOperations();
987 ByteCodeField rangeIndex = getRangeStorageIndex(operations);
988 writer.append(OpCode::GetUsers, operations, rangeIndex);
989 writer.appendPDLValue(op.getValue());
991 void Generator::generate(pdl_interp::GetValueTypeOp op,
992 ByteCodeWriter &writer) {
993 if (isa<pdl::RangeType>(op.getType())) {
994 Value result = op.getResult();
995 writer.append(OpCode::GetValueRangeTypes, result,
996 getRangeStorageIndex(result), op.getValue());
998 writer.append(OpCode::GetValueType, op.getResult(), op.getValue());
1001 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
1002 writer.append(OpCode::IsNotNull, op.getValue(), op.getSuccessors());
1004 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
1005 ByteCodeField patternIndex =
patterns.size();
1006 patterns.emplace_back(PDLByteCodePattern::create(
1007 op, configMap.lookup(op),
1008 rewriterToAddr[op.getRewriter().getLeafReference().getValue()]));
1009 writer.append(OpCode::RecordMatch, patternIndex,
1011 writer.appendPDLValueList(op.getInputs());
1013 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
1014 writer.append(OpCode::ReplaceOp, op.getInputOp());
1015 writer.appendPDLValueList(op.getReplValues());
1017 void Generator::generate(pdl_interp::SwitchAttributeOp op,
1018 ByteCodeWriter &writer) {
1019 writer.append(OpCode::SwitchAttribute, op.getAttribute(),
1020 op.getCaseValuesAttr(), op.getSuccessors());
1022 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
1023 ByteCodeWriter &writer) {
1024 writer.append(OpCode::SwitchOperandCount, op.getInputOp(),
1025 op.getCaseValuesAttr(), op.getSuccessors());
1027 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
1028 ByteCodeWriter &writer) {
1029 auto cases = llvm::map_range(op.getCaseValuesAttr(), [&](
Attribute attr) {
1030 return OperationName(cast<StringAttr>(attr).getValue(), ctx);
1032 writer.append(OpCode::SwitchOperationName, op.getInputOp(), cases,
1033 op.getSuccessors());
1035 void Generator::generate(pdl_interp::SwitchResultCountOp op,
1036 ByteCodeWriter &writer) {
1037 writer.append(OpCode::SwitchResultCount, op.getInputOp(),
1038 op.getCaseValuesAttr(), op.getSuccessors());
1040 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
1041 writer.append(OpCode::SwitchType, op.getValue(), op.getCaseValuesAttr(),
1042 op.getSuccessors());
1044 void Generator::generate(pdl_interp::SwitchTypesOp op, ByteCodeWriter &writer) {
1045 writer.append(OpCode::SwitchTypes, op.getValue(), op.getCaseValuesAttr(),
1046 op.getSuccessors());
1053 PDLByteCode::PDLByteCode(
1054 ModuleOp module,
SmallVector<std::unique_ptr<PDLPatternConfigSet>> configs,
1056 llvm::StringMap<PDLConstraintFunction> constraintFns,
1057 llvm::StringMap<PDLRewriteFunction> rewriteFns)
1058 : configs(std::move(configs)) {
1059 Generator
generator(module.getContext(), uniquedData, matcherByteCode,
1060 rewriterByteCode,
patterns, maxValueMemoryIndex,
1061 maxOpRangeCount, maxTypeRangeCount, maxValueRangeCount,
1062 maxLoopLevel, constraintFns, rewriteFns, configMap);
1066 for (
auto &it : constraintFns)
1067 constraintFunctions.push_back(std::move(it.second));
1068 for (
auto &it : rewriteFns)
1069 rewriteFunctions.push_back(std::move(it.second));
1075 state.memory.resize(maxValueMemoryIndex,
nullptr);
1076 state.opRangeMemory.resize(maxOpRangeCount);
1077 state.typeRangeMemory.resize(maxTypeRangeCount,
TypeRange());
1078 state.valueRangeMemory.resize(maxValueRangeCount,
ValueRange());
1079 state.loopIndex.resize(maxLoopLevel, 0);
1080 state.currentPatternBenefits.reserve(
patterns.size());
1082 state.currentPatternBenefits.push_back(pattern.getBenefit());
1093 class ByteCodeRewriteResultList :
public PDLResultList {
1095 ByteCodeRewriteResultList(
unsigned maxNumResults)
1096 : PDLResultList(maxNumResults) {}
1103 return allocatedTypeRanges;
1108 return allocatedValueRanges;
1113 class ByteCodeExecutor {
1119 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory,
1121 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory,
1128 : curCodeIt(curCodeIt), memory(memory), opRangeMemory(opRangeMemory),
1129 typeRangeMemory(typeRangeMemory),
1130 allocatedTypeRangeMemory(allocatedTypeRangeMemory),
1131 valueRangeMemory(valueRangeMemory),
1132 allocatedValueRangeMemory(allocatedValueRangeMemory),
1133 loopIndex(loopIndex), uniquedMemory(uniquedMemory), code(code),
1135 constraintFunctions(constraintFunctions),
1136 rewriteFunctions(rewriteFunctions) {}
1144 std::optional<Location> mainRewriteLoc = {});
1150 void executeAreEqual();
1151 void executeAreRangesEqual();
1152 void executeBranch();
1153 void executeCheckOperandCount();
1154 void executeCheckOperationName();
1155 void executeCheckResultCount();
1156 void executeCheckTypes();
1157 void executeContinue();
1158 void executeCreateConstantTypeRange();
1161 template <
typename T>
1162 void executeDynamicCreateRange(StringRef type);
1164 template <
typename T,
typename Range, PDLValue::Kind kind>
1165 void executeExtract();
1166 void executeFinalize();
1167 void executeForEach();
1168 void executeGetAttribute();
1169 void executeGetAttributeType();
1170 void executeGetDefiningOp();
1171 void executeGetOperand(
unsigned index);
1172 void executeGetOperands();
1173 void executeGetResult(
unsigned index);
1174 void executeGetResults();
1175 void executeGetUsers();
1176 void executeGetValueType();
1177 void executeGetValueRangeTypes();
1178 void executeIsNotNull();
1182 void executeSwitchAttribute();
1183 void executeSwitchOperandCount();
1184 void executeSwitchOperationName();
1185 void executeSwitchResultCount();
1186 void executeSwitchType();
1187 void executeSwitchTypes();
1188 void processNativeFunResults(ByteCodeRewriteResultList &results,
1189 unsigned numResults,
1190 LogicalResult &rewriteResult);
1193 void pushCodeIt(
const ByteCodeField *it) { resumeCodeIt.push_back(it); }
1197 assert(!resumeCodeIt.empty() &&
"attempt to pop code off empty stack");
1198 curCodeIt = resumeCodeIt.pop_back_val();
1202 const ByteCodeField *getPrevCodeIt()
const {
1205 return curCodeIt - 1 -
sizeof(
const void *) /
sizeof(ByteCodeField);
1209 return curCodeIt - 1;
1215 template <
typename T = ByteCodeField>
1216 T read(
size_t skipN = 0) {
1218 return readImpl<T>();
1220 ByteCodeField read(
size_t skipN = 0) {
return read<ByteCodeField>(skipN); }
1223 template <
typename ValueT,
typename T>
1226 for (
unsigned i = 0, e = read(); i != e; ++i)
1227 list.push_back(read<ValueT>());
1233 for (
unsigned i = 0, e = read(); i != e; ++i) {
1234 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1235 list.push_back(read<Type>());
1237 TypeRange *values = read<TypeRange *>();
1238 list.append(values->begin(), values->end());
1243 for (
unsigned i = 0, e = read(); i != e; ++i) {
1244 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1245 list.push_back(read<Value>());
1248 list.append(values->begin(), values->end());
1254 template <
typename T>
1255 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value, T>
1257 const void *pointer;
1258 std::memcpy(&pointer, curCodeIt,
sizeof(
const void *));
1259 curCodeIt +=
sizeof(
const void *) /
sizeof(ByteCodeField);
1260 return T::getFromOpaquePointer(pointer);
1263 void skip(
size_t skipN) { curCodeIt += skipN; }
1266 void selectJump(
bool isTrue) { selectJump(
size_t(isTrue ? 0 : 1)); }
1268 void selectJump(
size_t destIndex) {
1269 curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
1273 template <
typename T,
typename RangeT,
typename Comparator = std::equal_to<T>>
1274 void handleSwitch(
const T &value, RangeT &&cases, Comparator cmp = {}) {
1275 LDBG() <<
"Switch operation:\n * Value: " << value
1276 <<
"\n * Cases: " << llvm::interleaved(cases);
1280 for (
auto it = cases.begin(), e = cases.end(); it != e; ++it)
1281 if (cmp(*it, value))
1282 return selectJump(
size_t((it - cases.begin()) + 1));
1283 selectJump(
size_t(0));
1287 void storeToMemory(
unsigned index,
const void *value) {
1288 memory[index] = value;
1292 template <
typename T>
1293 std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value>
1294 storeToMemory(
unsigned index, T value) {
1295 memory[index] = value.getAsOpaquePointer();
1300 template <
typename T>
1301 const void *readFromMemory() {
1302 size_t index = *curCodeIt++;
1307 index < memory.size())
1308 return memory[index];
1311 return uniquedMemory[index - memory.size()];
1313 template <
typename T>
1314 std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
1315 return reinterpret_cast<T
>(
const_cast<void *
>(readFromMemory<T>()));
1317 template <
typename T>
1318 std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
1321 return T(T::getFromOpaquePointer(readFromMemory<T>()));
1323 template <
typename T>
1324 std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
1325 switch (read<PDLValue::Kind>()) {
1326 case PDLValue::Kind::Attribute:
1327 return read<Attribute>();
1328 case PDLValue::Kind::Operation:
1329 return read<Operation *>();
1330 case PDLValue::Kind::Type:
1331 return read<Type>();
1332 case PDLValue::Kind::Value:
1333 return read<Value>();
1334 case PDLValue::Kind::TypeRange:
1335 return read<TypeRange *>();
1336 case PDLValue::Kind::ValueRange:
1337 return read<ValueRange *>();
1339 llvm_unreachable(
"unhandled PDLValue::Kind");
1341 template <
typename T>
1342 std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
1343 static_assert((
sizeof(ByteCodeAddr) /
sizeof(ByteCodeField)) == 2,
1344 "unexpected ByteCode address size");
1345 ByteCodeAddr result;
1346 std::memcpy(&result, curCodeIt,
sizeof(ByteCodeAddr));
1350 template <
typename T>
1351 std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
1352 return *curCodeIt++;
1354 template <
typename T>
1355 std::enable_if_t<std::is_same<T, PDLValue::Kind>::value, T> readImpl() {
1361 template <
typename RangeT,
typename T = llvm::detail::ValueOfRange<RangeT>>
1362 void assignRangeToMemory(RangeT &&range,
unsigned memIndex,
1363 unsigned rangeIndex) {
1365 auto assignRange = [&](
auto &allocatedRangeMemory,
auto &rangeMemory) {
1367 if (range.empty()) {
1368 rangeMemory[rangeIndex] = {};
1371 llvm::OwningArrayRef<T> storage(llvm::size(range));
1376 allocatedRangeMemory.emplace_back(std::move(storage));
1377 rangeMemory[rangeIndex] = allocatedRangeMemory.back();
1379 memory[memIndex] = &rangeMemory[rangeIndex];
1383 if constexpr (std::is_same_v<T, Type>) {
1384 return assignRange(allocatedTypeRangeMemory, typeRangeMemory);
1385 }
else if constexpr (std::is_same_v<T, Value>) {
1386 return assignRange(allocatedValueRangeMemory, valueRangeMemory);
1388 llvm_unreachable(
"unhandled range type");
1393 const ByteCodeField *curCodeIt;
1402 std::vector<llvm::OwningArrayRef<Type>> &allocatedTypeRangeMemory;
1404 std::vector<llvm::OwningArrayRef<Value>> &allocatedValueRangeMemory;
1419 void ByteCodeExecutor::executeApplyConstraint(
PatternRewriter &rewriter) {
1420 LDBG() <<
"Executing ApplyConstraint:";
1421 ByteCodeField fun_idx = read();
1423 readList<PDLValue>(args);
1425 LDBG() <<
" * Arguments: " << llvm::interleaved(args);
1427 ByteCodeField isNegated = read();
1428 LDBG() <<
" * isNegated: " << isNegated;
1430 ByteCodeField numResults = read();
1431 const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx];
1432 ByteCodeRewriteResultList results(numResults);
1433 LogicalResult rewriteResult = constraintFn(rewriter, results, args);
1435 if (succeeded(rewriteResult)) {
1436 LDBG() <<
" * Constraint succeeded, results: "
1437 << llvm::interleaved(constraintResults);
1439 LDBG() <<
" * Constraint failed";
1441 assert((
failed(rewriteResult) || constraintResults.size() == numResults) &&
1442 "native PDL rewrite function succeeded but returned "
1443 "unexpected number of results");
1444 processNativeFunResults(results, numResults, rewriteResult);
1447 selectJump(isNegated != succeeded(rewriteResult));
1450 LogicalResult ByteCodeExecutor::executeApplyRewrite(
PatternRewriter &rewriter) {
1451 LDBG() <<
"Executing ApplyRewrite:";
1452 const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
1454 readList<PDLValue>(args);
1456 LDBG() <<
" * Arguments: " << llvm::interleaved(args);
1459 ByteCodeField numResults = read();
1460 ByteCodeRewriteResultList results(numResults);
1461 LogicalResult rewriteResult = rewriteFn(rewriter, results, args);
1463 assert(results.getResults().size() == numResults &&
1464 "native PDL rewrite function returned unexpected number of results");
1466 processNativeFunResults(results, numResults, rewriteResult);
1468 if (
failed(rewriteResult)) {
1469 LDBG() <<
" - Failed";
1475 void ByteCodeExecutor::processNativeFunResults(
1476 ByteCodeRewriteResultList &results,
unsigned numResults,
1477 LogicalResult &rewriteResult) {
1478 if (
failed(rewriteResult)) {
1481 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1483 if (resultKind == PDLValue::Kind::TypeRange ||
1484 resultKind == PDLValue::Kind::ValueRange) {
1494 for (
unsigned resultIdx = 0; resultIdx < numResults; resultIdx++) {
1497 PDLValue result = results.getResults()[resultIdx];
1498 LDBG() <<
" * Result: " << result;
1499 assert(result.getKind() == resultKind &&
1500 "native PDL rewrite function returned an unexpected type of "
1504 if (std::optional<TypeRange> typeRange = result.dyn_cast<
TypeRange>()) {
1505 unsigned rangeIndex = read();
1506 typeRangeMemory[rangeIndex] = *typeRange;
1507 memory[read()] = &typeRangeMemory[rangeIndex];
1508 }
else if (std::optional<ValueRange> valueRange =
1510 unsigned rangeIndex = read();
1511 valueRangeMemory[rangeIndex] = *valueRange;
1512 memory[read()] = &valueRangeMemory[rangeIndex];
1514 memory[read()] = result.getAsOpaquePointer();
1519 for (
auto &it : results.getAllocatedTypeRanges())
1520 allocatedTypeRangeMemory.push_back(std::move(it));
1521 for (
auto &it : results.getAllocatedValueRanges())
1522 allocatedValueRangeMemory.push_back(std::move(it));
1525 void ByteCodeExecutor::executeAreEqual() {
1526 LDBG() <<
"Executing AreEqual:";
1527 const void *lhs = read<const void *>();
1528 const void *rhs = read<const void *>();
1530 LDBG() <<
" * " << lhs <<
" == " << rhs;
1531 selectJump(lhs == rhs);
1534 void ByteCodeExecutor::executeAreRangesEqual() {
1535 LDBG() <<
"Executing AreRangesEqual:";
1537 const void *lhs = read<const void *>();
1538 const void *rhs = read<const void *>();
1540 switch (valueKind) {
1541 case PDLValue::Kind::TypeRange: {
1544 LDBG() <<
" * " << lhs <<
" == " << rhs;
1545 selectJump(*lhsRange == *rhsRange);
1548 case PDLValue::Kind::ValueRange: {
1549 const auto *lhsRange =
reinterpret_cast<const ValueRange *
>(lhs);
1550 const auto *rhsRange =
reinterpret_cast<const ValueRange *
>(rhs);
1551 LDBG() <<
" * " << lhs <<
" == " << rhs;
1552 selectJump(*lhsRange == *rhsRange);
1556 llvm_unreachable(
"unexpected `AreRangesEqual` value kind");
1560 void ByteCodeExecutor::executeBranch() {
1561 LDBG() <<
"Executing Branch";
1562 curCodeIt = &code[read<ByteCodeAddr>()];
1565 void ByteCodeExecutor::executeCheckOperandCount() {
1566 LDBG() <<
"Executing CheckOperandCount:";
1568 uint32_t expectedCount = read<uint32_t>();
1569 bool compareAtLeast = read();
1572 <<
"\n * Expected: " << expectedCount
1573 <<
"\n * Comparator: " << (compareAtLeast ?
">=" :
"==");
1580 void ByteCodeExecutor::executeCheckOperationName() {
1581 LDBG() <<
"Executing CheckOperationName:";
1585 LDBG() <<
" * Found: \"" << op->
getName() <<
"\"\n * Expected: \""
1586 << expectedName <<
"\"";
1587 selectJump(op->
getName() == expectedName);
1590 void ByteCodeExecutor::executeCheckResultCount() {
1591 LDBG() <<
"Executing CheckResultCount:";
1593 uint32_t expectedCount = read<uint32_t>();
1594 bool compareAtLeast = read();
1597 <<
"\n * Expected: " << expectedCount
1598 <<
"\n * Comparator: " << (compareAtLeast ?
">=" :
"==");
1605 void ByteCodeExecutor::executeCheckTypes() {
1606 LDBG() <<
"Executing AreEqual:";
1609 LDBG() <<
" * " << lhs <<
" == " << rhs;
1611 selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>());
1614 void ByteCodeExecutor::executeContinue() {
1615 ByteCodeField level = read();
1616 LDBG() <<
"Executing Continue\n * Level: " << level;
1621 void ByteCodeExecutor::executeCreateConstantTypeRange() {
1622 LDBG() <<
"Executing CreateConstantTypeRange:";
1623 unsigned memIndex = read();
1624 unsigned rangeIndex = read();
1625 ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>());
1627 LDBG() <<
" * Types: " << typesAttr;
1628 assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex,
1632 void ByteCodeExecutor::executeCreateOperation(
PatternRewriter &rewriter,
1634 LDBG() <<
"Executing CreateOperation:";
1636 unsigned memIndex = read();
1638 readList(state.operands);
1639 for (
unsigned i = 0, e = read(); i != e; ++i) {
1640 StringAttr name = read<StringAttr>();
1642 state.addAttribute(name, attr);
1647 unsigned numResults = read();
1649 InferTypeOpInterface::Concept *inferInterface =
1650 state.name.getInterface<InferTypeOpInterface>();
1651 assert(inferInterface &&
1652 "expected operation to provide InferTypeOpInterface");
1655 if (
failed(inferInterface->inferReturnTypes(
1656 state.getContext(), state.location, state.operands,
1657 state.attributes.getDictionary(state.getContext()),
1658 state.getRawProperties(), state.regions, state.types)))
1662 for (
unsigned i = 0; i != numResults; ++i) {
1663 if (read<PDLValue::Kind>() == PDLValue::Kind::Type) {
1664 state.types.push_back(read<Type>());
1666 TypeRange *resultTypes = read<TypeRange *>();
1667 state.types.append(resultTypes->begin(), resultTypes->end());
1673 memory[memIndex] = resultOp;
1675 LDBG() <<
" * Attributes: "
1676 << state.attributes.getDictionary(state.getContext())
1677 <<
"\n * Operands: " << llvm::interleaved(state.operands)
1678 <<
"\n * Result Types: " << llvm::interleaved(state.types)
1679 <<
"\n * Result: " << *resultOp;
1682 template <
typename T>
1683 void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) {
1684 LDBG() <<
"Executing CreateDynamic" << type <<
"Range:";
1685 unsigned memIndex = read();
1686 unsigned rangeIndex = read();
1690 LDBG() <<
" * " << type <<
"s: " << llvm::interleaved(values);
1692 assignRangeToMemory(values, memIndex, rangeIndex);
1696 LDBG() <<
"Executing EraseOp:";
1699 LDBG() <<
" * Operation: " << *op;
1703 template <
typename T,
typename Range, PDLValue::Kind kind>
1704 void ByteCodeExecutor::executeExtract() {
1705 LDBG() <<
"Executing Extract" <<
kind <<
":";
1706 Range *range = read<Range *>();
1707 unsigned index = read<uint32_t>();
1708 unsigned memIndex = read();
1711 memory[memIndex] =
nullptr;
1715 T result = index < range->
size() ? (*range)[index] : T();
1716 LDBG() <<
" * " <<
kind <<
"s(" << range->
size() <<
")";
1717 LDBG() <<
" * Index: " << index;
1718 LDBG() <<
" * Result: " << result;
1719 storeToMemory(memIndex, result);
1722 void ByteCodeExecutor::executeFinalize() { LDBG() <<
"Executing Finalize"; }
1724 void ByteCodeExecutor::executeForEach() {
1725 LDBG() <<
"Executing ForEach:";
1726 const ByteCodeField *prevCodeIt = getPrevCodeIt();
1727 unsigned rangeIndex = read();
1728 unsigned memIndex = read();
1729 const void *value =
nullptr;
1731 switch (read<PDLValue::Kind>()) {
1732 case PDLValue::Kind::Operation: {
1733 unsigned &index = loopIndex[read()];
1735 assert(index <= array.size() &&
"iterated past the end");
1736 if (index < array.size()) {
1737 LDBG() <<
" * Result: " << array[index];
1738 value = array[index];
1742 LDBG() <<
" * Done";
1744 selectJump(
size_t(0));
1748 llvm_unreachable(
"unexpected `ForEach` value kind");
1752 memory[memIndex] = value;
1753 pushCodeIt(prevCodeIt);
1756 read<ByteCodeAddr>();
1759 void ByteCodeExecutor::executeGetAttribute() {
1760 LDBG() <<
"Executing GetAttribute:";
1761 unsigned memIndex = read();
1763 StringAttr attrName = read<StringAttr>();
1766 LDBG() <<
" * Operation: " << *op <<
"\n * Attribute: " << attrName
1767 <<
"\n * Result: " << attr;
1771 void ByteCodeExecutor::executeGetAttributeType() {
1772 LDBG() <<
"Executing GetAttributeType:";
1773 unsigned memIndex = read();
1776 if (
auto typedAttr = dyn_cast<TypedAttr>(attr))
1777 type = typedAttr.getType();
1779 LDBG() <<
" * Attribute: " << attr <<
"\n * Result: " << type;
1783 void ByteCodeExecutor::executeGetDefiningOp() {
1784 LDBG() <<
"Executing GetDefiningOp:";
1785 unsigned memIndex = read();
1787 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1788 Value value = read<Value>();
1791 LDBG() <<
" * Value: " << value;
1794 if (values && !values->empty()) {
1795 op = values->front().getDefiningOp();
1797 LDBG() <<
" * Values: " << values;
1800 LDBG() <<
" * Result: " << op;
1801 memory[memIndex] = op;
1804 void ByteCodeExecutor::executeGetOperand(
unsigned index) {
1806 unsigned memIndex = read();
1810 LDBG() <<
" * Operation: " << *op <<
"\n * Index: " << index
1811 <<
"\n * Result: " << operand;
1818 template <
template <
typename>
class AttrSizedSegmentsT,
typename RangeT>
1821 ByteCodeField rangeIndex, StringRef attrSizedSegments,
1826 LDBG() <<
" * Getting all values";
1830 }
else if (op->
hasTrait<AttrSizedSegmentsT>()) {
1831 LDBG() <<
" * Extracting values from `" << attrSizedSegments <<
"`";
1834 if (!segmentAttr || segmentAttr.asArrayRef().size() <= index)
1838 unsigned startIndex = llvm::sum_of(segments.take_front(index));
1839 values = values.slice(startIndex, *std::next(segments.begin(), index));
1841 LDBG() <<
" * Extracting range[" << startIndex <<
", "
1842 << *std::next(segments.begin(), index) <<
"]";
1848 }
else if (values.size() >= index) {
1849 LDBG() <<
" * Treating values as trailing variadic range";
1850 values = values.drop_front(index);
1859 valueRangeMemory[rangeIndex] = values;
1860 return &valueRangeMemory[rangeIndex];
1864 return values.size() != 1 ? nullptr : values.front().getAsOpaquePointer();
1867 void ByteCodeExecutor::executeGetOperands() {
1868 LDBG() <<
"Executing GetOperands:";
1869 unsigned index = read<uint32_t>();
1871 ByteCodeField rangeIndex = read();
1873 void *result = executeGetOperandsResults<OpTrait::AttrSizedOperandSegments>(
1874 op->
getOperands(), op, index, rangeIndex,
"operandSegmentSizes",
1877 LDBG() <<
" * Invalid operand range";
1878 memory[read()] = result;
1881 void ByteCodeExecutor::executeGetResult(
unsigned index) {
1883 unsigned memIndex = read();
1887 LDBG() <<
" * Operation: " << *op <<
"\n * Index: " << index
1888 <<
"\n * Result: " << result;
1892 void ByteCodeExecutor::executeGetResults() {
1893 LDBG() <<
"Executing GetResults:";
1894 unsigned index = read<uint32_t>();
1896 ByteCodeField rangeIndex = read();
1898 void *result = executeGetOperandsResults<OpTrait::AttrSizedResultSegments>(
1899 op->
getResults(), op, index, rangeIndex,
"resultSegmentSizes",
1902 LDBG() <<
" * Invalid result range";
1903 memory[read()] = result;
1906 void ByteCodeExecutor::executeGetUsers() {
1907 LDBG() <<
"Executing GetUsers:";
1908 unsigned memIndex = read();
1909 unsigned rangeIndex = read();
1910 OwningOpRange &range = opRangeMemory[rangeIndex];
1911 memory[memIndex] = ⦥
1913 range = OwningOpRange();
1914 if (read<PDLValue::Kind>() == PDLValue::Kind::Value) {
1916 Value value = read<Value>();
1919 LDBG() <<
" * Value: " << value;
1929 LDBG() <<
" * Values (" << values->size()
1930 <<
"): " << llvm::interleaved(*values);
1934 for (
Value value : *values)
1936 range = OwningOpRange(users.size());
1940 LDBG() <<
" * Result: " << range.size() <<
" operations";
1943 void ByteCodeExecutor::executeGetValueType() {
1944 LDBG() <<
"Executing GetValueType:";
1945 unsigned memIndex = read();
1946 Value value = read<Value>();
1949 LDBG() <<
" * Value: " << value <<
"\n * Result: " << type;
1953 void ByteCodeExecutor::executeGetValueRangeTypes() {
1954 LDBG() <<
"Executing GetValueRangeTypes:";
1955 unsigned memIndex = read();
1956 unsigned rangeIndex = read();
1959 LDBG() <<
" * Values: <NULL>";
1960 memory[memIndex] =
nullptr;
1964 LDBG() <<
" * Values (" << values->size()
1965 <<
"): " << llvm::interleaved(*values)
1966 <<
"\n * Result: " << llvm::interleaved(values->
getType());
1967 typeRangeMemory[rangeIndex] = values->
getType();
1968 memory[memIndex] = &typeRangeMemory[rangeIndex];
1971 void ByteCodeExecutor::executeIsNotNull() {
1972 LDBG() <<
"Executing IsNotNull:";
1973 const void *value = read<const void *>();
1975 LDBG() <<
" * Value: " << value;
1976 selectJump(value !=
nullptr);
1979 void ByteCodeExecutor::executeRecordMatch(
1982 LDBG() <<
"Executing RecordMatch:";
1983 unsigned patternIndex = read();
1985 const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1990 LDBG() <<
" * Benefit: Impossible To Match";
1999 unsigned numMatchLocs = read();
2001 matchLocs.reserve(numMatchLocs);
2002 for (
unsigned i = 0; i != numMatchLocs; ++i)
2003 matchLocs.push_back(read<Operation *>()->getLoc());
2006 LDBG() <<
" * Benefit: " << benefit.
getBenefit();
2007 LDBG() <<
" * Location: " << matchLoc;
2008 matches.emplace_back(matchLoc,
patterns[patternIndex], benefit);
2014 unsigned numInputs = read();
2015 match.values.reserve(numInputs);
2016 match.typeRangeValues.reserve(numInputs);
2017 match.valueRangeValues.reserve(numInputs);
2018 for (
unsigned i = 0; i < numInputs; ++i) {
2019 switch (read<PDLValue::Kind>()) {
2020 case PDLValue::Kind::TypeRange:
2021 match.typeRangeValues.push_back(*read<TypeRange *>());
2022 match.values.push_back(&match.typeRangeValues.back());
2024 case PDLValue::Kind::ValueRange:
2025 match.valueRangeValues.push_back(*read<ValueRange *>());
2026 match.values.push_back(&match.valueRangeValues.back());
2029 match.values.push_back(read<const void *>());
2037 LDBG() <<
"Executing ReplaceOp:";
2042 LDBG() <<
" * Operation: " << *op
2043 <<
"\n * Values: " << llvm::interleaved(args);
2047 void ByteCodeExecutor::executeSwitchAttribute() {
2048 LDBG() <<
"Executing SwitchAttribute:";
2050 ArrayAttr cases = read<ArrayAttr>();
2051 handleSwitch(value, cases);
2054 void ByteCodeExecutor::executeSwitchOperandCount() {
2055 LDBG() <<
"Executing SwitchOperandCount:";
2057 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2059 LDBG() <<
" * Operation: " << *op;
2063 void ByteCodeExecutor::executeSwitchOperationName() {
2064 LDBG() <<
"Executing SwitchOperationName:";
2066 size_t caseCount = read();
2072 const ByteCodeField *prevCodeIt = curCodeIt;
2073 LDBG() <<
" * Value: " << value <<
"\n * Cases: "
2074 << llvm::interleaved(
2075 llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](
size_t) {
2076 return read<OperationName>();
2078 curCodeIt = prevCodeIt;
2082 for (
size_t i = 0; i != caseCount; ++i) {
2083 if (read<OperationName>() == value) {
2084 curCodeIt += (caseCount - i - 1);
2085 return selectJump(i + 1);
2088 selectJump(
size_t(0));
2091 void ByteCodeExecutor::executeSwitchResultCount() {
2092 LDBG() <<
"Executing SwitchResultCount:";
2094 auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
2096 LDBG() <<
" * Operation: " << *op;
2100 void ByteCodeExecutor::executeSwitchType() {
2101 LDBG() <<
"Executing SwitchType:";
2102 Type value = read<Type>();
2103 auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
2104 handleSwitch(value, cases);
2107 void ByteCodeExecutor::executeSwitchTypes() {
2108 LDBG() <<
"Executing SwitchTypes:";
2110 auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>();
2112 LDBG() <<
"Types: <NULL>";
2113 return selectJump(
size_t(0));
2115 handleSwitch(*value, cases, [](ArrayAttr caseValue,
const TypeRange &value) {
2116 return value == caseValue.getAsValueRange<TypeAttr>();
2123 std::optional<Location> mainRewriteLoc) {
2126 LDBG() << readInline<Location>();
2128 OpCode opCode =
static_cast<OpCode
>(read());
2130 case ApplyConstraint:
2131 executeApplyConstraint(rewriter);
2134 if (
failed(executeApplyRewrite(rewriter)))
2140 case AreRangesEqual:
2141 executeAreRangesEqual();
2146 case CheckOperandCount:
2147 executeCheckOperandCount();
2149 case CheckOperationName:
2150 executeCheckOperationName();
2152 case CheckResultCount:
2153 executeCheckResultCount();
2156 executeCheckTypes();
2161 case CreateConstantTypeRange:
2162 executeCreateConstantTypeRange();
2164 case CreateOperation:
2165 executeCreateOperation(rewriter, *mainRewriteLoc);
2167 case CreateDynamicTypeRange:
2168 executeDynamicCreateRange<Type>(
"Type");
2170 case CreateDynamicValueRange:
2171 executeDynamicCreateRange<Value>(
"Value");
2174 executeEraseOp(rewriter);
2177 executeExtract<Operation *, OwningOpRange, PDLValue::Kind::Operation>();
2180 executeExtract<Type, TypeRange, PDLValue::Kind::Type>();
2183 executeExtract<Value, ValueRange, PDLValue::Kind::Value>();
2193 executeGetAttribute();
2195 case GetAttributeType:
2196 executeGetAttributeType();
2199 executeGetDefiningOp();
2205 unsigned index = opCode - GetOperand0;
2206 LDBG() <<
"Executing GetOperand" << index <<
":";
2207 executeGetOperand(index);
2211 LDBG() <<
"Executing GetOperandN:";
2212 executeGetOperand(read<uint32_t>());
2215 executeGetOperands();
2221 unsigned index = opCode - GetResult0;
2222 LDBG() <<
"Executing GetResult" << index <<
":";
2223 executeGetResult(index);
2227 LDBG() <<
"Executing GetResultN:";
2228 executeGetResult(read<uint32_t>());
2231 executeGetResults();
2237 executeGetValueType();
2239 case GetValueRangeTypes:
2240 executeGetValueRangeTypes();
2247 "expected matches to be provided when executing the matcher");
2248 executeRecordMatch(rewriter, *matches);
2251 executeReplaceOp(rewriter);
2253 case SwitchAttribute:
2254 executeSwitchAttribute();
2256 case SwitchOperandCount:
2257 executeSwitchOperandCount();
2259 case SwitchOperationName:
2260 executeSwitchOperationName();
2262 case SwitchResultCount:
2263 executeSwitchResultCount();
2266 executeSwitchType();
2269 executeSwitchTypes();
2280 state.memory[0] = op;
2283 ByteCodeExecutor executor(
2284 matcherByteCode.data(), state.memory, state.opRangeMemory,
2285 state.typeRangeMemory, state.allocatedTypeRangeMemory,
2286 state.valueRangeMemory, state.allocatedValueRangeMemory, state.loopIndex,
2287 uniquedData, matcherByteCode, state.currentPatternBenefits,
patterns,
2288 constraintFunctions, rewriteFunctions);
2289 LogicalResult executeResult = executor.execute(rewriter, &matches);
2290 (void)executeResult;
2291 assert(succeeded(executeResult) &&
"unexpected matcher execution failure");
2294 llvm::stable_sort(matches,
2295 [](
const MatchResult &lhs,
const MatchResult &rhs) {
2296 return lhs.benefit > rhs.benefit;
2301 const MatchResult &match,
2303 auto *configSet =
match.pattern->getConfigSet();
2305 configSet->notifyRewriteBegin(rewriter);
2311 ByteCodeExecutor executor(
2312 &rewriterByteCode[
match.pattern->getRewriterAddr()], state.memory,
2313 state.opRangeMemory, state.typeRangeMemory,
2314 state.allocatedTypeRangeMemory, state.valueRangeMemory,
2315 state.allocatedValueRangeMemory, state.loopIndex, uniquedData,
2316 rewriterByteCode, state.currentPatternBenefits,
patterns,
2317 constraintFunctions, rewriteFunctions);
2318 LogicalResult result =
2319 executor.execute(rewriter,
nullptr,
match.location);
2322 configSet->notifyRewriteEnd(rewriter);
2331 LDBG() <<
" and rollback is not supported - aborting";
2332 llvm::report_fatal_error(
2333 "Native PDL Rewrite failed, but the pattern "
2334 "rewriter doesn't support recovery. Failable pattern rewrites should "
2335 "not be used with pattern rewriters that do not support them.");
static void * executeGetOperandsResults(RangeT values, Operation *op, unsigned index, ByteCodeField rangeIndex, StringRef attrSizedSegments, MutableArrayRef< ValueRange > valueRangeMemory)
This function is the internal implementation of GetResults and GetOperands that provides support for ...
static constexpr ByteCodeField kInferTypesMarker
A marker used to indicate if an operation should infer types.
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
union mlir::linalg::@1253::ArityGroupAndKind::Kind kind
static const mlir::GenInfo * generator
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void processValue(Value value, LiveMap &liveMap)
Attributes are known-constant values of operations.
const void * getAsOpaquePointer() const
Get an opaque pointer to the attribute.
This class represents an argument of a Block.
Block represents an ordered list of Operations.
Region * getParent() const
Provide a 'getParent' method for ilist_node_with_parent methods.
BlockArgListType getArguments()
bool isEntryBlock()
Return if this block is the entry block in the parent region.
Location getFusedLoc(ArrayRef< Location > locs, Attribute metadata=Attribute())
This class represents liveness information on block level.
const ValueSetT & in() const
Returns all values that are live at the beginning of the block (unordered).
Operation * getEndOperation(Value value, Operation *startOperation) const
Gets the end operation for the given value using the start operation provided (must be referenced in ...
Represents an analysis for computing liveness information from a given top-level operation.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This is a value defined by a result of an operation.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Value getOperand(unsigned idx)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
unsigned getNumResults()
Return the number of results held by this operation.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
bool isImpossibleToMatch() const
unsigned short getBenefit() const
If the corresponding pattern can match, return its benefit. If the.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual bool canRecoverFromRewriteFailure() const
A hook used to indicate if the pattern rewriter can recover from failure during the rewrite stage of ...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class implements the result iterators for the Operation class.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class implements the successor iterators for Block.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
const void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
auto walk(WalkFns &&...walkFns)
Walk this type and all attibutes/types nested within using the provided walk functions.
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
user_iterator user_begin() const
void * getAsOpaquePointer() const
Methods for supporting PointerLikeTypeTraits.
user_iterator user_end() const
user_range getUsers() const
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit)
Set the new benefit for a bytecode pattern.
void cleanupAfterMatchAndRewrite()
Cleanup any allocated state after a full match/rewrite has been completed.
void match(Operation *op, PatternRewriter &rewriter, SmallVectorImpl< MatchResult > &matches, PDLByteCodeMutableState &state) const
LogicalResult rewrite(PatternRewriter &rewriter, const MatchResult &match, PDLByteCodeMutableState &state) const
void walk(Operation *op, function_ref< void(Region *)> callback, WalkOrder order)
Walk all of the regions, blocks, or operations nested under (and including) the given operation.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Kind
An enumeration of the kinds of predicates.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...