20 using namespace tensor;
27 struct ExtractSliceOpReplacementInterface
28 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
29 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
31 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
34 return {extractSliceOp.getSource()};
38 struct InsertSliceOpReplacementInterface
39 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
40 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
42 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
45 return {insertSliceOp.getSource()};
49 struct ReshapeOpReplacementInterface
50 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
51 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
53 auto reshapeOp = cast<tensor::ReshapeOp>(op);
54 return {reshapeOp.getSource()};
58 template <
typename ConcreteOp>
59 struct ReassociativeReshapeOpReplacementInterface
60 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
61 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
63 auto reshapeOp = cast<ConcreteOp>(op);
64 return {reshapeOp.getSrc()};
72 CollapseShapeOp::attachInterface<
73 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
74 ExpandShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
76 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
77 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
78 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
86 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
91 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
96 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
101 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
106 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
111 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
116 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
121 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
126 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
142 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
146 <<
"could not find " << i
147 <<
"-th enclosing loop";
148 diag.attachNote(target->
getLoc()) <<
"target op";
151 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
156 if (
auto padOp = dyn_cast<tensor::PadOp>(target)) {
158 }
else if (
auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
162 <<
"unsupported target op";
163 diag.attachNote(target->
getLoc()) <<
"target op";
166 if (
failed(replacement)) {
168 emitSilenceableError() <<
"could not make target op loop-independent";
169 diag.attachNote(target->
getLoc()) <<
"target op";
172 rewriter.
replaceOp(target, *replacement);
173 results.
push_back(replacement->getDefiningOp());
182 class TensorTransformDialectExtension
184 TensorTransformDialectExtension> {
189 declareGeneratedDialect<affine::AffineDialect>();
190 declareGeneratedDialect<tensor::TensorDialect>();
192 registerTransformOps<
194 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
200 #define GET_OP_CLASSES
201 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
MLIRContext is the top-level object for a collection of MLIR operations.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
void registerTransformDialectExtension(DialectRegistry ®istry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry ®istry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.