22 using namespace tensor;
29 struct ExtractSliceOpReplacementInterface
30 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
31 ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
33 auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
36 return {extractSliceOp.getSource()};
40 struct InsertSliceOpReplacementInterface
41 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
42 InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
44 auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
47 return {insertSliceOp.getSource()};
51 struct ReshapeOpReplacementInterface
52 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
53 ReshapeOpReplacementInterface, tensor::ReshapeOp> {
55 auto reshapeOp = cast<tensor::ReshapeOp>(op);
56 return {reshapeOp.getSource()};
60 template <
typename ConcreteOp>
61 struct ReassociativeReshapeOpReplacementInterface
62 :
public transform::FindPayloadReplacementOpInterface::ExternalModel<
63 ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
65 auto reshapeOp = cast<ConcreteOp>(op);
66 return {reshapeOp.getSrc()};
74 CollapseShapeOp::attachInterface<
75 ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
76 ExpandShapeOp::attachInterface<
77 ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
78 ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
79 InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
80 ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
103 void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns(
108 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
113 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
118 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
123 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
137 void transform::TypeConversionCastShapeDynamicDimsOp::
139 bool ignoreDynamicInfo = getIgnoreDynamicInfo();
143 Location loc) -> std::optional<Value> {
144 if (inputs.size() != 1) {
147 Value input = inputs[0];
148 if (!ignoreDynamicInfo &&
152 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
155 return builder.
create<tensor::CastOp>(loc, resultType, input).
getResult();
159 Location loc) -> std::optional<Value> {
160 if (inputs.size() != 1) {
163 Value input = inputs[0];
164 if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
167 return builder.
create<tensor::CastOp>(loc, resultType, input).
getResult();
182 for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
186 <<
"could not find " << i
187 <<
"-th enclosing loop";
188 diag.attachNote(target->
getLoc()) <<
"target op";
191 ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
196 if (
auto padOp = dyn_cast<tensor::PadOp>(target)) {
198 }
else if (
auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
202 <<
"unsupported target op";
203 diag.attachNote(target->
getLoc()) <<
"target op";
206 if (
failed(replacement)) {
208 emitSilenceableError() <<
"could not make target op loop-independent";
209 diag.attachNote(target->
getLoc()) <<
"target op";
212 rewriter.
replaceOp(target, *replacement);
213 results.
push_back(replacement->getDefiningOp());
222 class TensorTransformDialectExtension
224 TensorTransformDialectExtension> {
229 declareGeneratedDialect<affine::AffineDialect>();
230 declareGeneratedDialect<tensor::TensorDialect>();
232 registerTransformOps<
234 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
240 #define GET_OP_CLASSES
241 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc"
static std::string diag(const llvm::Value &value)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtension(std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addSourceMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a legal type to an illega...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting type from an illegal,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold operations like tensor.pad and tensor.extract_slice into t...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with tensor.
void registerTransformDialectExtension(DialectRegistry ®istry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry ®istry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.