MLIR  16.0.0git
LinalgTransformOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/Parser/Parser.h"
23 #include "llvm/ADT/StringSet.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 using namespace mlir::transform;
28 
29 /// Extracts a vector of unsigned from an array attribute. Asserts if the
30 /// attribute contains values other than intergers. May truncate.
31 static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
32  SmallVector<unsigned> result;
33  result.reserve(attr.size());
34  for (APInt value : attr.getAsValueRange<IntegerAttr>())
35  result.push_back(value.getZExtValue());
36  return result;
37 }
38 
39 namespace {
40 /// A simple pattern rewriter that implements no special logic.
41 class SimpleRewriter : public PatternRewriter {
42 public:
43  SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
44 };
45 } // namespace
46 
47 /// Attempts to apply the pattern specified as template argument to the given
48 /// operation. The pattern is expected to have a `returningMatchAndRewrite`
49 /// function that returns the "main" result or failure. Returns failure if the
50 /// pattern failed to apply. Extra arguments are forwarded to the pattern
51 /// constructor.
52 template <typename PatternTy, typename... Args>
53 static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
54  // Check if the given operation has the type expected by the pattern.
55  using OpTy = typename llvm::function_traits<
56  decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
57  auto op = dyn_cast<OpTy>(operation);
58  if (!op)
59  return failure();
60 
61  // Apply the pattern directly to the op.
62  PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
63  SimpleRewriter rewriter(operation->getContext());
64  rewriter.setInsertionPoint(operation);
65  auto result = pattern.returningMatchAndRewrite(op, rewriter);
66  if (failed(result))
67  return failure();
68  return cast<LinalgOp>(result->getOperation());
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // DecomposeOp
73 //===----------------------------------------------------------------------===//
74 
76 transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
79  FailureOr<LinalgOp> windowed =
80  tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
81  if (succeeded(windowed)) {
82  results.push_back(*windowed);
84  }
85  FailureOr<LinalgOp> depthwise =
86  tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
87  if (succeeded(depthwise)) {
88  results.push_back(*depthwise);
90  }
91  results.assign(1, nullptr);
92  return emitDefaultSilenceableFailure(target);
93 }
94 
95 //===----------------------------------------------------------------------===//
96 // FuseOp
97 //===----------------------------------------------------------------------===//
98 
99 /// Apply a tiling transformation to all payload ops and store both the
100 /// tiled operation as well as the created tile loops.
101 static LogicalResult
103  unsigned numLoops,
104  transform::TransformResults &transformResults,
105  function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
106  SmallVector<Operation *> tiledLinalgOps;
107  SmallVector<SmallVector<Operation *>> loopOps(numLoops);
108  for (unsigned int i = 0; i < numLoops; ++i)
109  loopOps[i].reserve(payloadOps.size());
110 
111  for (Operation *target : payloadOps) {
112  auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
113  if (!linalgOp)
114  return transformOp->emitError("only LinalgOps are supported");
115 
116  FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
117  if (failed(tiled))
118  return failure();
119 
120  tiledLinalgOps.push_back(tiled->op);
121  if (tiled->loops.size() != numLoops)
122  // Not enough loops were generated. This usually means that the input size
123  // was smaller than the tiling size.
124  // TODO: LinalgTilingPattern should return failure().
125  return failure();
126  for (unsigned int i = 0; i < numLoops; ++i)
127  loopOps[i].push_back(tiled->loops[i]);
128  }
129 
130  transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
131  for (unsigned int i = 0; i < numLoops; ++i)
132  transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
133  return success();
134 }
135 
136 /// Parse a tiling-like operation that returns the tiled op as well as the
137 /// created tile loops. The function counts the non-zero tile sizes to compute
138 /// the number of results.
140  StringRef sizesAttrName) {
141  OpAsmParser::UnresolvedOperand targetOperand;
142  SMLoc opLoc = parser.getCurrentLocation();
143  if (parser.parseOperand(targetOperand) ||
144  parser.parseOptionalAttrDict(result.attributes))
145  return failure();
146  Attribute sizesAttr = result.attributes.get(sizesAttrName);
147  if (!sizesAttr)
148  return parser.emitError(opLoc)
149  << "expected '" << sizesAttrName << "' attribute";
150  auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
151  if (!sizesArrayAttr)
152  return parser.emitError(opLoc)
153  << "'" << sizesAttrName << "' attribute must be an array";
154  Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
155  size_t numExpectedLoops =
156  sizesArrayAttr.size() -
157  llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
158  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
159  if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
160  return failure();
161  return success();
162 }
163 
165 transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
167  LinalgTilingAndFusionOptions fusionOptions;
168  fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
169  fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
170 
172  getOperation(), state.getPayloadOps(getTarget()),
173  fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
174  transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
175  LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
176  SimpleRewriter rewriter(getContext());
177  rewriter.setInsertionPoint(linalgOp);
178  FailureOr<TileLoopNest> tileLoopNest =
179  pattern.returningMatchAndRewrite(linalgOp, rewriter);
180  if (failed(tileLoopNest))
181  return failure();
182 
183  TiledLinalgOp tiledLinalgOp;
184  tiledLinalgOp.op = tileLoopNest->getRootOp();
185  tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
186  tileLoopNest->getLoopOps().end()};
187  return tiledLinalgOp;
188  });
189  return DiagnosedSilenceableFailure(result);
190 }
191 
192 ParseResult transform::FuseOp::parse(OpAsmParser &parser,
193  OperationState &result) {
194  return parseTileLikeOp(
195  parser, result,
196  transform::FuseOp::getTileSizesAttrName(result.name).getValue());
197 }
198 
200  p << ' ';
201  p << getTarget();
202  p.printOptionalAttrDict((*this)->getAttrs());
203 }
204 
206  SmallVector<int64_t> permutation =
207  extractFromI64ArrayAttr(getTileInterchange());
208  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
209  if (!std::is_permutation(sequence.begin(), sequence.end(),
210  permutation.begin(), permutation.end())) {
211  return emitOpError() << "expects interchange to be a permutation, found "
212  << getTileInterchange();
213  }
214  return success();
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // FuseIntoContainingOp
219 //===----------------------------------------------------------------------===//
220 
222  Operation *containingOp,
223  RewriterBase &rewriter) {
224  auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
225  if (!tileableProducer)
226  return failure();
227 
228  // Search the producer slices accessed within the containing operation.
229  // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
230  // evolve into an interface.
232  for (Operation *user : tileableProducer->getUsers()) {
233  auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
234  if (!sliceOp)
235  continue;
236  if (!containingOp->isProperAncestor(sliceOp))
237  continue;
238  sliceOps.push_back(sliceOp);
239  }
240 
241  // Check for a non-empty list of fusion opportunities.
242  if (sliceOps.empty())
243  return failure();
244 
245  SmallVector<Value> destinationOperands =
246  tileableProducer.getDestinationOperands(rewriter);
247 
248  // Try to fuse the producer in-place.
249  SmallVector<Operation *> fusedOps;
250  for (tensor::ExtractSliceOp sliceOp : sliceOps) {
251  OpBuilder::InsertionGuard guard(rewriter);
252  rewriter.setInsertionPoint(sliceOp);
253 
254  // Tile the producer.
255  FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
256  rewriter, /*resultNumber=*/0, destinationOperands,
257  sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), true);
258  if (failed(tiledProducer))
259  return failure();
260  fusedOps.push_back(tiledProducer->getDefiningOp());
261  }
262 
263  // Replace the extract op.
264  for (const auto &en : enumerate(sliceOps))
265  rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
266  return fusedOps;
267 }
268 
270 cloneAndFuse(Operation *producerOp, Operation *containingOp,
271  RewriterBase &rewriter) {
272  // Gather all uses inside the containing op.
274  for (OpResult result : producerOp->getOpResults())
275  for (OpOperand &use : result.getUses())
276  if (containingOp->isProperAncestor(use.getOwner()))
277  uses.push_back(&use);
278 
279  // Check for a non-empty list of fusion opportunities.
280  if (uses.empty())
281  return failure();
282 
283  // Clone and fuse inside the containing op.
284  SmallVector<Operation *> fusedOps;
285  for (OpOperand *use : uses) {
286  unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
287  OpBuilder::InsertionGuard guard(rewriter);
288  rewriter.setInsertionPoint(use->getOwner());
289  Operation *cloned = rewriter.clone(*producerOp);
290  rewriter.updateRootInPlace(
291  use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
292  fusedOps.push_back(cloned);
293  }
294 
295  return fusedOps;
296 }
297 
299 transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
300  transform::TransformState &state) {
301  SmallVector<Operation *> fusedOps;
302  ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
303  // If nothing to fuse, propagate success.
304  if (producerOps.empty()) {
305  results.set(getResult().cast<OpResult>(), SmallVector<mlir::Operation *>{});
307  }
308  for (Operation *producerOp : producerOps) {
309  if (producerOp->getNumResults() != 1) {
310  Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
311  diag << "op with != 1 results not supported";
312  return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
313  }
314  }
315  ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
316  if (containingOps.size() != 1)
318  this->emitOpError("requires exactly one containing_op handle (got ")
319  << containingOps.size() << ")");
320  Operation *containingOp = containingOps.front();
321 
322  // Helper function to find the next producer that should be fused. Take any
323  // producer that has a use inside the containing op.
324  SmallVector<Operation *> remainingProducers(producerOps.begin(),
325  producerOps.end());
326  auto getNextProducer = [&]() -> FailureOr<Operation *> {
327  for (const auto &it : enumerate(remainingProducers)) {
328  Operation *producerOp = it.value();
329  bool hasUseInContainingOp =
330  any_of(producerOp->getUsers(), [&](Operation *op) {
331  return containingOp->isProperAncestor(op);
332  });
333  // TODO: When resolving the TODO below (no duplicate ops), take an op that
334  // has no use among the remaining producers. This is a topological
335  // sorting.
336  if (hasUseInContainingOp) {
337  remainingProducers.erase(remainingProducers.begin() + it.index());
338  return producerOp;
339  }
340  }
341  return failure();
342  };
343 
344  IRRewriter rewriter(getContext());
345  while (!remainingProducers.empty()) {
346  auto nextProducer = getNextProducer();
347  if (failed(nextProducer)) {
349  diag << "could not fuse ops into container";
350  return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
351  }
352 
353  Operation *producerOp = *nextProducer;
354  // TODO: If there are multiple uses of the producer in the containing op, we
355  // currently tile/clone the op multiple times (once per use). In some cases,
356  // we can tile/clone once and reuse the value for each use. Futhermore,
357  // producers should then be traversed according to a topological sorting.
358  auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
359  if (succeeded(tiled))
360  fusedOps.append(*tiled);
361 
362  auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
363  if (succeeded(cloned))
364  fusedOps.append(*cloned);
365 
366  if (failed(tiled) && failed(cloned)) {
368  diag << "could not fuse into containing op";
369  return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
370  }
371  }
372 
373  results.set(getFusedOp().cast<OpResult>(), fusedOps);
375 }
376 
377 //===----------------------------------------------------------------------===//
378 // GeneralizeOp
379 //===----------------------------------------------------------------------===//
380 
382 transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
384  transform::TransformState &state) {
385  // Exit early if no transformation is needed.
386  if (isa<GenericOp>(target)) {
387  results.push_back(target);
389  }
390  FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
391  if (succeeded(generic)) {
392  results.push_back(generic->getOperation());
394  }
395  results.assign(1, nullptr);
396  return emitDefaultSilenceableFailure(target);
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // InterchangeOp
401 //===----------------------------------------------------------------------===//
402 
404 transform::InterchangeOp::applyToOne(linalg::GenericOp target,
406  transform::TransformState &state) {
407  SmallVector<unsigned> interchangeVector =
408  extractUIntArray(getIteratorInterchange());
409  // Exit early if no transformation is needed.
410  if (interchangeVector.empty()) {
411  results.push_back(target);
413  }
414  SimpleRewriter rewriter(target->getContext());
416  interchangeGenericOp(rewriter, target, interchangeVector);
417  if (failed(res))
419  results.push_back(res->getOperation());
421 }
422 
424  SmallVector<unsigned> permutation =
425  extractUIntArray(getIteratorInterchange());
426  auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
427  if (!std::is_permutation(sequence.begin(), sequence.end(),
428  permutation.begin(), permutation.end())) {
429  return emitOpError()
430  << "expects iterator_interchange to be a permutation, found "
431  << getIteratorInterchange();
432  }
433  return success();
434 }
435 
436 //===---------------------------------------------------------------------===//
437 // MatchOp
438 //===---------------------------------------------------------------------===//
439 
441 transform::MatchOp::apply(transform::TransformResults &results,
442  transform::TransformState &state) {
443  llvm::StringSet<> strs;
444  if (getOps().has_value())
445  strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
446  getOps()->getAsValueRange<StringAttr>().end());
447 
448  ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
449  if (payloadOps.size() != 1)
451  this->emitOpError("requires exactly one target handle"));
452 
454  auto matchFun = [&](Operation *op) {
455  if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
456  return WalkResult::advance();
457 
458  // Interfaces cannot be matched by name, just by ID.
459  // So we specifically encode the interfaces we care about for this op.
460  if (getInterface().has_value()) {
461  auto iface = getInterface().value();
462  if (iface == transform::MatchInterfaceEnum::LinalgOp &&
463  !isa<linalg::LinalgOp>(op))
464  return WalkResult::advance();
465  if (iface == transform::MatchInterfaceEnum::TilingInterface &&
466  isa<TilingInterface>(op))
467  return WalkResult::advance();
468  }
469 
470  // Check if all specified attributes match.
471  if (getOpAttrs().has_value()) {
472  DictionaryAttr opAttrs = getOpAttrs().value();
473  for (NamedAttribute attr : opAttrs) {
474  if (attr.getName() == getInterfaceAttrName() ||
475  attr.getName() == getOpsAttrName())
476  continue;
477  if (!op->hasAttr(attr.getName()))
478  return WalkResult::advance();
479  if (op->getAttr(attr.getName()) != attr.getValue())
480  return WalkResult::advance();
481  }
482  }
483 
484  // All constraints are satisfied.
485  res.push_back(op);
486  return WalkResult::advance();
487  };
488 
489  payloadOps.front()->walk(matchFun);
490  results.set(getResult().cast<OpResult>(), res);
492 }
493 
494 //===---------------------------------------------------------------------===//
495 // MultiTileSizesOp
496 //===---------------------------------------------------------------------===//
497 
498 DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
499  LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
500  OpBuilder builder(target.getContext());
501  builder.setInsertionPoint(target);
502  OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
503  OpFoldResult divisor = builder.getIndexAttr(getDivisor());
505  builder, target, getDimension(), targetSize, divisor);
506  if (failed(spec)) {
507  return emitSilenceableError() << "could not generate tile size computation";
508  }
509 
510  AffineExpr s0 = builder.getAffineSymbolExpr(0);
511  AffineExpr s1 = builder.getAffineSymbolExpr(1);
512  Operation *splitPoint =
513  makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
514  {spec->lowTileSize, spec->lowTripCount});
515  Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
516  Operation *highTileSize = spec->highTileSize.getDefiningOp();
517  assert(lowTileSize && highTileSize && splitPoint &&
518  "tile sizes are not produced by operations");
519  results.reserve(results.size() + 3);
520  results.push_back(lowTileSize);
521  results.push_back(highTileSize);
522  results.push_back(splitPoint);
524 }
525 
526 void transform::MultiTileSizesOp::getEffects(
528  onlyReadsHandle(getTarget(), effects);
529  producesHandle(getResults(), effects);
530  modifiesPayload(effects);
531 }
532 
533 //===---------------------------------------------------------------------===//
534 // PadOp
535 //===---------------------------------------------------------------------===//
536 
538 transform::PadOp::applyToOne(linalg::LinalgOp target,
540  transform::TransformState &state) {
541  // Convert the integer packing flags to booleans.
542  SmallVector<bool> packPaddings;
543  for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
544  packPaddings.push_back(static_cast<bool>(packPadding));
545 
546  // Convert the padding values to attributes.
547  SmallVector<Attribute> paddingValues;
548  for (auto const &it :
549  llvm::zip(getPaddingValues(), target->getOperandTypes())) {
550  auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
551  if (!attr) {
552  emitOpError("expects padding values to be typed attributes");
554  }
555  Type elementType = getElementTypeOrSelf(std::get<1>(it));
556  // Try to parse string attributes to obtain an attribute of element type.
557  if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
558  paddingValues.push_back(
559  parseAttribute(attr.cast<StringAttr>(), elementType));
560  if (!paddingValues.back()) {
561  auto diag = this->emitOpError("expects a padding that parses to ")
562  << elementType << ", got " << std::get<0>(it);
563  diag.attachNote(target.getLoc()) << "when applied to this op";
565  }
566  continue;
567  }
568  // Otherwise, add the attribute directly.
569  if (attr.getType() != elementType) {
570  auto diag = this->emitOpError("expects a padding value of type ")
571  << elementType << ", got " << attr;
572  diag.attachNote(target.getLoc()) << "when applied to this op";
574  }
575  paddingValues.push_back(attr);
576  }
577 
578  // Extract the transpose vectors.
579  SmallVector<SmallVector<int64_t>> transposePaddings;
580  for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
581  transposePaddings.push_back(
582  extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
583 
584  LinalgPaddingOptions paddingOptions;
585  paddingOptions.setPaddingValues(paddingValues);
586  paddingOptions.setPaddingDimensions(
587  extractFromI64ArrayAttr(getPaddingDimensions()));
588  paddingOptions.setPackPaddings(packPaddings);
589  paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
590  paddingOptions.setTransposePaddings(transposePaddings);
591 
592  FailureOr<LinalgOp> result =
593  tryApply<LinalgPaddingPattern>(target, paddingOptions);
594  if (succeeded(result)) {
595  results.push_back(result->getOperation());
597  }
598 
599  results.assign(1, nullptr);
600  return emitDefaultSilenceableFailure(target);
601 }
602 
604  SmallVector<int64_t> packPaddings =
605  extractFromI64ArrayAttr(getPackPaddings());
606  if (any_of(packPaddings, [](int64_t packPadding) {
607  return packPadding != 0 && packPadding != 1;
608  })) {
609  return emitOpError()
610  << "expects pack_paddings to contain booleans (0/1), found "
611  << getPackPaddings();
612  }
613 
614  SmallVector<int64_t> paddingDimensions =
615  extractFromI64ArrayAttr(getPaddingDimensions());
616  if (any_of(paddingDimensions,
617  [](int64_t paddingDimension) { return paddingDimension < 0; })) {
618  return emitOpError()
619  << "expects padding_dimensions to contain positive integers, found "
620  << getPaddingDimensions();
621  }
622 
623  SmallVector<int64_t> hoistPaddings =
624  extractFromI64ArrayAttr(getHoistPaddings());
625  if (any_of(hoistPaddings,
626  [](int64_t hoistPadding) { return hoistPadding < 0; })) {
627  return emitOpError()
628  << "expects hoist_paddings to contain positive integers, found "
629  << getHoistPaddings();
630  }
631 
632  ArrayAttr transposes = getTransposePaddings();
633  for (Attribute attr : transposes) {
635  auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
636  if (!std::is_permutation(sequence.begin(), sequence.end(),
637  transpose.begin(), transpose.end())) {
638  return emitOpError()
639  << "expects transpose_paddings to be a permutation, found "
640  << attr;
641  }
642  }
643  return success();
644 }
645 
646 //===----------------------------------------------------------------------===//
647 // PromoteOp
648 //===----------------------------------------------------------------------===//
649 
651 transform::PromoteOp::applyToOne(linalg::LinalgOp target,
653  transform::TransformState &state) {
654  LinalgPromotionOptions promotionOptions;
655  if (!getOperandsToPromote().empty())
656  promotionOptions = promotionOptions.setOperandsToPromote(
657  extractFromI64ArrayAttr(getOperandsToPromote()));
658  if (getUseFullTilesByDefault())
659  promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
660  getUseFullTilesByDefault());
661  if (getUseAlloca())
662  promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
663  if (!getUseFullTileBuffers().empty())
664  promotionOptions = promotionOptions.setUseFullTileBuffers(
665  llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
666  if (getAlignment().has_value())
667  promotionOptions = promotionOptions.setAlignment(*getAlignment());
668 
669  if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
670  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
671 
672  SimpleRewriter rewriter(target->getContext());
673  rewriter.setInsertionPoint(target);
674  FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
675  if (failed(res))
676  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
677  results.push_back(target);
679 }
680 
681 //===----------------------------------------------------------------------===//
682 // ScalarizeOp
683 //===----------------------------------------------------------------------===//
684 
686 transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
688  transform::TransformState &state) {
689  LinalgTilingOptions tilingOptions;
690  tilingOptions.scalarizeDynamicDims();
691  // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
692  // sizes and asserts that it is not already set.
693  SmallVector<int64_t> emptyTileSizes;
694  LinalgTilingPattern pattern(getContext(), tilingOptions);
695  SimpleRewriter rewriter(getContext());
696  rewriter.setInsertionPoint(target);
697  FailureOr<TiledLinalgOp> result =
698  pattern.returningMatchAndRewrite(target, rewriter);
699  if (failed(result))
700  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
701 
702  results.push_back(result->op);
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // SplitOp
708 //===----------------------------------------------------------------------===//
709 
710 DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
711  TransformState &state) {
712  // Collect the dynamic split points if provided.
713  ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
714  SimpleRewriter rewriter(getContext());
715  SmallVector<OpFoldResult> splitPoints;
716  splitPoints.reserve(payload.size());
717  if (getDynamicSplitPoint()) {
719  splitPoints = llvm::to_vector(llvm::map_range(
720  state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
721  if (op->getNumResults() != 1 ||
722  !op->getResult(0).getType().isIndex()) {
723  diag = emitSilenceableError()
724  << "expected dynamic split point handle to point to a "
725  "single-result index-typed op";
726  diag.attachNote(op->getLoc()) << "dynamic split point";
727  }
728  return OpFoldResult(op->getResult(0));
729  }));
730  if (!diag.succeeded())
731  return diag;
732 
733  if (splitPoints.size() != payload.size()) {
734  emitError() << "expected the dynamic split point handle to point to as "
735  "many operations ("
736  << splitPoints.size() << ") as the target handle ("
737  << payload.size() << ")";
739  }
740  } else {
741  splitPoints.resize(payload.size(),
742  rewriter.getIndexAttr(getStaticSplitPoint()));
743  }
744 
745  // Split each target operation.
746  SmallVector<Operation *> first, second;
747  for (const auto &pair : llvm::zip(payload, splitPoints)) {
748  Operation *target = std::get<0>(pair);
749  auto linalgOp = dyn_cast<LinalgOp>(target);
750  if (!linalgOp) {
751  auto diag = emitSilenceableError() << "only applies to structured ops";
752  diag.attachNote(target->getLoc()) << "target op";
753  return diag;
754  }
755 
756  if (getDimension() >= linalgOp.getNumLoops()) {
757  auto diag = emitSilenceableError() << "dimension " << getDimension()
758  << " does not exist in target op";
759  diag.attachNote(target->getLoc()) << "target op";
760  return diag;
761  }
762 
763  rewriter.setInsertionPoint(linalgOp);
764  std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
765  rewriter, cast<TilingInterface>(linalgOp.getOperation()),
766  getDimension(), std::get<1>(pair));
767  }
768 
769  results.set(getFirst().cast<OpResult>(), first);
770  results.set(getSecond().cast<OpResult>(), second);
772 }
773 
774 void SplitOp::getEffects(
776  consumesHandle(getTarget(), effects);
777  if (getDynamicSplitPoint())
778  onlyReadsHandle(getDynamicSplitPoint(), effects);
779  producesHandle(getResults(), effects);
780  modifiesPayload(effects);
781 }
782 
783 ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
784  OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
785  IntegerAttr staticSplitPoint;
786  auto pdlOperationType =
787  pdl::OperationType::get(parser.getBuilder().getContext());
788  if (parser.parseOperand(target) ||
789  parser.resolveOperand(target, pdlOperationType, result.operands) ||
790  parser.parseKeyword("after"))
791  return failure();
792 
793  OptionalParseResult dynamicPointParseResult =
794  parser.parseOptionalOperand(dynamicSplitPoint);
795  if (!dynamicPointParseResult.has_value()) {
796  int64_t staticSplitPointValue;
797  if (failed(parser.parseInteger(staticSplitPointValue)))
798  return failure();
799 
800  staticSplitPoint =
801  parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
802  } else {
803  if (failed(*dynamicPointParseResult) ||
804  parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
805  result.operands)) {
806  return failure();
807  }
808 
809  staticSplitPoint =
810  parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
811  }
812 
813  result.addAttribute(
814  SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
815  staticSplitPoint);
816  if (failed(parser.parseOptionalAttrDict(result.attributes)))
817  return failure();
818 
819  result.addTypes({pdlOperationType, pdlOperationType});
820  return success();
821 }
822 
823 void SplitOp::print(OpAsmPrinter &printer) {
824  printer << " " << getTarget() << " after ";
825  int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
826  if (staticSplitSize != ShapedType::kDynamicSize)
827  printer << staticSplitSize;
828  else
829  printer << getDynamicSplitPoint();
830  printer << " ";
831  printer.printOptionalAttrDict(getOperation()->getAttrs(),
832  {getStaticSplitPointAttrName()});
833 }
834 
836  if ((static_cast<int64_t>(getStaticSplitPoint()) !=
837  ShapedType::kDynamicSize) ^
838  (getDynamicSplitPoint() == nullptr)) {
839  return emitOpError()
840  << "expects either a dynamic or a static split point to be provided";
841  }
842  return success();
843 }
844 
845 //===----------------------------------------------------------------------===//
846 // SplitReductionOp
847 //===----------------------------------------------------------------------===//
848 
850 transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
852  transform::TransformState &state) {
853  ControlSplitReductionFn splitFn = [&](LinalgOp) {
854  return std::pair<int64_t, unsigned>(getSplitFactor(),
855  getInsertSplitDimension());
856  };
857  SimpleRewriter rewriter(getContext());
858  rewriter.setInsertionPoint(target);
859  FailureOr<SplitReductionResult> splitResult =
860  (getUseScalingAlgorithm())
861  ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
862  : splitReduction(rewriter, target, splitFn, getUseAlloc());
863  if (failed(splitResult))
864  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
865 
866  results.push_back(splitResult->initOrAlloc);
867  results.push_back(splitResult->fillOp);
868  results.push_back(splitResult->splitLinalgOp);
869  results.push_back(splitResult->resultCombiningLinalgOp);
871 }
872 
873 //===----------------------------------------------------------------------===//
874 // TileOp
875 //===----------------------------------------------------------------------===//
876 
878 transform::TileOp::apply(TransformResults &transformResults,
879  TransformState &state) {
880  LinalgTilingOptions tilingOptions;
881  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
882 
883  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
884  SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
885  dynamicSizeProducers.reserve(getDynamicSizes().size());
886  for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
887  dynamicSizeProducers.push_back(
888  state.getPayloadOps(dynamicSizeProducerHandle));
889 
890  if (dynamicSizeProducers.back().size() != targets.size()) {
892  emitSilenceableError()
893  << "expected as many dynamic size-producing operations ("
894  << dynamicSizeProducers.back().size() << ") as target ops ("
895  << targets.size() << ")";
896  diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
897  return diag;
898  }
899 
900  for (Operation *op : dynamicSizeProducers.back()) {
901  if (op->getNumResults() == 1 &&
902  op->getResult(0).getType().isa<IndexType>())
903  continue;
905  emitSilenceableError() << "expected sizes to be produced by ops "
906  "with a single index-type result";
907  diag.attachNote(op->getLoc()) << "size producer op";
908  diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
909  return diag;
910  }
911  }
912 
915  loops.resize(getLoops().size());
916  for (auto &en : llvm::enumerate(targets)) {
917  auto linalgOp = dyn_cast<LinalgOp>(en.value());
918  if (!linalgOp) {
919  DiagnosedSilenceableFailure diag = emitSilenceableError()
920  << "only linalg ops are supported";
921  diag.attachNote(en.value()->getLoc()) << "target op";
922  return diag;
923  }
924 
925  unsigned index = en.index();
926  if (!tileSizes.empty()) {
927  tilingOptions.setTileSizeComputationFunction(
928  [&, index](OpBuilder &b, Operation *) {
929  SmallVector<Value, 4> sizes;
930  sizes.reserve(tileSizes.size());
931  unsigned dynamicIdx = 0;
932  for (OpFoldResult ofr : getMixedSizes()) {
933  if (auto attr = ofr.dyn_cast<Attribute>()) {
934  sizes.push_back(b.create<arith::ConstantIndexOp>(
935  getLoc(), attr.cast<IntegerAttr>().getInt()));
936  } else {
937  sizes.push_back(
938  dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
939  }
940  }
941  return sizes;
942  });
943  }
944 
945  tilingOptions.setInterchange(extractUIntArray(getInterchange()));
946  LinalgTilingPattern pattern(getContext(), tilingOptions);
947  SimpleRewriter rewriter(linalgOp.getContext());
948  FailureOr<TiledLinalgOp> tiledOp =
949  pattern.returningMatchAndRewrite(linalgOp, rewriter);
950  if (failed(tiledOp))
952 
953  tiled.push_back(tiledOp->op);
954  for (const auto &en2 : llvm::enumerate(tiledOp->loops))
955  loops[en2.index()].push_back(en2.value());
956  }
957 
958  transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
959  for (const auto &en : llvm::enumerate(loops))
960  transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
961 
963 }
964 
966  ValueRange dynamic = getDynamicSizes();
967  SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
969  results.reserve(tileSizes.size());
970  unsigned dynamicPos = 0;
971  Builder builder(getContext());
972  for (int64_t size : tileSizes) {
973  if (size == ShapedType::kDynamicSize) {
974  results.push_back(dynamic[dynamicPos++]);
975  } else {
976  results.push_back(builder.getIndexAttr(size));
977  }
978  }
979  return results;
980 }
981 
982 ParseResult transform::TileOp::parse(OpAsmParser &parser,
983  OperationState &result) {
986  ArrayAttr staticSizes;
987  auto pdlOperationType = pdl::OperationType::get(parser.getContext());
988  if (parser.parseOperand(target) ||
989  parser.resolveOperand(target, pdlOperationType, result.operands) ||
990  parseDynamicIndexList(parser, dynamicSizes, staticSizes,
991  ShapedType::kDynamicSize) ||
992  parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
993  parser.parseOptionalAttrDict(result.attributes))
994  return ParseResult::failure();
995 
996  result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
997  size_t numExpectedLoops =
998  staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
999  result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
1000  return success();
1001 }
1002 
1003 void TileOp::print(OpAsmPrinter &p) {
1004  p << ' ' << getTarget();
1005  printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
1006  ShapedType::kDynamicSize);
1007  p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
1008 }
1009 
1010 void transform::TileOp::getEffects(
1012  consumesHandle(getTarget(), effects);
1013  onlyReadsHandle(getDynamicSizes(), effects);
1014  producesHandle(getTiledLinalgOp(), effects);
1015  producesHandle(getLoops(), effects);
1016  modifiesPayload(effects);
1017 }
1018 
1019 //===----------------------------------------------------------------------===//
1020 // TileToForeachThreadOp
1021 //===----------------------------------------------------------------------===//
1022 
1023 DiagnosedSilenceableFailure transform::TileToForeachThreadOp::applyToOne(
1024  TilingInterface target, SmallVectorImpl<Operation *> &results,
1025  transform::TransformState &state) {
1026  IRRewriter rewriter(getContext());
1027  rewriter.setInsertionPoint(target);
1028  auto maybeThreadDimMappingAttr = getThreadDimMapping();
1029  auto dimMapping =
1030  llvm::to_vector(maybeThreadDimMappingAttr
1031  ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
1032  : ArrayRef<int64_t>{});
1033 
1035  if (Optional<ArrayAttr> numThreads = getNumThreads())
1036  tilingResult = linalg::tileToForeachThreadOp(
1037  rewriter, target, getAsOpFoldResult(*numThreads), dimMapping);
1038 
1039  if (Optional<ArrayAttr> tileSizes = getTileSizes())
1041  rewriter, target, getAsOpFoldResult(*tileSizes), dimMapping);
1042 
1043  if (failed(tilingResult))
1044  return emitDefaultSilenceableFailure(target);
1045  rewriter.replaceOp(target, tilingResult->tileOp->getResults());
1046  results.assign({tilingResult->tileOp, tilingResult->tiledOp});
1048 }
1049 
1050 //===----------------------------------------------------------------------===//
1051 // VectorizeOp
1052 //===----------------------------------------------------------------------===//
1053 
1055 transform::VectorizeOp::applyToOne(Operation *target,
1057  transform::TransformState &state) {
1058  if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1059  auto diag = this->emitOpError("requires isolated-from-above targets");
1060  diag.attachNote(target->getLoc()) << "non-isolated target";
1062  }
1063 
1064  MLIRContext *ctx = getContext();
1065  RewritePatternSet patterns(ctx);
1066  patterns.add<LinalgVectorizationPattern>(ctx);
1067 
1072  /*benefit=*/2);
1073  vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1074  vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1075  if (getVectorizePadding())
1077 
1078  if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1079  return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1080 
1081  results.push_back(target);
1083 }
1084 
1085 //===----------------------------------------------------------------------===//
1086 // Transform op registration
1087 //===----------------------------------------------------------------------===//
1088 
1089 namespace {
1090 /// Registers new ops and declares PDL as dependent dialect since the additional
1091 /// ops are using PDL types for operands and results.
1092 class LinalgTransformDialectExtension
1094  LinalgTransformDialectExtension> {
1095 public:
1096  using Base::Base;
1097 
1098  void init() {
1099  declareDependentDialect<pdl::PDLDialect>();
1100 
1101  declareGeneratedDialect<AffineDialect>();
1102  declareGeneratedDialect<arith::ArithmeticDialect>();
1103  declareGeneratedDialect<scf::SCFDialect>();
1104  declareGeneratedDialect<vector::VectorDialect>();
1105 
1106  registerTransformOps<
1107 #define GET_OP_LIST
1108 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1109  >();
1110  }
1111 };
1112 } // namespace
1113 
1114 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
1115 
1116 #define GET_OP_CLASSES
1117 #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1118 
1120  DialectRegistry &registry) {
1121  registry.addExtensions<LinalgTransformDialectExtension>();
1122 }
Diagnostic & attachNote(Optional< Location > loc=llvm::None)
Attaches a note to the last diagnostic.
Include the generated interface declarations.
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context)
This parses a single MLIR attribute to an MLIR context if it was valid.
LinalgPaddingOptions & setHoistPaddings(ArrayRef< int64_t > hp)
Definition: Transforms.h:560
SmallVector< OpFoldResult, 4 > getMixedSizes(ArrayAttr staticValues, ValueRange dynamicValues)
Return a vector of all the static and dynamic sizes.
void set(OpResult value, ArrayRef< Operation *> ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
static std::string diag(llvm::Value &v)
The result of a transform IR operation application.
MLIRContext * getContext() const
Definition: Builders.h:54
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Linalg tiling pattern.
Definition: Transforms.h:677
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:600
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
This is a value defined by a result of an operation.
Definition: Value.h:425
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:345
This class represents a single result from folding an operation.
Definition: OpDefinition.h:239
Operation * clone(Operation &op, BlockAndValueMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:496
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
Definition: LogicalResult.h:72
LinalgPaddingOptions & setPaddingDimensions(ArrayRef< int64_t > pd)
Definition: Transforms.h:547
SmallVector< int64_t > tileInterchange
Tile interchange used to permute the tile loops.
Definition: Transforms.h:582
LinalgTilingOptions & setInterchange(ArrayRef< unsigned > interchange)
Definition: Transforms.h:623
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ValueRange operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:798
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
Definition: LogicalResult.h:68
ArrayRef< Operation * > getPayloadOps(Value value) const
Returns the list of ops that the given transform IR value corresponds to.
LinalgPaddingOptions & setPackPaddings(ArrayRef< bool > pp)
Definition: Transforms.h:554
This is the representation of an operand reference.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
void populateVectorReductionToContractPatterns(RewritePatternSet &patterns)
Collect patterns to convert reduction op to vector.contract and fold transpose/broadcast ops into the...
static FailureOr< LinalgOp > tryApply(Operation *operation, Args &&...args)
Attempts to apply the pattern specified as template argument to the given operation.
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:83
static constexpr const bool value
void erase()
Remove this operation from its parent block and delete it.
Definition: Operation.cpp:414
void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, ArrayAttr integers, int64_t dynVal)
Printer hook for custom directive in assemblyFormat.
SmallVector< Value, 4 > operands
Match and rewrite for the pattern: ``` alloc = ...
Definition: Transforms.h:1263
FailureOr< GenericOp > interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef< unsigned > interchangeVector)
Interchange the iterator_types and iterator_maps dimensions and adapts the index accesses of op...
Definition: Interchange.cpp:51
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:149
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:147
Match and rewrite for the pattern: ``` alloc = ...
Definition: Transforms.h:1290
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:408
This class contains all of the information necessary to report a diagnostic to the DiagnosticEngine...
Definition: Diagnostics.h:157
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value: ...
LinalgPromotionOptions & setOperandsToPromote(ArrayRef< int64_t > operands)
Definition: Transforms.h:219
FailureOr< LinalgOp > splitReduction(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, const LinalgTransformationFilter &f, bool useAlloc=false)
Apply transformation to split the single linalg op reduction into a parallel and reduction dimension...
This class provides support for representing a failure result, or a valid value of type T...
Definition: LogicalResult.h:78
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
void addExtensions()
Add the given extensions to the registry.
FailureOr< SplitReductionResult > splitReductionByScaling(PatternRewriter &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
void onlyReadsHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:99
LinalgPaddingOptions & setPaddingValues(ArrayRef< Attribute > pv)
Definition: Transforms.h:541
LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options)
Promote memref.subviews feeding linalg-on-buffers operations.
Definition: Promotion.cpp:367
Attributes are known-constant values of operations.
Definition: Attributes.h:24
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:528
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:233
SmallVector< int64_t > tileSizes
Tile sizes used to tile the root operation.
Definition: Transforms.h:576
Base type for affine expression.
Definition: AffineExpr.h:68
LinalgTilingOptions & scalarizeDynamicDims()
Tile all dynamic dimensions by 1.
Definition: Transforms.cpp:136
filter controls LinalgTransformMarker matching and update when specified.
Definition: Transforms.h:949
FailureOr< LinalgOp > promoteSubViews(OpBuilder &b, LinalgOp op, const LinalgPromotionOptions &options)
Promote the subViews into a new buffer allocated at the insertion point b.
Definition: Promotion.cpp:389
void addTypes(ArrayRef< Type > newTypes)
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
static WalkResult advance()
Definition: Visitors.h:51
void updateRootInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around a root update of an operation.
Definition: PatternMatch.h:499
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:154
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
ParseResult parseDynamicIndexList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &values, ArrayAttr &integers, int64_t dynVal)
Pasrer hook for custom directive in assemblyFormat.
The state maintained across applications of various ops implementing the TransformOpInterface.
void registerTransformDialectExtension(DialectRegistry &registry)
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
result_range getOpResults()
Definition: Operation.h:337
FailureOr< TiledLinalgOp > returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Definition: Transforms.cpp:373
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:584
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:37
void populateVectorTransferPermutationMapLoweringPatterns(RewritePatternSet &patterns)
Collect a set of transfer read/write lowering patterns that simplify the permutation map (e...
NamedAttrList attributes
LinalgPromotionOptions & setUseAlloca(bool use)
Definition: Transforms.h:254
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:295
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
static FailureOr< SmallVector< Operation * > > cloneAndFuse(Operation *producerOp, Operation *containingOp, RewriterBase &rewriter)
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Definition: Builders.h:49
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
virtual OptionalParseResult parseOptionalOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single operand if present.
LinalgPromotionOptions & setAlignment(unsigned align)
Definition: Transforms.h:248
This class provides the API for ops that are known to be isolated from above.
U dyn_cast() const
Definition: Attributes.h:127
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
FailureOr< MultiSizeSpecification > computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, OpFoldResult targetSize, OpFoldResult divisor, bool emitAssertions=true)
Emits the IR computing the multi-sized tiling specification with two tile sizes not exceeding targetS...
Definition: Tiling.cpp:114
Linalg tile and fuse tensor ops pattern.
Definition: Transforms.h:801
std::pair< TilingInterface, TilingInterface > splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, OpFoldResult splitPoint)
Split the given op into two parts along the given iteration space dimension at the specified splitPoi...
Definition: Split.cpp:68
Specialization of arith.constant op that returns an integer of index type.
Definition: Arithmetic.h:80
Perform standalone tiling of a single LinalgOp by tileSizes.
Definition: Transforms.h:158
OpResult getOpResult(unsigned idx)
Definition: Operation.h:338
LinalgPromotionOptions & setUseFullTileBuffers(ArrayRef< bool > useFullTiles)
Definition: Transforms.h:231
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:55
static SmallVector< unsigned > extractUIntArray(ArrayAttr attr)
Extracts a vector of unsigned from an array attribute.
This class represents an operand of an operation.
Definition: Value.h:251
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
static LogicalResult applyTilingToAll(Operation *transformOp, ArrayRef< Operation *> payloadOps, unsigned numLoops, transform::TransformResults &transformResults, function_ref< FailureOr< TiledLinalgOp >(LinalgOp)> applyFn)
Apply a tiling transformation to all payload ops and store both the tiled operation as well as the cr...
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:47
FailureOr< TileLoopNest > returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const
matchAndRewrite implementation that returns the significant transformed pieces of IR...
Definition: Transforms.cpp:482
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
Definition: Verifier.cpp:372
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into &#39;result&#39; if it is present.
static DiagnosedSilenceableFailure silenceableFailure(Diagnostic &&diag)
Constructs a DiagnosedSilenceableFailure in the silenceable failure state, ready to emit the given di...
LinalgPromotionOptions & setUseFullTileBuffersByDefault(bool use)
Definition: Transforms.h:242
SmallVector< int64_t, 4 > extractFromI64ArrayAttr(Attribute attr)
Extract int64_t values from the assumed ArrayAttr of IntegerAttr.
static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, StringRef sizesAttrName)
Parse a tiling-like operation that returns the tiled op as well as the created tile loops...
void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit=1)
Populates patterns with patterns that vectorize tensor.pad.
std::function< std::pair< int64_t, unsigned >(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:1398
SmallVector< Operation *, 8 > loops
Definition: Transforms.h:160
LinalgTilingOptions & setTileSizeComputationFunction(TileSizeComputationFunction fun)
Definition: Transforms.h:600
static LogicalResult failure(bool isFailure=true)
If isFailure is true a failure result is generated, otherwise a &#39;success&#39; result is generated...
Definition: LogicalResult.h:36
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:221
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:67
FailureOr< ForeachThreadTilingResult > tileToForeachThreadOp(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > numThreads, ArrayRef< int64_t > threadDimMapping={})
Definition: Tiling.cpp:344
This class represents success/failure for parsing-like operations that find it important to chain tog...
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:650
static FailureOr< SmallVector< Operation * > > tileAndFuse(Operation *producerOp, Operation *containingOp, RewriterBase &rewriter)
LogicalResult applyPatternsAndFoldGreedily(MutableArrayRef< Region > regions, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig())
Rewrite the regions of the specified operation, which must be isolated from above, by repeatedly applying the highest benefit patterns in a greedy work-list driven manner.
This class helps build Operations.
Definition: Builders.h:193
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:345
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:95
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
FailureOr< ForeachThreadTilingResult > tileToForeachThreadOpUsingTileSizes(RewriterBase &builder, TilingInterface op, ArrayRef< OpFoldResult > tileSizes, ArrayRef< int64_t > threadDimMapping={})
Same as tileToForeachThreadOp, but calculate the number of threads required using the given tileSizes...
Definition: Tiling.cpp:353
bool isProperAncestor(Operation *other)
Return true if this operation is a proper ancestor of the other operation.
Definition: Operation.cpp:172
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:398
LinalgPaddingOptions & setTransposePaddings(ArrayRef< SmallVector< int64_t >> tp)
Definition: Transforms.h:568