MLIR  21.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/FormatVariadic.h"
20 #include "llvm/Support/InterleavedRange.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 
27 //===----------------------------------------------------------------------===//
28 // StructuredMatchOp
29 //===----------------------------------------------------------------------===//
30 
31 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32  Operation *current, transform::TransformResults &results,
34  // First, check if the payload operation is a structured Linalg operation.
35  if (!isa<linalg::LinalgOp>(current)) {
36  if (getFailurePropagationMode().value_or(
37  FailurePropagationMode::Propagate) ==
38  FailurePropagationMode::Propagate) {
39  return emitSilenceableError() << "expected a Linalg op";
40  }
41  // If errors are suppressed, succeed and set all results to empty lists.
42  LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
43  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
45  }
46 
47  // Bind `current` to the block argument.
48  auto scope = state.make_region_scope(getBodyRegion());
49  if (failed(state.mapBlockArgument(getBody()->getArgument(0),
50  MappedValue(current)))) {
52  }
53 
54  for (Operation &nested : getBody()->without_terminator()) {
56  state.applyTransform(cast<TransformOpInterface>(nested));
57  if (diag.isDefiniteFailure())
58  return diag;
59  if (diag.succeeded())
60  continue;
61 
62  // If propagating errors, do this immediately.
63  assert(diag.isSilenceableFailure());
64  if (getFailurePropagationMode().value_or(
65  FailurePropagationMode::Propagate) ==
66  FailurePropagationMode::Propagate) {
67  return diag;
68  }
69 
70  // If suppressing errors, print the message into the debug stream before
71  // silencing it. Then set all results value that are already known.
72  // Results come from the terminator operands, which may be defined in the
73  // (single) block of this operation or above it. When they are defined
74  // above, they are known to be mapped at this point per SSA dominance.
75  // When they are defined in this block, we additionally check if we have
76  // already applied the operation that defines them. If not, the
77  // corresponding results will be set to empty lists.
78  LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
79  << "\n");
80  (void)diag.silence();
81  SmallVector<OpOperand *> undefinedOperands;
82  for (OpOperand &terminatorOperand :
83  getBody()->getTerminator()->getOpOperands()) {
84  Operation *definingOp = terminatorOperand.get().getDefiningOp();
85  if (!definingOp)
86  continue;
87  if (definingOp->getBlock() != getBody())
88  continue;
89  if (definingOp->isBeforeInBlock(&nested))
90  continue;
91 
92  undefinedOperands.push_back(&terminatorOperand);
93  }
94 
96  auto filtered = llvm::make_filter_range(
97  getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
98  return !llvm::is_contained(undefinedOperands, &opOperand);
99  });
100  SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
101  filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
102  detail::prepareValueMappings(mappings, definedOperands, state);
103  for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
104  results.setMappedValues(getResults()[operand.getOperandNumber()],
105  mapping);
106  }
107  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
109  }
110 
111  // Set the results.
112  detail::forwardTerminatorOperands(getBody(), state, results);
114 }
115 
116 void transform::MatchStructuredOp::getEffects(
118  onlyReadsHandle(getCurrentMutable(), effects);
119  onlyReadsPayload(effects);
120  producesHandle(getOperation()->getOpResults(), effects);
121 }
122 
123 LogicalResult transform::MatchStructuredOp::verify() {
124  if (getBody()->getNumArguments() != 1)
125  return emitOpError() << "expected one body argument";
126  if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
127  return emitOpError() << "expected body argument to implement "
128  "TransformHandleTypeInterface";
129  }
130  for (Operation &nested : getBody()->without_terminator()) {
131  if (isa<MatchOpInterface>(nested))
132  continue;
134  emitOpError()
135  << "expects nested operations to implement MatchOpInterface";
136  diag.attachNote(nested.getLoc()) << "offending operation";
137  return diag;
138  }
139  return success();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // StructuredOpPredicateOpTrait
144 //===----------------------------------------------------------------------===//
145 
147  Operation *op, Value structuredOpHandle) {
148  if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
149  return op->emitOpError() << "expects parent op to be '"
150  << MatchStructuredOp::getOperationName() << "'";
151  }
152 
153  // Bail out here, let the verifier of the parent complain.
154  Operation *parent = op->getParentOp();
155  if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
156  parent->getRegion(0).front().getNumArguments() < 1)
157  return success();
158 
159  if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
160  return op->emitOpError()
161  << "expected predicate to apply to the surrounding structured op";
162  }
163  return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // MatchStructuredBodyOp
168 //===----------------------------------------------------------------------===//
169 
170 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
171  Operation *current, transform::TransformResults &results,
172  transform::TransformState &state) {
173  auto linalgOp = cast<linalg::LinalgOp>(current);
174  if (std::optional<uint64_t> position = getReductionPosition()) {
175  SmallVector<Operation *> combinerOps;
176  if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
177  combinerOps)) {
178  return emitSilenceableError() << "could not match reduction";
179  }
180  if (combinerOps.size() != 1) {
181  return emitSilenceableError() << "reduction combiner is not a single op";
182  }
184  }
185  if (getPassthrough()) {
186  Block &body = linalgOp->getRegion(0).front();
187  if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
188  return emitSilenceableError() << "not a passthrough";
189  }
191  }
192  if (getElementwise()) {
193  if (!isElementwise(linalgOp))
194  return emitSilenceableError() << "not elementwise";
196  }
197  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198  Block &body = linalgOp->getRegion(0).front();
199  std::string message;
200  llvm::raw_string_ostream os(message);
201  bool result = linalg::detail::isContractionBody(
202  body,
203  [&](Operation *elem, Operation *red) {
204  return elem->getName().getStringRef() ==
205  cast<StringAttr>((*contractionOps)[0]).getValue() &&
206  red->getName().getStringRef() ==
207  cast<StringAttr>((*contractionOps)[1]).getValue();
208  },
209  os);
210  if (result)
212  return emitSilenceableError() << "contraction: " << message;
213  }
214  return emitDefiniteFailure() << "unknown body condition";
215 }
216 
218  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219  getElementwise() + getContraction().has_value();
220 
221  if (numOptions > 1) {
222  StringAttr attributeNames[] = {
223  getReductionPositionAttrName(), getPassthroughAttrName(),
224  getElementwiseAttrName(), getContractionAttrName()};
225  return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
226  << "} is allowed";
227  }
228 
229  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
230  if (contractionAttr->size() != 2) {
231  return emitOpError() << "expects " << getContractionAttrName()
232  << " to contain two elements";
233  }
234  }
235  return success();
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // MatchStructuredClassifyContractionDimsOp
240 //===----------------------------------------------------------------------===//
241 
243 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
244  Operation *current, transform::TransformResults &results,
245  transform::TransformState &state) {
246  FailureOr<linalg::ContractionDimensions> contractionDims =
247  linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
248  if (failed(contractionDims))
249  return emitSilenceableError() << "could not infer contraction dimensions";
250 
251  MLIRContext *context = current->getContext();
252  Builder builder(context);
253  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
254  return llvm::to_vector(
255  llvm::map_range(values, [&](unsigned value) -> Attribute {
256  return builder.getI64IntegerAttr(value);
257  }));
258  };
259  results.setParams(cast<OpResult>(getBatch()),
260  makeI64Attrs(contractionDims->batch));
261  results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
262  results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
263  results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // MatchStructuredClassifyConvolutionDimsOp
269 //===----------------------------------------------------------------------===//
270 
272 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
273  Operation *current, transform::TransformResults &results,
274  transform::TransformState &state) {
275  FailureOr<linalg::ConvolutionDimensions> convolutionDims =
276  linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
277  if (failed(convolutionDims))
278  return emitSilenceableError() << "could not infer convolution dimensions";
279 
280  MLIRContext *context = current->getContext();
281  Builder builder(context);
282  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
283  return llvm::to_vector(
284  llvm::map_range(values, [&](unsigned value) -> Attribute {
285  return builder.getI64IntegerAttr(value);
286  }));
287  };
288  results.setParams(cast<OpResult>(getBatch()),
289  makeI64Attrs(convolutionDims->batch));
290  results.setParams(cast<OpResult>(getOutputImage()),
291  makeI64Attrs(convolutionDims->outputImage));
292  results.setParams(cast<OpResult>(getOutputChannel()),
293  makeI64Attrs(convolutionDims->outputChannel));
294  results.setParams(cast<OpResult>(getFilterLoop()),
295  makeI64Attrs(convolutionDims->filterLoop));
296  results.setParams(cast<OpResult>(getInputChannel()),
297  makeI64Attrs(convolutionDims->inputChannel));
298  results.setParams(cast<OpResult>(getDepth()),
299  makeI64Attrs(convolutionDims->depth));
300 
301  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
302  return llvm::to_vector(
303  llvm::map_range(values, [&](int64_t value) -> Attribute {
304  return builder.getI64IntegerAttr(value);
305  }));
306  };
307  results.setParams(cast<OpResult>(getStrides()),
308  makeI64AttrsFromI64(convolutionDims->strides));
309  results.setParams(cast<OpResult>(getDilations()),
310  makeI64AttrsFromI64(convolutionDims->dilations));
312 }
313 
314 //===----------------------------------------------------------------------===//
315 // Utilities for structured match predicates.
316 //===----------------------------------------------------------------------===//
317 
318 /// Checks if all values from `list` are also contained in `reference`. Returns
319 /// a silenceable error with the given message at the given location when it is
320 /// not the case. The error message must contain the "{0}" placeholder that
321 /// will be substituted with the value from `list` that is not contained in
322 /// `reference`.
324  ArrayRef<int64_t> list,
325  Location loc,
326  const char *message) {
327  for (int64_t value : list) {
328  if (llvm::any_of(reference, [&](unsigned ref) {
329  return static_cast<int64_t>(ref) == value;
330  })) {
331  continue;
332  }
333  return emitSilenceableFailure(loc) << llvm::formatv(message, value);
334  }
336 }
337 
338 //===----------------------------------------------------------------------===//
339 // MatchStructuredDimOp
340 //===----------------------------------------------------------------------===//
341 
342 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
343  Operation *current, transform::TransformResults &results,
344  transform::TransformState &state) {
345  auto linalgOp = cast<linalg::LinalgOp>(current);
346  SmallVector<int64_t> dimensions;
347  DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
348  if (!diag.succeeded())
349  return diag;
350 
351  // If asked to check for the kind of dimension, perform the check.
352  if (getParallel() || getReduction()) {
353  SmallVector<unsigned> reference;
354  if (getParallel())
355  linalgOp.getParallelDims(reference);
356  else if (getReduction())
357  linalgOp.getReductionDims(reference);
358 
360  containsAll(reference, dimensions, getLoc(),
361  getParallel() ? "expects dimension #{0} to be parallel"
362  : "expects dimension #{0} to be reduction");
363  if (!diag.succeeded())
364  return diag;
365  }
366 
367  // If not capturing, we are done here.
368  if (!getResult())
369  return diag;
370 
371  SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
372  Builder builder(current);
373  SmallVector<Attribute> captured = llvm::to_vector(
374  llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
375  return builder.getI64IntegerAttr(ranges[dim]);
376  }));
377  results.setParams(cast<OpResult>(getResult()), captured);
379 }
380 
381 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
382  linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
384  expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
385  getRawDimList(), op.getNumLoops(), dims);
386  if (diag.isSilenceableFailure()) {
387  diag.attachNote(op->getLoc())
388  << "while considering dimensions of this payload operation";
389  }
390  return diag;
391 }
392 
394  if (getParallel() && getReduction()) {
395  return emitOpError() << "cannot request the same dimension to be both "
396  "parallel and reduction";
397  }
398  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
399  getIsInverted(), getIsAll());
400 }
401 
402 //===----------------------------------------------------------------------===//
403 // MatchStructuredElementalBitwidthOp
404 //===----------------------------------------------------------------------===//
405 
407 transform::MatchStructuredElementalBitwidthOp::matchValue(
408  Value current, transform::TransformResults &results,
409  transform::TransformState &state) {
410  auto setupResult = [&](int64_t bitwidth) {
411  Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
412  results.setParams(cast<OpResult>(getResult()), {attr});
414  };
415 
416  Type type = current.getType();
417  if (type.isIntOrFloat())
418  return setupResult(type.getIntOrFloatBitWidth());
419 
420  if (auto shapedType = dyn_cast<ShapedType>(type)) {
421  if (shapedType.getElementType().isIntOrFloat())
422  return setupResult(shapedType.getElementTypeBitWidth());
423  }
424  return emitSilenceableError()
425  << "unsupported type for bitwidth extraction: " << type;
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // MatchStructuredInputOp
430 //===----------------------------------------------------------------------===//
431 
432 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
433  Operation *current, transform::TransformResults &results,
434  transform::TransformState &state) {
435  auto linalgOp = cast<linalg::LinalgOp>(current);
436  SmallVector<int64_t> positions;
437  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
438  if (!diag.succeeded())
439  return diag;
440 
441  SmallVector<MappedValue> operandMapping;
442  operandMapping.reserve(positions.size());
443  for (int64_t position : positions) {
444  AffineMap indexingMap =
445  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
446  if (getPermutation() && !indexingMap.isPermutation()) {
447  return emitSilenceableError() << "the indexing map for input #"
448  << position << " is not a permutation";
449  }
450  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
451  return emitSilenceableError()
452  << "the indexing map for input #" << position
453  << " is not a projected permutation";
454  }
455 
456  // If capture not requested, skip it.
457  if (!getResult())
458  continue;
459 
460  if (isa<AffineMapParamType>(getResult().getType())) {
461  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
462  continue;
463  }
464 
465  Value operand = linalgOp.getDpsInputOperand(position)->get();
466  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
467  operandMapping.emplace_back(operand);
468  continue;
469  }
470 
471  Operation *operandProducer = operand.getDefiningOp();
472  if (!operandProducer) {
473  return emitSilenceableError()
474  << "input #" << position << " is not produced by an operation";
475  }
476  operandMapping.emplace_back(operandProducer);
477  }
478  if (getResult())
479  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
481 }
482 
483 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
484  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
486  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
487  op.getNumDpsInputs(), positions);
488  if (diag.isSilenceableFailure()) {
489  diag.attachNote(op->getLoc())
490  << "while considering DPS inputs of this payload operation";
491  }
492  return diag;
493 }
494 
495 /// Verifies a matcher op for structured input or output, specifically the
496 /// attributes specifying the operand positions.
497 template <typename OpTy>
498 LogicalResult verifyStructuredOperandOp(OpTy op) {
499  if (op.getPermutation() && op.getProjectedPermutation()) {
500  return op.emitOpError()
501  << op.getPermutationAttrName() << " and "
502  << op.getProjectedPermutationAttrName() << " are mutually exclusive";
503  }
504  if (op.getRawPositionList().size() > 1 && op.getResult()) {
505  return op.emitOpError()
506  << "cannot bind multiple inputs/inits to the same value";
507  }
508 
509  return success();
510 }
511 
513  if (failed(verifyStructuredOperandOp(*this)))
514  return failure();
515  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
516  getIsInverted(), getIsAll());
517 }
518 
519 //===----------------------------------------------------------------------===//
520 // MatchStructuredInitOp
521 //===----------------------------------------------------------------------===//
522 
523 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
524  Operation *current, transform::TransformResults &results,
525  transform::TransformState &state) {
526  auto linalgOp = cast<linalg::LinalgOp>(current);
527  SmallVector<int64_t> positions;
528  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
529  if (!diag.succeeded())
530  return diag;
531 
532  SmallVector<MappedValue> operandMapping;
533  operandMapping.reserve(positions.size());
534  for (int64_t position : positions) {
535  AffineMap indexingMap =
536  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
537  if (getPermutation() && !indexingMap.isPermutation()) {
538  return emitSilenceableError() << "the indexing map for output(init) #"
539  << position << " is not a permutation";
540  }
541  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
542  return emitSilenceableError() << "the indexing map for output(init) #"
543  << position << " is not a permutation";
544  }
545 
546  // If capture not requested, skip it.
547  if (!getResult())
548  continue;
549 
550  if (isa<AffineMapParamType>(getResult().getType())) {
551  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
552  continue;
553  }
554 
555  Value operand = linalgOp.getDpsInitOperand(position)->get();
556  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
557  operandMapping.emplace_back(operand);
558  continue;
559  }
560 
561  Operation *operandProducer = operand.getDefiningOp();
562  if (!operandProducer) {
563  return emitSilenceableError() << "output(init) #" << position
564  << " is not produced by an operation";
565  }
566  operandMapping.emplace_back(operandProducer);
567  }
568  if (getResult())
569  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
571 }
572 
573 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
574  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
576  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
577  op.getNumDpsInits(), positions);
578  if (diag.isSilenceableFailure()) {
579  diag.attachNote(op->getLoc())
580  << "while considering DPS inits (outputs) of this payload operation";
581  }
582  return diag;
583 }
584 
586  if (failed(verifyStructuredOperandOp(*this)))
587  return failure();
588  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
589  getIsInverted(), getIsAll());
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // MatchStructuredNumInputsOp
594 //===----------------------------------------------------------------------===//
595 
597 transform::MatchStructuredNumInputsOp::matchOperation(
598  Operation *current, transform::TransformResults &results,
599  transform::TransformState &state) {
600  auto linalgOp = cast<linalg::LinalgOp>(current);
601  Attribute attr =
602  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
603  results.setParams(cast<OpResult>(getResult()), {attr});
605 }
606 
607 //===----------------------------------------------------------------------===//
608 // MatchStructuredNumInitsOp
609 //===----------------------------------------------------------------------===//
610 
612 transform::MatchStructuredNumInitsOp::matchOperation(
613  Operation *current, transform::TransformResults &results,
614  transform::TransformState &state) {
615  auto linalgOp = cast<linalg::LinalgOp>(current);
616  Attribute attr =
617  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
618  results.setParams(cast<OpResult>(getResult()), {attr});
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // MatchStructuredRankOp
624 //===----------------------------------------------------------------------===//
625 
626 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
627  Operation *current, transform::TransformResults &results,
628  transform::TransformState &state) {
629  auto linalgOp = cast<linalg::LinalgOp>(current);
630  int64_t numLoops = linalgOp.getNumLoops();
631  Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
632  results.setParams(cast<OpResult>(getRank()), {attr});
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // MatchStructuredResultOp
638 //===----------------------------------------------------------------------===//
639 
640 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
642  transform::TransformState &state) {
643  auto linalgOp = cast<linalg::LinalgOp>(op);
644  int64_t position;
645  DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
646  if (!diag.succeeded())
647  return diag;
648 
649  Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
650  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
651  results.setValues(cast<OpResult>(getResult()), {result});
653  }
654 
655  if (result.getUsers().empty()) {
656  return emitSilenceableError()
657  << "no users of the result #" << getPosition();
658  }
659  Operation *firstUser = *result.getUsers().begin();
660  if (getAny()) {
661  results.set(cast<OpResult>(getResult()), {firstUser});
663  }
664  if (getSingle()) {
665  if (!llvm::hasSingleElement(result.getUsers())) {
666  return emitSilenceableError()
667  << "more than one result user with single user requested";
668  }
669  results.set(cast<OpResult>(getResult()), {firstUser});
671  }
672 
673  return emitDefiniteFailure() << "unknown sub-predicate";
674 }
675 
677 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
678  int64_t &position) {
679  auto rawPosition = static_cast<int64_t>(getPosition());
680  position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
681  if (position >= op.getNumDpsInits() || position < 0) {
682  return emitSilenceableError()
683  << "position " << rawPosition
684  << " overflows the number of results(ints) of the payload operation";
685  }
687 }
688 
690  if ((getAny() || getSingle()) ^
691  isa<TransformHandleTypeInterface>(getResult().getType())) {
692  return emitOpError() << "expects either the any/single keyword or the type "
693  "value handle result type";
694  }
695  if (getAny() && getSingle()) {
696  return emitOpError() << "'any' and 'single' are mutually exclusive";
697  }
698  return success();
699 }
700 
701 //===----------------------------------------------------------------------===//
702 // MatchStructuredYieldOp
703 //===----------------------------------------------------------------------===//
704 
705 void transform::MatchStructuredYieldOp::getEffects(
707  onlyReadsHandle(getHandlesMutable(), effects);
708  onlyReadsPayload(effects);
709 }
710 
711 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
712  OperationState &state) {
713  build(builder, state, ValueRange());
714 }
715 
716 #define GET_OP_CLASSES
717 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
#define DBGS()
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:641
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:107
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:257
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:385
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:686
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:220
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This represents an operation in an abstracted form, suitable for use with the builder APIs.