MLIR  20.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/PatternMatch.h"
37 
38 #include "llvm/ADT/DenseMap.h"
39 #include "llvm/ADT/SmallSet.h"
40 #include "llvm/ADT/StringSet.h"
41 #include "llvm/ADT/TypeSwitch.h"
42 #include "llvm/Support/FormatVariadic.h"
43 #include "llvm/Support/MathExtras.h"
44 #include "llvm/Support/raw_ostream.h"
45 #include <optional>
46 
47 using namespace mlir;
48 using namespace mlir::linalg;
49 
50 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
52  int64_t dim) {
53  auto type = cast<ShapedType>(v.getType());
54  if (!type.isDynamicDim(dim))
55  return builder.getIndexAttr(type.getDimSize(dim));
56 
57  return getAsOpFoldResult(
59  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
60  return builder.create<tensor::DimOp>(loc, v, dim);
61  })
62  .Case<MemRefType>([&](MemRefType t) -> Value {
63  return builder.create<memref::DimOp>(loc, v, dim);
64  }));
65 }
66 
67 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
68 /// `source`.
69 static Value getSlice(OpBuilder &b, Location loc, Value source,
70  ArrayRef<OpFoldResult> offsets,
72  ArrayRef<OpFoldResult> strides) {
73  return TypeSwitch<Type, Value>(source.getType())
74  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
75  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
76  strides);
77  })
78  .Case<MemRefType>([&](MemRefType type) -> Value {
79  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
80  strides);
81  })
82  .Default([&](Type t) { return nullptr; });
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Helper functions
87 //===----------------------------------------------------------------------===//
88 
90  int64_t dim) {
91  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
92  return b.createOrFold<memref::DimOp>(loc, source, dim);
93  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
94  return b.createOrFold<tensor::DimOp>(loc, source, dim);
95  llvm_unreachable("Expected MemRefType or TensorType");
96 }
97 
99  int64_t dim) {
100  auto shapedType = llvm::cast<ShapedType>(source.getType());
101  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
102  return createOrFoldDimOp(b, loc, source, dim);
103  return b.getIndexAttr(shapedType.getDimSize(dim));
104 }
105 
106 //===----------------------------------------------------------------------===//
107 // Support for named Linalg ops defined in ods-gen.
108 //===----------------------------------------------------------------------===//
109 
112 
113 /// Fills the region of a structured operation using the provided
114 /// `regionBuilder`. The method is used by both named structured ops created by
115 /// ods-gen and by manually defined C++ ops. It is called by both builders and
116 /// parsers and creates a block with arguments corresponding to the elemental
117 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
118 /// ShapedType.
119 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
120  TypeRange inputTypes, TypeRange outputTypes,
122  RegionBuilderFn regionBuilder) {
123  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
124 
125  SmallVector<Type, 8> argTypes;
126  SmallVector<Location, 8> argLocs;
127  for (auto containers : {inputTypes, outputTypes}) {
128  for (auto t : containers) {
129  argTypes.push_back(
130  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
131 
132  // TODO: Pass in a proper location here.
133  argLocs.push_back(opBuilder.getUnknownLoc());
134  }
135  }
136 
137  // RAII.
138  OpBuilder::InsertionGuard guard(opBuilder);
139  Block *body =
140  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
141 
142  opBuilder.setInsertionPointToStart(body);
143  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
144  regionBuilder(b, *body, attrs);
145 
146  // indexing_maps is an auto-generated method.
147 
148  // iterator_types is an auto-generated method.
149 }
150 
151 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
152 /// The result types are derived automatically if `resultTensorTypes` is none.
153 /// The body of the operation is filled using `regionBuilder`. All ods-gen
154 /// created structured operations use the method to implement their builders.
156  std::optional<TypeRange> resultTensorTypes,
157  ValueRange inputs, ValueRange outputs,
158  ArrayRef<NamedAttribute> attributes,
159  RegionBuilderFn regionBuilder) {
160  // Derive the result types if needed.
161  SmallVector<Type> derivedResultTypes =
162  resultTensorTypes.value_or(TypeRange());
163  if (!resultTensorTypes)
164  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
165  llvm::IsaPred<RankedTensorType>);
166 
167  state.addOperands(inputs);
168  state.addOperands(outputs);
169  state.addTypes(derivedResultTypes);
170  state.addAttributes(attributes);
171  state.addAttribute(
172  "operandSegmentSizes",
173  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
174  static_cast<int32_t>(outputs.size())}));
175 
176  // Create and fill the region of the structured operation.
177  Region &region = *state.addRegion();
178  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
179  state.attributes.getAttrs(), regionBuilder);
180 }
181 
182 /// Common parsing used for both named structured ops created by ods-gen and by
183 /// manually defined C++ ops. Does not handle regions.
184 static ParseResult
186  SmallVectorImpl<Type> &inputTypes,
187  SmallVectorImpl<Type> &outputTypes,
188  bool addOperandSegmentSizes = true) {
189  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
191  outputsOperands;
192 
193  if (succeeded(parser.parseOptionalLess())) {
194  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
195  return failure();
196  }
197  attrsLoc = parser.getCurrentLocation();
198  if (parser.parseOptionalAttrDict(result.attributes))
199  return failure();
200 
201  if (succeeded(parser.parseOptionalKeyword("ins"))) {
202  if (parser.parseLParen())
203  return failure();
204 
205  inputsOperandsLoc = parser.getCurrentLocation();
206  if (parser.parseOperandList(inputsOperands) ||
207  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
208  return failure();
209  }
210 
211  if (succeeded(parser.parseOptionalKeyword("outs"))) {
212  outputsOperandsLoc = parser.getCurrentLocation();
213  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
214  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
215  return failure();
216  }
217 
218  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
219  result.operands) ||
220  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
221  result.operands))
222  return failure();
223 
224  if (addOperandSegmentSizes) {
225  // This is a bit complex because we're trying to be backward compatible with
226  // operation syntax that mix the inherent attributes and the discardable
227  // ones in the same dictionary. If the properties are used, we append the
228  // operandSegmentSizes there directly. Otherwise we append it to the
229  // discardable attributes dictionary where it is handled by the generic
230  // Operation::create(...) method.
231  if (result.propertiesAttr) {
232  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
233  attrs.append("operandSegmentSizes",
235  {static_cast<int32_t>(inputsOperands.size()),
236  static_cast<int32_t>(outputsOperands.size())}));
237  result.propertiesAttr = attrs.getDictionary(parser.getContext());
238  } else {
239  result.addAttribute("operandSegmentSizes",
241  {static_cast<int32_t>(inputsOperands.size()),
242  static_cast<int32_t>(outputsOperands.size())}));
243  }
244  }
245  if (!result.propertiesAttr) {
246  std::optional<RegisteredOperationName> info =
247  result.name.getRegisteredInfo();
248  if (info) {
249  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
250  return parser.emitError(attrsLoc)
251  << "'" << result.name.getStringRef() << "' op ";
252  })))
253  return failure();
254  }
255  }
256  return success();
257 }
258 
260  ValueRange outputs) {
261  if (!inputs.empty())
262  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
263  if (!outputs.empty())
264  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
265 }
266 
267 //===----------------------------------------------------------------------===//
268 // Specific parsing and printing for named structured ops created by ods-gen.
269 //===----------------------------------------------------------------------===//
270 
271 static ParseResult parseNamedStructuredOpRegion(
272  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
273  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
274  RegionBuilderFn regionBuilder) {
275  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
276  return parser.emitError(
277  parser.getCurrentLocation(),
278  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
279  "region expects {0} args, got {1}",
280  numRegionArgs, inputTypes.size() + outputTypes.size()));
281  }
282 
283  OpBuilder opBuilder(parser.getContext());
284  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
285  regionBuilder);
286  return success();
287 }
288 
289 static ParseResult
291  SmallVectorImpl<Type> &resultTypes) {
292  if (parser.parseOptionalArrowTypeList(resultTypes))
293  return failure();
294  return success();
295 }
296 
297 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
298  OperationState &result,
299  unsigned numRegionArgs,
300  RegionBuilderFn regionBuilder) {
301  // TODO: Enable when ods-gen supports captures.
302  SmallVector<Type, 1> inputTypes, outputTypes;
303  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
304  return failure();
305 
306  // TODO: consider merging results parsing into region parsing.
307  // Need to wait for declarative assembly resolution to decide.
308  SmallVector<Type, 1> outputTensorsTypes;
309  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
310  return failure();
311  result.addTypes(outputTensorsTypes);
312 
313  std::unique_ptr<Region> region = std::make_unique<Region>();
314  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
315  outputTypes, result.attributes.getAttrs(),
316  regionBuilder))
317  return failure();
318  result.addRegion(std::move(region));
319 
320  return success();
321 }
322 
324  TypeRange resultTypes) {
325  if (resultTypes.empty())
326  return;
327  p.printOptionalArrowTypeList(resultTypes);
328 }
329 
331  ValueRange inputs, ValueRange outputs) {
333  op->getAttrs(),
334  /*elidedAttrs=*/{"operandSegmentSizes",
335  // See generated code in
336  // LinalgNamedStructuredOps.yamlgen.cpp.inc
337  "linalg.memoized_indexing_maps"});
338 
339  // Printing is shared with generic ops, except for the region and
340  // attributes.
341  printCommonStructuredOpParts(p, inputs, outputs);
342 
343  // Results printing.
345 
346  // Region is elided.
347 }
348 
349 //===----------------------------------------------------------------------===//
350 // Region builder helper.
351 // TODO: Move this to a utility library.
352 // The public methods on this class are referenced directly from generated code.
353 // Helper build the unary, binary, and type conversion functions defined by the
354 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
355 // class.
356 //
357 // Implementations of the math functions must be polymorphic over numeric types,
358 // internally performing necessary casts. If the function application makes no
359 // sense, then the only recourse is to assert and return nullptr. This can be
360 // extended later if it becomes possible to fail construction of the region. The
361 // invariant should be enforced at a higher level.
362 //
363 // TODO: These helpers are currently type polymorphic over the class of integer
364 // and floating point types, but they will not internally cast within bit
365 // widths of a class (mixed precision such as i8->i32) or across classes
366 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
367 // to be handled with care and work is being considered to extend the op
368 // language to make such cases explicit. In the mean-time, violating this will
369 // fail verification, which is deemed acceptable.
370 //===----------------------------------------------------------------------===//
371 
372 namespace {
373 
374 class RegionBuilderHelper {
375 public:
376  RegionBuilderHelper(OpBuilder &builder, Block &block)
377  : builder(builder), block(block) {}
378 
379  // Build the unary functions defined by OpDSL.
380  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
381  if (!isFloatingPoint(arg))
382  llvm_unreachable("unsupported non numeric type");
383  OpBuilder::InsertionGuard g(builder);
384  builder.setInsertionPointToEnd(&block);
385  switch (unaryFn) {
386  case UnaryFn::exp:
387  return builder.create<math::ExpOp>(arg.getLoc(), arg);
388  case UnaryFn::log:
389  return builder.create<math::LogOp>(arg.getLoc(), arg);
390  case UnaryFn::abs:
391  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
392  case UnaryFn::ceil:
393  return builder.create<math::CeilOp>(arg.getLoc(), arg);
394  case UnaryFn::floor:
395  return builder.create<math::FloorOp>(arg.getLoc(), arg);
396  case UnaryFn::negf:
397  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
398  case UnaryFn::reciprocal: {
399  Attribute oneAttr = builder.getOneAttr(arg.getType());
400  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
401  ::cast<TypedAttr>(oneAttr));
402  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
403  }
404  case UnaryFn::round:
405  return builder.create<math::RoundOp>(arg.getLoc(), arg);
406  case UnaryFn::sqrt:
407  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
408  case UnaryFn::rsqrt:
409  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
410  case UnaryFn::square:
411  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
412  case UnaryFn::tanh:
413  return builder.create<math::TanhOp>(arg.getLoc(), arg);
414  case UnaryFn::erf:
415  return builder.create<math::ErfOp>(arg.getLoc(), arg);
416  }
417  llvm_unreachable("unsupported unary function");
418  }
419 
420  // Build the binary functions defined by OpDSL.
421  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
422  bool allComplex = isComplex(arg0) && isComplex(arg1);
423  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
424  bool allInteger = isInteger(arg0) && isInteger(arg1);
425  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
426  arg1.getType().getIntOrFloatBitWidth() == 1;
427  if (!allComplex && !allFloatingPoint && !allInteger)
428  llvm_unreachable("unsupported non numeric type");
429  OpBuilder::InsertionGuard g(builder);
430  builder.setInsertionPointToEnd(&block);
431  switch (binaryFn) {
432  case BinaryFn::add:
433  if (allComplex)
434  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
435  if (allFloatingPoint)
436  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
437  if (allBool)
438  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
439  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
440  case BinaryFn::sub:
441  if (allComplex)
442  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
443  if (allFloatingPoint)
444  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
445  if (allBool)
446  llvm_unreachable("unsupported operation: sub with bools");
447  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
448  case BinaryFn::mul:
449  if (allComplex)
450  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
451  if (allFloatingPoint)
452  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
453  if (allBool)
454  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
455  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
456  case BinaryFn::div:
457  if (allComplex)
458  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
459  if (allFloatingPoint)
460  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
461  if (allBool)
462  llvm_unreachable("unsupported operation: div with bools");
463  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
464  case BinaryFn::div_unsigned:
465  if (!allInteger || allBool)
466  llvm_unreachable("unsupported operation: unsigned div not on uint");
467  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
468  case BinaryFn::max_signed:
469  assert(!allComplex);
470  if (allFloatingPoint)
471  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
472  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
473  case BinaryFn::min_signed:
474  assert(!allComplex);
475  if (allFloatingPoint)
476  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
477  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
478  case BinaryFn::max_unsigned:
479  assert(!allComplex);
480  if (allFloatingPoint)
481  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
482  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
483  case BinaryFn::min_unsigned:
484  assert(!allComplex);
485  if (allFloatingPoint)
486  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
487  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
488  case BinaryFn::powf:
489  assert(allFloatingPoint);
490  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
491  }
492  llvm_unreachable("unsupported binary function");
493  }
494 
495  // Build the ternary functions defined by OpDSL.
496  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
497  Value arg2) {
498  bool headBool =
499  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
500  bool tailFloatingPoint =
501  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
502  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
503  OpBuilder::InsertionGuard g(builder);
504  builder.setInsertionPointToEnd(&block);
505  switch (ternaryFn) {
506  case TernaryFn::select:
507  if (!headBool && !(tailFloatingPoint || tailInteger))
508  llvm_unreachable("unsupported non numeric type");
509  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
510  }
511  llvm_unreachable("unsupported ternary function");
512  }
513 
514  // Build the type functions defined by OpDSL.
515  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
516  switch (typeFn) {
517  case TypeFn::cast_signed:
518  return cast(toType, operand, false);
519  case TypeFn::cast_unsigned:
520  return cast(toType, operand, true);
521  }
522  llvm_unreachable("unsupported type conversion function");
523  }
524 
525  void yieldOutputs(ValueRange values) {
526  OpBuilder::InsertionGuard g(builder);
527  builder.setInsertionPointToEnd(&block);
528  Location loc = builder.getUnknownLoc();
529  builder.create<YieldOp>(loc, values);
530  }
531 
532  Value constant(const std::string &value) {
533  OpBuilder::InsertionGuard g(builder);
534  builder.setInsertionPointToEnd(&block);
535  Location loc = builder.getUnknownLoc();
536  Attribute valueAttr = parseAttribute(value, builder.getContext());
537  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
538  }
539 
540  Value index(int64_t dim) {
541  OpBuilder::InsertionGuard g(builder);
542  builder.setInsertionPointToEnd(&block);
543  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
544  }
545 
546  Type getIntegerType(unsigned width) {
547  return IntegerType::get(builder.getContext(), width);
548  }
549 
550  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
551  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
552 
553 private:
554  // Generates operations to cast the given operand to a specified type.
555  // If the cast cannot be performed, a warning will be issued and the
556  // operand returned as-is (which will presumably yield a verification
557  // issue downstream).
558  Value cast(Type toType, Value operand, bool isUnsignedCast) {
559  OpBuilder::InsertionGuard g(builder);
560  builder.setInsertionPointToEnd(&block);
561  auto loc = operand.getLoc();
562  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
563  }
564 
565  bool isComplex(Value value) {
566  return llvm::isa<ComplexType>(value.getType());
567  }
568  bool isFloatingPoint(Value value) {
569  return llvm::isa<FloatType>(value.getType());
570  }
571  bool isInteger(Value value) {
572  return llvm::isa<IntegerType>(value.getType());
573  }
574 
575  OpBuilder &builder;
576  Block &block;
577 };
578 
579 } // namespace
580 
581 //===----------------------------------------------------------------------===//
582 // CopyOp
583 //===----------------------------------------------------------------------===//
584 
585 namespace {
586 
587 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
589  LogicalResult matchAndRewrite(CopyOp copyOp,
590  PatternRewriter &rewriter) const override {
591  if (copyOp.getInputs() != copyOp.getOutputs())
592  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
593  if (copyOp.hasPureBufferSemantics())
594  rewriter.eraseOp(copyOp);
595  else
596  rewriter.replaceOp(copyOp, copyOp.getInputs());
597 
598  return success();
599  }
600 };
601 
602 } // namespace
603 
604 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
605  MLIRContext *context) {
606  results.add<EraseSelfCopy>(context);
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // FillOp
611 //===----------------------------------------------------------------------===//
612 
613 namespace {
614 
615 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
616 ///
617 /// For such op chains, we can create new linalg.fill ops with the result
618 /// type of the tensor.expand/collapse_shape op.
619 template <typename TensorReshapeOp>
620 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
622  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
623  PatternRewriter &rewriter) const override {
624  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
625  if (!oldFill)
626  return failure();
627 
628  Location loc = oldFill.getLoc();
629  TensorReshapeOp newInit;
630  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
631 
632  newInit = rewriter.create<TensorReshapeOp>(
633  loc, reshapeOp.getResultType(), oldFill.output(),
634  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
635  reshapeOp.getStaticOutputShape());
636  } else {
637  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
638  oldFill.output(),
639  reshapeOp.getReassociation());
640  }
641  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
642  ValueRange{newInit});
643  return success();
644  }
645 };
646 
647 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
648 /// filling value are the same.
649 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
651 
652  LogicalResult matchAndRewrite(tensor::PadOp padOp,
653  PatternRewriter &rewriter) const override {
654  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
655  if (!fillOp)
656  return failure();
657 
658  // We can only fold if the padding value is the same as the original
659  // filling value.
660  Value padValue = padOp.getConstantPaddingValue();
661  if (!padValue || fillOp.value() != padValue)
662  return failure();
663 
664  ReifiedRankedShapedTypeDims reifiedShape;
665  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
666  return rewriter.notifyMatchFailure(
667  padOp, "failed to reify tensor.pad op result shape");
668 
669  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
670  padOp.getLoc(), reifiedShape.front(),
671  padOp.getResultType().getElementType());
672  Value replacement =
673  rewriter
674  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
675  ValueRange{emptyTensor})
676  .getResult(0);
677  if (replacement.getType() != padOp.getResultType()) {
678  replacement = rewriter.create<tensor::CastOp>(
679  fillOp.getLoc(), padOp.getResultType(), replacement);
680  }
681  rewriter.replaceOp(padOp, replacement);
682  return success();
683  }
684 };
685 
686 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
687 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
688 /// filling value are the same.
689 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
691 
692  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
693  PatternRewriter &rewriter) const override {
694  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
695  if (!srcPadOp)
696  return failure();
697 
698  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
699  return failure();
700 
701  // Walk back the tensor.insert_slice chain and find the first destination
702  // value at the start of the chain.
703  Value firstDest = insertOp.getDest();
704  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
705  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
706  return failure();
707 
708  // Make sure the range of values accessed are disjoint. Without this, we
709  // cannot fold tensor.pad away.
710  bool disjoint = false;
711  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
712  // If the dimension has dynamic offset/size, we cannot guarantee
713  // disjoint. So just skip it.
714  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
715  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
716  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
717  continue;
718 
719  // Get the range start and end, inclusively for both.
720  int64_t prevStart = prevOp.getStaticOffset(i);
721  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
722  prevOp.getStaticStride(i);
723  int64_t nextStart = insertOp.getStaticOffset(i);
724  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
725  insertOp.getStaticStride(i);
726  if (prevEnd < nextStart || nextEnd < prevStart) {
727  disjoint = true;
728  break;
729  }
730  }
731 
732  if (!disjoint)
733  break;
734  firstDest = prevOp.getDest();
735  }
736 
737  // Check whether the first destination is a fill op. For overlapped cases,
738  // this also cannot be true.
739  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
740  if (!dstFillOp)
741  return failure();
742 
743  // We can only fold if the padding value is the same as the original
744  // filling value.
745  Value padValue = srcPadOp.getConstantPaddingValue();
746  if (!padValue || dstFillOp.value() != padValue)
747  return failure();
748 
749  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
750  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
751 
752  Location loc = insertOp.getLoc();
753  MLIRContext *context = getContext();
754 
755  AffineExpr sym0, sym1;
756  bindSymbols(context, sym0, sym1);
757  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
758 
759  // Calculate the new offsets for the insert. It should be the old offsets
760  // plus low padding sizes.
761  SmallVector<OpFoldResult, 4> newOffsets;
762  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
763  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
764  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
765  }
766 
767  RankedTensorType srcPadType = srcPadOp.getSourceType();
769  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
770  if (srcPadType.isDynamicDim(i)) {
771  newSizes.push_back(
772  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
773  .getResult());
774  } else {
775  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
776  }
777  }
778 
779  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
780  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
781  newSizes, insertOp.getMixedStrides());
782  return success();
783  }
784 };
785 
786 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
787 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
788 public:
790 
791  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
792  PatternRewriter &rewriter) const override {
793  // See if tensor input of tensor.extract op is the result of a linalg.fill
794  // op.
795  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
796  if (!fillOp)
797  return failure();
798 
799  // Get scalar input operand of linalg.fill op.
800  Value extractedScalar = fillOp.getInputs()[0];
801 
802  // Replace tensor.extract op with scalar value used to fill the tensor.
803  rewriter.replaceOp(extractOp, extractedScalar);
804  return success();
805  }
806 };
807 
808 /// Folds pack(fill) into a single fill op if
809 /// 1. The pack op does not have padding value, or
810 /// 2. The filled value and padding value are the same.
811 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
812  tensor::PackOp packOp) {
813  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
814  if (!fillOp)
815  return failure();
816 
817  if (auto paddingValue = packOp.getPaddingValue())
818  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
819  return failure();
820 
821  Value packOpDest = packOp.getDest();
822  if (!packOpDest.hasOneUse())
823  return failure();
824 
825  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
826  packOp.getDest());
827 }
828 
829 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
830 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
831 public:
832  FoldFillWithPack(MLIRContext *context)
833  : OpRewritePattern<tensor::PackOp>(context) {}
834 
835  LogicalResult matchAndRewrite(tensor::PackOp packOp,
836  PatternRewriter &rewriter) const override {
837  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
838  if (failed(fillOp))
839  return failure();
840  rewriter.replaceOp(packOp, fillOp.value().result());
841  return success();
842  }
843 };
844 
845 /// Fold fill with copy.
846 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
848 
849  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
850  PatternRewriter &rewriter) const override {
851  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
852  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
853  fillOp.getInputs(),
854  copyOp.getOutputs());
855  return success();
856  }
857  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
858  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
859  fillOp.getOutputs());
860  return success();
861  }
862  return failure();
863  }
864 };
865 
866 /// Fold fill with transpose.
867 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
869 
870  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
871  PatternRewriter &rewriter) const override {
872  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
873  rewriter.replaceOpWithNewOp<FillOp>(
874  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
875  transposeOp.getDpsInitOperand(0)->get());
876  return success();
877  }
878  return failure();
879  }
880 };
881 
882 } // namespace
883 
884 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
885  MLIRContext *context) {
886  results
887  .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
888  FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
889  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
890  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // GenericOp
895 //===----------------------------------------------------------------------===//
896 
897 static void buildGenericRegion(
898  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
899  ValueRange outputs,
900  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
901  SmallVector<Type, 4> blockArgTypes;
902  SmallVector<Location, 4> blockArgLocs;
903  for (ValueRange container : {inputs, outputs}) {
904  for (Value v : container) {
905  Type t = v.getType();
906  blockArgTypes.push_back(
907  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
908  blockArgLocs.push_back(v.getLoc());
909  }
910  }
911 
912  OpBuilder::InsertionGuard guard(builder);
913  Block *bodyBlock =
914  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
915  bodyBuild(builder, loc, bodyBlock->getArguments());
916 }
917 
918 void GenericOp::getAsmBlockArgumentNames(Region &region,
919  OpAsmSetValueNameFn setNameFn) {
920  for (Value v : getRegionInputArgs())
921  setNameFn(v, "in");
922  for (Value v : getRegionOutputArgs())
923  setNameFn(v, "out");
924 }
925 
926 void GenericOp::build(
927  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
928  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
929  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
930  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
931  ArrayRef<NamedAttribute> attributes) {
932  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
933  iteratorTypes, doc, libraryCall);
934  result.addAttributes(attributes);
935  if (bodyBuild)
936  buildGenericRegion(builder, result.location, *result.regions.front(),
937  inputs, outputs, bodyBuild);
938 }
939 
940 void GenericOp::build(
941  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
942  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
943  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
944  StringRef libraryCall,
945  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
946  ArrayRef<NamedAttribute> attributes) {
947  build(builder, result, resultTensorTypes, inputs, outputs,
948  builder.getAffineMapArrayAttr(indexingMaps),
949  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
950  iteratorTypes,
951  [&](utils::IteratorType iter) -> mlir::Attribute {
952  return IteratorTypeAttr::get(builder.getContext(), iter);
953  }))),
954  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
955  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
956  bodyBuild, attributes);
957 }
958 
959 void GenericOp::build(
960  OpBuilder &builder, OperationState &result, ValueRange inputs,
961  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
962  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
963  StringRef libraryCall,
964  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
965  ArrayRef<NamedAttribute> attributes) {
966  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
967  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
968 }
969 
970 void GenericOp::build(
971  OpBuilder &builder, OperationState &result, ValueRange inputs,
972  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
973  ArrayRef<utils::IteratorType> iteratorTypes,
974  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
975  ArrayRef<NamedAttribute> attributes) {
976  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
977  /*doc=*/"",
978  /*libraryCall=*/"", bodyBuild, attributes);
979 }
980 
981 void GenericOp::build(
982  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
983  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
984  ArrayRef<utils::IteratorType> iteratorTypes,
985  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
986  ArrayRef<NamedAttribute> attributes) {
987  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
988  iteratorTypes,
989  /*doc=*/"",
990  /*libraryCall=*/"", bodyBuild, attributes);
991 }
992 
994  p << " ";
995 
996  // Print extra attributes.
997  auto genericAttrNames = linalgTraitAttrNames();
998 
999  llvm::StringSet<> genericAttrNamesSet;
1000  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1001  SmallVector<NamedAttribute, 8> genericAttrs;
1002  for (auto attr : (*this)->getAttrs()) {
1003  if (attr.getName() == getIteratorTypesAttrName()) {
1004  auto iteratorTypes =
1005  llvm::cast<ArrayAttr>(attr.getValue())
1006  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1007  // Convert IteratorType enums into the string representation. This is
1008  // needed, because tests still use the old format when 'iterator_types'
1009  // attribute is represented as an array of strings.
1010  // TODO: Remove this conversion once tests are fixed.
1011  SmallVector<Attribute> iteratorTypeNames =
1012  llvm::to_vector(llvm::map_range(
1013  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1014  return StringAttr::get(getContext(), stringifyIteratorType(t));
1015  }));
1016 
1017  genericAttrs.emplace_back(
1018  getIteratorTypesAttrName(),
1019  ArrayAttr::get(getContext(), iteratorTypeNames));
1020  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1021  genericAttrs.push_back(attr);
1022  }
1023  }
1024  if (!genericAttrs.empty()) {
1025  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1026  p << genericDictAttr;
1027  }
1028 
1029  // Printing is shared with named ops, except for the region and attributes
1030  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1031 
1032  genericAttrNames.push_back("operandSegmentSizes");
1033  genericAttrNamesSet.insert(genericAttrNames.back());
1034 
1035  bool hasExtraAttrs = false;
1036  for (NamedAttribute n : (*this)->getAttrs()) {
1037  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1038  break;
1039  }
1040  if (hasExtraAttrs) {
1041  p << " attrs = ";
1042  p.printOptionalAttrDict((*this)->getAttrs(),
1043  /*elidedAttrs=*/genericAttrNames);
1044  }
1045 
1046  // Print region.
1047  if (!getRegion().empty()) {
1048  p << ' ';
1049  p.printRegion(getRegion());
1050  }
1051 
1052  // Print results.
1053  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1054 }
1055 
1056 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1057  DictionaryAttr dictAttr;
1058  // Parse the core linalg traits that must check into a dictAttr.
1059  // The name is unimportant as we will overwrite result.attributes.
1060  // The core linalg traits must contain the information necessary to pass the
1061  // verifier.
1062  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1063  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1064  return failure();
1065  result.attributes.assign(dictAttr.getValue().begin(),
1066  dictAttr.getValue().end());
1067 
1068  // Convert array of string into an array of IteratorType enums. This is
1069  // needed, because tests still use the old format when 'iterator_types'
1070  // attribute is represented as an array of strings.
1071  // TODO: Remove this conversion once tests are fixed.
1072  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1073  result.attributes.get(getIteratorTypesAttrName(result.name)));
1074  if (!iteratorTypes) {
1075  return parser.emitError(attributeLocation)
1076  << "expected " << getIteratorTypesAttrName(result.name)
1077  << " array attribute";
1078  }
1079 
1080  SmallVector<Attribute> iteratorTypeAttrs;
1081 
1082  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1083  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1084  if (!maybeIteratorType.has_value())
1085  return parser.emitError(parser.getCurrentLocation())
1086  << "unexpected iterator_type (" << s << ")";
1087 
1088  iteratorTypeAttrs.push_back(
1089  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1090  }
1091  result.attributes.set(getIteratorTypesAttrName(result.name),
1092  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1093 
1094  // Parsing is shared with named ops, except for the region.
1095  SmallVector<Type, 1> inputTypes, outputTypes;
1096  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1097  return failure();
1098 
1099  // Optional attributes may be added.
1100  if (succeeded(parser.parseOptionalKeyword("attrs")))
1101  if (failed(parser.parseEqual()) ||
1102  failed(parser.parseOptionalAttrDict(result.attributes)))
1103  return failure();
1104 
1105  std::unique_ptr<Region> region = std::make_unique<Region>();
1106  if (parser.parseRegion(*region, {}))
1107  return failure();
1108  result.addRegion(std::move(region));
1109 
1110  // Generic ops may specify that a subset of its outputs are tensors. Such
1111  // outputs are specified in the result type.
1112  // TODO: may need to move output parsing before region parsing.
1113  // Need to wait for declarative assembly resolution to decide.
1114  SmallVector<Type, 1> outputTensorsTypes;
1115  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1116  return failure();
1117  result.addTypes(outputTensorsTypes);
1118 
1119  return success();
1120 }
1121 
1124  &effects,
1125  LinalgOp linalgOp) {
1126  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1127  if (!llvm::isa<MemRefType>(operand.getType()))
1128  continue;
1129  effects.emplace_back(
1130  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1131  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1132  }
1133 
1134  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1135  if (!llvm::isa<MemRefType>(operand.get().getType()))
1136  continue;
1137  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1138  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1139  /*effectOnFullRegion=*/true,
1141  }
1142  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1143  /*effectOnFullRegion=*/true,
1145  }
1146 }
1147 
1148 void GenericOp::getEffects(
1150  &effects) {
1151  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1152 }
1153 
1154 LogicalResult GenericOp::verify() { return success(); }
1155 
1156 namespace {
1157 
1158 /// Remove any linalg operation (on tensors) that are just copying
1159 /// the values from inputs to the results. Requirements are
1160 /// 1) All iterator types are parallel
1161 /// 2) The body contains just a yield operation with the yielded values being
1162 /// the arguments corresponding to the operands.
1163 template <typename OpTy>
1164 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1166 
1167  LogicalResult matchAndRewrite(OpTy linalgOp,
1168  PatternRewriter &rewriter) const override {
1169  // Check all indexing maps are identity.
1170  if (llvm::any_of(linalgOp.getIndexingMapsArray(),
1171  [](AffineMap map) { return !map.isIdentity(); }))
1172  return failure();
1173 
1174  // Check that the body of the linalg operation is just a linalg.yield
1175  // operation.
1176  Block &body = linalgOp->getRegion(0).front();
1177  if (!llvm::hasSingleElement(body))
1178  return failure();
1179  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1180  if (!yieldOp)
1181  return failure();
1182 
1183  // In the buffer case, we need to check exact buffer equality.
1184  if (linalgOp.hasPureBufferSemantics()) {
1185  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1186  linalgOp.getDpsInputOperand(0)->get() ==
1187  linalgOp.getDpsInitOperand(0)->get()) {
1188  rewriter.eraseOp(linalgOp);
1189  return success();
1190  }
1191  return failure();
1192  }
1193 
1194  // Mixed semantics is not supported yet.
1195  if (!linalgOp.hasPureTensorSemantics())
1196  return failure();
1197 
1198  // Get the argument number of the returned values. That is the operand
1199  // number to use for replacing uses of this operation.
1200  SmallVector<Value> returnedArgs;
1201  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1202  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1203  if (!yieldArg || yieldArg.getOwner() != &body)
1204  return failure();
1205  unsigned argumentNumber = yieldArg.getArgNumber();
1206  Value returnedArg = linalgOp->getOperand(argumentNumber);
1207  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1208  // The input can have a different type than the result, e.g. a dynamic
1209  // input dimension can be turned into a static output dimension.
1210  Type returnType = returnedArg.getType();
1211  if (returnType != resultType) {
1212  // Distinguish between sparse conversion or dense tensor casting.
1213  // TODO: unify the two ops?
1214  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1216  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1217  linalgOp.getLoc(), resultType, returnedArg);
1218  else {
1219  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1220  resultType))
1221  return failure();
1222  returnedArg = rewriter.create<tensor::CastOp>(
1223  linalgOp.getLoc(), resultType, returnedArg);
1224  }
1225  }
1226  returnedArgs.push_back(returnedArg);
1227  }
1228 
1229  if (returnedArgs.size() != linalgOp->getNumResults())
1230  return failure();
1231  rewriter.replaceOp(linalgOp, returnedArgs);
1232  return success();
1233  }
1234 };
1235 
1236 } // namespace
1237 
1238 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1239  MLIRContext *context) {
1240  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1241 }
1242 
1243 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1244  return memref::foldMemRefCast(*this);
1245 }
1246 
1247 //===----------------------------------------------------------------------===//
1248 // MapOp
1249 //===----------------------------------------------------------------------===//
1250 
1251 static ParseResult parseDstStyleOp(
1252  OpAsmParser &parser, OperationState &result,
1253  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1254  nullptr) {
1255  // Parse `ins` and `outs`.
1256  SmallVector<Type, 4> inputTypes, outputTypes;
1257  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1258  /*addOperandSegmentSizes=*/false))
1259  return failure();
1260 
1261  // Add result types.
1262  for (Type outputType : outputTypes) {
1263  if (llvm::isa<RankedTensorType>(outputType))
1264  result.addTypes(outputType);
1265  }
1266 
1267  // Parse required attributes.
1268  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1269  return failure();
1270 
1271  // Parse optional attributes.
1272  if (parser.parseOptionalAttrDict(result.attributes))
1273  return failure();
1274  return success();
1275 }
1276 
1277 void MapOp::getAsmBlockArgumentNames(Region &region,
1278  OpAsmSetValueNameFn setNameFn) {
1279  for (Value v : getRegionInputArgs())
1280  setNameFn(v, "in");
1281 }
1282 
1283 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1284  if (!getResults().empty())
1285  setNameFn(getResults().front(), "mapped");
1286 }
1287 
1288 void MapOp::build(
1289  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1290  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1291  ArrayRef<NamedAttribute> attributes) {
1292  build(builder, result, TypeRange{}, inputs, init);
1293  result.addAttributes(attributes);
1294 
1295  // Add output types for `RankedTensorType` output arguments.
1296  Type initType = init.getType();
1297  if (llvm::isa<RankedTensorType>(initType))
1298  result.addTypes(initType);
1299 
1300  if (bodyBuild)
1301  buildGenericRegion(builder, result.location, *result.regions.front(),
1302  inputs, /*outputs=*/{}, bodyBuild);
1303 }
1304 
1305 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1306  const OperationName &payloadOpName,
1307  const NamedAttrList &payloadOpAttrs,
1308  ArrayRef<Value> operands,
1309  bool initFirst = false) {
1310  OpBuilder b(parser.getContext());
1311  Region *body = result.addRegion();
1312  Block &block = body->emplaceBlock();
1313  b.setInsertionPointToStart(&block);
1314  SmallVector<Value> bbArgs;
1315  for (auto &operand : operands) {
1316  block.addArgument(
1317  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1318  b.getUnknownLoc());
1319  }
1320  SmallVector<Value> payloadOpOperands;
1321  // If initFirst flag is enabled, we consider init as the first position of
1322  // payload operands.
1323  if (initFirst) {
1324  payloadOpOperands.push_back(block.getArguments().back());
1325  for (const auto &arg : block.getArguments().drop_back())
1326  payloadOpOperands.push_back(arg);
1327  } else {
1328  payloadOpOperands = {block.getArguments().begin(),
1329  block.getArguments().end()};
1330  }
1331 
1332  Operation *payloadOp = b.create(
1333  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1334  payloadOpOperands,
1335  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1336  .getElementType()},
1337  payloadOpAttrs);
1338  b.create<YieldOp>(result.location, payloadOp->getResults());
1339 }
1340 
1341 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1342  std::optional<OperationName> payloadOpName;
1343  NamedAttrList payloadOpAttrs;
1344  if (succeeded(parser.parseOptionalLBrace())) {
1345  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1346  if (failed(operationName))
1347  return failure();
1348  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1349  return failure();
1350  payloadOpName = operationName.value();
1351  if (parser.parseRBrace())
1352  return failure();
1353  }
1354 
1355  if (parseDstStyleOp(parser, result))
1356  return failure();
1357 
1358  if (payloadOpName.has_value()) {
1359  if (!result.operands.empty())
1360  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1361  payloadOpAttrs,
1362  ArrayRef(result.operands).drop_back());
1363  else
1364  result.addRegion();
1365  } else {
1367  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1368  /*allowType=*/true, /*allowAttrs=*/true)) {
1369  return failure();
1370  }
1371  Region *body = result.addRegion();
1372  if (parser.parseRegion(*body, regionArgs))
1373  return failure();
1374  }
1375  return success();
1376 }
1377 
1378 // Retrieve the operation from the body, if it is the only one (except
1379 // yield) and if it gets the same amount of arguments as the body does.
1380 // If initFirst flag is enabled, we check that init takes the first position in
1381 // operands of payload.
1382 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1383  if (body->getOperations().size() != 2)
1384  return nullptr;
1385  Operation &payload = body->getOperations().front();
1386  assert(isa<YieldOp>(body->getOperations().back()));
1387 
1388  if (payload.getNumOperands() == 0 ||
1389  payload.getNumOperands() != body->getNumArguments())
1390  return nullptr;
1391  if (initFirst) {
1392  // check init
1393  if (payload.getOperands().back() != body->getArgument(0))
1394  return nullptr;
1395  // check rest
1396  for (const auto &[operand, bbArg] :
1397  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1398  if (bbArg != operand)
1399  return nullptr;
1400  }
1401  } else {
1402  for (const auto &[operand, bbArg] :
1403  llvm::zip(payload.getOperands(), body->getArguments())) {
1404  if (bbArg != operand)
1405  return nullptr;
1406  }
1407  }
1408  return &payload;
1409 }
1410 
1411 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1412  SmallVector<StringRef> elidedAttrs;
1413  std::string attrToElide;
1414  p << " { " << payloadOp->getName().getStringRef();
1415  for (const auto &attr : payloadOp->getAttrs()) {
1416  auto fastAttr =
1417  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1418  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1419  attrToElide = attr.getName().str();
1420  elidedAttrs.push_back(attrToElide);
1421  break;
1422  }
1423  }
1424  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1425  p << " }";
1426 }
1427 
1428 void MapOp::print(OpAsmPrinter &p) {
1429  Block *mapper = getBody();
1430  Operation *payloadOp = findPayloadOp(mapper);
1431  if (payloadOp) {
1432  printShortForm(p, payloadOp);
1433  }
1434 
1435  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1436  p.printOptionalAttrDict((*this)->getAttrs());
1437 
1438  if (!payloadOp) {
1439  // Print region if the payload op was not detected.
1440  p.increaseIndent();
1441  p.printNewline();
1442  p << "(";
1443  llvm::interleaveComma(mapper->getArguments(), p,
1444  [&](auto arg) { p.printRegionArgument(arg); });
1445  p << ") ";
1446 
1447  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1448  p.decreaseIndent();
1449  }
1450 }
1451 
1452 LogicalResult MapOp::verify() {
1453  auto *bodyBlock = getBody();
1454  auto blockArgs = bodyBlock->getArguments();
1455 
1456  // Checks if the number of `inputs` match the arity of the `mapper` region.
1457  if (getInputs().size() != blockArgs.size())
1458  return emitOpError() << "expects number of operands to match the arity of "
1459  "mapper, but got: "
1460  << getInputs().size() << " and " << blockArgs.size();
1461 
1462  // The parameters of mapper should all match the element type of inputs.
1463  for (const auto &[bbArgType, inputArg] :
1464  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1465  auto inputElemType =
1466  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1467  if (bbArgType != inputElemType) {
1468  return emitOpError() << "expected element type of input " << inputElemType
1469  << " to match bbArg type " << bbArgType;
1470  }
1471  }
1472 
1473  // The shape of each input must match the shape of the output.
1474  auto outputShape = getInit().getType().getShape();
1475  for (Type inputArgType : TypeRange{getInputs()}) {
1476  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1477  if (inputElemShape != outputShape) {
1478  return emitOpError() << "expected shape of input (" << inputElemShape
1479  << ") to match shape of output (" << outputShape
1480  << ")";
1481  }
1482  }
1483 
1484  return success();
1485 }
1486 
1487 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1488  int64_t rank = getInit().getType().getRank();
1489  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1490 }
1491 
1492 ArrayAttr MapOp::getIndexingMaps() {
1493  Builder builder(getContext());
1494  int64_t rank = getInit().getType().getRank();
1495  int64_t numIndexingMaps = getOperands().size();
1497  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1498 }
1499 
1500 void MapOp::getEffects(
1502  &effects) {
1503  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1504 }
1505 
1506 //===----------------------------------------------------------------------===//
1507 // ReduceOp
1508 //===----------------------------------------------------------------------===//
1509 
1510 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1511  OpAsmSetValueNameFn setNameFn) {
1512  for (Value v : getRegionInputArgs())
1513  setNameFn(v, "in");
1514  for (Value v : getRegionOutputArgs())
1515  setNameFn(v, "init");
1516 }
1517 
1518 void ReduceOp::getAsmResultNames(
1519  function_ref<void(Value, StringRef)> setNameFn) {
1520  if (!getResults().empty())
1521  setNameFn(getResults().front(), "reduced");
1522 }
1523 
1524 void ReduceOp::build(
1525  OpBuilder &builder, OperationState &result, ValueRange inputs,
1526  ValueRange inits, ArrayRef<int64_t> dimensions,
1527  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1528  ArrayRef<NamedAttribute> attributes) {
1529  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1530  result.addAttributes(attributes);
1531 
1532  // Add output types for `RankedTensorType` output arguments.
1533  for (Value init : inits) {
1534  Type initType = init.getType();
1535  if (llvm::isa<RankedTensorType>(initType))
1536  result.addTypes(initType);
1537  }
1538 
1539  if (bodyBuild)
1540  buildGenericRegion(builder, result.location, *result.regions.front(),
1541  inputs, inits, bodyBuild);
1542 }
1543 
1544 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1545  int64_t inputRank =
1546  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1547  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1548  utils::IteratorType::parallel);
1549  for (int64_t reductionDim : getDimensions())
1550  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1551  return iteratorTypes;
1552 }
1553 
1554 ArrayAttr ReduceOp::getIndexingMaps() {
1555  int64_t inputRank =
1556  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1557  SmallVector<AffineMap> affineMaps(
1558  getNumDpsInputs(),
1560  AffineMap resultMap =
1562  .dropResults(getDimensions());
1563  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1564  affineMaps.push_back(resultMap);
1565  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1566 }
1567 
1568 void ReduceOp::getEffects(
1570  &effects) {
1571  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1572 }
1573 
1574 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1575  NamedAttrList &attributes,
1576  StringRef attributeName) {
1577  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1578  return failure();
1579 
1580  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1581  return success();
1582 }
1583 
1584 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1585  std::optional<OperationName> payloadOpName;
1586  NamedAttrList payloadOpAttrs;
1587  if (succeeded(parser.parseOptionalLBrace())) {
1588  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1589  if (failed(operationName))
1590  return failure();
1591  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1592  return failure();
1593  payloadOpName = operationName.value();
1594  if (parser.parseRBrace())
1595  return failure();
1596  }
1597 
1598  if (parseDstStyleOp(
1599  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1600  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1601  }))
1602  return failure();
1603 
1604  if (payloadOpName.has_value()) {
1605  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1606  ArrayRef(result.operands), /*initFirst=*/true);
1607  } else {
1609  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1610  /*allowType=*/true, /*allowAttrs=*/true)) {
1611  return failure();
1612  }
1613 
1614  Region *body = result.addRegion();
1615  if (parser.parseRegion(*body, regionArgs))
1616  return failure();
1617  }
1618 
1619  return success();
1620 }
1621 
1622 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1623  ArrayRef<int64_t> attributeValue) {
1624  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1625 }
1626 
1627 void ReduceOp::print(OpAsmPrinter &p) {
1628  Block *mapper = getBody();
1629  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1630  if (payloadOp) {
1631  printShortForm(p, payloadOp);
1632  }
1633 
1634  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1635  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1636  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1637  if (!payloadOp) {
1638  // Print region if the payload op was not detected.
1639  p.increaseIndent();
1640  p.printNewline();
1641  p << "(";
1642  llvm::interleaveComma(mapper->getArguments(), p,
1643  [&](auto arg) { p.printRegionArgument(arg); });
1644  p << ") ";
1645 
1646  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1647  p.decreaseIndent();
1648  }
1649 }
1650 
1651 LogicalResult ReduceOp::verify() {
1652  ArrayRef<int64_t> dimensionsRef = getDimensions();
1653 
1654  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1655  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1656  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1657  return emitOpError() << "expects all inputs to have the same shapes. "
1658  "Shape at input-index "
1659  << i
1660  << " is not equal to the shape at input-index 0.";
1661  }
1662  }
1663  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1664  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1665  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1666  return emitOpError() << "expects all outputs to have the same shapes. "
1667  "Shape at output-index "
1668  << i
1669  << " is not equal to the shape at output-index 0.";
1670  }
1671  }
1672  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1673  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1674 
1675  DenseSet<int64_t> dimensionsToReduce;
1676  for (int64_t dimension : dimensionsRef) {
1677  if (dimension < 0 || dimension >= inputType.getRank()) {
1678  return emitOpError()
1679  << "dimensions for reduction should be in the range [0, "
1680  << inputType.getRank() - 1 << "].";
1681  }
1682  dimensionsToReduce.insert(dimension);
1683  }
1684 
1685  auto inputDims = inputType.getShape();
1686  auto initDims = initType.getShape();
1687 
1688  // Input dimensions that will be left after the reduction.
1689  SmallVector<int64_t> reducedInputDims;
1690  for (const auto &en : llvm::enumerate(inputDims)) {
1691  if (!dimensionsToReduce.count(en.index()))
1692  reducedInputDims.push_back(en.value());
1693  }
1694 
1695  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1696  return emitOpError() << "number of dimensions after reduction "
1697  << reducedInputDims.size()
1698  << " doesn't match the init rank "
1699  << initType.getRank();
1700  }
1701 
1702  if (reducedInputDims != initDims)
1703  return emitOpError() << "init dimensions [" << initDims
1704  << "] doesn't match input dimensions after reduction ["
1705  << reducedInputDims << "]";
1706 
1707  Block *block = getBody();
1708  if (block->getNumArguments() != this->getNumOperands())
1709  return emitOpError()
1710  << "mismatching number of operands and block arguments";
1711 
1712  // Check that the first block arguments match the element type of the inputs.
1713  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1714  Type inputElementType =
1715  llvm::cast<ShapedType>(input.getType()).getElementType();
1716  if (inputElementType != bbArg.getType())
1717  return emitOpError()
1718  << "input element type " << inputElementType
1719  << " does not match corresponding block argument type "
1720  << bbArg.getType();
1721  }
1722 
1723  // Check that the last block arguments match the element type of the outputs.
1724  for (auto [output, bbArg] : llvm::zip(
1725  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1726  auto outputElementType =
1727  llvm::cast<ShapedType>(output.getType()).getElementType();
1728  if (outputElementType != bbArg.getType())
1729  return emitOpError()
1730  << "output element type " << outputElementType
1731  << " does not match corresponding block argument type "
1732  << bbArg.getType();
1733  }
1734  return success();
1735 }
1736 
1737 //===----------------------------------------------------------------------===//
1738 // TransposeOp
1739 //===----------------------------------------------------------------------===//
1740 
1741 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1742  Region &region, ValueRange inputs,
1743  ValueRange outputs) {
1744  buildGenericRegion(builder, loc, region, inputs, outputs,
1745  [](OpBuilder &b, Location loc, ValueRange args) {
1746  if (!args.empty())
1747  b.create<linalg::YieldOp>(loc, args[0]);
1748  });
1749 }
1750 
1751 void TransposeOp::build(::mlir::OpBuilder &builder,
1752  ::mlir::OperationState &result, Value input, Value init,
1753  DenseI64ArrayAttr permutation,
1754  ArrayRef<NamedAttribute> attributes) {
1755  result.addOperands(input);
1756  result.addOperands(init);
1757  result.addAttribute(getPermutationAttrName(result.name), permutation);
1758  result.addAttributes(attributes);
1759 
1760  // Add output types for `RankedTensorType` output arguments.
1761  Type initType = init.getType();
1762  if (llvm::isa<RankedTensorType>(initType))
1763  result.addTypes(initType);
1764 
1765  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1766  init);
1767 }
1768 
1769 void TransposeOp::build(::mlir::OpBuilder &builder,
1770  ::mlir::OperationState &result, Value input, Value init,
1771  ArrayRef<int64_t> permutation,
1772  ArrayRef<NamedAttribute> attributes) {
1773  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1774  attributes);
1775 }
1776 
1777 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1778  if (failed(parseDstStyleOp(
1779  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1780  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1781  })))
1782  return failure();
1783 
1784  OpBuilder builder(parser.getContext());
1785  buildIdentityRegion(builder, result.location, *result.addRegion(),
1786  /*inputs=*/result.operands,
1787  /*outputs=*/{});
1788  return success();
1789 }
1790 
1791 void TransposeOp::getAsmResultNames(
1792  function_ref<void(Value, StringRef)> setNameFn) {
1793  if (!getResults().empty())
1794  setNameFn(getResults().front(), "transposed");
1795 }
1796 
1798  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1799  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1800  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1801 }
1802 
1803 LogicalResult TransposeOp::verify() {
1804  ArrayRef<int64_t> permutationRef = getPermutation();
1805 
1806  if (!isPermutationVector(permutationRef))
1807  return emitOpError("permutation is not valid");
1808 
1809  auto inputType = getInput().getType();
1810  auto initType = getInit().getType();
1811 
1812  int64_t rank = inputType.getRank();
1813 
1814  if (rank != initType.getRank())
1815  return emitOpError() << "input rank " << rank
1816  << " does not match init rank " << initType.getRank();
1817 
1818  if (rank != static_cast<int64_t>(permutationRef.size()))
1819  return emitOpError() << "size of permutation " << permutationRef.size()
1820  << " does not match the argument rank " << rank;
1821 
1822  auto inputDims = inputType.getShape();
1823  auto initDims = initType.getShape();
1824 
1825  for (int64_t i = 0; i < rank; ++i) {
1826  int64_t inputDim = inputDims[permutationRef[i]];
1827  int64_t initDim = initDims[i];
1828 
1829  if (inputDim != initDim) {
1830  return emitOpError() << "dim(result, " << i << ") = " << initDim
1831  << " doesn't match dim(input, permutation[" << i
1832  << "]) = " << inputDim;
1833  }
1834  }
1835 
1836  return success();
1837 }
1838 
1839 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1840  int64_t rank = getInit().getType().getRank();
1841  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1842 }
1843 
1844 ArrayAttr TransposeOp::getIndexingMaps() {
1845  Builder builder(getContext());
1846  int64_t rank = getInit().getType().getRank();
1847  return builder.getAffineMapArrayAttr(
1849  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1850  builder.getMultiDimIdentityMap(rank)});
1851 }
1852 
1853 void TransposeOp::getEffects(
1855  &effects) {
1856  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1857 }
1858 
1859 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1861  // Single dimension transpose.
1862  if (getPermutation().size() == 0) {
1863  result.push_back(getInput());
1864  return success();
1865  }
1866  // Identity permutation.
1867  if (isIdentityPermutation(getPermutation())) {
1868  result.push_back(getInput());
1869  return success();
1870  }
1871 
1872  return failure();
1873 }
1874 
1875 /// Fold transpose with transpose.
1876 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1878 
1879  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1880  PatternRewriter &rewriter) const override {
1881  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1882  if (!defTransposeOp)
1883  return failure();
1884  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1885  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1886  SmallVector<int64_t> foldedPerms;
1887  foldedPerms.reserve(perms.size());
1888  for (int64_t perm : perms)
1889  foldedPerms.push_back(defPerms[perm]);
1890 
1891  rewriter.replaceOpWithNewOp<TransposeOp>(
1892  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1893  foldedPerms);
1894  return success();
1895  }
1896 };
1897 
1898 /// This pattern canonicalize transpose by swapping the order of
1899 /// broadcast and transpose:
1900 /// transpose(broadcast(input)) -> broadcast(transpose(input))
1901 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
1903 
1904  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1905  PatternRewriter &rewriter) const override {
1906  Value input = transposeOp.getInput();
1907  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
1908  if (!input.hasOneUse() || !broadcastOp)
1909  return failure();
1910 
1911  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
1912  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1913 
1914  // Get new perms and new dimensions.
1915  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
1916  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
1917  SmallVector<int64_t> resultDimensions;
1918  unsigned dimensionSize = dimensions.size();
1919  for (unsigned i = 0; i < dimensionSize; ++i)
1920  resultDimensions.push_back(invertPerm[dimensions[i]]);
1921 
1922  // Create transpose result.
1923  Value broadcastInput = broadcastOp.getInput();
1924  Location loc = transposeOp.getLoc();
1925  MLIRContext *ctx = transposeOp.getContext();
1927  auto broadcastInputTy =
1928  mlir::cast<RankedTensorType>(broadcastInput.getType());
1929  unsigned inputRank = broadcastInputTy.getRank();
1930  for (unsigned i = 0; i < inputRank; ++i) {
1931  if (broadcastInputTy.isDynamicDim(i)) {
1932  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
1933  ->getResult(0));
1934  } else {
1935  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
1936  broadcastInputTy.getDimSize(i)));
1937  }
1938  }
1939  SmallVector<OpFoldResult> transposeResultShapes =
1940  applyPermutation(dims, resultPerms);
1941  Value transposeInit = rewriter.create<tensor::EmptyOp>(
1942  transposeOp.getLoc(), transposeResultShapes,
1943  broadcastInputTy.getElementType());
1944 
1945  // Create broadcast(transpose(input)).
1946  Value transposeResult =
1947  rewriter
1948  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
1949  resultPerms)
1950  ->getResult(0);
1951  rewriter.replaceOpWithNewOp<BroadcastOp>(
1952  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
1953  return success();
1954  }
1955 };
1956 
1957 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1958  MLIRContext *context) {
1960 }
1961 
1962 //===----------------------------------------------------------------------===//
1963 // BroadcastOp
1964 //===----------------------------------------------------------------------===//
1965 
1966 void BroadcastOp::build(::mlir::OpBuilder &builder,
1967  ::mlir::OperationState &result, Value input, Value init,
1968  DenseI64ArrayAttr dimensions,
1969  ArrayRef<NamedAttribute> attributes) {
1970  result.addOperands(input);
1971  result.addOperands(init);
1972  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1973  result.addAttributes(attributes);
1974 
1975  // Add output types for `RankedTensorType` output arguments.
1976  Type initType = init.getType();
1977  if (llvm::isa<RankedTensorType>(initType))
1978  result.addTypes(initType);
1979 
1980  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1981  init);
1982 }
1983 
1984 void BroadcastOp::build(::mlir::OpBuilder &builder,
1985  ::mlir::OperationState &result, Value input, Value init,
1986  ArrayRef<int64_t> dimensions,
1987  ArrayRef<NamedAttribute> attributes) {
1988  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1989  attributes);
1990 }
1991 
1992 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
1993  if (failed(parseDstStyleOp(
1994  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1995  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1996  })))
1997  return failure();
1998 
1999  OpBuilder builder(parser.getContext());
2000  buildIdentityRegion(builder, result.location, *result.addRegion(),
2001  /*inputs=*/result.operands,
2002  /*outputs=*/{});
2003  return success();
2004 }
2005 
2006 void BroadcastOp::getAsmResultNames(
2007  function_ref<void(Value, StringRef)> setNameFn) {
2008  if (!getResults().empty())
2009  setNameFn(getResults().front(), "broadcasted");
2010 }
2011 
2013  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2014  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2015  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2016 }
2017 
2018 LogicalResult BroadcastOp::verify() {
2019  ArrayRef<int64_t> dimensionsRef = getDimensions();
2020 
2021  auto inputType = getInput().getType();
2022  auto initType = getInit().getType();
2023 
2024  int64_t inputRank = inputType.getRank();
2025  int64_t initRank = initType.getRank();
2026 
2027  auto inputShape = inputType.getShape();
2028  auto initShape = initType.getShape();
2029 
2030  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2031  return emitOpError() << "input rank plus added dimensions does not "
2032  "match init rank. input rank: "
2033  << inputRank
2034  << ", dimensions size: " << dimensionsRef.size()
2035  << ", init rank: " << initRank;
2036 
2037  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2038  if (dim < 0 || dim >= initRank)
2039  return emitOpError() << "dimension " << idx
2040  << " is out of range. expected range: [0, "
2041  << initRank - 1 << "], got: " << dim;
2042  }
2043 
2044  // Mapping from input dims to init dims.
2045  SmallVector<int64_t> dimMap;
2046  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2047  if (!llvm::is_contained(dimensionsRef, dim))
2048  dimMap.push_back(dim);
2049  }
2050 
2051  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2052  // This dimensions is mapped from the input. Init and input dims should
2053  // match.
2054  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2055  return emitOpError() << "input dim " << inputDimIdx
2056  << " should match init dim " << initDimIdx
2057  << ". input: " << inputShape[inputDimIdx]
2058  << ", init: " << initShape[initDimIdx];
2059  }
2060 
2061  return success();
2062 }
2063 
2064 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2065  int64_t rank = getInit().getType().getRank();
2066  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2067 }
2068 
2069 ArrayAttr BroadcastOp::getIndexingMaps() {
2070  Builder builder(getContext());
2071  int64_t rank = getInit().getType().getRank();
2072  return builder.getAffineMapArrayAttr(
2073  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2074  builder.getMultiDimIdentityMap(rank)});
2075 }
2076 
2077 void BroadcastOp::getEffects(
2079  &effects) {
2080  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2081 }
2082 
2083 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2084  MLIRContext *context) {
2085  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2086 }
2087 
2088 //===----------------------------------------------------------------------===//
2089 // YieldOp
2090 //===----------------------------------------------------------------------===//
2091 
2093  if (getNumOperands() > 0)
2094  p << ' ' << getOperands();
2095  p.printOptionalAttrDict((*this)->getAttrs());
2096  if (getNumOperands() > 0)
2097  p << " : " << getOperandTypes();
2098 }
2099 
2100 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2102  SmallVector<Type, 2> types;
2103  SMLoc loc = parser.getCurrentLocation();
2104  return failure(parser.parseOperandList(opInfo) ||
2105  parser.parseOptionalAttrDict(result.attributes) ||
2106  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2107  parser.resolveOperands(opInfo, types, loc, result.operands));
2108 }
2109 
2110 // Check the operand number and types must match the element types of the
2111 // LinalgOp interface's shaped operands.
2112 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2113  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2114  return op.emitOpError("expected number of yield values (")
2115  << op.getNumOperands()
2116  << ") to match the number of inits / outs operands of the enclosing "
2117  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2118 
2119  for (OpOperand &opOperand : op->getOpOperands()) {
2120  OpOperand *outputOperand =
2121  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2122  Type elementType = outputOperand->get().getType();
2123  if (isa<MemRefType, RankedTensorType>(elementType))
2124  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2125  if (opOperand.get().getType() != elementType)
2126  return op.emitOpError("type of yield operand ")
2127  << (opOperand.getOperandNumber() + 1) << " ("
2128  << opOperand.get().getType() << ") doesn't match "
2129  << "the element type of the enclosing linalg.generic op ("
2130  << elementType << ")";
2131  }
2132  return success();
2133 }
2134 
2135 LogicalResult linalg::YieldOp::verify() {
2136  auto *parentOp = (*this)->getParentOp();
2137  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2138  return emitOpError("expected single non-empty parent region");
2139 
2140  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2141  return verifyYield(*this, linalgOp);
2142 
2143  return emitOpError("expected parent op with LinalgOp interface");
2144 }
2145 
2146 //===----------------------------------------------------------------------===//
2147 // IndexOp
2148 //===----------------------------------------------------------------------===//
2149 
2150 LogicalResult IndexOp::verify() {
2151  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2152  if (!linalgOp)
2153  return emitOpError("expected parent op with LinalgOp interface");
2154  if (linalgOp.getNumLoops() <= getDim())
2155  return emitOpError("expected dim (")
2156  << getDim() << ") to be lower than the number of loops ("
2157  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2158  return success();
2159 }
2160 
2161 /////// Operations corresponding to library calls defined with Tablegen ////////
2162 
2163 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2164 
2165 #define GET_OP_CLASSES
2166 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2167 
2168 #define GET_OP_CLASSES
2169 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2170 
2171 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2172  unsigned rank,
2173  MLIRContext *context) {
2174  if (maybeMap)
2175  return *maybeMap;
2176  if (rank == 0)
2177  return AffineMap::get(context);
2178  return AffineMap::getMultiDimIdentityMap(rank, context);
2179 }
2180 
2182 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2183  MLIRContext *context) {
2185  res.reserve(num);
2186  for (unsigned i = 0; i < num; ++i)
2187  res.push_back(getAffineDimExpr(startIdx++, context));
2188  return res;
2189 }
2190 
2193  auto rangeA = llvm::make_range(a.begin(), a.end());
2194  auto rangeB = llvm::make_range(b.begin(), b.end());
2195  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2196  return llvm::to_vector<4>(concatRanges);
2197 }
2198 
2199 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2200  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2201  ss << "view";
2202  for (auto size : memref.getShape())
2203  if (size < 0)
2204  ss << "sx";
2205  else
2206  ss << size << "x";
2207  if (failed(appendMangledType(ss, memref.getElementType())))
2208  return failure();
2209  if (auto as = memref.getMemorySpace()) {
2210  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2211  ss << "as" << attr.getInt();
2212  else
2213  return failure();
2214  }
2215  return success();
2216  }
2217  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2218  ss << "vector";
2219  llvm::interleave(
2220  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2221  if (failed(appendMangledType(ss, vec.getElementType())))
2222  return failure();
2223  return success();
2224  }
2225  if (t.isSignlessIntOrIndexOrFloat()) {
2226  ss << t;
2227  return success();
2228  }
2229  return failure();
2230 }
2231 
2233  assert(isa<LinalgOp>(op));
2234  std::string name(op->getName().getStringRef().str());
2235  std::string fun = "";
2236  for (NamedAttribute kv : op->getAttrs()) {
2237  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2238  fun = stringifyEnum(ufa.getValue()).str() + "_";
2239  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2240  fun = stringifyEnum(bfa.getValue()).str() + "_";
2241  }
2242  }
2243  name.reserve(128);
2244  std::replace(name.begin(), name.end(), '.', '_');
2245  llvm::raw_string_ostream ss(name);
2246  ss << "_" << fun;
2247  for (Type t : op->getOperandTypes()) {
2248  if (failed(appendMangledType(ss, t)))
2249  return std::string();
2250  ss << "_";
2251  }
2252  std::string res = ss.str();
2253  res.pop_back();
2254  return res;
2255 }
2256 
2257 //===----------------------------------------------------------------------===//
2258 // Canonicalizers and Folders.
2259 //===----------------------------------------------------------------------===//
2260 
2261 namespace {
2262 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2264 
2265  LogicalResult matchAndRewrite(LinalgOp op,
2266  PatternRewriter &rewriter) const override {
2267  for (OpOperand &opOperand : op->getOpOperands()) {
2268  // Linalg "inputs" may be either tensor or memref type.
2269  // tensor<0xelt_type> is a convention that may not always mean
2270  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2271  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2272  if (!mt)
2273  continue;
2274  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2275  rewriter.eraseOp(op);
2276  return success();
2277  }
2278  }
2279  return failure();
2280  }
2281 };
2282 
2283 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2284 /// result that is more static than the linalg op.
2285 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2287 
2288  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2289  PatternRewriter &rewriter) const override {
2290  if (!tensor::canFoldIntoProducerOp(castOp))
2291  return failure();
2292 
2293  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2294  if (!linalgOp)
2295  return failure();
2296 
2297  // Cast can be in conditionally reachable region, if which case folding will
2298  // generate invalid code. Only conservatively fold ops in same block for
2299  // now.
2300  if (castOp->getBlock() != linalgOp->getBlock())
2301  return failure();
2302 
2303  OpBuilder::InsertionGuard guard(rewriter);
2304  rewriter.setInsertionPoint(linalgOp);
2305 
2306  Location loc = linalgOp.getLoc();
2307  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2308  unsigned resultNumber = resultValue.getResultNumber();
2309  auto resultType =
2310  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2311  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2312  // going from a more dynamic shape to a less dynamic shape. If the producer
2313  // for this cast, i.e. producer of the out operand, is also an operation
2314  // that folds with tensor.cast consumer (like this pattern), the cast will
2315  // continue to propagate as far up the stack as it can go.
2316  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2317  Value newOperand =
2318  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2319  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2320  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2321  linalgOp.getDpsInits().end());
2322  outputOperands[resultNumber] = newOperand;
2323  newOperands.append(outputOperands.begin(), outputOperands.end());
2324 
2325  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2326  linalgOp->result_type_end());
2327  resultTypes[resultNumber] = resultType;
2328  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2329 
2330  // Create a tensor.cast operation back to the original type.
2331  Value castBack = rewriter.create<tensor::CastOp>(
2332  loc, resultValue.getType(), newOp->getResult(resultNumber));
2333 
2334  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2335  results[resultNumber] = castBack;
2336  rewriter.replaceOp(linalgOp, results);
2337  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2338  return success();
2339  }
2340 };
2341 
2342 /// For each of the operand in `operands` this function maps the static sizes of
2343 /// dimensions to their affine dim expressions.
2344 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2345  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2346  for (OpOperand &opOperand : operands) {
2347  if (linalgOp.isScalar(&opOperand))
2348  continue;
2349  Value src = opOperand.get();
2350  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2351  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2352 
2353  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2354  // `tensor.cast` operation and source of the cast operation has a static
2355  // shape, then assign it to the `sourceShape`.
2356  auto *parentOp = src.getDefiningOp();
2357  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2358  if (parentOp) {
2359  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2360  Value castSource = castOp.getSource();
2361  auto castSourceType =
2362  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2363  if (castSourceType && castSourceType.hasStaticShape())
2364  sourceShape = castSourceType.getShape();
2365  }
2366  }
2367 
2368  // If the source shape's dimension has a static shape, map the affine dim
2369  // expression to the known static size.
2370  for (unsigned i = 0; i < sourceShape.size(); i++) {
2371  if (sourceType.isDynamicDim(i))
2372  continue;
2373  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2374  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2375  }
2376  }
2377 }
2378 
2379 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2380 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2381 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2382 /// change then `changeNeeded` is false and same operand is added in the
2383 /// `newOperands` list.
2384 static void createNewOperandWithStaticSizes(
2385  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2386  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2387  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2388  bool &changeNeeded) {
2389  Value src = opOperand->get();
2390  newOperands.push_back(src);
2391  if (linalgOp.isScalar(opOperand))
2392  return;
2393  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2394  Type resultType = sourceType;
2395  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2396  resultTypes.push_back(resultType);
2397  return;
2398  }
2399  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2400  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2401  SmallVector<int64_t> newShape;
2402  // If operand is updated with new shape, `newOperandNeeded` will be
2403  // true.
2404  bool newOperandNeeded = false;
2405  for (unsigned i = 0; i < sourceShape.size(); i++) {
2406  int64_t dimShape = sourceShape[i];
2407  AffineExpr dimExpr = sourceMap.getResult(i);
2408  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2409  newShape.push_back(dimShape);
2410  continue;
2411  }
2412  // Dimension has a dynamic shape and corresponding affine dim
2413  // expression is present in the map. So assign the size for the
2414  // given affine dim expression to the dimension.
2415  newShape.push_back(affineExprToSize[dimExpr]);
2416  newOperandNeeded = true;
2417  }
2418  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2419  if (newOperandNeeded) {
2420  changeNeeded = true;
2421  // Get the new operand value given its size and element type by
2422  // casting it.
2423  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2424  unsigned index = opOperand->getOperandNumber();
2425  newOperands[index] = newOperand;
2426  }
2427  if (linalgOp.isDpsInit(opOperand))
2428  resultTypes.push_back(resultType);
2429 }
2430 
2431 /// Static shapes for the operands can be inferred if any one of the operands
2432 /// have a static shape. This can be done by referring to the affine dim
2433 /// expressions for the operand.
2434 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2436 
2437  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2438  PatternRewriter &rewriter) const override {
2439  if (!linalgOp.hasPureTensorSemantics())
2440  return failure();
2441 
2442  // Maps must be projected permutations.
2443  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2444  return !map.isProjectedPermutation();
2445  }))
2446  return failure();
2447 
2448  // Maps affine dim expressions to the static size of that dimension.
2449  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2450  Location loc = linalgOp.getLoc();
2451 
2452  // For each of the affine dim expression, check if the size is known. If
2453  // known add that in the map.
2454  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2455 
2456  SmallVector<Value> newOperands;
2457  SmallVector<Type> resultTypes;
2458 
2459  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2460  // change in their types.
2461  bool changeNeeded = false;
2462  newOperands.reserve(linalgOp->getNumOperands());
2463  resultTypes.reserve(linalgOp.getNumDpsInits());
2464 
2465  // Iterate over all the operands and update the static sizes.
2466  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2467  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2468  affineExprToSize, linalgOp, newOperands,
2469  resultTypes, changeNeeded);
2470  }
2471 
2472  // If the generic op has all the required static information, no
2473  // canonicalization needed.
2474  if (!changeNeeded)
2475  return failure();
2476 
2477  // Clone op.
2478  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2479  SmallVector<Value> replacements;
2480  replacements.reserve(newOp->getNumResults());
2481  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2482  Value newResult = std::get<1>(it);
2483  Value oldResult = std::get<0>(it);
2484  Type newType = newResult.getType();
2485  Type oldType = oldResult.getType();
2486  replacements.push_back(
2487  (newType != oldType)
2488  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2489  : newResult);
2490  }
2491  rewriter.replaceOp(linalgOp, replacements);
2492  return success();
2493  }
2494 };
2495 
2496 } // namespace
2497 
2498 // All named ops canonicalizers and folders are auto-generated in the
2499 // .cpp.inc.
2500 
2501 //===----------------------------------------------------------------------===//
2502 // SoftmaxOp
2503 //===----------------------------------------------------------------------===//
2504 
2505 LogicalResult SoftmaxOp::verify() {
2506  ShapedType inputType = getInputOperandType();
2507  ShapedType outputType = getOutputOperandType();
2508 
2509  ArrayRef<int64_t> inputShape = inputType.getShape();
2510  ArrayRef<int64_t> outputShape = outputType.getShape();
2511  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2512  return emitOpError("incompatible output shape");
2513 
2514  int64_t inputRank = getInputOperandRank();
2515  int64_t dimension = getDimension();
2516  if ((dimension < 0) || (dimension >= inputRank))
2517  return emitOpError("incorrect dimension specified");
2518 
2519  return success();
2520 }
2521 
2522 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2523  int64_t operandRank = getInputOperandRank();
2524  SmallVector<Range> loopBounds(operandRank);
2525  Location loc = getLoc();
2526  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2527  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2528  Value source = getInput();
2529  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2530  loopBounds[dim].offset = zero;
2531  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2532  loopBounds[dim].stride = one;
2533  }
2534  return loopBounds;
2535 }
2536 
2537 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2538  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2539  utils::IteratorType::parallel);
2540  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2541  return iteratorTypes;
2542 }
2543 
2544 FailureOr<TilingResult>
2545 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2546  ArrayRef<OpFoldResult> offsets,
2547  ArrayRef<OpFoldResult> sizes) {
2548  int64_t rank = getInputOperandRank();
2549  auto oneAttr = builder.getI64IntegerAttr(1);
2550  SmallVector<OpFoldResult> strides(rank, oneAttr);
2551  SmallVector<Value> tiledOperands;
2552  tiledOperands.emplace_back(
2553  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2554  tiledOperands.emplace_back(
2555  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2556 
2557  SmallVector<Type, 4> resultTypes;
2558  if (hasPureTensorSemantics())
2559  resultTypes.push_back(tiledOperands[1].getType());
2560  Operation *tiledOp =
2561  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2562 
2563  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2564 }
2565 
2566 LogicalResult SoftmaxOp::getResultTilePosition(
2567  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2568  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2569  SmallVector<OpFoldResult> &resultSizes) {
2570  if (resultNumber == 0) {
2571  resultOffsets.assign(offsets.begin(), offsets.end());
2572  resultSizes.assign(sizes.begin(), sizes.end());
2573  return success();
2574  }
2575  return failure();
2576 }
2577 
2578 // cast(dynamic) -> static.
2579 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2580  return memref::foldMemRefCast(*this);
2581 }
2582 
2583 LogicalResult
2585  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2587  Location loc = getOperation()->getLoc();
2588  IRRewriter rewriter(b);
2589  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2590  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2591  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2592  if (!outputShapedType.isDynamicDim(dim)) {
2593  // Static dim: Return IntegerAttr.
2594  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2595  } else {
2596  // Dynamic dim: Return Value.
2597  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2598  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2599  }
2600  }
2601  reifiedReturnShapes.emplace_back(std::move(shapes));
2602  return success();
2603 }
2604 
2605 void SoftmaxOp::getEffects(
2607  &effects) {
2608  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2609  if (!llvm::isa<MemRefType>(operand.getType()))
2610  continue;
2611  effects.emplace_back(MemoryEffects::Read::get(),
2612  &getOperation()->getOpOperand(index), /*stage=*/0,
2613  /*effectOnFullRegion=*/true,
2615  }
2616 
2617  for (OpOperand &operand : getDpsInitsMutable()) {
2618  if (!llvm::isa<MemRefType>(operand.get().getType()))
2619  continue;
2620  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2621  /*effectOnFullRegion=*/true,
2623  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2624  /*effectOnFullRegion=*/true,
2626  }
2627 }
2628 
2629 // Helper functions for softmax decomposition.
2630 // @{
2631 
2632 // Helper function to produce the iterator types (reduction or parallel) and
2633 // affine maps for the iterators used in the decomposition of softmax.
2634 // This method creates:
2635 // If allParallel == true:
2636 // - iterator type: {parallel, ..., parallel}
2637 // - affine maps:
2638 // -- identity with inputRank dimensions.
2639 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2640 // where N == inputRank.
2641 //
2642 // If allParallel == false:
2643 // - iterator type at dim(i) == parallel for i != \p dim and
2644 // dim(dim) == reduction.
2645 // - affine map:
2646 // -- identity with inputRank dimensions.
2647 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2648 // where N == inputRank.
2649 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2651  int64_t dim, bool allParallel = false) {
2652  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2653  utils::IteratorType::parallel);
2654  if (!allParallel)
2655  iteratorTypes[dim] = utils::IteratorType::reduction;
2656  MLIRContext *ctxt = builder.getContext();
2657  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2658  SmallVector<AffineExpr, 2> affineExprs;
2659  for (int i = 0; i < inputRank; i++) {
2660  if (i != dim)
2661  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2662  }
2663  auto reductionMap =
2664  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2665  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2666  return std::make_tuple(iteratorTypes, indexingMaps);
2667 }
2668 
2669 // Helper function to produce a linalg.generic that computes a reduction on
2670 // dimension \p dim with the operation type \p T.
2671 template <typename T>
2672 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2673  int64_t dim) {
2674  auto inputType = cast<ShapedType>(input.getType());
2675  ArrayRef<int64_t> inputShape = inputType.getShape();
2676  int64_t inputRank = inputShape.size();
2677  auto [iteratorTypes, indexingMaps] =
2678  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2679  assert(indexingMaps.size() == 2 &&
2680  "We should have two maps: 1 for the input, 1 for the output");
2681  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2682 
2683  auto genericOp = builder.create<linalg::GenericOp>(
2684  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2685  [&](OpBuilder &b, Location loc, ValueRange args) {
2686  Value result = b.create<T>(loc, args[0], args[1]);
2687  b.create<linalg::YieldOp>(loc, result);
2688  });
2689  return genericOp.getResult(0);
2690 }
2691 
2692 /// Produce a linalg generic that computes the second step of the softmax
2693 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2694 /// on dimension \p dim.
2695 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2696  Value max, Value output, int64_t dim) {
2697  auto inputType = cast<ShapedType>(input.getType());
2698  ArrayRef<int64_t> inputShape = inputType.getShape();
2699  int64_t inputRank = inputShape.size();
2700  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2701  builder, inputRank, dim, /*allParallel=*/true);
2702  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2703  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2704  // Add the affine map for the output argument.
2705  indexingMaps.push_back(indexingMaps[0]);
2706  auto genericOp = builder.create<linalg::GenericOp>(
2707  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2708  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2709  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2710  Value result = b.create<math::ExpOp>(loc, diff);
2711  b.create<linalg::YieldOp>(loc, result);
2712  });
2713  return genericOp.getResult(0);
2714 }
2715 
2716 /// Produce a linalg generic that computes the final step of the softmax
2717 /// decomposition.
2718 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2719 /// yield n / d
2720 /// }
2721 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2722  Value denominator, Value output, int64_t dim) {
2723  auto inputType = cast<ShapedType>(numerator.getType());
2724  ArrayRef<int64_t> inputShape = inputType.getShape();
2725  int64_t inputRank = inputShape.size();
2726  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2727  builder, inputRank, dim, /*allParallel=*/true);
2728  assert(indexingMaps.size() == 2 &&
2729  "We should have one map for each input (2)");
2730  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2731  // Add the affine map for the output tensor.
2732  indexingMaps.push_back(indexingMaps[0]);
2733  auto genericOp = builder.create<linalg::GenericOp>(
2734  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2735  indexingMaps, iteratorTypes,
2736  [&](OpBuilder &b, Location loc, ValueRange args) {
2737  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2738  b.create<linalg::YieldOp>(loc, result);
2739  });
2740  return genericOp.getResult(0);
2741 }
2742 // @} End helper functions for softmax decomposition.
2743 
2744 /// Given an N-dimensional tensor x, this method converts
2745 /// softmax(x) to the following sequence of operations:
2746 ///
2747 /// 1. Compute the max of x along dimension d. This results
2748 /// in a N-1 dimensional tensor m.
2749 /// m = max(x, dim = d)
2750 ///
2751 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2752 /// a N dimensional tensor z.
2753 /// z = exp(x - m)
2754 ///
2755 /// 3. Compute the sum of z along dimension d. This results in
2756 /// a N-1 dimensional tensor l.
2757 /// l = sum(z, dim = d)
2758 ///
2759 /// 4. Divide z and l. This gives the N-dimensional softmax.
2760 /// softmax = z / l
2761 ///
2762 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2763  OpBuilder::InsertionGuard guard(b);
2764  b.setInsertionPoint(*this);
2765  Location loc = getLoc();
2766  Value input = getInput();
2767  ShapedType inputType = getInputOperandType();
2768  Type elementType = inputType.getElementType();
2769  int64_t reductionDim = getDimension();
2770  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2771  Value output = getOutput();
2772  dims.erase(dims.begin() + reductionDim);
2773  // Step 1: Compute max along dim.
2774  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2775  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2776  elementType, b, loc,
2777  /*useOnlyFiniteValue=*/true);
2778  Value neutralForMaxFInit =
2779  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2780  .result();
2781  Value max =
2782  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2783 
2784  // Step 2: Subtract max from input and exponentiate.
2785  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2786 
2787  // Step 3: Compute sum along dim.
2788  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2789  b, loc, /*useOnlyFiniteValue=*/true);
2790  Value zeroInit =
2791  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2792  Value denominator =
2793  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2794 
2795  // Step 4: Compute softmax.
2796  Value result =
2797  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2798  return SmallVector<Value>{result};
2799 }
2800 
2801 //===----------------------------------------------------------------------===//
2802 // WinogradFilterTransformOp
2803 //===----------------------------------------------------------------------===//
2804 
2805 LogicalResult WinogradFilterTransformOp::verify() {
2806  auto filterType = cast<ShapedType>(getFilter().getType());
2807  ArrayRef<int64_t> filterShape = filterType.getShape();
2808  int64_t filterH = filterShape[1];
2809  int64_t filterW = filterShape[2];
2810  int64_t r = getR();
2811  int64_t m = getM();
2812 
2813  if (filterH != r && filterH != 1)
2814  return emitOpError("expect filter height either equals to r or 1");
2815  if (filterW != r && filterW != 1)
2816  return emitOpError("expect filter width either equals to r or 1");
2817  if (filterH == 1 && filterW == 1)
2818  return emitOpError("expect either filter height or width equals to r");
2819 
2820  SmallVector<int64_t> expectedOutputShape;
2821  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2822  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2823  expectedOutputShape.push_back(filterShape[3]);
2824  expectedOutputShape.push_back(filterShape[0]);
2825 
2826  auto outputType = cast<ShapedType>(getOutput().getType());
2827  ArrayRef<int64_t> outputShape = outputType.getShape();
2828  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2829  return emitOpError("the output shape is not expected");
2830  }
2831  return success();
2832 }
2833 
2834 //===----------------------------------------------------------------------===//
2835 // WinogradInputTransformOp
2836 //===----------------------------------------------------------------------===//
2837 
2838 LogicalResult WinogradInputTransformOp::verify() {
2839  auto inputType = cast<ShapedType>(getInput().getType());
2840  ArrayRef<int64_t> inputShape = inputType.getShape();
2841  int64_t inputH = inputShape[1];
2842  int64_t inputW = inputShape[2];
2843  int m = getM();
2844  int r = getR();
2845  int64_t tileSize = m + r - 1;
2846  bool leftTransform = inputH != 1;
2847  bool rightTransform = inputW != 1;
2848 
2849  SmallVector<int64_t> expectedOutputShape(6, inputH);
2850  if (ShapedType::isDynamic(inputH)) {
2851  expectedOutputShape[0] = tileSize;
2852  expectedOutputShape[2] = ShapedType::kDynamic;
2853  } else {
2854  expectedOutputShape[0] = leftTransform ? tileSize : 1;
2855  expectedOutputShape[2] = leftTransform ? (inputH - (r - 1)) / m : 1;
2856  }
2857  if (ShapedType::isDynamic(inputW)) {
2858  expectedOutputShape[1] = tileSize;
2859  expectedOutputShape[3] = ShapedType::kDynamic;
2860  } else {
2861  expectedOutputShape[1] = rightTransform ? tileSize : 1;
2862  expectedOutputShape[3] = rightTransform ? (inputW - (r - 1)) / m : 1;
2863  }
2864  expectedOutputShape[4] = inputShape[0];
2865  expectedOutputShape[5] = inputShape[3];
2866 
2867  auto outputType = cast<ShapedType>(getOutput().getType());
2868  ArrayRef<int64_t> outputShape = outputType.getShape();
2869  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2870  return emitOpError("the output shape is not expected");
2871  }
2872  return success();
2873 }
2874 
2875 //===----------------------------------------------------------------------===//
2876 // WinogradOutputTransformOp
2877 //===----------------------------------------------------------------------===//
2878 
2879 LogicalResult WinogradOutputTransformOp::verify() {
2880  auto valueType = cast<ShapedType>(getValue().getType());
2881  ArrayRef<int64_t> valueShape = valueType.getShape();
2882  int64_t valueH = valueShape[0];
2883  int64_t valueW = valueShape[1];
2884  int64_t valueTileH = valueShape[2];
2885  int64_t valueTileW = valueShape[3];
2886  int m = getM();
2887  int r = getR();
2888  bool leftTransform = valueH != 1;
2889  bool rightTransform = valueW != 1;
2890 
2891  SmallVector<int64_t> expectedOutputShape(4, valueH);
2892  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
2893  expectedOutputShape[1] = ShapedType::kDynamic;
2894  } else {
2895  if (valueH != (leftTransform ? m + r - 1 : 1))
2896  return emitOpError("expect input height equals to input tile size");
2897  expectedOutputShape[1] = (leftTransform ? m : 1) * valueTileH;
2898  }
2899  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
2900  expectedOutputShape[2] = ShapedType::kDynamic;
2901  } else {
2902  if (valueW != (rightTransform ? m + r - 1 : 1))
2903  return emitOpError("expect input width equals to input tile size");
2904  expectedOutputShape[2] = (rightTransform ? m : 1) * valueTileW;
2905  }
2906  expectedOutputShape[0] = valueShape[4];
2907  expectedOutputShape[3] = valueShape[5];
2908 
2909  auto outputType = cast<ShapedType>(getOutput().getType());
2910  ArrayRef<int64_t> outputShape = outputType.getShape();
2911  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2912  return emitOpError("the output shape is not expected");
2913  }
2914  return success();
2915 }
2916 
2917 //===----------------------------------------------------------------------===//
2918 // LinalgDialect
2919 //===----------------------------------------------------------------------===//
2920 
2921 void LinalgDialect::getCanonicalizationPatterns(
2922  RewritePatternSet &results) const {
2923  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2924  InferStaticShapeOfOperands>(getContext());
2925 }
2926 
2928  Attribute value, Type type,
2929  Location loc) {
2930  return arith::ConstantOp::materialize(builder, value, type, loc);
2931 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1741
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:271
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2650
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2721
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:119
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2199
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:259
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:330
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1574
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1622
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2695
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1382
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:155
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1251
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2672
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1122
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2112
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:297
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:897
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:290
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:51
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1305
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1411
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:323
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:185
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:128
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:183
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:187
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:398
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:132
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:273
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:277
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:329
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:351
This class helps build Operations.
Definition: Builders.h:210
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:401
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:441
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:523
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:468
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:126
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:108
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2585
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2191
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:98
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2232
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2171
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2182
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:89
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:76
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:78
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:135
Fraction abs(const Fraction &f)
Definition: Fraction.h:106
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:349
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1876
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1879
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:1901
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1904
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.