MLIR  20.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 
22 using namespace mlir;
23 
24 #define DEBUG_TYPE "linalg-transforms"
25 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
26 
27 //===----------------------------------------------------------------------===//
28 // StructuredMatchOp
29 //===----------------------------------------------------------------------===//
30 
31 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
32  Operation *current, transform::TransformResults &results,
34  // First, check if the payload operation is a structured Linalg operation.
35  if (!isa<linalg::LinalgOp>(current)) {
36  if (getFailurePropagationMode().value_or(
37  FailurePropagationMode::Propagate) ==
38  FailurePropagationMode::Propagate) {
39  return emitSilenceableError() << "expected a Linalg op";
40  }
41  // If errors are suppressed, succeed and set all results to empty lists.
42  LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
43  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
45  }
46 
47  // Bind `current` to the block argument.
48  auto scope = state.make_region_scope(getBodyRegion());
49  if (failed(state.mapBlockArgument(getBody()->getArgument(0),
50  MappedValue(current)))) {
52  }
53 
54  for (Operation &nested : getBody()->without_terminator()) {
56  state.applyTransform(cast<TransformOpInterface>(nested));
57  if (diag.isDefiniteFailure())
58  return diag;
59  if (diag.succeeded())
60  continue;
61 
62  // If propagating errors, do this immediately.
63  assert(diag.isSilenceableFailure());
64  if (getFailurePropagationMode().value_or(
65  FailurePropagationMode::Propagate) ==
66  FailurePropagationMode::Propagate) {
67  return diag;
68  }
69 
70  // If suppressing errors, print the message into the debug stream before
71  // silencing it. Then set all results value that are already known.
72  // Results come from the terminator operands, which may be defined in the
73  // (single) block of this operation or above it. When they are defined
74  // above, they are known to be mapped at this point per SSA dominance.
75  // When they are defined in this block, we additionally check if we have
76  // already applied the operation that defines them. If not, the
77  // corresponding results will be set to empty lists.
78  LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
79  << "\n");
80  (void)diag.silence();
81  SmallVector<OpOperand *> undefinedOperands;
82  for (OpOperand &terminatorOperand :
83  getBody()->getTerminator()->getOpOperands()) {
84  Operation *definingOp = terminatorOperand.get().getDefiningOp();
85  if (!definingOp)
86  continue;
87  if (definingOp->getBlock() != getBody())
88  continue;
89  if (definingOp->isBeforeInBlock(&nested))
90  continue;
91 
92  undefinedOperands.push_back(&terminatorOperand);
93  }
94 
96  auto filtered = llvm::make_filter_range(
97  getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
98  return !llvm::is_contained(undefinedOperands, &opOperand);
99  });
100  SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
101  filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
102  detail::prepareValueMappings(mappings, definedOperands, state);
103  for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
104  results.setMappedValues(getResults()[operand.getOperandNumber()],
105  mapping);
106  }
107  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
109  }
110 
111  // Set the results.
112  detail::forwardTerminatorOperands(getBody(), state, results);
114 }
115 
116 void transform::MatchStructuredOp::getEffects(
118  onlyReadsHandle(getCurrentMutable(), effects);
119  onlyReadsPayload(effects);
120  producesHandle(getOperation()->getOpResults(), effects);
121 }
122 
123 LogicalResult transform::MatchStructuredOp::verify() {
124  if (getBody()->getNumArguments() != 1)
125  return emitOpError() << "expected one body argument";
126  if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
127  return emitOpError() << "expected body argument to implement "
128  "TransformHandleTypeInterface";
129  }
130  for (Operation &nested : getBody()->without_terminator()) {
131  if (isa<MatchOpInterface>(nested))
132  continue;
134  emitOpError()
135  << "expects nested operations to implement MatchOpInterface";
136  diag.attachNote(nested.getLoc()) << "offending operation";
137  return diag;
138  }
139  return success();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // StructuredOpPredicateOpTrait
144 //===----------------------------------------------------------------------===//
145 
147  Operation *op, Value structuredOpHandle) {
148  if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
149  return op->emitOpError() << "expects parent op to be '"
150  << MatchStructuredOp::getOperationName() << "'";
151  }
152 
153  // Bail out here, let the verifier of the parent complain.
154  Operation *parent = op->getParentOp();
155  if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
156  parent->getRegion(0).front().getNumArguments() < 1)
157  return success();
158 
159  if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
160  return op->emitOpError()
161  << "expected predicate to apply to the surrounding structured op";
162  }
163  return success();
164 }
165 
166 //===----------------------------------------------------------------------===//
167 // MatchStructuredBodyOp
168 //===----------------------------------------------------------------------===//
169 
170 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
171  Operation *current, transform::TransformResults &results,
172  transform::TransformState &state) {
173  auto linalgOp = cast<linalg::LinalgOp>(current);
174  if (std::optional<uint64_t> position = getReductionPosition()) {
175  SmallVector<Operation *> combinerOps;
176  if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
177  combinerOps)) {
178  return emitSilenceableError() << "could not match reduction";
179  }
180  if (combinerOps.size() != 1) {
181  return emitSilenceableError() << "reduction combiner is not a single op";
182  }
184  }
185  if (getPassthrough()) {
186  Block &body = linalgOp->getRegion(0).front();
187  if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
188  return emitSilenceableError() << "not a passthrough";
189  }
191  }
192  if (getElementwise()) {
193  if (!isElementwise(linalgOp))
194  return emitSilenceableError() << "not elementwise";
196  }
197  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
198  Block &body = linalgOp->getRegion(0).front();
199  std::string message;
200  llvm::raw_string_ostream os(message);
201  bool result = linalg::detail::isContractionBody(
202  body,
203  [&](Operation *elem, Operation *red) {
204  return elem->getName().getStringRef() ==
205  cast<StringAttr>((*contractionOps)[0]).getValue() &&
206  red->getName().getStringRef() ==
207  cast<StringAttr>((*contractionOps)[1]).getValue();
208  },
209  os);
210  if (result)
212  return emitSilenceableError() << "contraction: " << message;
213  }
214  return emitDefiniteFailure() << "unknown body condition";
215 }
216 
218  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
219  getElementwise() + getContraction().has_value();
220 
221  if (numOptions > 1) {
222  std::string attributeNames;
223  llvm::raw_string_ostream os(attributeNames);
224  llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
225  getPassthroughAttrName(),
226  getElementwiseAttrName(),
227  getContractionAttrName()},
228  os);
229  return emitOpError() << "only one of {" << attributeNames << "} is allowed";
230  }
231 
232  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
233  if (contractionAttr->size() != 2) {
234  return emitOpError() << "expects " << getContractionAttrName()
235  << " to contain two elements";
236  }
237  }
238  return success();
239 }
240 
241 //===----------------------------------------------------------------------===//
242 // MatchStructuredClassifyContractionDimsOp
243 //===----------------------------------------------------------------------===//
244 
246 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
247  Operation *current, transform::TransformResults &results,
248  transform::TransformState &state) {
249  FailureOr<linalg::ContractionDimensions> contractionDims =
250  linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
251  if (failed(contractionDims))
252  return emitSilenceableError() << "could not infer contraction dimensions";
253 
254  MLIRContext *context = current->getContext();
255  Builder builder(context);
256  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
257  return llvm::to_vector(
258  llvm::map_range(values, [&](unsigned value) -> Attribute {
259  return builder.getI64IntegerAttr(value);
260  }));
261  };
262  results.setParams(cast<OpResult>(getBatch()),
263  makeI64Attrs(contractionDims->batch));
264  results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
265  results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
266  results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
268 }
269 
270 //===----------------------------------------------------------------------===//
271 // MatchStructuredClassifyConvolutionDimsOp
272 //===----------------------------------------------------------------------===//
273 
275 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
276  Operation *current, transform::TransformResults &results,
277  transform::TransformState &state) {
278  FailureOr<linalg::ConvolutionDimensions> convolutionDims =
279  linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
280  if (failed(convolutionDims))
281  return emitSilenceableError() << "could not infer convolution dimensions";
282 
283  MLIRContext *context = current->getContext();
284  Builder builder(context);
285  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
286  return llvm::to_vector(
287  llvm::map_range(values, [&](unsigned value) -> Attribute {
288  return builder.getI64IntegerAttr(value);
289  }));
290  };
291  results.setParams(cast<OpResult>(getBatch()),
292  makeI64Attrs(convolutionDims->batch));
293  results.setParams(cast<OpResult>(getOutputImage()),
294  makeI64Attrs(convolutionDims->outputImage));
295  results.setParams(cast<OpResult>(getOutputChannel()),
296  makeI64Attrs(convolutionDims->outputChannel));
297  results.setParams(cast<OpResult>(getFilterLoop()),
298  makeI64Attrs(convolutionDims->filterLoop));
299  results.setParams(cast<OpResult>(getInputChannel()),
300  makeI64Attrs(convolutionDims->inputChannel));
301  results.setParams(cast<OpResult>(getDepth()),
302  makeI64Attrs(convolutionDims->depth));
303 
304  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
305  return llvm::to_vector(
306  llvm::map_range(values, [&](int64_t value) -> Attribute {
307  return builder.getI64IntegerAttr(value);
308  }));
309  };
310  results.setParams(cast<OpResult>(getStrides()),
311  makeI64AttrsFromI64(convolutionDims->strides));
312  results.setParams(cast<OpResult>(getDilations()),
313  makeI64AttrsFromI64(convolutionDims->dilations));
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // Utilities for structured match predicates.
319 //===----------------------------------------------------------------------===//
320 
321 /// Checks if all values from `list` are also contained in `reference`. Returns
322 /// a silenceable error with the given message at the given location when it is
323 /// not the case. The error message must contain the "{0}" placeholder that
324 /// will be substituted with the value from `list` that is not contained in
325 /// `reference`.
327  ArrayRef<int64_t> list,
328  Location loc,
329  const char *message) {
330  for (int64_t value : list) {
331  if (llvm::any_of(reference, [&](unsigned ref) {
332  return static_cast<int64_t>(ref) == value;
333  })) {
334  continue;
335  }
336  return emitSilenceableFailure(loc) << llvm::formatv(message, value);
337  }
339 }
340 
341 //===----------------------------------------------------------------------===//
342 // MatchStructuredDimOp
343 //===----------------------------------------------------------------------===//
344 
345 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
346  Operation *current, transform::TransformResults &results,
347  transform::TransformState &state) {
348  auto linalgOp = cast<linalg::LinalgOp>(current);
349  SmallVector<int64_t> dimensions;
350  DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
351  if (!diag.succeeded())
352  return diag;
353 
354  // If asked to check for the kind of dimension, perform the check.
355  if (getParallel() || getReduction()) {
356  SmallVector<unsigned> reference;
357  if (getParallel())
358  linalgOp.getParallelDims(reference);
359  else if (getReduction())
360  linalgOp.getReductionDims(reference);
361 
363  containsAll(reference, dimensions, getLoc(),
364  getParallel() ? "expects dimension #{0} to be parallel"
365  : "expects dimension #{0} to be reduction");
366  if (!diag.succeeded())
367  return diag;
368  }
369 
370  // If not capturing, we are done here.
371  if (!getResult())
372  return diag;
373 
374  SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
375  Builder builder(current);
376  SmallVector<Attribute> captured = llvm::to_vector(
377  llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
378  return builder.getI64IntegerAttr(ranges[dim]);
379  }));
380  results.setParams(cast<OpResult>(getResult()), captured);
382 }
383 
384 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
385  linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
387  expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
388  getRawDimList(), op.getNumLoops(), dims);
389  if (diag.isSilenceableFailure()) {
390  diag.attachNote(op->getLoc())
391  << "while considering dimensions of this payload operation";
392  }
393  return diag;
394 }
395 
397  if (getParallel() && getReduction()) {
398  return emitOpError() << "cannot request the same dimension to be both "
399  "parallel and reduction";
400  }
401  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
402  getIsInverted(), getIsAll());
403 }
404 
405 //===----------------------------------------------------------------------===//
406 // MatchStructuredElementalBitwidthOp
407 //===----------------------------------------------------------------------===//
408 
410 transform::MatchStructuredElementalBitwidthOp::matchValue(
411  Value current, transform::TransformResults &results,
412  transform::TransformState &state) {
413  auto setupResult = [&](int64_t bitwidth) {
414  Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
415  results.setParams(cast<OpResult>(getResult()), {attr});
417  };
418 
419  Type type = current.getType();
420  if (type.isIntOrFloat())
421  return setupResult(type.getIntOrFloatBitWidth());
422 
423  if (auto shapedType = dyn_cast<ShapedType>(type)) {
424  if (shapedType.getElementType().isIntOrFloat())
425  return setupResult(shapedType.getElementTypeBitWidth());
426  }
427  return emitSilenceableError()
428  << "unsupported type for bitwidth extraction: " << type;
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // MatchStructuredInputOp
433 //===----------------------------------------------------------------------===//
434 
435 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
436  Operation *current, transform::TransformResults &results,
437  transform::TransformState &state) {
438  auto linalgOp = cast<linalg::LinalgOp>(current);
439  SmallVector<int64_t> positions;
440  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
441  if (!diag.succeeded())
442  return diag;
443 
444  SmallVector<MappedValue> operandMapping;
445  operandMapping.reserve(positions.size());
446  for (int64_t position : positions) {
447  AffineMap indexingMap =
448  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
449  if (getPermutation() && !indexingMap.isPermutation()) {
450  return emitSilenceableError() << "the indexing map for input #"
451  << position << " is not a permutation";
452  }
453  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
454  return emitSilenceableError()
455  << "the indexing map for input #" << position
456  << " is not a projected permutation";
457  }
458 
459  // If capture not requested, skip it.
460  if (!getResult())
461  continue;
462 
463  if (isa<AffineMapParamType>(getResult().getType())) {
464  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
465  continue;
466  }
467 
468  Value operand = linalgOp.getDpsInputOperand(position)->get();
469  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
470  operandMapping.emplace_back(operand);
471  continue;
472  }
473 
474  Operation *operandProducer = operand.getDefiningOp();
475  if (!operandProducer) {
476  return emitSilenceableError()
477  << "input #" << position << " is not produced by an operation";
478  }
479  operandMapping.emplace_back(operandProducer);
480  }
481  if (getResult())
482  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
484 }
485 
486 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
487  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
489  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
490  op.getNumDpsInputs(), positions);
491  if (diag.isSilenceableFailure()) {
492  diag.attachNote(op->getLoc())
493  << "while considering DPS inputs of this payload operation";
494  }
495  return diag;
496 }
497 
498 /// Verifies a matcher op for structured input or output, specifically the
499 /// attributes specifying the operand positions.
500 template <typename OpTy>
501 LogicalResult verifyStructuredOperandOp(OpTy op) {
502  if (op.getPermutation() && op.getProjectedPermutation()) {
503  return op.emitOpError()
504  << op.getPermutationAttrName() << " and "
505  << op.getProjectedPermutationAttrName() << " are mutually exclusive";
506  }
507  if (op.getRawPositionList().size() > 1 && op.getResult()) {
508  return op.emitOpError()
509  << "cannot bind multiple inputs/inits to the same value";
510  }
511 
512  return success();
513 }
514 
516  if (failed(verifyStructuredOperandOp(*this)))
517  return failure();
518  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
519  getIsInverted(), getIsAll());
520 }
521 
522 //===----------------------------------------------------------------------===//
523 // MatchStructuredInitOp
524 //===----------------------------------------------------------------------===//
525 
526 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
527  Operation *current, transform::TransformResults &results,
528  transform::TransformState &state) {
529  auto linalgOp = cast<linalg::LinalgOp>(current);
530  SmallVector<int64_t> positions;
531  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
532  if (!diag.succeeded())
533  return diag;
534 
535  SmallVector<MappedValue> operandMapping;
536  operandMapping.reserve(positions.size());
537  for (int64_t position : positions) {
538  AffineMap indexingMap =
539  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
540  if (getPermutation() && !indexingMap.isPermutation()) {
541  return emitSilenceableError() << "the indexing map for output(init) #"
542  << position << " is not a permutation";
543  }
544  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
545  return emitSilenceableError() << "the indexing map for output(init) #"
546  << position << " is not a permutation";
547  }
548 
549  // If capture not requested, skip it.
550  if (!getResult())
551  continue;
552 
553  if (isa<AffineMapParamType>(getResult().getType())) {
554  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
555  continue;
556  }
557 
558  Value operand = linalgOp.getDpsInitOperand(position)->get();
559  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
560  operandMapping.emplace_back(operand);
561  continue;
562  }
563 
564  Operation *operandProducer = operand.getDefiningOp();
565  if (!operandProducer) {
566  return emitSilenceableError() << "output(init) #" << position
567  << " is not produced by an operation";
568  }
569  operandMapping.emplace_back(operandProducer);
570  }
571  if (getResult())
572  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
574 }
575 
576 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
577  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
579  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
580  op.getNumDpsInits(), positions);
581  if (diag.isSilenceableFailure()) {
582  diag.attachNote(op->getLoc())
583  << "while considering DPS inits (outputs) of this payload operation";
584  }
585  return diag;
586 }
587 
589  if (failed(verifyStructuredOperandOp(*this)))
590  return failure();
591  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
592  getIsInverted(), getIsAll());
593 }
594 
595 //===----------------------------------------------------------------------===//
596 // MatchStructuredNumInputsOp
597 //===----------------------------------------------------------------------===//
598 
600 transform::MatchStructuredNumInputsOp::matchOperation(
601  Operation *current, transform::TransformResults &results,
602  transform::TransformState &state) {
603  auto linalgOp = cast<linalg::LinalgOp>(current);
604  Attribute attr =
605  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
606  results.setParams(cast<OpResult>(getResult()), {attr});
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // MatchStructuredNumInitsOp
612 //===----------------------------------------------------------------------===//
613 
615 transform::MatchStructuredNumInitsOp::matchOperation(
616  Operation *current, transform::TransformResults &results,
617  transform::TransformState &state) {
618  auto linalgOp = cast<linalg::LinalgOp>(current);
619  Attribute attr =
620  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
621  results.setParams(cast<OpResult>(getResult()), {attr});
623 }
624 
625 //===----------------------------------------------------------------------===//
626 // MatchStructuredRankOp
627 //===----------------------------------------------------------------------===//
628 
629 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
630  Operation *current, transform::TransformResults &results,
631  transform::TransformState &state) {
632  auto linalgOp = cast<linalg::LinalgOp>(current);
633  int64_t numLoops = linalgOp.getNumLoops();
634  Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
635  results.setParams(cast<OpResult>(getRank()), {attr});
637 }
638 
639 //===----------------------------------------------------------------------===//
640 // MatchStructuredResultOp
641 //===----------------------------------------------------------------------===//
642 
643 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
645  transform::TransformState &state) {
646  auto linalgOp = cast<linalg::LinalgOp>(op);
647  int64_t position;
648  DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
649  if (!diag.succeeded())
650  return diag;
651 
652  Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
653  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
654  results.setValues(cast<OpResult>(getResult()), {result});
656  }
657 
658  if (result.getUsers().empty()) {
659  return emitSilenceableError()
660  << "no users of the result #" << getPosition();
661  }
662  Operation *firstUser = *result.getUsers().begin();
663  if (getAny()) {
664  results.set(cast<OpResult>(getResult()), {firstUser});
666  }
667  if (getSingle()) {
668  if (!llvm::hasSingleElement(result.getUsers())) {
669  return emitSilenceableError()
670  << "more than one result user with single user requested";
671  }
672  results.set(cast<OpResult>(getResult()), {firstUser});
674  }
675 
676  return emitDefiniteFailure() << "unknown sub-predicate";
677 }
678 
680 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
681  int64_t &position) {
682  auto rawPosition = static_cast<int64_t>(getPosition());
683  position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
684  if (position >= op.getNumDpsInits() || position < 0) {
685  return emitSilenceableError()
686  << "position " << rawPosition
687  << " overflows the number of results(ints) of the payload operation";
688  }
690 }
691 
693  if ((getAny() || getSingle()) ^
694  isa<TransformHandleTypeInterface>(getResult().getType())) {
695  return emitOpError() << "expects either the any/single keyword or the type "
696  "value handle result type";
697  }
698  if (getAny() && getSingle()) {
699  return emitOpError() << "'any' and 'single' are mutually exclusive";
700  }
701  return success();
702 }
703 
704 //===----------------------------------------------------------------------===//
705 // MatchStructuredYieldOp
706 //===----------------------------------------------------------------------===//
707 
708 void transform::MatchStructuredYieldOp::getEffects(
710  onlyReadsHandle(getHandlesMutable(), effects);
711  onlyReadsPayload(effects);
712 }
713 
714 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
715  OperationState &state) {
716  build(builder, state, ValueRange());
717 }
718 
719 #define GET_OP_CLASSES
720 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
#define DBGS()
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:152
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:216
This class represents an operand of an operation.
Definition: Value.h:267
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:669
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:682
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:127
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:132
Type getType() const
Return the type of this value.
Definition: Value.h:129
user_range getUsers() const
Definition: Value.h:228
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:169
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This represents an operation in an abstracted form, suitable for use with the builder APIs.