MLIR  21.0.0git
LinalgMatchOps.cpp
Go to the documentation of this file.
1 //===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/FormatVariadic.h"
21 #include "llvm/Support/InterleavedRange.h"
22 
23 using namespace mlir;
24 
25 #define DEBUG_TYPE "linalg-transforms"
26 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27 
28 //===----------------------------------------------------------------------===//
29 // StructuredMatchOp
30 //===----------------------------------------------------------------------===//
31 
32 DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
33  Operation *current, transform::TransformResults &results,
35  // First, check if the payload operation is a structured Linalg operation.
36  if (!isa<linalg::LinalgOp>(current)) {
37  if (getFailurePropagationMode().value_or(
38  FailurePropagationMode::Propagate) ==
39  FailurePropagationMode::Propagate) {
40  return emitSilenceableError() << "expected a Linalg op";
41  }
42  // If errors are suppressed, succeed and set all results to empty lists.
43  LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
44  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
46  }
47 
48  // Bind `current` to the block argument.
49  auto scope = state.make_region_scope(getBodyRegion());
50  if (failed(state.mapBlockArgument(getBody()->getArgument(0),
51  MappedValue(current)))) {
53  }
54 
55  for (Operation &nested : getBody()->without_terminator()) {
57  state.applyTransform(cast<TransformOpInterface>(nested));
58  if (diag.isDefiniteFailure())
59  return diag;
60  if (diag.succeeded())
61  continue;
62 
63  // If propagating errors, do this immediately.
64  assert(diag.isSilenceableFailure());
65  if (getFailurePropagationMode().value_or(
66  FailurePropagationMode::Propagate) ==
67  FailurePropagationMode::Propagate) {
68  return diag;
69  }
70 
71  // If suppressing errors, print the message into the debug stream before
72  // silencing it. Then set all results value that are already known.
73  // Results come from the terminator operands, which may be defined in the
74  // (single) block of this operation or above it. When they are defined
75  // above, they are known to be mapped at this point per SSA dominance.
76  // When they are defined in this block, we additionally check if we have
77  // already applied the operation that defines them. If not, the
78  // corresponding results will be set to empty lists.
79  LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
80  << "\n");
81  (void)diag.silence();
82  SmallVector<OpOperand *> undefinedOperands;
83  for (OpOperand &terminatorOperand :
84  getBody()->getTerminator()->getOpOperands()) {
85  Operation *definingOp = terminatorOperand.get().getDefiningOp();
86  if (!definingOp)
87  continue;
88  if (definingOp->getBlock() != getBody())
89  continue;
90  if (definingOp->isBeforeInBlock(&nested))
91  continue;
92 
93  undefinedOperands.push_back(&terminatorOperand);
94  }
95 
97  auto filtered = llvm::make_filter_range(
98  getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
99  return !llvm::is_contained(undefinedOperands, &opOperand);
100  });
101  SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
102  filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
103  detail::prepareValueMappings(mappings, definedOperands, state);
104  for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
105  results.setMappedValues(getResults()[operand.getOperandNumber()],
106  mapping);
107  }
108  results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
110  }
111 
112  // Set the results.
113  detail::forwardTerminatorOperands(getBody(), state, results);
115 }
116 
117 void transform::MatchStructuredOp::getEffects(
119  onlyReadsHandle(getCurrentMutable(), effects);
120  onlyReadsPayload(effects);
121  producesHandle(getOperation()->getOpResults(), effects);
122 }
123 
124 LogicalResult transform::MatchStructuredOp::verify() {
125  if (getBody()->getNumArguments() != 1)
126  return emitOpError() << "expected one body argument";
127  if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
128  return emitOpError() << "expected body argument to implement "
129  "TransformHandleTypeInterface";
130  }
131  for (Operation &nested : getBody()->without_terminator()) {
132  if (isa<MatchOpInterface>(nested))
133  continue;
135  emitOpError()
136  << "expects nested operations to implement MatchOpInterface";
137  diag.attachNote(nested.getLoc()) << "offending operation";
138  return diag;
139  }
140  return success();
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // StructuredOpPredicateOpTrait
145 //===----------------------------------------------------------------------===//
146 
148  Operation *op, Value structuredOpHandle) {
149  if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
150  return op->emitOpError() << "expects parent op to be '"
151  << MatchStructuredOp::getOperationName() << "'";
152  }
153 
154  // Bail out here, let the verifier of the parent complain.
155  Operation *parent = op->getParentOp();
156  if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
157  parent->getRegion(0).front().getNumArguments() < 1)
158  return success();
159 
160  if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
161  return op->emitOpError()
162  << "expected predicate to apply to the surrounding structured op";
163  }
164  return success();
165 }
166 
167 //===----------------------------------------------------------------------===//
168 // MatchStructuredBodyOp
169 //===----------------------------------------------------------------------===//
170 
171 DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
172  Operation *current, transform::TransformResults &results,
173  transform::TransformState &state) {
174  auto linalgOp = cast<linalg::LinalgOp>(current);
175  if (std::optional<uint64_t> position = getReductionPosition()) {
176  SmallVector<Operation *> combinerOps;
177  if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
178  combinerOps)) {
179  return emitSilenceableError() << "could not match reduction";
180  }
181  if (combinerOps.size() != 1) {
182  return emitSilenceableError() << "reduction combiner is not a single op";
183  }
185  }
186  if (getPassthrough()) {
187  Block &body = linalgOp->getRegion(0).front();
188  if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
189  return emitSilenceableError() << "not a passthrough";
190  }
192  }
193  if (getElementwise()) {
194  if (!isElementwise(linalgOp))
195  return emitSilenceableError() << "not elementwise";
197  }
198  if (std::optional<ArrayAttr> contractionOps = getContraction()) {
199  Block &body = linalgOp->getRegion(0).front();
200  std::string message;
201  llvm::raw_string_ostream os(message);
202  bool result = linalg::detail::isContractionBody(
203  body,
204  [&](Operation *elem, Operation *red) {
205  return elem->getName().getStringRef() ==
206  cast<StringAttr>((*contractionOps)[0]).getValue() &&
207  red->getName().getStringRef() ==
208  cast<StringAttr>((*contractionOps)[1]).getValue();
209  },
210  os);
211  if (result)
213  return emitSilenceableError() << "contraction: " << message;
214  }
215  return emitDefiniteFailure() << "unknown body condition";
216 }
217 
219  int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
220  getElementwise() + getContraction().has_value();
221 
222  if (numOptions > 1) {
223  StringAttr attributeNames[] = {
224  getReductionPositionAttrName(), getPassthroughAttrName(),
225  getElementwiseAttrName(), getContractionAttrName()};
226  return emitOpError() << "only one of {" << llvm::interleaved(attributeNames)
227  << "} is allowed";
228  }
229 
230  if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
231  if (contractionAttr->size() != 2) {
232  return emitOpError() << "expects " << getContractionAttrName()
233  << " to contain two elements";
234  }
235  }
236  return success();
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // MatchStructuredClassifyContractionDimsOp
241 //===----------------------------------------------------------------------===//
242 
244 transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
245  Operation *current, transform::TransformResults &results,
246  transform::TransformState &state) {
247  FailureOr<linalg::ContractionDimensions> contractionDims =
248  linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
249  if (failed(contractionDims))
250  return emitSilenceableError() << "could not infer contraction dimensions";
251 
252  MLIRContext *context = current->getContext();
253  Builder builder(context);
254  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
255  return llvm::to_vector(
256  llvm::map_range(values, [&](unsigned value) -> Attribute {
257  return builder.getI64IntegerAttr(value);
258  }));
259  };
260  results.setParams(cast<OpResult>(getBatch()),
261  makeI64Attrs(contractionDims->batch));
262  results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
263  results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
264  results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // MatchStructuredClassifyConvolutionDimsOp
270 //===----------------------------------------------------------------------===//
271 
273 transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
274  Operation *current, transform::TransformResults &results,
275  transform::TransformState &state) {
276  FailureOr<linalg::ConvolutionDimensions> convolutionDims =
277  linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
278  if (failed(convolutionDims))
279  return emitSilenceableError() << "could not infer convolution dimensions";
280 
281  MLIRContext *context = current->getContext();
282  Builder builder(context);
283  auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
284  return llvm::to_vector(
285  llvm::map_range(values, [&](unsigned value) -> Attribute {
286  return builder.getI64IntegerAttr(value);
287  }));
288  };
289  results.setParams(cast<OpResult>(getBatch()),
290  makeI64Attrs(convolutionDims->batch));
291  results.setParams(cast<OpResult>(getOutputImage()),
292  makeI64Attrs(convolutionDims->outputImage));
293  results.setParams(cast<OpResult>(getOutputChannel()),
294  makeI64Attrs(convolutionDims->outputChannel));
295  results.setParams(cast<OpResult>(getFilterLoop()),
296  makeI64Attrs(convolutionDims->filterLoop));
297  results.setParams(cast<OpResult>(getInputChannel()),
298  makeI64Attrs(convolutionDims->inputChannel));
299  results.setParams(cast<OpResult>(getDepth()),
300  makeI64Attrs(convolutionDims->depth));
301 
302  auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
303  return llvm::to_vector(
304  llvm::map_range(values, [&](int64_t value) -> Attribute {
305  return builder.getI64IntegerAttr(value);
306  }));
307  };
308  results.setParams(cast<OpResult>(getStrides()),
309  makeI64AttrsFromI64(convolutionDims->strides));
310  results.setParams(cast<OpResult>(getDilations()),
311  makeI64AttrsFromI64(convolutionDims->dilations));
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // Utilities for structured match predicates.
317 //===----------------------------------------------------------------------===//
318 
319 /// Checks if all values from `list` are also contained in `reference`. Returns
320 /// a silenceable error with the given message at the given location when it is
321 /// not the case. The error message must contain the "{0}" placeholder that
322 /// will be substituted with the value from `list` that is not contained in
323 /// `reference`.
325  ArrayRef<int64_t> list,
326  Location loc,
327  const char *message) {
328  for (int64_t value : list) {
329  if (llvm::any_of(reference, [&](unsigned ref) {
330  return static_cast<int64_t>(ref) == value;
331  })) {
332  continue;
333  }
334  return emitSilenceableFailure(loc) << llvm::formatv(message, value);
335  }
337 }
338 
339 //===----------------------------------------------------------------------===//
340 // MatchStructuredDimOp
341 //===----------------------------------------------------------------------===//
342 
343 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
344  Operation *current, transform::TransformResults &results,
345  transform::TransformState &state) {
346  auto linalgOp = cast<linalg::LinalgOp>(current);
347  SmallVector<int64_t> dimensions;
348  DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
349  if (!diag.succeeded())
350  return diag;
351 
352  // If asked to check for the kind of dimension, perform the check.
353  if (getParallel() || getReduction()) {
354  SmallVector<unsigned> reference;
355  if (getParallel())
356  linalgOp.getParallelDims(reference);
357  else if (getReduction())
358  linalgOp.getReductionDims(reference);
359 
361  containsAll(reference, dimensions, getLoc(),
362  getParallel() ? "expects dimension #{0} to be parallel"
363  : "expects dimension #{0} to be reduction");
364  if (!diag.succeeded())
365  return diag;
366  }
367 
368  // If not capturing, we are done here.
369  if (!getResult())
370  return diag;
371 
372  SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
373  Builder builder(current);
374  SmallVector<Attribute> captured = llvm::to_vector(
375  llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
376  return builder.getI64IntegerAttr(ranges[dim]);
377  }));
378  results.setParams(cast<OpResult>(getResult()), captured);
380 }
381 
382 DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
383  linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
385  expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
386  getRawDimList(), op.getNumLoops(), dims);
387  if (diag.isSilenceableFailure()) {
388  diag.attachNote(op->getLoc())
389  << "while considering dimensions of this payload operation";
390  }
391  return diag;
392 }
393 
395  if (getParallel() && getReduction()) {
396  return emitOpError() << "cannot request the same dimension to be both "
397  "parallel and reduction";
398  }
399  return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
400  getIsInverted(), getIsAll());
401 }
402 
403 //===----------------------------------------------------------------------===//
404 // MatchStructuredElementalBitwidthOp
405 //===----------------------------------------------------------------------===//
406 
408 transform::MatchStructuredElementalBitwidthOp::matchValue(
409  Value current, transform::TransformResults &results,
410  transform::TransformState &state) {
411  auto setupResult = [&](int64_t bitwidth) {
412  Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
413  results.setParams(cast<OpResult>(getResult()), {attr});
415  };
416 
417  Type type = current.getType();
418  if (type.isIntOrFloat())
419  return setupResult(type.getIntOrFloatBitWidth());
420 
421  if (auto shapedType = dyn_cast<ShapedType>(type)) {
422  if (shapedType.getElementType().isIntOrFloat())
423  return setupResult(shapedType.getElementTypeBitWidth());
424  }
425  return emitSilenceableError()
426  << "unsupported type for bitwidth extraction: " << type;
427 }
428 
429 //===----------------------------------------------------------------------===//
430 // MatchStructuredInputOp
431 //===----------------------------------------------------------------------===//
432 
433 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
434  Operation *current, transform::TransformResults &results,
435  transform::TransformState &state) {
436  auto linalgOp = cast<linalg::LinalgOp>(current);
437  SmallVector<int64_t> positions;
438  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
439  if (!diag.succeeded())
440  return diag;
441 
442  SmallVector<MappedValue> operandMapping;
443  operandMapping.reserve(positions.size());
444  for (int64_t position : positions) {
445  AffineMap indexingMap =
446  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
447  if (getPermutation() && !indexingMap.isPermutation()) {
448  return emitSilenceableError() << "the indexing map for input #"
449  << position << " is not a permutation";
450  }
451  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
452  return emitSilenceableError()
453  << "the indexing map for input #" << position
454  << " is not a projected permutation";
455  }
456 
457  // If capture not requested, skip it.
458  if (!getResult())
459  continue;
460 
461  if (isa<AffineMapParamType>(getResult().getType())) {
462  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
463  continue;
464  }
465 
466  Value operand = linalgOp.getDpsInputOperand(position)->get();
467  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
468  operandMapping.emplace_back(operand);
469  continue;
470  }
471 
472  Operation *operandProducer = operand.getDefiningOp();
473  if (!operandProducer) {
474  return emitSilenceableError()
475  << "input #" << position << " is not produced by an operation";
476  }
477  operandMapping.emplace_back(operandProducer);
478  }
479  if (getResult())
480  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
482 }
483 
484 DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
485  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
487  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
488  op.getNumDpsInputs(), positions);
489  if (diag.isSilenceableFailure()) {
490  diag.attachNote(op->getLoc())
491  << "while considering DPS inputs of this payload operation";
492  }
493  return diag;
494 }
495 
496 /// Verifies a matcher op for structured input or output, specifically the
497 /// attributes specifying the operand positions.
498 template <typename OpTy>
499 LogicalResult verifyStructuredOperandOp(OpTy op) {
500  if (op.getPermutation() && op.getProjectedPermutation()) {
501  return op.emitOpError()
502  << op.getPermutationAttrName() << " and "
503  << op.getProjectedPermutationAttrName() << " are mutually exclusive";
504  }
505  if (op.getRawPositionList().size() > 1 && op.getResult()) {
506  return op.emitOpError()
507  << "cannot bind multiple inputs/inits to the same value";
508  }
509 
510  return success();
511 }
512 
514  if (failed(verifyStructuredOperandOp(*this)))
515  return failure();
516  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
517  getIsInverted(), getIsAll());
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // MatchStructuredInitOp
522 //===----------------------------------------------------------------------===//
523 
524 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
525  Operation *current, transform::TransformResults &results,
526  transform::TransformState &state) {
527  auto linalgOp = cast<linalg::LinalgOp>(current);
528  SmallVector<int64_t> positions;
529  DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
530  if (!diag.succeeded())
531  return diag;
532 
533  SmallVector<MappedValue> operandMapping;
534  operandMapping.reserve(positions.size());
535  for (int64_t position : positions) {
536  AffineMap indexingMap =
537  linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
538  if (getPermutation() && !indexingMap.isPermutation()) {
539  return emitSilenceableError() << "the indexing map for output(init) #"
540  << position << " is not a permutation";
541  }
542  if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
543  return emitSilenceableError() << "the indexing map for output(init) #"
544  << position << " is not a permutation";
545  }
546 
547  // If capture not requested, skip it.
548  if (!getResult())
549  continue;
550 
551  if (isa<AffineMapParamType>(getResult().getType())) {
552  operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
553  continue;
554  }
555 
556  Value operand = linalgOp.getDpsInitOperand(position)->get();
557  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
558  operandMapping.emplace_back(operand);
559  continue;
560  }
561 
562  Operation *operandProducer = operand.getDefiningOp();
563  if (!operandProducer) {
564  return emitSilenceableError() << "output(init) #" << position
565  << " is not produced by an operation";
566  }
567  operandMapping.emplace_back(operandProducer);
568  }
569  if (getResult())
570  results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
572 }
573 
574 DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
575  linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
577  getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
578  op.getNumDpsInits(), positions);
579  if (diag.isSilenceableFailure()) {
580  diag.attachNote(op->getLoc())
581  << "while considering DPS inits (outputs) of this payload operation";
582  }
583  return diag;
584 }
585 
587  if (failed(verifyStructuredOperandOp(*this)))
588  return failure();
589  return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
590  getIsInverted(), getIsAll());
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // MatchStructuredNumInputsOp
595 //===----------------------------------------------------------------------===//
596 
598 transform::MatchStructuredNumInputsOp::matchOperation(
599  Operation *current, transform::TransformResults &results,
600  transform::TransformState &state) {
601  auto linalgOp = cast<linalg::LinalgOp>(current);
602  Attribute attr =
603  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
604  results.setParams(cast<OpResult>(getResult()), {attr});
606 }
607 
608 //===----------------------------------------------------------------------===//
609 // MatchStructuredNumInitsOp
610 //===----------------------------------------------------------------------===//
611 
613 transform::MatchStructuredNumInitsOp::matchOperation(
614  Operation *current, transform::TransformResults &results,
615  transform::TransformState &state) {
616  auto linalgOp = cast<linalg::LinalgOp>(current);
617  Attribute attr =
618  Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
619  results.setParams(cast<OpResult>(getResult()), {attr});
621 }
622 
623 //===----------------------------------------------------------------------===//
624 // MatchStructuredRankOp
625 //===----------------------------------------------------------------------===//
626 
627 DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
628  Operation *current, transform::TransformResults &results,
629  transform::TransformState &state) {
630  auto linalgOp = cast<linalg::LinalgOp>(current);
631  int64_t numLoops = linalgOp.getNumLoops();
632  Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
633  results.setParams(cast<OpResult>(getRank()), {attr});
635 }
636 
637 //===----------------------------------------------------------------------===//
638 // MatchStructuredResultOp
639 //===----------------------------------------------------------------------===//
640 
641 DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
643  transform::TransformState &state) {
644  auto linalgOp = cast<linalg::LinalgOp>(op);
645  int64_t position;
646  DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
647  if (!diag.succeeded())
648  return diag;
649 
650  Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
651  if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
652  results.setValues(cast<OpResult>(getResult()), {result});
654  }
655 
656  if (result.getUsers().empty()) {
657  return emitSilenceableError()
658  << "no users of the result #" << getPosition();
659  }
660  Operation *firstUser = *result.getUsers().begin();
661  if (getAny()) {
662  results.set(cast<OpResult>(getResult()), {firstUser});
664  }
665  if (getSingle()) {
666  if (!llvm::hasSingleElement(result.getUsers())) {
667  return emitSilenceableError()
668  << "more than one result user with single user requested";
669  }
670  results.set(cast<OpResult>(getResult()), {firstUser});
672  }
673 
674  return emitDefiniteFailure() << "unknown sub-predicate";
675 }
676 
678 transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
679  int64_t &position) {
680  auto rawPosition = static_cast<int64_t>(getPosition());
681  position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
682  if (position >= op.getNumDpsInits() || position < 0) {
683  return emitSilenceableError()
684  << "position " << rawPosition
685  << " overflows the number of results(ints) of the payload operation";
686  }
688 }
689 
691  if ((getAny() || getSingle()) ^
692  isa<TransformHandleTypeInterface>(getResult().getType())) {
693  return emitOpError() << "expects either the any/single keyword or the type "
694  "value handle result type";
695  }
696  if (getAny() && getSingle()) {
697  return emitOpError() << "'any' and 'single' are mutually exclusive";
698  }
699  return success();
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // MatchStructuredYieldOp
704 //===----------------------------------------------------------------------===//
705 
706 void transform::MatchStructuredYieldOp::getEffects(
708  onlyReadsHandle(getHandlesMutable(), effects);
709  onlyReadsPayload(effects);
710 }
711 
712 void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
713  OperationState &state) {
714  build(builder, state, ValueRange());
715 }
716 
717 #define GET_OP_CLASSES
718 #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"
static DiagnosedSilenceableFailure containsAll(ArrayRef< unsigned > reference, ArrayRef< int64_t > list, Location loc, const char *message)
Checks if all values from list are also contained in reference.
LogicalResult verifyStructuredOperandOp(OpTy op)
Verifies a matcher op for structured input or output, specifically the attributes specifying the oper...
#define DBGS()
static std::string diag(const llvm::Value &value)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
bool isPermutation() const
Returns true if the AffineMap represents a symbol-less permutation map.
Definition: AffineMap.cpp:648
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This class helps build Operations.
Definition: Builders.h:205
This class represents an operand of an operation.
Definition: Value.h:243
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool isBeforeInBlock(Operation *other)
Given an operation 'other' that is within the same parent block, return whether the current operation...
Definition: Operation.cpp:386
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:674
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition: Operation.h:687
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition: Value.h:108
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:204
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
void setValues(OpResult handle, Range &&values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
void setParams(OpResult value, ArrayRef< TransformState::Param > params)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void set(OpResult value, Range &&ops)
Indicates that the result of the transform IR op at the given position corresponds to the given list ...
void setRemainingToEmpty(TransformOpInterface transform)
Sets the currently unset results to empty lists of the kind expected by the corresponding results of ...
void setMappedValues(OpResult handle, ArrayRef< MappedValue > values)
Indicates that the result of the transform IR op at the given position corresponds to the given range...
The state maintained across applications of various ops implementing the TransformOpInterface.
bool isContractionBody(Block &block, function_ref< bool(Operation *, Operation *)> isaPair, llvm::raw_ostream &errs=mlir::thread_safe_nulls())
Returns true if the block contains a contraction of the following form:
FailureOr< ConvolutionDimensions > inferConvolutionDims(LinalgOp linalgOp)
Find at least 1 parallel (output_image) and reduction (filter_loop) dimension candidates that form a ...
bool isElementwise(LinalgOp op)
Check if a LinalgOp is an element-wise operation.
Definition: Utils.cpp:223
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
uint64_t getN(LevelType lt)
Definition: Enums.h:442
uint64_t getM(LevelType lt)
Definition: Enums.h:443
void forwardTerminatorOperands(Block *block, transform::TransformState &state, transform::TransformResults &results)
Populates results with payload associations that match exactly those of the operands to block's termi...
LogicalResult verifyStructuredOpPredicateOpTrait(Operation *op, Value structuredOpHandle)
void prepareValueMappings(SmallVectorImpl< SmallVector< transform::MappedValue >> &mappings, ValueRange values, const transform::TransformState &state)
Populates mappings with mapped values associated with the given transform IR values in the given stat...
LogicalResult verifyTransformMatchDimsOp(Operation *op, ArrayRef< int64_t > raw, bool inverted, bool all)
Checks if the positional specification defined is valid and reports errors otherwise.
void onlyReadsPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
DiagnosedSilenceableFailure expandTargetSpecification(Location loc, bool isAll, bool isInverted, ArrayRef< int64_t > rawList, int64_t maxNumber, SmallVectorImpl< int64_t > &result)
Populates result with the positional identifiers relative to maxNumber.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void onlyReadsHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
llvm::PointerUnion< Operation *, Param, Value > MappedValue
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424
This represents an operation in an abstracted form, suitable for use with the builder APIs.