22 using namespace tensor;
 
   29 struct ExtractSliceOpReplacementInterface
 
   30     : 
public transform::FindPayloadReplacementOpInterface::ExternalModel<
 
   31           ExtractSliceOpReplacementInterface, tensor::ExtractSliceOp> {
 
   33     auto extractSliceOp = cast<tensor::ExtractSliceOp>(op);
 
   36     return {extractSliceOp.getSource()};
 
   40 struct InsertSliceOpReplacementInterface
 
   41     : 
public transform::FindPayloadReplacementOpInterface::ExternalModel<
 
   42           InsertSliceOpReplacementInterface, tensor::InsertSliceOp> {
 
   44     auto insertSliceOp = cast<tensor::InsertSliceOp>(op);
 
   47     return {insertSliceOp.getSource()};
 
   51 struct ReshapeOpReplacementInterface
 
   52     : 
public transform::FindPayloadReplacementOpInterface::ExternalModel<
 
   53           ReshapeOpReplacementInterface, tensor::ReshapeOp> {
 
   55     auto reshapeOp = cast<tensor::ReshapeOp>(op);
 
   56     return {reshapeOp.getSource()};
 
   60 template <
typename ConcreteOp>
 
   61 struct ReassociativeReshapeOpReplacementInterface
 
   62     : 
public transform::FindPayloadReplacementOpInterface::ExternalModel<
 
   63           ReassociativeReshapeOpReplacementInterface<ConcreteOp>, ConcreteOp> {
 
   65     auto reshapeOp = cast<ConcreteOp>(op);
 
   66     return {reshapeOp.getSrc()};
 
   74     CollapseShapeOp::attachInterface<
 
   75         ReassociativeReshapeOpReplacementInterface<CollapseShapeOp>>(*ctx);
 
   76     ExpandShapeOp::attachInterface<
 
   77         ReassociativeReshapeOpReplacementInterface<ExpandShapeOp>>(*ctx);
 
   78     ExtractSliceOp::attachInterface<ExtractSliceOpReplacementInterface>(*ctx);
 
   79     InsertSliceOp::attachInterface<InsertSliceOpReplacementInterface>(*ctx);
 
   80     ReshapeOp::attachInterface<ReshapeOpReplacementInterface>(*ctx);
 
   88 void transform::ApplyDecomposeTensorConcatPatternsOp::populatePatterns(
 
   93 void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp::
 
   98 void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns(
 
  103 void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns(
 
  108 void transform::ApplyFoldTensorSubsetOpsIntoVectorTransfersPatternsOp::
 
  113 void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp::
 
  118 void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns(
 
  123 void transform::ApplyBubbleUpExtractSlicePatternsOp::populatePatterns(
 
  128 void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns(
 
  131     Operation *producer = fusedOperand->get().getDefiningOp();
 
  132     return producer && producer->
hasOneUse();
 
  150 void transform::TypeConversionCastShapeDynamicDimsOp::
 
  152   bool ignoreDynamicInfo = getIgnoreDynamicInfo();
 
  157     if (inputs.size() != 1) {
 
  160     Value input = inputs[0];
 
  161     if (!ignoreDynamicInfo &&
 
  165     if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
 
  168     return tensor::CastOp::create(builder, loc, resultType, input).getResult();
 
  173     if (inputs.size() != 1) {
 
  176     Value input = inputs[0];
 
  177     if (!tensor::CastOp::areCastCompatible(input.getType(), resultType)) {
 
  180     return tensor::CastOp::create(builder, loc, resultType, input).getResult();
 
  195   for (uint64_t i = 0, e = getNumLoops(); i < e; ++i) {
 
  199                                          << 
"could not find " << i
 
  200                                          << 
"-th enclosing loop";
 
  201       diag.attachNote(target->
getLoc()) << 
"target op";
 
  204     ivs.push_back(cast<scf::ForOp>(nextOp).getInductionVar());
 
  208   FailureOr<Value> replacement = failure();
 
  209   if (
auto padOp = dyn_cast<tensor::PadOp>(target)) {
 
  211   } 
else if (
auto emptyOp = dyn_cast<tensor::EmptyOp>(target)) {
 
  215                                        << 
"unsupported target op";
 
  216     diag.attachNote(target->
getLoc()) << 
"target op";
 
  219   if (
failed(replacement)) {
 
  221         emitSilenceableError() << 
"could not make target op loop-independent";
 
  222     diag.attachNote(target->
getLoc()) << 
"target op";
 
  225   rewriter.
replaceOp(target, *replacement);
 
  226   results.
push_back(replacement->getDefiningOp());
 
  235 class TensorTransformDialectExtension
 
  237           TensorTransformDialectExtension> {
 
  244     declareGeneratedDialect<affine::AffineDialect>();
 
  245     declareGeneratedDialect<tensor::TensorDialect>();
 
  247     registerTransformOps<
 
  249 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" 
  255 #define GET_OP_CLASSES 
  256 #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" 
static std::string diag(const llvm::Value &value)
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
void addExtensions()
Add the given extensions to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
bool hasOneUse()
Returns true if this operation has exactly one use.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void addSourceMaterialization(FnT &&callback)
All of the following materializations require function objects that are convertible to the following ...
void addTargetMaterialization(FnT &&callback)
This method registers a materialization that will be called when converting a value to a target type ...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
void registerTransformDialectExtension(DialectRegistry ®istry)
void populateMergeConsecutiveInsertExtractSlicePatterns(RewritePatternSet &patterns)
Collects patterns to merge consecutive tensor.insert_slice/extract_slice into one.
bool isCastLikeInsertSliceOp(InsertSliceOp op)
A tensor.insert_slice is a cast-like operation if it merely rank-extends the source tensor or inserts...
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into consumer load/store ops into patterns.
void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that fold tensor.expand_shape and tensor.collapse_shape into other o...
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns)
Appends patterns that are used to bubble up tensor.extract slice op above its producer.
void registerFindPayloadReplacementOpInterfaceExternalModels(DialectRegistry ®istry)
bool isCastLikeExtractSliceOp(ExtractSliceOp op)
A tensor.extract_slice is a cast-like operation if it merely rank-reduces unit dimensions of the sour...
void populateDropRedundantInsertSliceRankExpansionPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that drop redundant tensor.insert_slice rank expansions.
FailureOr< Value > buildIndependentOp(OpBuilder &b, tensor::PadOp padOp, ValueRange independencies)
Build a new tensor::PadOp with low/high padding that is independent of all given independencies.
bool preservesStaticInformation(Type source, Type target)
Returns true if target is a ranked tensor type that preserves static information available in the sou...
std::function< bool(OpOperand *)> ControlFoldFn
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns, const ControlFoldFn &controlFn)
Populates patterns with patterns that replace tensor ops (such as tensor.generate) with constants whe...
void populateFoldTensorSubsetIntoVectorTransferPatterns(RewritePatternSet &patterns)
Appends patterns for folding tensor subset ops into vector transfer ops.
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns