MLIR  20.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
29 #include "mlir/IR/AffineMap.h"
32 #include "mlir/IR/Matchers.h"
35 #include "mlir/IR/PatternMatch.h"
38 
39 #include "llvm/ADT/DenseMap.h"
40 #include "llvm/ADT/SmallSet.h"
41 #include "llvm/ADT/StringSet.h"
42 #include "llvm/ADT/TypeSwitch.h"
43 #include "llvm/Support/FormatVariadic.h"
44 #include "llvm/Support/MathExtras.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include <optional>
47 
48 using namespace mlir;
49 using namespace mlir::linalg;
50 
51 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
53  int64_t dim) {
54  auto type = cast<ShapedType>(v.getType());
55  if (!type.isDynamicDim(dim))
56  return builder.getIndexAttr(type.getDimSize(dim));
57 
58  return getAsOpFoldResult(
60  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
61  return builder.create<tensor::DimOp>(loc, v, dim);
62  })
63  .Case<MemRefType>([&](MemRefType t) -> Value {
64  return builder.create<memref::DimOp>(loc, v, dim);
65  }));
66 }
67 
68 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
69 /// `source`.
70 static Value getSlice(OpBuilder &b, Location loc, Value source,
71  ArrayRef<OpFoldResult> offsets,
73  ArrayRef<OpFoldResult> strides) {
74  return TypeSwitch<Type, Value>(source.getType())
75  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
76  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
77  strides);
78  })
79  .Case<MemRefType>([&](MemRefType type) -> Value {
80  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
81  strides);
82  })
83  .Default([&](Type t) { return nullptr; });
84 }
85 
86 //===----------------------------------------------------------------------===//
87 // Helper functions
88 //===----------------------------------------------------------------------===//
89 
91  int64_t dim) {
92  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
93  return b.createOrFold<memref::DimOp>(loc, source, dim);
94  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
95  return b.createOrFold<tensor::DimOp>(loc, source, dim);
96  llvm_unreachable("Expected MemRefType or TensorType");
97 }
98 
100  int64_t dim) {
101  auto shapedType = llvm::cast<ShapedType>(source.getType());
102  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
103  return createOrFoldDimOp(b, loc, source, dim);
104  return b.getIndexAttr(shapedType.getDimSize(dim));
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Support for named Linalg ops defined in ods-gen.
109 //===----------------------------------------------------------------------===//
110 
113 
114 /// Fills the region of a structured operation using the provided
115 /// `regionBuilder`. The method is used by both named structured ops created by
116 /// ods-gen and by manually defined C++ ops. It is called by both builders and
117 /// parsers and creates a block with arguments corresponding to the elemental
118 /// types of `inputTypes` and `outputTypes`. All output types are asserted to be
119 /// ShapedType.
120 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
121  TypeRange inputTypes, TypeRange outputTypes,
123  RegionBuilderFn regionBuilder) {
124  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
125 
126  SmallVector<Type, 8> argTypes;
127  SmallVector<Location, 8> argLocs;
128  for (auto containers : {inputTypes, outputTypes}) {
129  for (auto t : containers) {
130  argTypes.push_back(
131  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
132 
133  // TODO: Pass in a proper location here.
134  argLocs.push_back(opBuilder.getUnknownLoc());
135  }
136  }
137 
138  // RAII.
139  OpBuilder::InsertionGuard guard(opBuilder);
140  Block *body =
141  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
142 
143  opBuilder.setInsertionPointToStart(body);
144  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
145  regionBuilder(b, *body, attrs);
146 
147  // indexing_maps is an auto-generated method.
148 
149  // iterator_types is an auto-generated method.
150 }
151 
152 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
153 /// The result types are derived automatically if `resultTensorTypes` is none.
154 /// The body of the operation is filled using `regionBuilder`. All ods-gen
155 /// created structured operations use the method to implement their builders.
157  std::optional<TypeRange> resultTensorTypes,
158  ValueRange inputs, ValueRange outputs,
159  ArrayRef<NamedAttribute> attributes,
160  RegionBuilderFn regionBuilder) {
161  // Derive the result types if needed.
162  SmallVector<Type> derivedResultTypes =
163  resultTensorTypes.value_or(TypeRange());
164  if (!resultTensorTypes)
165  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
166  llvm::IsaPred<RankedTensorType>);
167 
168  state.addOperands(inputs);
169  state.addOperands(outputs);
170  state.addTypes(derivedResultTypes);
171  state.addAttributes(attributes);
172  state.addAttribute(
173  "operandSegmentSizes",
174  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
175  static_cast<int32_t>(outputs.size())}));
176 
177  // Create and fill the region of the structured operation.
178  Region &region = *state.addRegion();
179  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
180  state.attributes.getAttrs(), regionBuilder);
181 }
182 
183 /// Common parsing used for both named structured ops created by ods-gen and by
184 /// manually defined C++ ops. Does not handle regions.
185 static ParseResult
187  SmallVectorImpl<Type> &inputTypes,
188  SmallVectorImpl<Type> &outputTypes,
189  bool addOperandSegmentSizes = true) {
190  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
192  outputsOperands;
193 
194  if (succeeded(parser.parseOptionalLess())) {
195  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
196  return failure();
197  }
198  attrsLoc = parser.getCurrentLocation();
199  if (parser.parseOptionalAttrDict(result.attributes))
200  return failure();
201 
202  if (succeeded(parser.parseOptionalKeyword("ins"))) {
203  if (parser.parseLParen())
204  return failure();
205 
206  inputsOperandsLoc = parser.getCurrentLocation();
207  if (parser.parseOperandList(inputsOperands) ||
208  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
209  return failure();
210  }
211 
212  if (succeeded(parser.parseOptionalKeyword("outs"))) {
213  outputsOperandsLoc = parser.getCurrentLocation();
214  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
215  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
216  return failure();
217  }
218 
219  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
220  result.operands) ||
221  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
222  result.operands))
223  return failure();
224 
225  if (addOperandSegmentSizes) {
226  // This is a bit complex because we're trying to be backward compatible with
227  // operation syntax that mix the inherent attributes and the discardable
228  // ones in the same dictionary. If the properties are used, we append the
229  // operandSegmentSizes there directly. Otherwise we append it to the
230  // discardable attributes dictionary where it is handled by the generic
231  // Operation::create(...) method.
232  if (result.propertiesAttr) {
233  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
234  attrs.append("operandSegmentSizes",
236  {static_cast<int32_t>(inputsOperands.size()),
237  static_cast<int32_t>(outputsOperands.size())}));
238  result.propertiesAttr = attrs.getDictionary(parser.getContext());
239  } else {
240  result.addAttribute("operandSegmentSizes",
242  {static_cast<int32_t>(inputsOperands.size()),
243  static_cast<int32_t>(outputsOperands.size())}));
244  }
245  }
246  if (!result.propertiesAttr) {
247  std::optional<RegisteredOperationName> info =
248  result.name.getRegisteredInfo();
249  if (info) {
250  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
251  return parser.emitError(attrsLoc)
252  << "'" << result.name.getStringRef() << "' op ";
253  })))
254  return failure();
255  }
256  }
257  return success();
258 }
259 
261  ValueRange outputs) {
262  if (!inputs.empty())
263  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
264  if (!outputs.empty())
265  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // Specific parsing and printing for named structured ops created by ods-gen.
270 //===----------------------------------------------------------------------===//
271 
272 static ParseResult parseNamedStructuredOpRegion(
273  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
274  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
275  RegionBuilderFn regionBuilder) {
276  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
277  return parser.emitError(
278  parser.getCurrentLocation(),
279  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
280  "region expects {0} args, got {1}",
281  numRegionArgs, inputTypes.size() + outputTypes.size()));
282  }
283 
284  OpBuilder opBuilder(parser.getContext());
285  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
286  regionBuilder);
287  return success();
288 }
289 
290 static ParseResult
292  SmallVectorImpl<Type> &resultTypes) {
293  if (parser.parseOptionalArrowTypeList(resultTypes))
294  return failure();
295  return success();
296 }
297 
298 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
299  OperationState &result,
300  unsigned numRegionArgs,
301  RegionBuilderFn regionBuilder) {
302  // TODO: Enable when ods-gen supports captures.
303  SmallVector<Type, 1> inputTypes, outputTypes;
304  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
305  return failure();
306 
307  // TODO: consider merging results parsing into region parsing.
308  // Need to wait for declarative assembly resolution to decide.
309  SmallVector<Type, 1> outputTensorsTypes;
310  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
311  return failure();
312  result.addTypes(outputTensorsTypes);
313 
314  std::unique_ptr<Region> region = std::make_unique<Region>();
315  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
316  outputTypes, result.attributes.getAttrs(),
317  regionBuilder))
318  return failure();
319  result.addRegion(std::move(region));
320 
321  return success();
322 }
323 
325  TypeRange resultTypes) {
326  if (resultTypes.empty())
327  return;
328  p.printOptionalArrowTypeList(resultTypes);
329 }
330 
332  ValueRange inputs, ValueRange outputs) {
334  op->getAttrs(),
335  /*elidedAttrs=*/{"operandSegmentSizes",
336  // See generated code in
337  // LinalgNamedStructuredOps.yamlgen.cpp.inc
338  "linalg.memoized_indexing_maps"});
339 
340  // Printing is shared with generic ops, except for the region and
341  // attributes.
342  printCommonStructuredOpParts(p, inputs, outputs);
343 
344  // Results printing.
346 
347  // Region is elided.
348 }
349 
350 //===----------------------------------------------------------------------===//
351 // Region builder helper.
352 // TODO: Move this to a utility library.
353 // The public methods on this class are referenced directly from generated code.
354 // Helper build the unary, binary, and type conversion functions defined by the
355 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
356 // class.
357 //
358 // Implementations of the math functions must be polymorphic over numeric types,
359 // internally performing necessary casts. If the function application makes no
360 // sense, then the only recourse is to assert and return nullptr. This can be
361 // extended later if it becomes possible to fail construction of the region. The
362 // invariant should be enforced at a higher level.
363 //
364 // TODO: These helpers are currently type polymorphic over the class of integer
365 // and floating point types, but they will not internally cast within bit
366 // widths of a class (mixed precision such as i8->i32) or across classes
367 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
368 // to be handled with care and work is being considered to extend the op
369 // language to make such cases explicit. In the mean-time, violating this will
370 // fail verification, which is deemed acceptable.
371 //===----------------------------------------------------------------------===//
372 
373 namespace {
374 
375 class RegionBuilderHelper {
376 public:
377  RegionBuilderHelper(OpBuilder &builder, Block &block)
378  : builder(builder), block(block) {}
379 
380  // Build the unary functions defined by OpDSL.
381  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
382  if (!isFloatingPoint(arg))
383  llvm_unreachable("unsupported non numeric type");
384  OpBuilder::InsertionGuard g(builder);
385  builder.setInsertionPointToEnd(&block);
386  switch (unaryFn) {
387  case UnaryFn::exp:
388  return builder.create<math::ExpOp>(arg.getLoc(), arg);
389  case UnaryFn::log:
390  return builder.create<math::LogOp>(arg.getLoc(), arg);
391  case UnaryFn::abs:
392  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
393  case UnaryFn::ceil:
394  return builder.create<math::CeilOp>(arg.getLoc(), arg);
395  case UnaryFn::floor:
396  return builder.create<math::FloorOp>(arg.getLoc(), arg);
397  case UnaryFn::negf:
398  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
399  case UnaryFn::reciprocal: {
400  Attribute oneAttr = builder.getOneAttr(arg.getType());
401  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
402  ::cast<TypedAttr>(oneAttr));
403  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
404  }
405  case UnaryFn::round:
406  return builder.create<math::RoundOp>(arg.getLoc(), arg);
407  case UnaryFn::sqrt:
408  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
409  case UnaryFn::rsqrt:
410  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
411  case UnaryFn::square:
412  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
413  case UnaryFn::tanh:
414  return builder.create<math::TanhOp>(arg.getLoc(), arg);
415  case UnaryFn::erf:
416  return builder.create<math::ErfOp>(arg.getLoc(), arg);
417  }
418  llvm_unreachable("unsupported unary function");
419  }
420 
421  // Build the binary functions defined by OpDSL.
422  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
423  bool allComplex = isComplex(arg0) && isComplex(arg1);
424  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
425  bool allInteger = isInteger(arg0) && isInteger(arg1);
426  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
427  arg1.getType().getIntOrFloatBitWidth() == 1;
428  if (!allComplex && !allFloatingPoint && !allInteger)
429  llvm_unreachable("unsupported non numeric type");
430  OpBuilder::InsertionGuard g(builder);
431  builder.setInsertionPointToEnd(&block);
432  switch (binaryFn) {
433  case BinaryFn::add:
434  if (allComplex)
435  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
436  if (allFloatingPoint)
437  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
438  if (allBool)
439  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
440  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
441  case BinaryFn::sub:
442  if (allComplex)
443  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
444  if (allFloatingPoint)
445  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
446  if (allBool)
447  llvm_unreachable("unsupported operation: sub with bools");
448  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
449  case BinaryFn::mul:
450  if (allComplex)
451  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
452  if (allFloatingPoint)
453  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
454  if (allBool)
455  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
456  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
457  case BinaryFn::div:
458  if (allComplex)
459  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
460  if (allFloatingPoint)
461  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
462  if (allBool)
463  llvm_unreachable("unsupported operation: div with bools");
464  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
465  case BinaryFn::div_unsigned:
466  if (!allInteger || allBool)
467  llvm_unreachable("unsupported operation: unsigned div not on uint");
468  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
469  case BinaryFn::max_signed:
470  assert(!allComplex);
471  if (allFloatingPoint)
472  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
473  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
474  case BinaryFn::min_signed:
475  assert(!allComplex);
476  if (allFloatingPoint)
477  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
478  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
479  case BinaryFn::max_unsigned:
480  assert(!allComplex);
481  if (allFloatingPoint)
482  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
483  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
484  case BinaryFn::min_unsigned:
485  assert(!allComplex);
486  if (allFloatingPoint)
487  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
488  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
489  case BinaryFn::powf:
490  assert(allFloatingPoint);
491  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
492  }
493  llvm_unreachable("unsupported binary function");
494  }
495 
496  // Build the ternary functions defined by OpDSL.
497  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
498  Value arg2) {
499  bool headBool =
500  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
501  bool tailFloatingPoint =
502  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
503  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
504  OpBuilder::InsertionGuard g(builder);
505  builder.setInsertionPointToEnd(&block);
506  switch (ternaryFn) {
507  case TernaryFn::select:
508  if (!headBool && !(tailFloatingPoint || tailInteger))
509  llvm_unreachable("unsupported non numeric type");
510  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
511  }
512  llvm_unreachable("unsupported ternary function");
513  }
514 
515  // Build the type functions defined by OpDSL.
516  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
517  switch (typeFn) {
518  case TypeFn::cast_signed:
519  return cast(toType, operand, false);
520  case TypeFn::cast_unsigned:
521  return cast(toType, operand, true);
522  }
523  llvm_unreachable("unsupported type conversion function");
524  }
525 
526  void yieldOutputs(ValueRange values) {
527  OpBuilder::InsertionGuard g(builder);
528  builder.setInsertionPointToEnd(&block);
529  Location loc = builder.getUnknownLoc();
530  builder.create<YieldOp>(loc, values);
531  }
532 
533  Value constant(const std::string &value) {
534  OpBuilder::InsertionGuard g(builder);
535  builder.setInsertionPointToEnd(&block);
536  Location loc = builder.getUnknownLoc();
537  Attribute valueAttr = parseAttribute(value, builder.getContext());
538  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
539  }
540 
541  Value index(int64_t dim) {
542  OpBuilder::InsertionGuard g(builder);
543  builder.setInsertionPointToEnd(&block);
544  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
545  }
546 
547  Type getIntegerType(unsigned width) {
548  return IntegerType::get(builder.getContext(), width);
549  }
550 
551  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
552  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
553 
554 private:
555  // Generates operations to cast the given operand to a specified type.
556  // If the cast cannot be performed, a warning will be issued and the
557  // operand returned as-is (which will presumably yield a verification
558  // issue downstream).
559  Value cast(Type toType, Value operand, bool isUnsignedCast) {
560  OpBuilder::InsertionGuard g(builder);
561  builder.setInsertionPointToEnd(&block);
562  auto loc = operand.getLoc();
563  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
564  }
565 
566  bool isComplex(Value value) {
567  return llvm::isa<ComplexType>(value.getType());
568  }
569  bool isFloatingPoint(Value value) {
570  return llvm::isa<FloatType>(value.getType());
571  }
572  bool isInteger(Value value) {
573  return llvm::isa<IntegerType>(value.getType());
574  }
575 
576  OpBuilder &builder;
577  Block &block;
578 };
579 
580 } // namespace
581 
582 //===----------------------------------------------------------------------===//
583 // CopyOp
584 //===----------------------------------------------------------------------===//
585 
586 namespace {
587 
588 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
590  LogicalResult matchAndRewrite(CopyOp copyOp,
591  PatternRewriter &rewriter) const override {
592  if (copyOp.getInputs() != copyOp.getOutputs())
593  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
594  if (copyOp.hasPureBufferSemantics())
595  rewriter.eraseOp(copyOp);
596  else
597  rewriter.replaceOp(copyOp, copyOp.getInputs());
598 
599  return success();
600  }
601 };
602 
603 } // namespace
604 
605 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
606  MLIRContext *context) {
607  results.add<EraseSelfCopy>(context);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // FillOp
612 //===----------------------------------------------------------------------===//
613 
614 namespace {
615 
616 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
617 ///
618 /// For such op chains, we can create new linalg.fill ops with the result
619 /// type of the tensor.expand/collapse_shape op.
620 template <typename TensorReshapeOp>
621 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
623  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
624  PatternRewriter &rewriter) const override {
625  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
626  if (!oldFill)
627  return failure();
628 
629  Location loc = oldFill.getLoc();
630  TensorReshapeOp newInit;
631  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
632 
633  newInit = rewriter.create<TensorReshapeOp>(
634  loc, reshapeOp.getResultType(), oldFill.output(),
635  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
636  reshapeOp.getStaticOutputShape());
637  } else {
638  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
639  oldFill.output(),
640  reshapeOp.getReassociation());
641  }
642  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
643  ValueRange{newInit});
644  return success();
645  }
646 };
647 
648 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
649 /// filling value are the same.
650 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
652 
653  LogicalResult matchAndRewrite(tensor::PadOp padOp,
654  PatternRewriter &rewriter) const override {
655  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
656  if (!fillOp)
657  return failure();
658 
659  // We can only fold if the padding value is the same as the original
660  // filling value.
661  Value padValue = padOp.getConstantPaddingValue();
662  if (!padValue || fillOp.value() != padValue)
663  return failure();
664 
665  ReifiedRankedShapedTypeDims reifiedShape;
666  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
667  return rewriter.notifyMatchFailure(
668  padOp, "failed to reify tensor.pad op result shape");
669 
670  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
671  padOp.getLoc(), reifiedShape.front(),
672  padOp.getResultType().getElementType());
673  Value replacement =
674  rewriter
675  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
676  ValueRange{emptyTensor})
677  .getResult(0);
678  if (replacement.getType() != padOp.getResultType()) {
679  replacement = rewriter.create<tensor::CastOp>(
680  fillOp.getLoc(), padOp.getResultType(), replacement);
681  }
682  rewriter.replaceOp(padOp, replacement);
683  return success();
684  }
685 };
686 
687 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
688 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
689 /// filling value are the same.
690 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
692 
693  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
694  PatternRewriter &rewriter) const override {
695  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
696  if (!srcPadOp)
697  return failure();
698 
699  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
700  return failure();
701 
702  // Walk back the tensor.insert_slice chain and find the first destination
703  // value at the start of the chain.
704  Value firstDest = insertOp.getDest();
705  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
706  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
707  return failure();
708 
709  // Make sure the range of values accessed are disjoint. Without this, we
710  // cannot fold tensor.pad away.
711  bool disjoint = false;
712  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
713  // If the dimension has dynamic offset/size, we cannot guarantee
714  // disjoint. So just skip it.
715  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
716  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
717  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
718  continue;
719 
720  // Get the range start and end, inclusively for both.
721  int64_t prevStart = prevOp.getStaticOffset(i);
722  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
723  prevOp.getStaticStride(i);
724  int64_t nextStart = insertOp.getStaticOffset(i);
725  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
726  insertOp.getStaticStride(i);
727  if (prevEnd < nextStart || nextEnd < prevStart) {
728  disjoint = true;
729  break;
730  }
731  }
732 
733  if (!disjoint)
734  break;
735  firstDest = prevOp.getDest();
736  }
737 
738  // Check whether the first destination is a fill op. For overlapped cases,
739  // this also cannot be true.
740  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
741  if (!dstFillOp)
742  return failure();
743 
744  // We can only fold if the padding value is the same as the original
745  // filling value.
746  Value padValue = srcPadOp.getConstantPaddingValue();
747  if (!padValue || dstFillOp.value() != padValue)
748  return failure();
749 
750  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
751  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
752 
753  Location loc = insertOp.getLoc();
754  MLIRContext *context = getContext();
755 
756  AffineExpr sym0, sym1;
757  bindSymbols(context, sym0, sym1);
758  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
759 
760  // Calculate the new offsets for the insert. It should be the old offsets
761  // plus low padding sizes.
762  SmallVector<OpFoldResult, 4> newOffsets;
763  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
764  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
765  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
766  }
767 
768  RankedTensorType srcPadType = srcPadOp.getSourceType();
770  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
771  if (srcPadType.isDynamicDim(i)) {
772  newSizes.push_back(
773  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
774  .getResult());
775  } else {
776  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
777  }
778  }
779 
780  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
781  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
782  newSizes, insertOp.getMixedStrides());
783  return success();
784  }
785 };
786 
787 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
788 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
789 public:
791 
792  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
793  PatternRewriter &rewriter) const override {
794  // See if tensor input of tensor.extract op is the result of a linalg.fill
795  // op.
796  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
797  if (!fillOp)
798  return failure();
799 
800  // Get scalar input operand of linalg.fill op.
801  Value extractedScalar = fillOp.getInputs()[0];
802 
803  // Replace tensor.extract op with scalar value used to fill the tensor.
804  rewriter.replaceOp(extractOp, extractedScalar);
805  return success();
806  }
807 };
808 
809 /// Folds pack(fill) into a single fill op if
810 /// 1. The pack op does not have padding value, or
811 /// 2. The filled value and padding value are the same.
812 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
813  tensor::PackOp packOp) {
814  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
815  if (!fillOp)
816  return failure();
817 
818  if (auto paddingValue = packOp.getPaddingValue())
819  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
820  return failure();
821 
822  Value packOpDest = packOp.getDest();
823  if (!packOpDest.hasOneUse())
824  return failure();
825 
826  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
827  packOp.getDest());
828 }
829 
830 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
831 struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
832 public:
833  FoldFillWithPack(MLIRContext *context)
834  : OpRewritePattern<tensor::PackOp>(context) {}
835 
836  LogicalResult matchAndRewrite(tensor::PackOp packOp,
837  PatternRewriter &rewriter) const override {
838  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
839  if (failed(fillOp))
840  return failure();
841  rewriter.replaceOp(packOp, fillOp.value().result());
842  return success();
843  }
844 };
845 
846 /// Fold fill with copy.
847 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
849 
850  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
851  PatternRewriter &rewriter) const override {
852  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
853  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
854  fillOp.getInputs(),
855  copyOp.getOutputs());
856  return success();
857  }
858  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
859  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
860  fillOp.getOutputs());
861  return success();
862  }
863  return failure();
864  }
865 };
866 
867 /// Fold fill with transpose.
868 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
870 
871  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
872  PatternRewriter &rewriter) const override {
873  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
874  rewriter.replaceOpWithNewOp<FillOp>(
875  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
876  transposeOp.getDpsInitOperand(0)->get());
877  return success();
878  }
879  return failure();
880  }
881 };
882 
883 /// Fold a concat with all elements being fills of the same value
884 /// into a fill of the concat result shape.
885 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
887 
888  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
889  PatternRewriter &rewriter) const override {
890  auto concatOperands = concatOp.getInputs();
891  if (concatOperands.empty()) {
892  return failure();
893  }
894 
895  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
896  if (!firstFillOp) {
897  return failure();
898  }
899  // Prefetch the fill value.
900  OpFoldResult firstFillVal =
901  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
902  // Collect all the outs values for the fill operations.
903  SmallVector<Value> allOuts;
904  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
905 
906  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
907  auto fillOp = v.getDefiningOp<linalg::FillOp>();
908  if (!fillOp) {
909  return false;
910  }
911 
912  OpFoldResult fillVal =
913  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
914  if (fillVal != firstFillVal)
915  return false;
916 
917  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
918  return true;
919  };
920  if (!llvm::all_of(concatOperands.drop_front(),
921  isDefinedByCompatibleFillOp)) {
922  return rewriter.notifyMatchFailure(
923  concatOp, "not all operands are defined by a compatible fill op");
924  }
925 
926  Value outsConcat = rewriter.create<tensor::ConcatOp>(
927  concatOp.getLoc(), concatOp.getDim(), allOuts);
928  rewriter.replaceOpWithNewOp<linalg::FillOp>(
929  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
930  return success();
931  }
932 };
933 
934 } // namespace
935 
936 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
937  MLIRContext *context) {
938  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
939  FoldFillWithPack, FoldFillWithPad,
940  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
941  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
942  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
943 }
944 
945 //===----------------------------------------------------------------------===//
946 // GenericOp
947 //===----------------------------------------------------------------------===//
948 
949 static void buildGenericRegion(
950  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
951  ValueRange outputs,
952  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
953  SmallVector<Type, 4> blockArgTypes;
954  SmallVector<Location, 4> blockArgLocs;
955  for (ValueRange container : {inputs, outputs}) {
956  for (Value v : container) {
957  Type t = v.getType();
958  blockArgTypes.push_back(
959  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
960  blockArgLocs.push_back(v.getLoc());
961  }
962  }
963 
964  OpBuilder::InsertionGuard guard(builder);
965  Block *bodyBlock =
966  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
967  bodyBuild(builder, loc, bodyBlock->getArguments());
968 }
969 
970 void GenericOp::getAsmBlockArgumentNames(Region &region,
971  OpAsmSetValueNameFn setNameFn) {
972  for (Value v : getRegionInputArgs())
973  setNameFn(v, "in");
974  for (Value v : getRegionOutputArgs())
975  setNameFn(v, "out");
976 }
977 
978 void GenericOp::build(
979  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
980  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
981  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
982  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
983  ArrayRef<NamedAttribute> attributes) {
984  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
985  iteratorTypes, doc, libraryCall);
986  result.addAttributes(attributes);
987  if (bodyBuild)
988  buildGenericRegion(builder, result.location, *result.regions.front(),
989  inputs, outputs, bodyBuild);
990 }
991 
992 void GenericOp::build(
993  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
994  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
995  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
996  StringRef libraryCall,
997  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
998  ArrayRef<NamedAttribute> attributes) {
999  build(builder, result, resultTensorTypes, inputs, outputs,
1000  builder.getAffineMapArrayAttr(indexingMaps),
1001  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1002  iteratorTypes,
1003  [&](utils::IteratorType iter) -> mlir::Attribute {
1004  return IteratorTypeAttr::get(builder.getContext(), iter);
1005  }))),
1006  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1007  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1008  bodyBuild, attributes);
1009 }
1010 
1011 void GenericOp::build(
1012  OpBuilder &builder, OperationState &result, ValueRange inputs,
1013  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1014  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1015  StringRef libraryCall,
1016  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1017  ArrayRef<NamedAttribute> attributes) {
1018  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1019  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1020 }
1021 
1022 void GenericOp::build(
1023  OpBuilder &builder, OperationState &result, ValueRange inputs,
1024  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1025  ArrayRef<utils::IteratorType> iteratorTypes,
1026  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1027  ArrayRef<NamedAttribute> attributes) {
1028  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1029  /*doc=*/"",
1030  /*libraryCall=*/"", bodyBuild, attributes);
1031 }
1032 
1033 void GenericOp::build(
1034  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1035  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1036  ArrayRef<utils::IteratorType> iteratorTypes,
1037  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1038  ArrayRef<NamedAttribute> attributes) {
1039  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1040  iteratorTypes,
1041  /*doc=*/"",
1042  /*libraryCall=*/"", bodyBuild, attributes);
1043 }
1044 
1045 void GenericOp::print(OpAsmPrinter &p) {
1046  p << " ";
1047 
1048  // Print extra attributes.
1049  auto genericAttrNames = linalgTraitAttrNames();
1050 
1051  llvm::StringSet<> genericAttrNamesSet;
1052  genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
1053  SmallVector<NamedAttribute, 8> genericAttrs;
1054  for (auto attr : (*this)->getAttrs()) {
1055  if (attr.getName() == getIteratorTypesAttrName()) {
1056  auto iteratorTypes =
1057  llvm::cast<ArrayAttr>(attr.getValue())
1058  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1059  // Convert IteratorType enums into the string representation. This is
1060  // needed, because tests still use the old format when 'iterator_types'
1061  // attribute is represented as an array of strings.
1062  // TODO: Remove this conversion once tests are fixed.
1063  SmallVector<Attribute> iteratorTypeNames =
1064  llvm::to_vector(llvm::map_range(
1065  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1066  return StringAttr::get(getContext(), stringifyIteratorType(t));
1067  }));
1068 
1069  genericAttrs.emplace_back(
1070  getIteratorTypesAttrName(),
1071  ArrayAttr::get(getContext(), iteratorTypeNames));
1072  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1073  genericAttrs.push_back(attr);
1074  }
1075  }
1076  if (!genericAttrs.empty()) {
1077  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1078  p << genericDictAttr;
1079  }
1080 
1081  // Printing is shared with named ops, except for the region and attributes
1082  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1083 
1084  genericAttrNames.push_back("operandSegmentSizes");
1085  genericAttrNamesSet.insert(genericAttrNames.back());
1086 
1087  bool hasExtraAttrs = false;
1088  for (NamedAttribute n : (*this)->getAttrs()) {
1089  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1090  break;
1091  }
1092  if (hasExtraAttrs) {
1093  p << " attrs = ";
1094  p.printOptionalAttrDict((*this)->getAttrs(),
1095  /*elidedAttrs=*/genericAttrNames);
1096  }
1097 
1098  // Print region.
1099  if (!getRegion().empty()) {
1100  p << ' ';
1101  p.printRegion(getRegion());
1102  }
1103 
1104  // Print results.
1105  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1106 }
1107 
1108 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1109  DictionaryAttr dictAttr;
1110  // Parse the core linalg traits that must check into a dictAttr.
1111  // The name is unimportant as we will overwrite result.attributes.
1112  // The core linalg traits must contain the information necessary to pass the
1113  // verifier.
1114  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1115  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1116  return failure();
1117  result.attributes.assign(dictAttr.getValue().begin(),
1118  dictAttr.getValue().end());
1119 
1120  // Convert array of string into an array of IteratorType enums. This is
1121  // needed, because tests still use the old format when 'iterator_types'
1122  // attribute is represented as an array of strings.
1123  // TODO: Remove this conversion once tests are fixed.
1124  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1125  result.attributes.get(getIteratorTypesAttrName(result.name)));
1126  if (!iteratorTypes) {
1127  return parser.emitError(attributeLocation)
1128  << "expected " << getIteratorTypesAttrName(result.name)
1129  << " array attribute";
1130  }
1131 
1132  SmallVector<Attribute> iteratorTypeAttrs;
1133 
1134  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1135  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1136  if (!maybeIteratorType.has_value())
1137  return parser.emitError(parser.getCurrentLocation())
1138  << "unexpected iterator_type (" << s << ")";
1139 
1140  iteratorTypeAttrs.push_back(
1141  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1142  }
1143  result.attributes.set(getIteratorTypesAttrName(result.name),
1144  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1145 
1146  // Parsing is shared with named ops, except for the region.
1147  SmallVector<Type, 1> inputTypes, outputTypes;
1148  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1149  return failure();
1150 
1151  // Optional attributes may be added.
1152  if (succeeded(parser.parseOptionalKeyword("attrs")))
1153  if (failed(parser.parseEqual()) ||
1154  failed(parser.parseOptionalAttrDict(result.attributes)))
1155  return failure();
1156 
1157  std::unique_ptr<Region> region = std::make_unique<Region>();
1158  if (parser.parseRegion(*region, {}))
1159  return failure();
1160  result.addRegion(std::move(region));
1161 
1162  // Generic ops may specify that a subset of its outputs are tensors. Such
1163  // outputs are specified in the result type.
1164  // TODO: may need to move output parsing before region parsing.
1165  // Need to wait for declarative assembly resolution to decide.
1166  SmallVector<Type, 1> outputTensorsTypes;
1167  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1168  return failure();
1169  result.addTypes(outputTensorsTypes);
1170 
1171  return success();
1172 }
1173 
1176  &effects,
1177  LinalgOp linalgOp) {
1178  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1179  if (!llvm::isa<MemRefType>(operand.getType()))
1180  continue;
1181  effects.emplace_back(
1182  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1183  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1184  }
1185 
1186  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1187  if (!llvm::isa<MemRefType>(operand.get().getType()))
1188  continue;
1189  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1190  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1191  /*effectOnFullRegion=*/true,
1193  }
1194  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1195  /*effectOnFullRegion=*/true,
1197  }
1198 }
1199 
1200 void GenericOp::getEffects(
1202  &effects) {
1203  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1204 }
1205 
1207 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1208  // Operands with value semantics are speculatable, while operands with memory
1209  // semantics are not.
1210  if (!linalgOp.hasPureTensorSemantics())
1212  // The body of the op can still have speculation in its region.
1214 }
1215 
1216 Speculation::Speculatability GenericOp::getSpeculatability() {
1217  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1218 }
1219 
1220 LogicalResult GenericOp::verify() { return success(); }
1221 
1222 namespace {
1223 
1224 /// Remove any linalg operation (on tensors) that are just copying
1225 /// the values from inputs to the results. Requirements are
1226 /// 1) All iterator types are parallel
1227 /// 2) The body contains just a yield operation with the yielded values being
1228 /// the arguments corresponding to the operands.
1229 template <typename OpTy>
1230 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1232 
1233  LogicalResult matchAndRewrite(OpTy linalgOp,
1234  PatternRewriter &rewriter) const override {
1235  // All indexing maps must be equal. It follows that they are permutations.
1236  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1237  return failure();
1238 
1239  // Check that the body of the linalg operation is just a linalg.yield
1240  // operation.
1241  Block &body = linalgOp->getRegion(0).front();
1242  if (!llvm::hasSingleElement(body))
1243  return failure();
1244  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1245  if (!yieldOp)
1246  return failure();
1247 
1248  // In the buffer case, we need to check exact buffer equality.
1249  if (linalgOp.hasPureBufferSemantics()) {
1250  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1251  linalgOp.getDpsInputOperand(0)->get() ==
1252  linalgOp.getDpsInitOperand(0)->get()) {
1253  rewriter.eraseOp(linalgOp);
1254  return success();
1255  }
1256  return failure();
1257  }
1258 
1259  // Mixed semantics is not supported yet.
1260  if (!linalgOp.hasPureTensorSemantics())
1261  return failure();
1262 
1263  // Get the argument number of the returned values. That is the operand
1264  // number to use for replacing uses of this operation.
1265  SmallVector<Value> returnedArgs;
1266  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1267  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1268  if (!yieldArg || yieldArg.getOwner() != &body)
1269  return failure();
1270  unsigned argumentNumber = yieldArg.getArgNumber();
1271  Value returnedArg = linalgOp->getOperand(argumentNumber);
1272  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1273  // The input can have a different type than the result, e.g. a dynamic
1274  // input dimension can be turned into a static output dimension.
1275  Type returnType = returnedArg.getType();
1276  if (returnType != resultType) {
1277  // Distinguish between sparse conversion or dense tensor casting.
1278  // TODO: unify the two ops?
1279  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1281  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1282  linalgOp.getLoc(), resultType, returnedArg);
1283  else {
1284  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1285  resultType))
1286  return failure();
1287  returnedArg = rewriter.create<tensor::CastOp>(
1288  linalgOp.getLoc(), resultType, returnedArg);
1289  }
1290  }
1291  returnedArgs.push_back(returnedArg);
1292  }
1293 
1294  if (returnedArgs.size() != linalgOp->getNumResults())
1295  return failure();
1296  rewriter.replaceOp(linalgOp, returnedArgs);
1297  return success();
1298  }
1299 };
1300 
1301 } // namespace
1302 
1303 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1304  MLIRContext *context) {
1305  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1306 }
1307 
1308 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1309  return memref::foldMemRefCast(*this);
1310 }
1311 
1312 //===----------------------------------------------------------------------===//
1313 // MapOp
1314 //===----------------------------------------------------------------------===//
1315 
1316 static ParseResult parseDstStyleOp(
1317  OpAsmParser &parser, OperationState &result,
1318  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1319  nullptr) {
1320  // Parse `ins` and `outs`.
1321  SmallVector<Type, 4> inputTypes, outputTypes;
1322  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1323  /*addOperandSegmentSizes=*/false))
1324  return failure();
1325 
1326  // Add result types.
1327  for (Type outputType : outputTypes) {
1328  if (llvm::isa<RankedTensorType>(outputType))
1329  result.addTypes(outputType);
1330  }
1331 
1332  // Parse required attributes.
1333  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1334  return failure();
1335 
1336  // Parse optional attributes.
1337  if (parser.parseOptionalAttrDict(result.attributes))
1338  return failure();
1339  return success();
1340 }
1341 
1342 void MapOp::getAsmBlockArgumentNames(Region &region,
1343  OpAsmSetValueNameFn setNameFn) {
1344  for (Value v : getRegionInputArgs())
1345  setNameFn(v, "in");
1346 }
1347 
1348 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1349  if (!getResults().empty())
1350  setNameFn(getResults().front(), "mapped");
1351 }
1352 
1353 void MapOp::build(
1354  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1355  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1356  ArrayRef<NamedAttribute> attributes) {
1357  build(builder, result, TypeRange{}, inputs, init);
1358  result.addAttributes(attributes);
1359 
1360  // Add output types for `RankedTensorType` output arguments.
1361  Type initType = init.getType();
1362  if (llvm::isa<RankedTensorType>(initType))
1363  result.addTypes(initType);
1364 
1365  if (bodyBuild)
1366  buildGenericRegion(builder, result.location, *result.regions.front(),
1367  inputs, /*outputs=*/{}, bodyBuild);
1368 }
1369 
1370 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1371  const OperationName &payloadOpName,
1372  const NamedAttrList &payloadOpAttrs,
1373  ArrayRef<Value> operands,
1374  bool initFirst = false) {
1375  OpBuilder b(parser.getContext());
1376  Region *body = result.addRegion();
1377  Block &block = body->emplaceBlock();
1378  b.setInsertionPointToStart(&block);
1379  SmallVector<Value> bbArgs;
1380  for (auto &operand : operands) {
1381  block.addArgument(
1382  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1383  b.getUnknownLoc());
1384  }
1385  SmallVector<Value> payloadOpOperands;
1386  // If initFirst flag is enabled, we consider init as the first position of
1387  // payload operands.
1388  if (initFirst) {
1389  payloadOpOperands.push_back(block.getArguments().back());
1390  for (const auto &arg : block.getArguments().drop_back())
1391  payloadOpOperands.push_back(arg);
1392  } else {
1393  payloadOpOperands = {block.getArguments().begin(),
1394  block.getArguments().end()};
1395  }
1396 
1397  Operation *payloadOp = b.create(
1398  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1399  payloadOpOperands,
1400  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1401  .getElementType()},
1402  payloadOpAttrs);
1403  b.create<YieldOp>(result.location, payloadOp->getResults());
1404 }
1405 
1406 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1407  std::optional<OperationName> payloadOpName;
1408  NamedAttrList payloadOpAttrs;
1409  if (succeeded(parser.parseOptionalLBrace())) {
1410  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1411  if (failed(operationName))
1412  return failure();
1413  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1414  return failure();
1415  payloadOpName = operationName.value();
1416  if (parser.parseRBrace())
1417  return failure();
1418  }
1419 
1420  if (parseDstStyleOp(parser, result))
1421  return failure();
1422 
1423  if (payloadOpName.has_value()) {
1424  if (!result.operands.empty())
1425  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1426  payloadOpAttrs,
1427  ArrayRef(result.operands).drop_back());
1428  else
1429  result.addRegion();
1430  } else {
1432  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1433  /*allowType=*/true, /*allowAttrs=*/true)) {
1434  return failure();
1435  }
1436  Region *body = result.addRegion();
1437  if (parser.parseRegion(*body, regionArgs))
1438  return failure();
1439  }
1440  return success();
1441 }
1442 
1443 // Retrieve the operation from the body, if it is the only one (except
1444 // yield) and if it gets the same amount of arguments as the body does.
1445 // If initFirst flag is enabled, we check that init takes the first position in
1446 // operands of payload.
1447 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1448  if (body->getOperations().size() != 2)
1449  return nullptr;
1450  Operation &payload = body->getOperations().front();
1451  assert(isa<YieldOp>(body->getOperations().back()));
1452 
1453  if (payload.getNumOperands() == 0 ||
1454  payload.getNumOperands() != body->getNumArguments())
1455  return nullptr;
1456  if (initFirst) {
1457  // check init
1458  if (payload.getOperands().back() != body->getArgument(0))
1459  return nullptr;
1460  // check rest
1461  for (const auto &[operand, bbArg] :
1462  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1463  if (bbArg != operand)
1464  return nullptr;
1465  }
1466  } else {
1467  for (const auto &[operand, bbArg] :
1468  llvm::zip(payload.getOperands(), body->getArguments())) {
1469  if (bbArg != operand)
1470  return nullptr;
1471  }
1472  }
1473  return &payload;
1474 }
1475 
1476 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1477  SmallVector<StringRef> elidedAttrs;
1478  std::string attrToElide;
1479  p << " { " << payloadOp->getName().getStringRef();
1480  for (const auto &attr : payloadOp->getAttrs()) {
1481  auto fastAttr =
1482  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1483  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1484  attrToElide = attr.getName().str();
1485  elidedAttrs.push_back(attrToElide);
1486  break;
1487  }
1488  }
1489  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1490  p << " }";
1491 }
1492 
1493 void MapOp::print(OpAsmPrinter &p) {
1494  Block *mapper = getBody();
1495  Operation *payloadOp = findPayloadOp(mapper);
1496  if (payloadOp) {
1497  printShortForm(p, payloadOp);
1498  }
1499 
1500  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1501  p.printOptionalAttrDict((*this)->getAttrs());
1502 
1503  if (!payloadOp) {
1504  // Print region if the payload op was not detected.
1505  p.increaseIndent();
1506  p.printNewline();
1507  p << "(";
1508  llvm::interleaveComma(mapper->getArguments(), p,
1509  [&](auto arg) { p.printRegionArgument(arg); });
1510  p << ") ";
1511 
1512  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1513  p.decreaseIndent();
1514  }
1515 }
1516 
1517 LogicalResult MapOp::verify() {
1518  auto *bodyBlock = getBody();
1519  auto blockArgs = bodyBlock->getArguments();
1520 
1521  // Checks if the number of `inputs` match the arity of the `mapper` region.
1522  if (getInputs().size() != blockArgs.size())
1523  return emitOpError() << "expects number of operands to match the arity of "
1524  "mapper, but got: "
1525  << getInputs().size() << " and " << blockArgs.size();
1526 
1527  // The parameters of mapper should all match the element type of inputs.
1528  for (const auto &[bbArgType, inputArg] :
1529  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1530  auto inputElemType =
1531  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1532  if (bbArgType != inputElemType) {
1533  return emitOpError() << "expected element type of input " << inputElemType
1534  << " to match bbArg type " << bbArgType;
1535  }
1536  }
1537 
1538  // The shape of each input must match the shape of the output.
1539  auto outputShape = getInit().getType().getShape();
1540  for (Type inputArgType : TypeRange{getInputs()}) {
1541  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1542  if (inputElemShape != outputShape) {
1543  return emitOpError() << "expected shape of input (" << inputElemShape
1544  << ") to match shape of output (" << outputShape
1545  << ")";
1546  }
1547  }
1548 
1549  return success();
1550 }
1551 
1552 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1553  int64_t rank = getInit().getType().getRank();
1554  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1555 }
1556 
1557 ArrayAttr MapOp::getIndexingMaps() {
1558  Builder builder(getContext());
1559  int64_t rank = getInit().getType().getRank();
1560  int64_t numIndexingMaps = getOperands().size();
1562  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1563 }
1564 
1565 void MapOp::getEffects(
1567  &effects) {
1568  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1569 }
1570 
1571 Speculation::Speculatability MapOp::getSpeculatability() {
1572  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1573 }
1574 
1575 //===----------------------------------------------------------------------===//
1576 // ReduceOp
1577 //===----------------------------------------------------------------------===//
1578 
1579 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1580  OpAsmSetValueNameFn setNameFn) {
1581  for (Value v : getRegionInputArgs())
1582  setNameFn(v, "in");
1583  for (Value v : getRegionOutputArgs())
1584  setNameFn(v, "init");
1585 }
1586 
1587 void ReduceOp::getAsmResultNames(
1588  function_ref<void(Value, StringRef)> setNameFn) {
1589  if (!getResults().empty())
1590  setNameFn(getResults().front(), "reduced");
1591 }
1592 
1593 void ReduceOp::build(
1594  OpBuilder &builder, OperationState &result, ValueRange inputs,
1595  ValueRange inits, ArrayRef<int64_t> dimensions,
1596  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1597  ArrayRef<NamedAttribute> attributes) {
1598  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1599  result.addAttributes(attributes);
1600 
1601  // Add output types for `RankedTensorType` output arguments.
1602  for (Value init : inits) {
1603  Type initType = init.getType();
1604  if (llvm::isa<RankedTensorType>(initType))
1605  result.addTypes(initType);
1606  }
1607 
1608  if (bodyBuild)
1609  buildGenericRegion(builder, result.location, *result.regions.front(),
1610  inputs, inits, bodyBuild);
1611 }
1612 
1613 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1614  int64_t inputRank =
1615  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1616  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1617  utils::IteratorType::parallel);
1618  for (int64_t reductionDim : getDimensions())
1619  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1620  return iteratorTypes;
1621 }
1622 
1623 ArrayAttr ReduceOp::getIndexingMaps() {
1624  int64_t inputRank =
1625  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1626  SmallVector<AffineMap> affineMaps(
1627  getNumDpsInputs(),
1629  AffineMap resultMap =
1631  .dropResults(getDimensions());
1632  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1633  affineMaps.push_back(resultMap);
1634  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1635 }
1636 
1637 void ReduceOp::getEffects(
1639  &effects) {
1640  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1641 }
1642 
1643 Speculation::Speculatability ReduceOp::getSpeculatability() {
1644  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1645 }
1646 
1647 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1648  NamedAttrList &attributes,
1649  StringRef attributeName) {
1650  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1651  return failure();
1652 
1653  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1654  return success();
1655 }
1656 
1657 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1658  std::optional<OperationName> payloadOpName;
1659  NamedAttrList payloadOpAttrs;
1660  if (succeeded(parser.parseOptionalLBrace())) {
1661  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1662  if (failed(operationName))
1663  return failure();
1664  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1665  return failure();
1666  payloadOpName = operationName.value();
1667  if (parser.parseRBrace())
1668  return failure();
1669  }
1670 
1671  if (parseDstStyleOp(
1672  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1673  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1674  }))
1675  return failure();
1676 
1677  if (payloadOpName.has_value()) {
1678  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1679  ArrayRef(result.operands), /*initFirst=*/true);
1680  } else {
1682  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1683  /*allowType=*/true, /*allowAttrs=*/true)) {
1684  return failure();
1685  }
1686 
1687  Region *body = result.addRegion();
1688  if (parser.parseRegion(*body, regionArgs))
1689  return failure();
1690  }
1691 
1692  return success();
1693 }
1694 
1695 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1696  ArrayRef<int64_t> attributeValue) {
1697  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1698 }
1699 
1700 void ReduceOp::print(OpAsmPrinter &p) {
1701  Block *mapper = getBody();
1702  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1703  if (payloadOp) {
1704  printShortForm(p, payloadOp);
1705  }
1706 
1707  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1708  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1709  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1710  if (!payloadOp) {
1711  // Print region if the payload op was not detected.
1712  p.increaseIndent();
1713  p.printNewline();
1714  p << "(";
1715  llvm::interleaveComma(mapper->getArguments(), p,
1716  [&](auto arg) { p.printRegionArgument(arg); });
1717  p << ") ";
1718 
1719  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1720  p.decreaseIndent();
1721  }
1722 }
1723 
1724 LogicalResult ReduceOp::verify() {
1725  ArrayRef<int64_t> dimensionsRef = getDimensions();
1726 
1727  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1728  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1729  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1730  return emitOpError() << "expects all inputs to have the same shapes. "
1731  "Shape at input-index "
1732  << i
1733  << " is not equal to the shape at input-index 0.";
1734  }
1735  }
1736  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1737  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1738  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1739  return emitOpError() << "expects all outputs to have the same shapes. "
1740  "Shape at output-index "
1741  << i
1742  << " is not equal to the shape at output-index 0.";
1743  }
1744  }
1745  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1746  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1747 
1748  DenseSet<int64_t> dimensionsToReduce;
1749  for (int64_t dimension : dimensionsRef) {
1750  if (dimension < 0 || dimension >= inputType.getRank()) {
1751  return emitOpError()
1752  << "dimensions for reduction should be in the range [0, "
1753  << inputType.getRank() - 1 << "].";
1754  }
1755  dimensionsToReduce.insert(dimension);
1756  }
1757 
1758  auto inputDims = inputType.getShape();
1759  auto initDims = initType.getShape();
1760 
1761  // Input dimensions that will be left after the reduction.
1762  SmallVector<int64_t> reducedInputDims;
1763  for (const auto &en : llvm::enumerate(inputDims)) {
1764  if (!dimensionsToReduce.count(en.index()))
1765  reducedInputDims.push_back(en.value());
1766  }
1767 
1768  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1769  return emitOpError() << "number of dimensions after reduction "
1770  << reducedInputDims.size()
1771  << " doesn't match the init rank "
1772  << initType.getRank();
1773  }
1774 
1775  if (reducedInputDims != initDims)
1776  return emitOpError() << "init dimensions [" << initDims
1777  << "] doesn't match input dimensions after reduction ["
1778  << reducedInputDims << "]";
1779 
1780  Block *block = getBody();
1781  if (block->getNumArguments() != this->getNumOperands())
1782  return emitOpError()
1783  << "mismatching number of operands and block arguments";
1784 
1785  // Check that the first block arguments match the element type of the inputs.
1786  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1787  Type inputElementType =
1788  llvm::cast<ShapedType>(input.getType()).getElementType();
1789  if (inputElementType != bbArg.getType())
1790  return emitOpError()
1791  << "input element type " << inputElementType
1792  << " does not match corresponding block argument type "
1793  << bbArg.getType();
1794  }
1795 
1796  // Check that the last block arguments match the element type of the outputs.
1797  for (auto [output, bbArg] : llvm::zip(
1798  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1799  auto outputElementType =
1800  llvm::cast<ShapedType>(output.getType()).getElementType();
1801  if (outputElementType != bbArg.getType())
1802  return emitOpError()
1803  << "output element type " << outputElementType
1804  << " does not match corresponding block argument type "
1805  << bbArg.getType();
1806  }
1807  return success();
1808 }
1809 
1810 //===----------------------------------------------------------------------===//
1811 // TransposeOp
1812 //===----------------------------------------------------------------------===//
1813 
1814 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1815  Region &region, ValueRange inputs,
1816  ValueRange outputs) {
1817  buildGenericRegion(builder, loc, region, inputs, outputs,
1818  [](OpBuilder &b, Location loc, ValueRange args) {
1819  if (!args.empty())
1820  b.create<linalg::YieldOp>(loc, args[0]);
1821  });
1822 }
1823 
1824 void TransposeOp::build(::mlir::OpBuilder &builder,
1825  ::mlir::OperationState &result, Value input, Value init,
1826  DenseI64ArrayAttr permutation,
1827  ArrayRef<NamedAttribute> attributes) {
1828  result.addOperands(input);
1829  result.addOperands(init);
1830  result.addAttribute(getPermutationAttrName(result.name), permutation);
1831  result.addAttributes(attributes);
1832 
1833  // Add output types for `RankedTensorType` output arguments.
1834  Type initType = init.getType();
1835  if (llvm::isa<RankedTensorType>(initType))
1836  result.addTypes(initType);
1837 
1838  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1839  init);
1840 }
1841 
1842 void TransposeOp::build(::mlir::OpBuilder &builder,
1843  ::mlir::OperationState &result, Value input, Value init,
1844  ArrayRef<int64_t> permutation,
1845  ArrayRef<NamedAttribute> attributes) {
1846  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1847  attributes);
1848 }
1849 
1850 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1851  if (failed(parseDstStyleOp(
1852  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1853  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1854  })))
1855  return failure();
1856 
1857  OpBuilder builder(parser.getContext());
1858  buildIdentityRegion(builder, result.location, *result.addRegion(),
1859  /*inputs=*/result.operands,
1860  /*outputs=*/{});
1861  return success();
1862 }
1863 
1864 void TransposeOp::getAsmResultNames(
1865  function_ref<void(Value, StringRef)> setNameFn) {
1866  if (!getResults().empty())
1867  setNameFn(getResults().front(), "transposed");
1868 }
1869 
1871  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1872  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1873  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1874 }
1875 
1876 LogicalResult TransposeOp::verify() {
1877  ArrayRef<int64_t> permutationRef = getPermutation();
1878 
1879  if (!isPermutationVector(permutationRef))
1880  return emitOpError("permutation is not valid");
1881 
1882  auto inputType = getInput().getType();
1883  auto initType = getInit().getType();
1884 
1885  int64_t rank = inputType.getRank();
1886 
1887  if (rank != initType.getRank())
1888  return emitOpError() << "input rank " << rank
1889  << " does not match init rank " << initType.getRank();
1890 
1891  if (rank != static_cast<int64_t>(permutationRef.size()))
1892  return emitOpError() << "size of permutation " << permutationRef.size()
1893  << " does not match the argument rank " << rank;
1894 
1895  auto inputDims = inputType.getShape();
1896  auto initDims = initType.getShape();
1897 
1898  for (int64_t i = 0; i < rank; ++i) {
1899  int64_t inputDim = inputDims[permutationRef[i]];
1900  int64_t initDim = initDims[i];
1901 
1902  if (inputDim != initDim) {
1903  return emitOpError() << "dim(result, " << i << ") = " << initDim
1904  << " doesn't match dim(input, permutation[" << i
1905  << "]) = " << inputDim;
1906  }
1907  }
1908 
1909  return success();
1910 }
1911 
1912 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1913  int64_t rank = getInit().getType().getRank();
1914  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1915 }
1916 
1917 ArrayAttr TransposeOp::getIndexingMaps() {
1918  Builder builder(getContext());
1919  int64_t rank = getInit().getType().getRank();
1920  return builder.getAffineMapArrayAttr(
1922  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1923  builder.getMultiDimIdentityMap(rank)});
1924 }
1925 
1926 void TransposeOp::getEffects(
1928  &effects) {
1929  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1930 }
1931 
1932 Speculation::Speculatability TransposeOp::getSpeculatability() {
1933  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1934 }
1935 
1936 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1938  // Only the tensor type is supported.
1939  if (!isa<TensorType>(getInput().getType()))
1940  return failure();
1941 
1942  // Single dimension transpose.
1943  if (getPermutation().size() == 0) {
1944  result.push_back(getInput());
1945  return success();
1946  }
1947  // Identity permutation.
1948  if (isIdentityPermutation(getPermutation())) {
1949  result.push_back(getInput());
1950  return success();
1951  }
1952 
1953  return failure();
1954 }
1955 
1956 /// Fold transpose with transpose.
1957 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1959 
1960  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1961  PatternRewriter &rewriter) const override {
1962  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
1963  if (!defTransposeOp)
1964  return failure();
1965  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
1966  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1967  SmallVector<int64_t> foldedPerms;
1968  foldedPerms.reserve(perms.size());
1969  for (int64_t perm : perms)
1970  foldedPerms.push_back(defPerms[perm]);
1971 
1972  rewriter.replaceOpWithNewOp<TransposeOp>(
1973  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
1974  foldedPerms);
1975  return success();
1976  }
1977 };
1978 
1979 /// This pattern canonicalize transpose by swapping the order of
1980 /// broadcast and transpose:
1981 /// transpose(broadcast(input)) -> broadcast(transpose(input))
1982 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
1984 
1985  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1986  PatternRewriter &rewriter) const override {
1987  Value input = transposeOp.getInput();
1988  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
1989  if (!input.hasOneUse() || !broadcastOp)
1990  return failure();
1991 
1992  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
1993  ArrayRef<int64_t> perms = transposeOp.getPermutation();
1994 
1995  // Get new perms and new dimensions.
1996  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
1997  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
1998  SmallVector<int64_t> resultDimensions;
1999  unsigned dimensionSize = dimensions.size();
2000  for (unsigned i = 0; i < dimensionSize; ++i)
2001  resultDimensions.push_back(invertPerm[dimensions[i]]);
2002 
2003  // Create transpose result.
2004  Value broadcastInput = broadcastOp.getInput();
2005  Location loc = transposeOp.getLoc();
2006  MLIRContext *ctx = transposeOp.getContext();
2008  auto broadcastInputTy =
2009  mlir::cast<RankedTensorType>(broadcastInput.getType());
2010  unsigned inputRank = broadcastInputTy.getRank();
2011  for (unsigned i = 0; i < inputRank; ++i) {
2012  if (broadcastInputTy.isDynamicDim(i)) {
2013  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2014  ->getResult(0));
2015  } else {
2016  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2017  broadcastInputTy.getDimSize(i)));
2018  }
2019  }
2020  SmallVector<OpFoldResult> transposeResultShapes =
2021  applyPermutation(dims, resultPerms);
2022  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2023  transposeOp.getLoc(), transposeResultShapes,
2024  broadcastInputTy.getElementType());
2025 
2026  // Create broadcast(transpose(input)).
2027  Value transposeResult =
2028  rewriter
2029  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2030  resultPerms)
2031  ->getResult(0);
2032  rewriter.replaceOpWithNewOp<BroadcastOp>(
2033  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2034  return success();
2035  }
2036 };
2037 
2038 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2039  MLIRContext *context) {
2041 }
2042 
2043 //===----------------------------------------------------------------------===//
2044 // BroadcastOp
2045 //===----------------------------------------------------------------------===//
2046 
2047 void BroadcastOp::build(::mlir::OpBuilder &builder,
2048  ::mlir::OperationState &result, Value input, Value init,
2049  DenseI64ArrayAttr dimensions,
2050  ArrayRef<NamedAttribute> attributes) {
2051  result.addOperands(input);
2052  result.addOperands(init);
2053  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2054  result.addAttributes(attributes);
2055 
2056  // Add output types for `RankedTensorType` output arguments.
2057  Type initType = init.getType();
2058  if (llvm::isa<RankedTensorType>(initType))
2059  result.addTypes(initType);
2060 
2061  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2062  init);
2063 }
2064 
2065 void BroadcastOp::build(::mlir::OpBuilder &builder,
2066  ::mlir::OperationState &result, Value input, Value init,
2067  ArrayRef<int64_t> dimensions,
2068  ArrayRef<NamedAttribute> attributes) {
2069  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2070  attributes);
2071 }
2072 
2073 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2074  if (failed(parseDstStyleOp(
2075  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2076  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2077  })))
2078  return failure();
2079 
2080  OpBuilder builder(parser.getContext());
2081  buildIdentityRegion(builder, result.location, *result.addRegion(),
2082  /*inputs=*/result.operands,
2083  /*outputs=*/{});
2084  return success();
2085 }
2086 
2087 void BroadcastOp::getAsmResultNames(
2088  function_ref<void(Value, StringRef)> setNameFn) {
2089  if (!getResults().empty())
2090  setNameFn(getResults().front(), "broadcasted");
2091 }
2092 
2094  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2095  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2096  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2097 }
2098 
2099 LogicalResult BroadcastOp::verify() {
2100  ArrayRef<int64_t> dimensionsRef = getDimensions();
2101 
2102  auto inputType = getInput().getType();
2103  auto initType = getInit().getType();
2104 
2105  int64_t inputRank = inputType.getRank();
2106  int64_t initRank = initType.getRank();
2107 
2108  auto inputShape = inputType.getShape();
2109  auto initShape = initType.getShape();
2110 
2111  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2112  return emitOpError() << "input rank plus added dimensions does not "
2113  "match init rank. input rank: "
2114  << inputRank
2115  << ", dimensions size: " << dimensionsRef.size()
2116  << ", init rank: " << initRank;
2117 
2118  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2119  if (dim < 0 || dim >= initRank)
2120  return emitOpError() << "dimension " << idx
2121  << " is out of range. expected range: [0, "
2122  << initRank - 1 << "], got: " << dim;
2123  }
2124 
2125  // Mapping from input dims to init dims.
2126  SmallVector<int64_t> dimMap;
2127  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2128  if (!llvm::is_contained(dimensionsRef, dim))
2129  dimMap.push_back(dim);
2130  }
2131 
2132  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2133  // This dimensions is mapped from the input. Init and input dims should
2134  // match.
2135  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2136  return emitOpError() << "input dim " << inputDimIdx
2137  << " should match init dim " << initDimIdx
2138  << ". input: " << inputShape[inputDimIdx]
2139  << ", init: " << initShape[initDimIdx];
2140  }
2141 
2142  return success();
2143 }
2144 
2145 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2146  int64_t rank = getInit().getType().getRank();
2147  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2148 }
2149 
2150 ArrayAttr BroadcastOp::getIndexingMaps() {
2151  Builder builder(getContext());
2152  int64_t rank = getInit().getType().getRank();
2153  return builder.getAffineMapArrayAttr(
2154  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2155  builder.getMultiDimIdentityMap(rank)});
2156 }
2157 
2158 void BroadcastOp::getEffects(
2160  &effects) {
2161  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2162 }
2163 
2164 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2165  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2166 }
2167 
2168 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2169  MLIRContext *context) {
2170  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2171 }
2172 
2173 //===----------------------------------------------------------------------===//
2174 // YieldOp
2175 //===----------------------------------------------------------------------===//
2176 
2178  if (getNumOperands() > 0)
2179  p << ' ' << getOperands();
2180  p.printOptionalAttrDict((*this)->getAttrs());
2181  if (getNumOperands() > 0)
2182  p << " : " << getOperandTypes();
2183 }
2184 
2185 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2187  SmallVector<Type, 2> types;
2188  SMLoc loc = parser.getCurrentLocation();
2189  return failure(parser.parseOperandList(opInfo) ||
2190  parser.parseOptionalAttrDict(result.attributes) ||
2191  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2192  parser.resolveOperands(opInfo, types, loc, result.operands));
2193 }
2194 
2195 // Check the operand number and types must match the element types of the
2196 // LinalgOp interface's shaped operands.
2197 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2198  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2199  return op.emitOpError("expected number of yield values (")
2200  << op.getNumOperands()
2201  << ") to match the number of inits / outs operands of the enclosing "
2202  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2203 
2204  for (OpOperand &opOperand : op->getOpOperands()) {
2205  OpOperand *outputOperand =
2206  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2207  Type elementType = outputOperand->get().getType();
2208  if (isa<MemRefType, RankedTensorType>(elementType))
2209  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2210  if (opOperand.get().getType() != elementType)
2211  return op.emitOpError("type of yield operand ")
2212  << (opOperand.getOperandNumber() + 1) << " ("
2213  << opOperand.get().getType() << ") doesn't match "
2214  << "the element type of the enclosing linalg.generic op ("
2215  << elementType << ")";
2216  }
2217  return success();
2218 }
2219 
2220 LogicalResult linalg::YieldOp::verify() {
2221  auto *parentOp = (*this)->getParentOp();
2222  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2223  return emitOpError("expected single non-empty parent region");
2224 
2225  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2226  return verifyYield(*this, linalgOp);
2227 
2228  return emitOpError("expected parent op with LinalgOp interface");
2229 }
2230 
2231 //===----------------------------------------------------------------------===//
2232 // IndexOp
2233 //===----------------------------------------------------------------------===//
2234 
2235 LogicalResult IndexOp::verify() {
2236  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2237  if (!linalgOp)
2238  return emitOpError("expected parent op with LinalgOp interface");
2239  if (linalgOp.getNumLoops() <= getDim())
2240  return emitOpError("expected dim (")
2241  << getDim() << ") to be lower than the number of loops ("
2242  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2243  return success();
2244 }
2245 
2246 /////// Operations corresponding to library calls defined with Tablegen ////////
2247 
2248 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2249 
2250 #define GET_OP_CLASSES
2251 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2252 
2253 #define GET_OP_CLASSES
2254 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2255 
2256 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2257  unsigned rank,
2258  MLIRContext *context) {
2259  if (maybeMap)
2260  return *maybeMap;
2261  if (rank == 0)
2262  return AffineMap::get(context);
2263  return AffineMap::getMultiDimIdentityMap(rank, context);
2264 }
2265 
2267 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2268  MLIRContext *context) {
2270  res.reserve(num);
2271  for (unsigned i = 0; i < num; ++i)
2272  res.push_back(getAffineDimExpr(startIdx++, context));
2273  return res;
2274 }
2275 
2278  auto rangeA = llvm::make_range(a.begin(), a.end());
2279  auto rangeB = llvm::make_range(b.begin(), b.end());
2280  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2281  return llvm::to_vector<4>(concatRanges);
2282 }
2283 
2284 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2285  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2286  ss << "view";
2287  for (auto size : memref.getShape())
2288  if (size < 0)
2289  ss << "sx";
2290  else
2291  ss << size << "x";
2292  if (failed(appendMangledType(ss, memref.getElementType())))
2293  return failure();
2294  if (auto as = memref.getMemorySpace()) {
2295  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2296  ss << "as" << attr.getInt();
2297  else
2298  return failure();
2299  }
2300  return success();
2301  }
2302  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2303  ss << "vector";
2304  llvm::interleave(
2305  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2306  if (failed(appendMangledType(ss, vec.getElementType())))
2307  return failure();
2308  return success();
2309  }
2310  if (t.isSignlessIntOrIndexOrFloat()) {
2311  ss << t;
2312  return success();
2313  }
2314  return failure();
2315 }
2316 
2318  assert(isa<LinalgOp>(op));
2319  std::string name(op->getName().getStringRef().str());
2320  std::string fun = "";
2321  for (NamedAttribute kv : op->getAttrs()) {
2322  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2323  fun = stringifyEnum(ufa.getValue()).str() + "_";
2324  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2325  fun = stringifyEnum(bfa.getValue()).str() + "_";
2326  }
2327  }
2328  name.reserve(128);
2329  std::replace(name.begin(), name.end(), '.', '_');
2330  llvm::raw_string_ostream ss(name);
2331  ss << "_" << fun;
2332  for (Type t : op->getOperandTypes()) {
2333  if (failed(appendMangledType(ss, t)))
2334  return std::string();
2335  ss << "_";
2336  }
2337  std::string res = ss.str();
2338  res.pop_back();
2339  return res;
2340 }
2341 
2342 //===----------------------------------------------------------------------===//
2343 // Canonicalizers and Folders.
2344 //===----------------------------------------------------------------------===//
2345 
2346 namespace {
2347 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2349 
2350  LogicalResult matchAndRewrite(LinalgOp op,
2351  PatternRewriter &rewriter) const override {
2352  for (OpOperand &opOperand : op->getOpOperands()) {
2353  // Linalg "inputs" may be either tensor or memref type.
2354  // tensor<0xelt_type> is a convention that may not always mean
2355  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2356  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2357  if (!mt)
2358  continue;
2359  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2360  rewriter.eraseOp(op);
2361  return success();
2362  }
2363  }
2364  return failure();
2365  }
2366 };
2367 
2368 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2369 /// result that is more static than the linalg op.
2370 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2372 
2373  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2374  PatternRewriter &rewriter) const override {
2375  if (!tensor::canFoldIntoProducerOp(castOp))
2376  return failure();
2377 
2378  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2379  if (!linalgOp)
2380  return failure();
2381 
2382  // Cast can be in conditionally reachable region, if which case folding will
2383  // generate invalid code. Only conservatively fold ops in same block for
2384  // now.
2385  if (castOp->getBlock() != linalgOp->getBlock())
2386  return failure();
2387 
2388  OpBuilder::InsertionGuard guard(rewriter);
2389  rewriter.setInsertionPoint(linalgOp);
2390 
2391  Location loc = linalgOp.getLoc();
2392  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2393  unsigned resultNumber = resultValue.getResultNumber();
2394  auto resultType =
2395  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2396  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2397  // going from a more dynamic shape to a less dynamic shape. If the producer
2398  // for this cast, i.e. producer of the out operand, is also an operation
2399  // that folds with tensor.cast consumer (like this pattern), the cast will
2400  // continue to propagate as far up the stack as it can go.
2401  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2402  Value newOperand =
2403  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2404  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2405  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2406  linalgOp.getDpsInits().end());
2407  outputOperands[resultNumber] = newOperand;
2408  newOperands.append(outputOperands.begin(), outputOperands.end());
2409 
2410  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2411  linalgOp->result_type_end());
2412  resultTypes[resultNumber] = resultType;
2413  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2414 
2415  // Create a tensor.cast operation back to the original type.
2416  Value castBack = rewriter.create<tensor::CastOp>(
2417  loc, resultValue.getType(), newOp->getResult(resultNumber));
2418 
2419  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2420  results[resultNumber] = castBack;
2421  rewriter.replaceOp(linalgOp, results);
2422  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2423  return success();
2424  }
2425 };
2426 
2427 /// For each of the operand in `operands` this function maps the static sizes of
2428 /// dimensions to their affine dim expressions.
2429 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2430  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2431  for (OpOperand &opOperand : operands) {
2432  if (linalgOp.isScalar(&opOperand))
2433  continue;
2434  Value src = opOperand.get();
2435  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2436  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2437 
2438  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2439  // `tensor.cast` operation and source of the cast operation has a static
2440  // shape, then assign it to the `sourceShape`.
2441  auto *parentOp = src.getDefiningOp();
2442  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2443  if (parentOp) {
2444  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2445  Value castSource = castOp.getSource();
2446  auto castSourceType =
2447  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2448  if (castSourceType && castSourceType.hasStaticShape())
2449  sourceShape = castSourceType.getShape();
2450  }
2451  }
2452 
2453  // If the source shape's dimension has a static shape, map the affine dim
2454  // expression to the known static size.
2455  for (unsigned i = 0; i < sourceShape.size(); i++) {
2456  if (sourceType.isDynamicDim(i))
2457  continue;
2458  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2459  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2460  }
2461  }
2462 }
2463 
2464 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2465 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2466 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2467 /// change then `changeNeeded` is false and same operand is added in the
2468 /// `newOperands` list.
2469 static void createNewOperandWithStaticSizes(
2470  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2471  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2472  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2473  bool &changeNeeded) {
2474  Value src = opOperand->get();
2475  newOperands.push_back(src);
2476  if (linalgOp.isScalar(opOperand))
2477  return;
2478  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2479  Type resultType = sourceType;
2480  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2481  resultTypes.push_back(resultType);
2482  return;
2483  }
2484  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2485  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2486  SmallVector<int64_t> newShape;
2487  // If operand is updated with new shape, `newOperandNeeded` will be
2488  // true.
2489  bool newOperandNeeded = false;
2490  for (unsigned i = 0; i < sourceShape.size(); i++) {
2491  int64_t dimShape = sourceShape[i];
2492  AffineExpr dimExpr = sourceMap.getResult(i);
2493  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2494  newShape.push_back(dimShape);
2495  continue;
2496  }
2497  // Dimension has a dynamic shape and corresponding affine dim
2498  // expression is present in the map. So assign the size for the
2499  // given affine dim expression to the dimension.
2500  newShape.push_back(affineExprToSize[dimExpr]);
2501  newOperandNeeded = true;
2502  }
2503  resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2504  if (newOperandNeeded) {
2505  changeNeeded = true;
2506  // Get the new operand value given its size and element type by
2507  // casting it.
2508  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2509  unsigned index = opOperand->getOperandNumber();
2510  newOperands[index] = newOperand;
2511  }
2512  if (linalgOp.isDpsInit(opOperand))
2513  resultTypes.push_back(resultType);
2514 }
2515 
2516 /// Static shapes for the operands can be inferred if any one of the operands
2517 /// have a static shape. This can be done by referring to the affine dim
2518 /// expressions for the operand.
2519 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2521 
2522  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2523  PatternRewriter &rewriter) const override {
2524  if (!linalgOp.hasPureTensorSemantics())
2525  return failure();
2526 
2527  // Maps must be projected permutations.
2528  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2529  return !map.isProjectedPermutation();
2530  }))
2531  return failure();
2532 
2533  // Maps affine dim expressions to the static size of that dimension.
2534  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2535  Location loc = linalgOp.getLoc();
2536 
2537  // For each of the affine dim expression, check if the size is known. If
2538  // known add that in the map.
2539  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2540 
2541  SmallVector<Value> newOperands;
2542  SmallVector<Type> resultTypes;
2543 
2544  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2545  // change in their types.
2546  bool changeNeeded = false;
2547  newOperands.reserve(linalgOp->getNumOperands());
2548  resultTypes.reserve(linalgOp.getNumDpsInits());
2549 
2550  // Iterate over all the operands and update the static sizes.
2551  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2552  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2553  affineExprToSize, linalgOp, newOperands,
2554  resultTypes, changeNeeded);
2555  }
2556 
2557  // If the generic op has all the required static information, no
2558  // canonicalization needed.
2559  if (!changeNeeded)
2560  return failure();
2561 
2562  // Clone op.
2563  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2564  SmallVector<Value> replacements;
2565  replacements.reserve(newOp->getNumResults());
2566  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2567  Value newResult = std::get<1>(it);
2568  Value oldResult = std::get<0>(it);
2569  Type newType = newResult.getType();
2570  Type oldType = oldResult.getType();
2571  replacements.push_back(
2572  (newType != oldType)
2573  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2574  : newResult);
2575  }
2576  rewriter.replaceOp(linalgOp, replacements);
2577  return success();
2578  }
2579 };
2580 
2581 } // namespace
2582 
2583 // All named ops canonicalizers and folders are auto-generated in the
2584 // .cpp.inc.
2585 
2586 //===----------------------------------------------------------------------===//
2587 // SoftmaxOp
2588 //===----------------------------------------------------------------------===//
2589 
2590 LogicalResult SoftmaxOp::verify() {
2591  ShapedType inputType = getInputOperandType();
2592  ShapedType outputType = getOutputOperandType();
2593 
2594  ArrayRef<int64_t> inputShape = inputType.getShape();
2595  ArrayRef<int64_t> outputShape = outputType.getShape();
2596  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2597  return emitOpError("incompatible output shape");
2598 
2599  int64_t inputRank = getInputOperandRank();
2600  int64_t dimension = getDimension();
2601  if ((dimension < 0) || (dimension >= inputRank))
2602  return emitOpError("incorrect dimension specified");
2603 
2604  return success();
2605 }
2606 
2607 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2608  int64_t operandRank = getInputOperandRank();
2609  SmallVector<Range> loopBounds(operandRank);
2610  Location loc = getLoc();
2611  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2612  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2613  Value source = getInput();
2614  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2615  loopBounds[dim].offset = zero;
2616  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2617  loopBounds[dim].stride = one;
2618  }
2619  return loopBounds;
2620 }
2621 
2622 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2623  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2624  utils::IteratorType::parallel);
2625  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2626  return iteratorTypes;
2627 }
2628 
2629 FailureOr<TilingResult>
2630 SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2631  ArrayRef<OpFoldResult> offsets,
2632  ArrayRef<OpFoldResult> sizes) {
2633  int64_t rank = getInputOperandRank();
2634  auto oneAttr = builder.getI64IntegerAttr(1);
2635  SmallVector<OpFoldResult> strides(rank, oneAttr);
2636  SmallVector<Value> tiledOperands;
2637  tiledOperands.emplace_back(
2638  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2639  tiledOperands.emplace_back(
2640  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2641 
2642  SmallVector<Type, 4> resultTypes;
2643  if (hasPureTensorSemantics())
2644  resultTypes.push_back(tiledOperands[1].getType());
2645  Operation *tiledOp =
2646  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2647 
2648  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2649 }
2650 
2651 LogicalResult SoftmaxOp::getResultTilePosition(
2652  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2653  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2654  SmallVector<OpFoldResult> &resultSizes) {
2655  if (resultNumber == 0) {
2656  resultOffsets.assign(offsets.begin(), offsets.end());
2657  resultSizes.assign(sizes.begin(), sizes.end());
2658  return success();
2659  }
2660  return failure();
2661 }
2662 
2663 // cast(dynamic) -> static.
2664 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2665  return memref::foldMemRefCast(*this);
2666 }
2667 
2668 LogicalResult
2670  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2672  Location loc = getOperation()->getLoc();
2673  IRRewriter rewriter(b);
2674  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2675  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2676  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2677  if (!outputShapedType.isDynamicDim(dim)) {
2678  // Static dim: Return IntegerAttr.
2679  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2680  } else {
2681  // Dynamic dim: Return Value.
2682  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2683  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2684  }
2685  }
2686  reifiedReturnShapes.emplace_back(std::move(shapes));
2687  return success();
2688 }
2689 
2690 void SoftmaxOp::getEffects(
2692  &effects) {
2693  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2694  if (!llvm::isa<MemRefType>(operand.getType()))
2695  continue;
2696  effects.emplace_back(MemoryEffects::Read::get(),
2697  &getOperation()->getOpOperand(index), /*stage=*/0,
2698  /*effectOnFullRegion=*/true,
2700  }
2701 
2702  for (OpOperand &operand : getDpsInitsMutable()) {
2703  if (!llvm::isa<MemRefType>(operand.get().getType()))
2704  continue;
2705  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2706  /*effectOnFullRegion=*/true,
2708  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2709  /*effectOnFullRegion=*/true,
2711  }
2712 }
2713 
2714 // Helper functions for softmax decomposition.
2715 // @{
2716 
2717 // Helper function to produce the iterator types (reduction or parallel) and
2718 // affine maps for the iterators used in the decomposition of softmax.
2719 // This method creates:
2720 // If allParallel == true:
2721 // - iterator type: {parallel, ..., parallel}
2722 // - affine maps:
2723 // -- identity with inputRank dimensions.
2724 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2725 // where N == inputRank.
2726 //
2727 // If allParallel == false:
2728 // - iterator type at dim(i) == parallel for i != \p dim and
2729 // dim(dim) == reduction.
2730 // - affine map:
2731 // -- identity with inputRank dimensions.
2732 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2733 // where N == inputRank.
2734 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2736  int64_t dim, bool allParallel = false) {
2737  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2738  utils::IteratorType::parallel);
2739  if (!allParallel)
2740  iteratorTypes[dim] = utils::IteratorType::reduction;
2741  MLIRContext *ctxt = builder.getContext();
2742  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2743  SmallVector<AffineExpr, 2> affineExprs;
2744  for (int i = 0; i < inputRank; i++) {
2745  if (i != dim)
2746  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2747  }
2748  auto reductionMap =
2749  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2750  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2751  return std::make_tuple(iteratorTypes, indexingMaps);
2752 }
2753 
2754 // Helper function to produce a linalg.generic that computes a reduction on
2755 // dimension \p dim with the operation type \p T.
2756 template <typename T>
2757 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2758  int64_t dim) {
2759  auto inputType = cast<ShapedType>(input.getType());
2760  ArrayRef<int64_t> inputShape = inputType.getShape();
2761  int64_t inputRank = inputShape.size();
2762  auto [iteratorTypes, indexingMaps] =
2763  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2764  assert(indexingMaps.size() == 2 &&
2765  "We should have two maps: 1 for the input, 1 for the output");
2766  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2767 
2768  auto genericOp = builder.create<linalg::GenericOp>(
2769  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2770  [&](OpBuilder &b, Location loc, ValueRange args) {
2771  Value result = b.create<T>(loc, args[0], args[1]);
2772  b.create<linalg::YieldOp>(loc, result);
2773  });
2774  return genericOp.getResult(0);
2775 }
2776 
2777 /// Produce a linalg generic that computes the second step of the softmax
2778 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2779 /// on dimension \p dim.
2780 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2781  Value max, Value output, int64_t dim) {
2782  auto inputType = cast<ShapedType>(input.getType());
2783  ArrayRef<int64_t> inputShape = inputType.getShape();
2784  int64_t inputRank = inputShape.size();
2785  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2786  builder, inputRank, dim, /*allParallel=*/true);
2787  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2788  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2789  // Add the affine map for the output argument.
2790  indexingMaps.push_back(indexingMaps[0]);
2791  auto genericOp = builder.create<linalg::GenericOp>(
2792  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2793  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2794  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2795  Value result = b.create<math::ExpOp>(loc, diff);
2796  b.create<linalg::YieldOp>(loc, result);
2797  });
2798  return genericOp.getResult(0);
2799 }
2800 
2801 /// Produce a linalg generic that computes the final step of the softmax
2802 /// decomposition.
2803 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2804 /// yield n / d
2805 /// }
2806 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2807  Value denominator, Value output, int64_t dim) {
2808  auto inputType = cast<ShapedType>(numerator.getType());
2809  ArrayRef<int64_t> inputShape = inputType.getShape();
2810  int64_t inputRank = inputShape.size();
2811  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2812  builder, inputRank, dim, /*allParallel=*/true);
2813  assert(indexingMaps.size() == 2 &&
2814  "We should have one map for each input (2)");
2815  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2816  // Add the affine map for the output tensor.
2817  indexingMaps.push_back(indexingMaps[0]);
2818  auto genericOp = builder.create<linalg::GenericOp>(
2819  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2820  indexingMaps, iteratorTypes,
2821  [&](OpBuilder &b, Location loc, ValueRange args) {
2822  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2823  b.create<linalg::YieldOp>(loc, result);
2824  });
2825  return genericOp.getResult(0);
2826 }
2827 // @} End helper functions for softmax decomposition.
2828 
2829 /// Given an N-dimensional tensor x, this method converts
2830 /// softmax(x) to the following sequence of operations:
2831 ///
2832 /// 1. Compute the max of x along dimension d. This results
2833 /// in a N-1 dimensional tensor m.
2834 /// m = max(x, dim = d)
2835 ///
2836 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2837 /// a N dimensional tensor z.
2838 /// z = exp(x - m)
2839 ///
2840 /// 3. Compute the sum of z along dimension d. This results in
2841 /// a N-1 dimensional tensor l.
2842 /// l = sum(z, dim = d)
2843 ///
2844 /// 4. Divide z and l. This gives the N-dimensional softmax.
2845 /// softmax = z / l
2846 ///
2847 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2848  OpBuilder::InsertionGuard guard(b);
2849  b.setInsertionPoint(*this);
2850  Location loc = getLoc();
2851  Value input = getInput();
2852  ShapedType inputType = getInputOperandType();
2853  Type elementType = inputType.getElementType();
2854  int64_t reductionDim = getDimension();
2855  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2856  Value output = getOutput();
2857  dims.erase(dims.begin() + reductionDim);
2858  // Step 1: Compute max along dim.
2859  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2860  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2861  elementType, b, loc,
2862  /*useOnlyFiniteValue=*/true);
2863  Value neutralForMaxFInit =
2864  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2865  .result();
2866  Value max =
2867  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2868 
2869  // Step 2: Subtract max from input and exponentiate.
2870  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2871 
2872  // Step 3: Compute sum along dim.
2873  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2874  b, loc, /*useOnlyFiniteValue=*/true);
2875  Value zeroInit =
2876  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2877  Value denominator =
2878  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2879 
2880  // Step 4: Compute softmax.
2881  Value result =
2882  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2883  return SmallVector<Value>{result};
2884 }
2885 
2886 //===----------------------------------------------------------------------===//
2887 // WinogradFilterTransformOp
2888 //===----------------------------------------------------------------------===//
2889 
2890 LogicalResult WinogradFilterTransformOp::verify() {
2891  auto filterType = cast<ShapedType>(getFilter().getType());
2892  ArrayRef<int64_t> filterShape = filterType.getShape();
2893  int64_t filterH = filterShape[getFilterHDim()];
2894  int64_t filterW = filterShape[getFilterWDim()];
2895  int64_t r = getR();
2896  int64_t m = getM();
2897 
2898  if (filterH != r && filterH != 1)
2899  return emitOpError("expect filter height either equals to r or 1");
2900  if (filterW != r && filterW != 1)
2901  return emitOpError("expect filter width either equals to r or 1");
2902  if (filterH == 1 && filterW == 1)
2903  return emitOpError("expect either filter height or width equals to r");
2904 
2905  SmallVector<int64_t> expectedOutputShape;
2906  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2907  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2908  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2909  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2910 
2911  auto outputType = cast<ShapedType>(getOutput().getType());
2912  ArrayRef<int64_t> outputShape = outputType.getShape();
2913  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2914  return emitOpError("the output shape is not expected");
2915  }
2916  return success();
2917 }
2918 
2920 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2921  Location loc = getLoc();
2922  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2923  IntegerAttr oneAttr = builder.getIndexAttr(1);
2924  Value filter = getFilter();
2925  int64_t filterRank = getFilterOperandRank();
2926  SmallVector<Range> loopBounds(filterRank);
2927  for (unsigned dim = 0; dim < filterRank; ++dim) {
2928  loopBounds[dim].offset = zeroAttr;
2929  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2930  loopBounds[dim].stride = oneAttr;
2931  }
2932  return loopBounds;
2933 }
2934 
2936 WinogradFilterTransformOp::getLoopIteratorTypes() {
2937  int64_t filterRank = getFilterOperandRank();
2938  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2939  utils::IteratorType::parallel);
2940  return iteratorTypes;
2941 }
2942 
2943 LogicalResult WinogradFilterTransformOp::getResultTilePosition(
2944  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2945  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2946  SmallVector<OpFoldResult> &resultSizes) {
2947  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2948  ShapedType filterType = getFilterOperandType();
2949  ArrayRef<int64_t> filterShape = filterType.getShape();
2950  int64_t filterH = filterShape[getFilterHDim()];
2951  int64_t filterW = filterShape[getFilterWDim()];
2952  int64_t m = getM();
2953  int64_t r = getR();
2954  int64_t alpha = m + r - 1;
2955  int64_t alphaH = filterH != 1 ? alpha : 1;
2956  int64_t alphaW = filterW != 1 ? alpha : 1;
2957  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
2958  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
2959 
2960  resultOffsets.append(
2961  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
2962  resultSizes.append(
2963  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
2964 
2965  return success();
2966 }
2967 
2968 /// Implement tiling for winograd_filter_transform
2969 /// The input of winograd_filter_transform is (F, KH, KW, C).
2970 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
2971 /// Users can specify the tile sizes of F and C.
2972 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
2973 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
2974 FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation(
2975  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
2976  ArrayRef<OpFoldResult> sizes) {
2977  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
2978  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2979  ShapedType filterType = getFilterOperandType();
2980  ArrayRef<int64_t> filterShape = filterType.getShape();
2981  int64_t filterH = filterShape[getFilterHDim()];
2982  int64_t filterW = filterShape[getFilterWDim()];
2983  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
2984  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
2985  SmallVector<Value> tiledOperands;
2986  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2987 
2988  sliceOffsets.append(
2989  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
2990  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
2991  sizes[getFilterCDim()]});
2992  int64_t filterRank = getFilterOperandRank();
2993  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
2994  Location loc = getLoc();
2995  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
2996  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides));
2997 
2998  SmallVector<OpFoldResult> resultOffsets, resultSizes;
2999  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3000  resultSizes)))
3001  return failure();
3002 
3003  int64_t outputRank = getOutputOperandRank();
3004  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3005  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3006  loc, getOutput(), resultOffsets, resultSizes, outputStrides));
3007 
3008  SmallVector<Type> resultTypes;
3009  resultTypes.push_back(tiledOperands[1].getType());
3010  Operation *tiledOp =
3011  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3012 
3013  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3014 }
3015 
3016 //===----------------------------------------------------------------------===//
3017 // WinogradInputTransformOp
3018 //===----------------------------------------------------------------------===//
3019 
3020 LogicalResult WinogradInputTransformOp::verify() {
3021  auto inputType = cast<ShapedType>(getInput().getType());
3022  ArrayRef<int64_t> inputShape = inputType.getShape();
3023  int64_t inputH = inputShape[getInputHDim()];
3024  int64_t inputW = inputShape[getInputWDim()];
3025  int m = getM();
3026  int r = getR();
3027  int64_t tileSize = m + r - 1;
3028  bool leftTransform = inputH != 1;
3029  bool rightTransform = inputW != 1;
3030 
3031  SmallVector<int64_t> expectedOutputShape(6, inputH);
3032  if (ShapedType::isDynamic(inputH)) {
3033  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3034  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3035  } else {
3036  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3037  expectedOutputShape[getOutputTileHDim()] =
3038  leftTransform ? (inputH - (r - 1)) / m : 1;
3039  }
3040  if (ShapedType::isDynamic(inputW)) {
3041  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3042  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3043  } else {
3044  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3045  expectedOutputShape[getOutputTileWDim()] =
3046  rightTransform ? (inputW - (r - 1)) / m : 1;
3047  }
3048  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3049  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3050 
3051  auto outputType = cast<ShapedType>(getOutput().getType());
3052  ArrayRef<int64_t> outputShape = outputType.getShape();
3053  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3054  return emitOpError("the output shape is not expected");
3055  }
3056  return success();
3057 }
3058 
3060 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3061  Location loc = getLoc();
3062  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3063  IntegerAttr oneAttr = builder.getIndexAttr(1);
3064  Value output = getOutput();
3065  int64_t outputRank = getOutputOperandRank();
3066  SmallVector<Range> loopBounds(outputRank);
3067  for (unsigned dim = 0; dim < outputRank; ++dim) {
3068  loopBounds[dim].offset = zeroAttr;
3069  // alphaH, alphaW, tileH, tileW, N, C
3070  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3071  loopBounds[dim].stride = oneAttr;
3072  }
3073  return loopBounds;
3074 }
3075 
3077 WinogradInputTransformOp::getLoopIteratorTypes() {
3078  int64_t outputRank = getOutputOperandRank();
3079  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3080  utils::IteratorType::parallel);
3081  return iteratorTypes;
3082 }
3083 
3084 LogicalResult WinogradInputTransformOp::getResultTilePosition(
3085  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3086  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3087  SmallVector<OpFoldResult> &resultSizes) {
3088  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3089  ShapedType inputType = getInputOperandType();
3090  ArrayRef<int64_t> inputShape = inputType.getShape();
3091  int64_t inputH = inputShape[getInputHDim()];
3092  int64_t inputW = inputShape[getInputWDim()];
3093  int64_t m = getM();
3094  int64_t r = getR();
3095  int64_t alpha = m + r - 1;
3096  int64_t alphaH = inputH != 1 ? alpha : 1;
3097  int64_t alphaW = inputW != 1 ? alpha : 1;
3098  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3099  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3100 
3101  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3102  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3103  offsets[getOutputCDim()]});
3104  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3105  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3106  sizes[getOutputCDim()]});
3107 
3108  return success();
3109 }
3110 
3111 /// Implement tiling for winograd_input_transform
3112 /// The input of winograd_input_transform is (N, H, W, C).
3113 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3114 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3115 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3116 /// the values for the sizes of tileH, tileW, N, C for one tile.
3117 FailureOr<TilingResult>
3118 WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
3119  ArrayRef<OpFoldResult> offsets,
3120  ArrayRef<OpFoldResult> sizes) {
3121  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3122  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3123  ShapedType inputType = getInputOperandType();
3124  ArrayRef<int64_t> inputShape = inputType.getShape();
3125  int64_t inputH = inputShape[getInputHDim()];
3126  int64_t inputW = inputShape[getInputWDim()];
3127  int64_t m = getM();
3128  int64_t r = getR();
3129 
3130  Location loc = getLoc();
3131  MLIRContext *context = builder.getContext();
3132  auto offsetAffineMap =
3133  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3134  Value mappedOffsetH = affine::makeComposedAffineApply(
3135  builder, loc, offsetAffineMap, offsets[getOutputTileHDim()]);
3136  Value mappedOffsetW = affine::makeComposedAffineApply(
3137  builder, loc, offsetAffineMap, offsets[getOutputTileWDim()]);
3138  auto sizeAffineMap = AffineMap::get(
3139  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3140  Value mappedSizeH = affine::makeComposedAffineApply(
3141  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3142  Value mappedSizeW = affine::makeComposedAffineApply(
3143  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3144 
3145  SmallVector<Value> tiledOperands;
3146  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3147 
3148  OpFoldResult offsetH =
3149  inputH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3150  OpFoldResult offsetW =
3151  inputW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3152  sliceOffsets.append(
3153  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3154  OpFoldResult sizeH =
3155  inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3156  OpFoldResult sizeW =
3157  inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3158  sliceSizes.append(
3159  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3160  int64_t inputRank = getInputOperandRank();
3161  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3162  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3163  loc, getInput(), sliceOffsets, sliceSizes, inputStrides));
3164 
3165  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3166  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3167  resultSizes)))
3168  return failure();
3169 
3170  int64_t outputRank = getOutputOperandRank();
3171  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3172  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3173  loc, getOutput(), resultOffsets, resultSizes, outputStrides));
3174 
3175  SmallVector<Type> resultTypes;
3176  resultTypes.push_back(tiledOperands[1].getType());
3177  Operation *tiledOp =
3178  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3179 
3180  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3181 }
3182 
3183 //===----------------------------------------------------------------------===//
3184 // WinogradOutputTransformOp
3185 //===----------------------------------------------------------------------===//
3186 
3187 LogicalResult WinogradOutputTransformOp::verify() {
3188  auto valueType = cast<ShapedType>(getValue().getType());
3189  ArrayRef<int64_t> valueShape = valueType.getShape();
3190  int64_t valueH = valueShape[getValueAlphaHDim()];
3191  int64_t valueW = valueShape[getValueAlphaWDim()];
3192  int64_t valueTileH = valueShape[getValueTileHDim()];
3193  int64_t valueTileW = valueShape[getValueTileWDim()];
3194  int m = getM();
3195  int r = getR();
3196  bool leftTransform = valueH != 1;
3197  bool rightTransform = valueW != 1;
3198 
3199  int64_t outputRank = getOutputOperandRank();
3200  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3201  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3202  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3203  } else {
3204  if (valueH != (leftTransform ? m + r - 1 : 1))
3205  return emitOpError("expect input height equals to input tile size");
3206  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3207  }
3208  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3209  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3210  } else {
3211  if (valueW != (rightTransform ? m + r - 1 : 1))
3212  return emitOpError("expect input width equals to input tile size");
3213  expectedOutputShape[getOutputWDim()] =
3214  (rightTransform ? m : 1) * valueTileW;
3215  }
3216  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3217  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3218 
3219  auto outputType = cast<ShapedType>(getOutput().getType());
3220  ArrayRef<int64_t> outputShape = outputType.getShape();
3221  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3222  return emitOpError("the output shape is not expected");
3223  }
3224  return success();
3225 }
3226 
3228 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3229  Location loc = getLoc();
3230  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3231  IntegerAttr oneAttr = builder.getIndexAttr(1);
3232  Value value = getValue();
3233  int64_t valueRank = getValueOperandRank();
3234  SmallVector<Range> loopBounds(valueRank);
3235  for (unsigned dim = 0; dim < valueRank; ++dim) {
3236  loopBounds[dim].offset = zeroAttr;
3237  // alphaH, alphaW, tileH, tileW, N, F
3238  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3239  loopBounds[dim].stride = oneAttr;
3240  }
3241  return loopBounds;
3242 }
3243 
3245 WinogradOutputTransformOp::getLoopIteratorTypes() {
3246  int64_t valueRank = getValueOperandRank();
3247  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3248  utils::IteratorType::parallel);
3249  return iteratorTypes;
3250 }
3251 
3252 LogicalResult WinogradOutputTransformOp::getResultTilePosition(
3253  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3254  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3255  SmallVector<OpFoldResult> &resultSizes) {
3256  int64_t m = getM();
3257 
3258  Location loc = getLoc();
3259  MLIRContext *context = builder.getContext();
3260  auto affineMap =
3261  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3262 
3263  Value mappedOffsetH = affine::makeComposedAffineApply(
3264  builder, loc, affineMap, offsets[getValueTileHDim()]);
3265  Value mappedOffsetW = affine::makeComposedAffineApply(
3266  builder, loc, affineMap, offsets[getValueTileWDim()]);
3267  Value mappedSizeH = affine::makeComposedAffineApply(
3268  builder, loc, affineMap, sizes[getValueTileHDim()]);
3269  Value mappedSizeW = affine::makeComposedAffineApply(
3270  builder, loc, affineMap, sizes[getValueTileWDim()]);
3271 
3272  ShapedType valueType = getValueOperandType();
3273  ArrayRef<int64_t> valueShape = valueType.getShape();
3274  int64_t valueH = valueShape[0];
3275  int64_t valueW = valueShape[1];
3276  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3277  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3278  OpFoldResult offsetH =
3279  valueH != 1 ? OpFoldResult(mappedOffsetH) : OpFoldResult(zeroAttr);
3280  OpFoldResult offsetW =
3281  valueW != 1 ? OpFoldResult(mappedOffsetW) : OpFoldResult(zeroAttr);
3282  OpFoldResult sizeH =
3283  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3284  OpFoldResult sizeW =
3285  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3286 
3287  resultOffsets.append(
3288  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3289  resultSizes.append(
3290  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3291  return success();
3292 }
3293 
3294 /// Implement tiling for winograd_output_transform
3295 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3296 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3297 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3298 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3299 /// for the sizes of tileH, tileW, N, F for one tile.
3300 FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
3301  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3302  ArrayRef<OpFoldResult> sizes) {
3303  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3304  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3305  Location loc = getLoc();
3306  SmallVector<Value> tiledOperands;
3307  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3308 
3309  ShapedType valueType = getValueOperandType();
3310  ArrayRef<int64_t> valueShape = valueType.getShape();
3311  int64_t alphaH = valueShape[getValueAlphaHDim()];
3312  int64_t alphaW = valueShape[getValueAlphaWDim()];
3313  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3314  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3315 
3316  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3317  offsets[getValueTileWDim()], offsets[getValueNDim()],
3318  offsets[getValueFDim()]});
3319  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3320  sizes[getValueTileWDim()], sizes[getValueNDim()],
3321  sizes[getValueFDim()]});
3322  int64_t valueRank = getValueOperandRank();
3323  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3324  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3325  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides));
3326 
3327  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3328  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3329  resultSizes)))
3330  return failure();
3331 
3332  int64_t outputRank = getOutputOperandRank();
3333  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3334  tiledOperands.emplace_back(builder.create<tensor::ExtractSliceOp>(
3335  loc, getOutput(), resultOffsets, resultSizes, strides));
3336 
3337  SmallVector<Type> resultTypes;
3338  resultTypes.push_back(tiledOperands[1].getType());
3339  Operation *tiledOp =
3340  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3341 
3342  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
3343 }
3344 
3345 //===----------------------------------------------------------------------===//
3346 // LinalgDialect
3347 //===----------------------------------------------------------------------===//
3348 
3349 void LinalgDialect::getCanonicalizationPatterns(
3350  RewritePatternSet &results) const {
3351  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
3352  InferStaticShapeOfOperands>(getContext());
3353 }
3354 
3356  Attribute value, Type type,
3357  Location loc) {
3358  return arith::ConstantOp::materialize(builder, value, type, loc);
3359 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1814
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:272
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2735
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2806
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:120
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2284
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:260
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:331
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1647
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1695
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2780
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1447
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:156
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1316
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2757
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1207
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1174
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2197
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:298
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:949
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:291
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:52
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1370
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1476
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:324
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:186
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
unsigned getNumArguments()
Definition: Block.h:126
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
OpListType & getOperations()
Definition: Block.h:135
BlockArgListType getArguments()
Definition: Block.h:85
Operation & front()
Definition: Block.h:151
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:136
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:191
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:195
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:406
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:140
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:281
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:383
MLIRContext * getContext() const
Definition: Builders.h:55
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:285
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:337
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:403
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:449
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:408
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
unsigned getNumOperands()
Definition: Operation.h:341
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:507
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_iterator result_end()
Definition: Operation.h:409
MutableArrayRef< OpOperand > getOpOperands()
Definition: Operation.h:378
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
result_range getResults()
Definition: Operation.h:410
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:110
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1142
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1192
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2589
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2276
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:99
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2317
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2256
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2267
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:90
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:348
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:65
Include the generated interface declarations.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:768
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:607
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1957
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1960
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:1982
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1985
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.