MLIR  21.0.0git
LinalgOps.cpp
Go to the documentation of this file.
1 //===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Linalg operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
30 #include "mlir/IR/AffineMap.h"
31 #include "mlir/IR/Attributes.h"
34 #include "mlir/IR/Matchers.h"
37 #include "mlir/IR/PatternMatch.h"
40 
41 #include "llvm/ADT/DenseMap.h"
42 #include "llvm/ADT/STLExtras.h"
43 #include "llvm/ADT/SetOperations.h"
44 #include "llvm/ADT/SmallSet.h"
45 #include "llvm/ADT/SmallVector.h"
46 #include "llvm/ADT/StringSet.h"
47 #include "llvm/ADT/TypeSwitch.h"
48 #include "llvm/Support/FormatVariadic.h"
49 #include "llvm/Support/LogicalResult.h"
50 #include "llvm/Support/MathExtras.h"
51 #include "llvm/Support/raw_ostream.h"
52 #include <cassert>
53 #include <optional>
54 
55 using namespace mlir;
56 using namespace mlir::linalg;
57 
58 /// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
60  int64_t dim) {
61  auto type = cast<ShapedType>(v.getType());
62  if (!type.isDynamicDim(dim))
63  return builder.getIndexAttr(type.getDimSize(dim));
64 
65  return getAsOpFoldResult(
67  .Case<RankedTensorType>([&](RankedTensorType t) -> Value {
68  return builder.create<tensor::DimOp>(loc, v, dim);
69  })
70  .Case<MemRefType>([&](MemRefType t) -> Value {
71  return builder.create<memref::DimOp>(loc, v, dim);
72  }));
73 }
74 
75 /// Returns a memref.subview or a tensor.extract_slice based on the type of the
76 /// `source`.
77 static Operation *getSlice(OpBuilder &b, Location loc, Value source,
78  ArrayRef<OpFoldResult> offsets,
80  ArrayRef<OpFoldResult> strides) {
81  return TypeSwitch<Type, Operation *>(source.getType())
82  .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * {
83  return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
84  strides);
85  })
86  .Case<MemRefType>([&](MemRefType type) -> Operation * {
87  return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
88  strides);
89  })
90  .Default([&](Type t) -> Operation * { return nullptr; });
91 }
92 
93 //===----------------------------------------------------------------------===//
94 // Helper functions
95 //===----------------------------------------------------------------------===//
96 
98  int64_t dim) {
99  if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
100  return b.createOrFold<memref::DimOp>(loc, source, dim);
101  if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
102  return b.createOrFold<tensor::DimOp>(loc, source, dim);
103  llvm_unreachable("Expected MemRefType or TensorType");
104 }
105 
107  int64_t dim) {
108  auto shapedType = llvm::cast<ShapedType>(source.getType());
109  if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
110  return createOrFoldDimOp(b, loc, source, dim);
111  return b.getIndexAttr(shapedType.getDimSize(dim));
112 }
113 
114 //===----------------------------------------------------------------------===//
115 // Support for named Linalg ops defined in ods-gen.
116 //===----------------------------------------------------------------------===//
117 
120 
121 /// Fills the region of a structured operation using the provided
122 /// `regionBuilder`. The method is used by both named structured ops created by
123 /// ods-gen and by manually defined C++ ops. It is called by both builders and
124 /// parsers and creates a block with arguments corresponding to the elemental
125 /// types of `inputTypes` and `outputTypes`.
126 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127  TypeRange inputTypes, TypeRange outputTypes,
129  RegionBuilderFn regionBuilder) {
130  SmallVector<Type, 8> argTypes;
131  SmallVector<Location, 8> argLocs;
132  for (auto containers : {inputTypes, outputTypes}) {
133  for (auto t : containers) {
134  argTypes.push_back(
135  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
136 
137  // TODO: Pass in a proper location here.
138  argLocs.push_back(opBuilder.getUnknownLoc());
139  }
140  }
141 
142  // RAII.
143  OpBuilder::InsertionGuard guard(opBuilder);
144  Block *body =
145  opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
146 
147  opBuilder.setInsertionPointToStart(body);
148  ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
149  regionBuilder(b, *body, attrs);
150 
151  // indexing_maps is an auto-generated method.
152 
153  // iterator_types is an auto-generated method.
154 }
155 
156 /// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
157 /// The result types are derived automatically if `resultTensorTypes` is none.
158 /// The body of the operation is filled using `regionBuilder`. All ods-gen
159 /// created structured operations use the method to implement their builders.
161  std::optional<TypeRange> resultTensorTypes,
162  ValueRange inputs, ValueRange outputs,
163  ArrayRef<NamedAttribute> attributes,
164  RegionBuilderFn regionBuilder) {
165  // Derive the result types if needed.
166  SmallVector<Type> derivedResultTypes =
167  resultTensorTypes.value_or(TypeRange());
168  if (!resultTensorTypes)
169  copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
170  llvm::IsaPred<RankedTensorType>);
171 
172  state.addOperands(inputs);
173  state.addOperands(outputs);
174  state.addTypes(derivedResultTypes);
175 
176  state.addAttributes(attributes);
177  state.addAttribute(
178  "operandSegmentSizes",
179  b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
180  static_cast<int32_t>(outputs.size())}));
181 
182  // Create and fill the region of the structured operation.
183  Region &region = *state.addRegion();
184  fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
185  state.attributes.getAttrs(), regionBuilder);
186 }
187 
188 static void buildMatmulOp(OpBuilder &b, OperationState &state,
189  std::optional<TypeRange> resultTensorTypes,
190  ValueRange inputs, ValueRange outputs,
191  ArrayRef<NamedAttribute> attributes,
192  RegionBuilderFn regionBuilder,
193  ArrayRef<AffineMap> indexingMaps) {
194  // Initialize indexingMaps attribute, for MatmulOp.
195  SmallVector<Attribute, 3> indexingMapsAttrVal;
196  indexingMapsAttrVal = llvm::map_to_vector(
197  MatmulOp::getDefaultIndexingMaps(b.getContext()),
198  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
199  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
200  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
201  attributes, regionBuilder);
202 }
203 
205  std::optional<TypeRange> resultTensorTypes,
206  ValueRange inputs, ValueRange outputs,
207  ArrayRef<NamedAttribute> attributes,
208  RegionBuilderFn regionBuilder,
209  ArrayRef<AffineMap> indexingMaps) {
210  // Initialize indexingMaps attribute, for BatchMatmulOp.
211  SmallVector<Attribute, 4> indexingMapsAttrVal;
212  indexingMapsAttrVal =
213  llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
214  return AffineMapAttr::get(map);
215  });
216  state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
217  return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
218  attributes, regionBuilder);
219 }
220 
221 /// Common parsing used for both named structured ops created by ods-gen and by
222 /// manually defined C++ ops. Does not handle regions.
223 static ParseResult
225  SmallVectorImpl<Type> &inputTypes,
226  SmallVectorImpl<Type> &outputTypes,
227  bool addOperandSegmentSizes = true) {
228  SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
230  outputsOperands;
231 
232  if (succeeded(parser.parseOptionalLess())) {
233  if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater())
234  return failure();
235  }
236  attrsLoc = parser.getCurrentLocation();
237  if (parser.parseOptionalAttrDict(result.attributes))
238  return failure();
239 
240  if (succeeded(parser.parseOptionalKeyword("ins"))) {
241  if (parser.parseLParen())
242  return failure();
243 
244  inputsOperandsLoc = parser.getCurrentLocation();
245  if (parser.parseOperandList(inputsOperands) ||
246  parser.parseColonTypeList(inputTypes) || parser.parseRParen())
247  return failure();
248  }
249 
250  if (succeeded(parser.parseOptionalKeyword("outs"))) {
251  outputsOperandsLoc = parser.getCurrentLocation();
252  if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
253  parser.parseColonTypeList(outputTypes) || parser.parseRParen())
254  return failure();
255  }
256 
257  if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
258  result.operands) ||
259  parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
260  result.operands))
261  return failure();
262 
263  if (addOperandSegmentSizes) {
264  // This is a bit complex because we're trying to be backward compatible with
265  // operation syntax that mix the inherent attributes and the discardable
266  // ones in the same dictionary. If the properties are used, we append the
267  // operandSegmentSizes there directly. Otherwise we append it to the
268  // discardable attributes dictionary where it is handled by the generic
269  // Operation::create(...) method.
270  if (result.propertiesAttr) {
271  NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
272  attrs.append("operandSegmentSizes",
274  {static_cast<int32_t>(inputsOperands.size()),
275  static_cast<int32_t>(outputsOperands.size())}));
276  result.propertiesAttr = attrs.getDictionary(parser.getContext());
277  } else {
278  result.addAttribute("operandSegmentSizes",
280  {static_cast<int32_t>(inputsOperands.size()),
281  static_cast<int32_t>(outputsOperands.size())}));
282  }
283  }
284  if (!result.propertiesAttr) {
285  std::optional<RegisteredOperationName> info =
286  result.name.getRegisteredInfo();
287  if (info) {
288  if (failed(info->verifyInherentAttrs(result.attributes, [&]() {
289  return parser.emitError(attrsLoc)
290  << "'" << result.name.getStringRef() << "' op ";
291  })))
292  return failure();
293  }
294  }
295  return success();
296 }
297 
299  ValueRange outputs) {
300  if (!inputs.empty())
301  p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
302  if (!outputs.empty())
303  p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // Specific parsing and printing for named structured ops created by ods-gen.
308 //===----------------------------------------------------------------------===//
309 
310 static ParseResult parseNamedStructuredOpRegion(
311  OpAsmParser &parser, Region &region, unsigned numRegionArgs,
312  TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
313  RegionBuilderFn regionBuilder) {
314  if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
315  return parser.emitError(
316  parser.getCurrentLocation(),
317  llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
318  "region expects {0} args, got {1}",
319  numRegionArgs, inputTypes.size() + outputTypes.size()));
320  }
321 
322  OpBuilder opBuilder(parser.getContext());
323  fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
324  regionBuilder);
325  return success();
326 }
327 
328 static ParseResult
330  SmallVectorImpl<Type> &resultTypes) {
331  if (parser.parseOptionalArrowTypeList(resultTypes))
332  return failure();
333  return success();
334 }
335 
336 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
337  OperationState &result,
338  unsigned numRegionArgs,
339  RegionBuilderFn regionBuilder) {
340  // TODO: Enable when ods-gen supports captures.
341  SmallVector<Type, 1> inputTypes, outputTypes;
342  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
343  return failure();
344 
345  // Parse optional attributes.
346  if (parser.parseOptionalAttrDict(result.attributes))
347  return failure();
348 
349  // TODO: consider merging results parsing into region parsing.
350  // Need to wait for declarative assembly resolution to decide.
351  SmallVector<Type, 1> outputTensorsTypes;
352  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
353  return failure();
354  result.addTypes(outputTensorsTypes);
355 
356  std::unique_ptr<Region> region = std::make_unique<Region>();
357  if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
358  outputTypes, result.attributes.getAttrs(),
359  regionBuilder))
360  return failure();
361  result.addRegion(std::move(region));
362 
363  return success();
364 }
365 
367  TypeRange resultTypes) {
368  if (resultTypes.empty())
369  return;
370  p.printOptionalArrowTypeList(resultTypes);
371 }
372 
374  ValueRange inputs, ValueRange outputs,
375  ArrayRef<StringRef> elidedAttrs = {}) {
376  p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
377 
378  // Printing is shared with generic ops, except for the region and
379  // attributes.
380  printCommonStructuredOpParts(p, inputs, outputs);
381 
382  // Results printing.
384 
385  // Region is elided.
386 }
387 
388 //===----------------------------------------------------------------------===//
389 // Region builder helper.
390 // TODO: Move this to a utility library.
391 // The public methods on this class are referenced directly from generated code.
392 // Helper build the unary, binary, and type conversion functions defined by the
393 // DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
394 // class.
395 //
396 // Implementations of the math functions must be polymorphic over numeric types,
397 // internally performing necessary casts. If the function application makes no
398 // sense, then the only recourse is to assert and return nullptr. This can be
399 // extended later if it becomes possible to fail construction of the region. The
400 // invariant should be enforced at a higher level.
401 //
402 // TODO: These helpers are currently type polymorphic over the class of integer
403 // and floating point types, but they will not internally cast within bit
404 // widths of a class (mixed precision such as i8->i32) or across classes
405 // (i.e. mixed float and integer). Many such combinations are ambiguous or need
406 // to be handled with care and work is being considered to extend the op
407 // language to make such cases explicit. In the mean-time, violating this will
408 // fail verification, which is deemed acceptable.
409 //===----------------------------------------------------------------------===//
410 
411 namespace {
412 
413 class RegionBuilderHelper {
414 public:
415  RegionBuilderHelper(OpBuilder &builder, Block &block)
416  : builder(builder), block(block) {}
417 
418  // Build the unary functions defined by OpDSL.
419  Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
420  if (!isFloatingPoint(arg))
421  llvm_unreachable("unsupported non numeric type");
422  OpBuilder::InsertionGuard g(builder);
423  builder.setInsertionPointToEnd(&block);
424  switch (unaryFn) {
425  case UnaryFn::exp:
426  return builder.create<math::ExpOp>(arg.getLoc(), arg);
427  case UnaryFn::log:
428  return builder.create<math::LogOp>(arg.getLoc(), arg);
429  case UnaryFn::abs:
430  return builder.create<math::AbsFOp>(arg.getLoc(), arg);
431  case UnaryFn::ceil:
432  return builder.create<math::CeilOp>(arg.getLoc(), arg);
433  case UnaryFn::floor:
434  return builder.create<math::FloorOp>(arg.getLoc(), arg);
435  case UnaryFn::negf:
436  return builder.create<arith::NegFOp>(arg.getLoc(), arg);
437  case UnaryFn::reciprocal: {
438  Attribute oneAttr = builder.getOneAttr(arg.getType());
439  auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
440  ::cast<TypedAttr>(oneAttr));
441  return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
442  }
443  case UnaryFn::round:
444  return builder.create<math::RoundOp>(arg.getLoc(), arg);
445  case UnaryFn::sqrt:
446  return builder.create<math::SqrtOp>(arg.getLoc(), arg);
447  case UnaryFn::rsqrt:
448  return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
449  case UnaryFn::square:
450  return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
451  case UnaryFn::tanh:
452  return builder.create<math::TanhOp>(arg.getLoc(), arg);
453  case UnaryFn::erf:
454  return builder.create<math::ErfOp>(arg.getLoc(), arg);
455  }
456  llvm_unreachable("unsupported unary function");
457  }
458 
459  // Build the binary functions defined by OpDSL.
460  Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
461  bool allComplex = isComplex(arg0) && isComplex(arg1);
462  bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
463  bool allInteger = isInteger(arg0) && isInteger(arg1);
464  bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
465  arg1.getType().getIntOrFloatBitWidth() == 1;
466  if (!allComplex && !allFloatingPoint && !allInteger)
467  llvm_unreachable("unsupported non numeric type");
468  OpBuilder::InsertionGuard g(builder);
469  builder.setInsertionPointToEnd(&block);
470  switch (binaryFn) {
471  case BinaryFn::add:
472  if (allComplex)
473  return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
474  if (allFloatingPoint)
475  return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
476  if (allBool)
477  return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
478  return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
479  case BinaryFn::sub:
480  if (allComplex)
481  return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
482  if (allFloatingPoint)
483  return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
484  if (allBool)
485  llvm_unreachable("unsupported operation: sub with bools");
486  return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
487  case BinaryFn::mul:
488  if (allComplex)
489  return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
490  if (allFloatingPoint)
491  return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
492  if (allBool)
493  return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
494  return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
495  case BinaryFn::div:
496  if (allComplex)
497  return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
498  if (allFloatingPoint)
499  return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
500  if (allBool)
501  llvm_unreachable("unsupported operation: div with bools");
502  return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
503  case BinaryFn::div_unsigned:
504  if (!allInteger || allBool)
505  llvm_unreachable("unsupported operation: unsigned div not on uint");
506  return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
507  case BinaryFn::max_signed:
508  assert(!allComplex);
509  if (allFloatingPoint)
510  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
511  return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
512  case BinaryFn::min_signed:
513  assert(!allComplex);
514  if (allFloatingPoint)
515  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
516  return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
517  case BinaryFn::max_unsigned:
518  assert(!allComplex);
519  if (allFloatingPoint)
520  return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
521  return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
522  case BinaryFn::min_unsigned:
523  assert(!allComplex);
524  if (allFloatingPoint)
525  return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
526  return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
527  case BinaryFn::powf:
528  assert(allFloatingPoint);
529  return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1);
530  }
531  llvm_unreachable("unsupported binary function");
532  }
533 
534  // Build the ternary functions defined by OpDSL.
535  Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
536  Value arg2) {
537  bool headBool =
538  isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
539  bool tailFloatingPoint =
540  isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
541  bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg2);
542  OpBuilder::InsertionGuard g(builder);
543  builder.setInsertionPointToEnd(&block);
544  switch (ternaryFn) {
545  case TernaryFn::select:
546  if (!headBool && !(tailFloatingPoint || tailInteger))
547  llvm_unreachable("unsupported non numeric type");
548  return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
549  }
550  llvm_unreachable("unsupported ternary function");
551  }
552 
553  // Build the type functions defined by OpDSL.
554  Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
555  switch (typeFn) {
556  case TypeFn::cast_signed:
557  return cast(toType, operand, false);
558  case TypeFn::cast_unsigned:
559  return cast(toType, operand, true);
560  }
561  llvm_unreachable("unsupported type conversion function");
562  }
563 
564  void yieldOutputs(ValueRange values) {
565  OpBuilder::InsertionGuard g(builder);
566  builder.setInsertionPointToEnd(&block);
567  Location loc = builder.getUnknownLoc();
568  builder.create<YieldOp>(loc, values);
569  }
570 
571  Value constant(const std::string &value) {
572  OpBuilder::InsertionGuard g(builder);
573  builder.setInsertionPointToEnd(&block);
574  Location loc = builder.getUnknownLoc();
575  Attribute valueAttr = parseAttribute(value, builder.getContext());
576  return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
577  }
578 
579  Value index(int64_t dim) {
580  OpBuilder::InsertionGuard g(builder);
581  builder.setInsertionPointToEnd(&block);
582  return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
583  }
584 
585  Type getIntegerType(unsigned width) {
586  return IntegerType::get(builder.getContext(), width);
587  }
588 
589  Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
590  Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
591 
592 private:
593  // Generates operations to cast the given operand to a specified type.
594  // If the cast cannot be performed, a warning will be issued and the
595  // operand returned as-is (which will presumably yield a verification
596  // issue downstream).
597  Value cast(Type toType, Value operand, bool isUnsignedCast) {
598  OpBuilder::InsertionGuard g(builder);
599  builder.setInsertionPointToEnd(&block);
600  auto loc = operand.getLoc();
601  return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
602  }
603 
604  bool isComplex(Value value) {
605  return llvm::isa<ComplexType>(value.getType());
606  }
607  bool isFloatingPoint(Value value) {
608  return llvm::isa<FloatType>(value.getType());
609  }
610  bool isInteger(Value value) {
611  return llvm::isa<IntegerType>(value.getType());
612  }
613 
614  OpBuilder &builder;
615  Block &block;
616 };
617 
618 } // namespace
619 
620 //===----------------------------------------------------------------------===//
621 // CopyOp
622 //===----------------------------------------------------------------------===//
623 
624 namespace {
625 
626 struct EraseSelfCopy : OpRewritePattern<CopyOp> {
628  LogicalResult matchAndRewrite(CopyOp copyOp,
629  PatternRewriter &rewriter) const override {
630  if (copyOp.getInputs() != copyOp.getOutputs())
631  return rewriter.notifyMatchFailure(copyOp, "not a self copy");
632  if (copyOp.hasPureBufferSemantics())
633  rewriter.eraseOp(copyOp);
634  else
635  rewriter.replaceOp(copyOp, copyOp.getInputs());
636 
637  return success();
638  }
639 };
640 
641 } // namespace
642 
643 void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
644  MLIRContext *context) {
645  results.add<EraseSelfCopy>(context);
646 }
647 
648 //===----------------------------------------------------------------------===//
649 // FillOp
650 //===----------------------------------------------------------------------===//
651 
652 namespace {
653 
654 /// Fold linalg.fill -> tensor.expand/collapse_shape chain.
655 ///
656 /// For such op chains, we can create new linalg.fill ops with the result
657 /// type of the tensor.expand/collapse_shape op.
658 template <typename TensorReshapeOp>
659 struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
661  LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
662  PatternRewriter &rewriter) const override {
663  auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
664  if (!oldFill)
665  return failure();
666 
667  Location loc = oldFill.getLoc();
668  TensorReshapeOp newInit;
669  if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) {
670 
671  newInit = rewriter.create<TensorReshapeOp>(
672  loc, reshapeOp.getResultType(), oldFill.output(),
673  reshapeOp.getReassociation(), reshapeOp.getOutputShape(),
674  reshapeOp.getStaticOutputShape());
675  } else {
676  newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(),
677  oldFill.output(),
678  reshapeOp.getReassociation());
679  }
680  rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
681  ValueRange{newInit});
682  return success();
683  }
684 };
685 
686 /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
687 /// filling value are the same.
688 struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
690 
691  LogicalResult matchAndRewrite(tensor::PadOp padOp,
692  PatternRewriter &rewriter) const override {
693  auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
694  if (!fillOp)
695  return failure();
696 
697  // We can only fold if the padding value is the same as the original
698  // filling value.
699  Value padValue = padOp.getConstantPaddingValue();
700  if (!padValue || fillOp.value() != padValue)
701  return failure();
702 
703  ReifiedRankedShapedTypeDims reifiedShape;
704  if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
705  return rewriter.notifyMatchFailure(
706  padOp, "failed to reify tensor.pad op result shape");
707 
708  auto emptyTensor = rewriter.create<tensor::EmptyOp>(
709  padOp.getLoc(), reifiedShape.front(),
710  padOp.getResultType().getElementType());
711  Value replacement =
712  rewriter
713  .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
714  ValueRange{emptyTensor})
715  .getResult(0);
716  if (replacement.getType() != padOp.getResultType()) {
717  replacement = rewriter.create<tensor::CastOp>(
718  fillOp.getLoc(), padOp.getResultType(), replacement);
719  }
720  rewriter.replaceOp(padOp, replacement);
721  return success();
722  }
723 };
724 
725 /// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
726 /// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
727 /// filling value are the same.
728 struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
730 
731  LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
732  PatternRewriter &rewriter) const override {
733  auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
734  if (!srcPadOp)
735  return failure();
736 
737  if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
738  return failure();
739 
740  // Walk back the tensor.insert_slice chain and find the first destination
741  // value at the start of the chain.
742  Value firstDest = insertOp.getDest();
743  while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
744  if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
745  return failure();
746 
747  // Make sure the range of values accessed are disjoint. Without this, we
748  // cannot fold tensor.pad away.
749  bool disjoint = false;
750  for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
751  // If the dimension has dynamic offset/size, we cannot guarantee
752  // disjoint. So just skip it.
753  if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
754  insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
755  prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
756  continue;
757 
758  // Get the range start and end, inclusively for both.
759  int64_t prevStart = prevOp.getStaticOffset(i);
760  int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
761  prevOp.getStaticStride(i);
762  int64_t nextStart = insertOp.getStaticOffset(i);
763  int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
764  insertOp.getStaticStride(i);
765  if (prevEnd < nextStart || nextEnd < prevStart) {
766  disjoint = true;
767  break;
768  }
769  }
770 
771  if (!disjoint)
772  break;
773  firstDest = prevOp.getDest();
774  }
775 
776  // Check whether the first destination is a fill op. For overlapped cases,
777  // this also cannot be true.
778  auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
779  if (!dstFillOp)
780  return failure();
781 
782  // We can only fold if the padding value is the same as the original
783  // filling value.
784  Value padValue = srcPadOp.getConstantPaddingValue();
785  if (!padValue || dstFillOp.value() != padValue)
786  return failure();
787 
788  SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
789  SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
790 
791  Location loc = insertOp.getLoc();
792  MLIRContext *context = getContext();
793 
794  AffineExpr sym0, sym1;
795  bindSymbols(context, sym0, sym1);
796  auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
797 
798  // Calculate the new offsets for the insert. It should be the old offsets
799  // plus low padding sizes.
800  SmallVector<OpFoldResult, 4> newOffsets;
801  for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
802  newOffsets.push_back(affine::makeComposedFoldedAffineApply(
803  rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
804  }
805 
806  RankedTensorType srcPadType = srcPadOp.getSourceType();
808  for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
809  if (srcPadType.isDynamicDim(i)) {
810  newSizes.push_back(
811  rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
812  .getResult());
813  } else {
814  newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i)));
815  }
816  }
817 
818  rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
819  insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
820  newSizes, insertOp.getMixedStrides());
821  return success();
822  }
823 };
824 
825 /// Fold tensor.extract(linalg.fill(<input>)) into <input>
826 struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
827 public:
829 
830  LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
831  PatternRewriter &rewriter) const override {
832  // See if tensor input of tensor.extract op is the result of a linalg.fill
833  // op.
834  auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
835  if (!fillOp)
836  return failure();
837 
838  // Get scalar input operand of linalg.fill op.
839  Value extractedScalar = fillOp.getInputs()[0];
840 
841  // Replace tensor.extract op with scalar value used to fill the tensor.
842  rewriter.replaceOp(extractOp, extractedScalar);
843  return success();
844  }
845 };
846 
847 /// Folds pack(fill) into a single fill op if
848 /// 1. The pack op does not have padding value, or
849 /// 2. The filled value and padding value are the same.
850 static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
851  linalg::PackOp packOp) {
852  auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
853  if (!fillOp)
854  return failure();
855 
856  if (auto paddingValue = packOp.getPaddingValue())
857  if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
858  return failure();
859 
860  Value packOpDest = packOp.getDest();
861  if (!packOpDest.hasOneUse())
862  return failure();
863 
864  return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
865  packOp.getDest());
866 }
867 
868 /// Wrapper pattern that applies foldFillPackIntoFillOp method.
869 struct FoldFillWithPack : public OpRewritePattern<linalg::PackOp> {
870 public:
871  FoldFillWithPack(MLIRContext *context)
872  : OpRewritePattern<linalg::PackOp>(context) {}
873 
874  LogicalResult matchAndRewrite(linalg::PackOp packOp,
875  PatternRewriter &rewriter) const override {
876  auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
877  if (failed(fillOp))
878  return failure();
879  rewriter.replaceOp(packOp, fillOp.value().result());
880  return success();
881  }
882 };
883 
884 /// Fold fill with copy.
885 struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
887 
888  LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
889  PatternRewriter &rewriter) const override {
890  if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
891  rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
892  fillOp.getInputs(),
893  copyOp.getOutputs());
894  return success();
895  }
896  if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
897  rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
898  fillOp.getOutputs());
899  return success();
900  }
901  return failure();
902  }
903 };
904 
905 /// Fold fill with transpose.
906 struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
908 
909  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
910  PatternRewriter &rewriter) const override {
911  if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
912  rewriter.replaceOpWithNewOp<FillOp>(
913  transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
914  transposeOp.getDpsInitOperand(0)->get());
915  return success();
916  }
917  return failure();
918  }
919 };
920 
921 /// Fold a concat with all elements being fills of the same value
922 /// into a fill of the concat result shape.
923 struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> {
925 
926  LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
927  PatternRewriter &rewriter) const override {
928  auto concatOperands = concatOp.getInputs();
929  if (concatOperands.empty()) {
930  return failure();
931  }
932 
933  auto firstFillOp = concatOperands.front().getDefiningOp<linalg::FillOp>();
934  if (!firstFillOp) {
935  return failure();
936  }
937  // Prefetch the fill value.
938  OpFoldResult firstFillVal =
939  getAsOpFoldResult(firstFillOp.getDpsInputOperand(0)->get());
940  // Collect all the outs values for the fill operations.
941  SmallVector<Value> allOuts;
942  allOuts.push_back(firstFillOp.getDpsInitOperand(0)->get());
943 
944  auto isDefinedByCompatibleFillOp = [&](Value v) -> bool {
945  auto fillOp = v.getDefiningOp<linalg::FillOp>();
946  if (!fillOp) {
947  return false;
948  }
949 
950  OpFoldResult fillVal =
951  getAsOpFoldResult(fillOp.getDpsInputOperand(0)->get());
952  if (fillVal != firstFillVal)
953  return false;
954 
955  allOuts.push_back(fillOp.getDpsInitOperand(0)->get());
956  return true;
957  };
958  if (!llvm::all_of(concatOperands.drop_front(),
959  isDefinedByCompatibleFillOp)) {
960  return rewriter.notifyMatchFailure(
961  concatOp, "not all operands are defined by a compatible fill op");
962  }
963 
964  Value outsConcat = rewriter.create<tensor::ConcatOp>(
965  concatOp.getLoc(), concatOp.getDim(), allOuts);
966  rewriter.replaceOpWithNewOp<linalg::FillOp>(
967  concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat);
968  return success();
969  }
970 };
971 
972 } // namespace
973 
974 void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
975  MLIRContext *context) {
976  results.add<FoldConcatsOfFill, FoldFillWithCopy, FoldFillWithTensorExtract,
977  FoldFillWithPack, FoldFillWithPad,
978  FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
979  FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
980  FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
981 }
982 
983 //===----------------------------------------------------------------------===//
984 // GenericOp
985 //===----------------------------------------------------------------------===//
986 
987 static void buildGenericRegion(
988  OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
989  ValueRange outputs,
990  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
991  SmallVector<Type, 4> blockArgTypes;
992  SmallVector<Location, 4> blockArgLocs;
993  for (ValueRange container : {inputs, outputs}) {
994  for (Value v : container) {
995  Type t = v.getType();
996  blockArgTypes.push_back(
997  isa<MemRefType, RankedTensorType>(t) ? getElementTypeOrSelf(t) : t);
998  blockArgLocs.push_back(v.getLoc());
999  }
1000  }
1001 
1002  OpBuilder::InsertionGuard guard(builder);
1003  Block *bodyBlock =
1004  builder.createBlock(&region, region.end(), blockArgTypes, blockArgLocs);
1005  bodyBuild(builder, loc, bodyBlock->getArguments());
1006 }
1007 
1008 void GenericOp::getAsmBlockArgumentNames(Region &region,
1009  OpAsmSetValueNameFn setNameFn) {
1010  for (Value v : getRegionInputArgs())
1011  setNameFn(v, "in");
1012  for (Value v : getRegionOutputArgs())
1013  setNameFn(v, "out");
1014 }
1015 
1016 void GenericOp::build(
1017  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1018  ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
1019  ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
1020  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1021  ArrayRef<NamedAttribute> attributes) {
1022  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1023  iteratorTypes, doc, libraryCall);
1024  result.addAttributes(attributes);
1025  if (bodyBuild)
1026  buildGenericRegion(builder, result.location, *result.regions.front(),
1027  inputs, outputs, bodyBuild);
1028 }
1029 
1030 void GenericOp::build(
1031  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1032  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1033  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1034  StringRef libraryCall,
1035  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1036  ArrayRef<NamedAttribute> attributes) {
1037  build(builder, result, resultTensorTypes, inputs, outputs,
1038  builder.getAffineMapArrayAttr(indexingMaps),
1039  builder.getArrayAttr(llvm::to_vector(llvm::map_range(
1040  iteratorTypes,
1041  [&](utils::IteratorType iter) -> mlir::Attribute {
1042  return IteratorTypeAttr::get(builder.getContext(), iter);
1043  }))),
1044  doc.empty() ? StringAttr() : builder.getStringAttr(doc),
1045  libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
1046  bodyBuild, attributes);
1047 }
1048 
1049 void GenericOp::build(
1050  OpBuilder &builder, OperationState &result, ValueRange inputs,
1051  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1052  ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
1053  StringRef libraryCall,
1054  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1055  ArrayRef<NamedAttribute> attributes) {
1056  build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
1057  iteratorTypes, doc, libraryCall, bodyBuild, attributes);
1058 }
1059 
1060 void GenericOp::build(
1061  OpBuilder &builder, OperationState &result, ValueRange inputs,
1062  ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1063  ArrayRef<utils::IteratorType> iteratorTypes,
1064  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1065  ArrayRef<NamedAttribute> attributes) {
1066  build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
1067  /*doc=*/"",
1068  /*libraryCall=*/"", bodyBuild, attributes);
1069 }
1070 
1071 void GenericOp::build(
1072  OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
1073  ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
1074  ArrayRef<utils::IteratorType> iteratorTypes,
1075  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1076  ArrayRef<NamedAttribute> attributes) {
1077  build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
1078  iteratorTypes,
1079  /*doc=*/"",
1080  /*libraryCall=*/"", bodyBuild, attributes);
1081 }
1082 
1083 void GenericOp::print(OpAsmPrinter &p) {
1084  p << " ";
1085 
1086  // Print extra attributes.
1087  auto genericAttrNames = linalgTraitAttrNames();
1088 
1089  llvm::StringSet<> genericAttrNamesSet;
1090  genericAttrNamesSet.insert_range(genericAttrNames);
1091  SmallVector<NamedAttribute, 8> genericAttrs;
1092  for (auto attr : (*this)->getAttrs()) {
1093  if (attr.getName() == getIteratorTypesAttrName()) {
1094  auto iteratorTypes =
1095  llvm::cast<ArrayAttr>(attr.getValue())
1096  .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
1097  // Convert IteratorType enums into the string representation. This is
1098  // needed, because tests still use the old format when 'iterator_types'
1099  // attribute is represented as an array of strings.
1100  // TODO: Remove this conversion once tests are fixed.
1101  SmallVector<Attribute> iteratorTypeNames =
1102  llvm::to_vector(llvm::map_range(
1103  iteratorTypes, [&](utils::IteratorType t) -> Attribute {
1104  return StringAttr::get(getContext(), stringifyIteratorType(t));
1105  }));
1106 
1107  genericAttrs.emplace_back(
1108  getIteratorTypesAttrName(),
1109  ArrayAttr::get(getContext(), iteratorTypeNames));
1110  } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
1111  genericAttrs.push_back(attr);
1112  }
1113  }
1114  if (!genericAttrs.empty()) {
1115  auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
1116  p << genericDictAttr;
1117  }
1118 
1119  // Printing is shared with named ops, except for the region and attributes
1120  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1121 
1122  genericAttrNames.push_back("operandSegmentSizes");
1123  genericAttrNamesSet.insert(genericAttrNames.back());
1124 
1125  bool hasExtraAttrs = false;
1126  for (NamedAttribute n : (*this)->getAttrs()) {
1127  if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
1128  break;
1129  }
1130  if (hasExtraAttrs) {
1131  p << " attrs = ";
1132  p.printOptionalAttrDict((*this)->getAttrs(),
1133  /*elidedAttrs=*/genericAttrNames);
1134  }
1135 
1136  // Print region.
1137  if (!getRegion().empty()) {
1138  p << ' ';
1139  p.printRegion(getRegion());
1140  }
1141 
1142  // Print results.
1143  printNamedStructuredOpResults(p, getResultTensors().getTypes());
1144 }
1145 
1146 ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1147  DictionaryAttr dictAttr;
1148  // Parse the core linalg traits that must check into a dictAttr.
1149  // The name is unimportant as we will overwrite result.attributes.
1150  // The core linalg traits must contain the information necessary to pass the
1151  // verifier.
1152  llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1153  if (parser.parseAttribute(dictAttr, "_", result.attributes))
1154  return failure();
1155  result.attributes.assign(dictAttr.getValue().begin(),
1156  dictAttr.getValue().end());
1157 
1158  // Convert array of string into an array of IteratorType enums. This is
1159  // needed, because tests still use the old format when 'iterator_types'
1160  // attribute is represented as an array of strings.
1161  // TODO: Remove this conversion once tests are fixed.
1162  auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1163  result.attributes.get(getIteratorTypesAttrName(result.name)));
1164  if (!iteratorTypes) {
1165  return parser.emitError(attributeLocation)
1166  << "expected " << getIteratorTypesAttrName(result.name)
1167  << " array attribute";
1168  }
1169 
1170  SmallVector<Attribute> iteratorTypeAttrs;
1171 
1172  for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1173  auto maybeIteratorType = utils::symbolizeIteratorType(s);
1174  if (!maybeIteratorType.has_value())
1175  return parser.emitError(parser.getCurrentLocation())
1176  << "unexpected iterator_type (" << s << ")";
1177 
1178  iteratorTypeAttrs.push_back(
1179  IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1180  }
1181  result.attributes.set(getIteratorTypesAttrName(result.name),
1182  parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1183 
1184  // Parsing is shared with named ops, except for the region.
1185  SmallVector<Type, 1> inputTypes, outputTypes;
1186  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1187  return failure();
1188 
1189  // Optional attributes may be added.
1190  if (succeeded(parser.parseOptionalKeyword("attrs")))
1191  if (failed(parser.parseEqual()) ||
1192  failed(parser.parseOptionalAttrDict(result.attributes)))
1193  return failure();
1194 
1195  std::unique_ptr<Region> region = std::make_unique<Region>();
1196  if (parser.parseRegion(*region, {}))
1197  return failure();
1198  result.addRegion(std::move(region));
1199 
1200  // Generic ops may specify that a subset of its outputs are tensors. Such
1201  // outputs are specified in the result type.
1202  // TODO: may need to move output parsing before region parsing.
1203  // Need to wait for declarative assembly resolution to decide.
1204  SmallVector<Type, 1> outputTensorsTypes;
1205  if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1206  return failure();
1207  result.addTypes(outputTensorsTypes);
1208 
1209  return success();
1210 }
1211 
1214  &effects,
1215  LinalgOp linalgOp) {
1216  for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
1217  if (!llvm::isa<MemRefType>(operand.getType()))
1218  continue;
1219  effects.emplace_back(
1220  MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
1221  /*effectOnFullRegion=*/true, SideEffects::DefaultResource::get());
1222  }
1223 
1224  for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
1225  if (!llvm::isa<MemRefType>(operand.get().getType()))
1226  continue;
1227  if (linalgOp.payloadUsesValueFromOperand(&operand)) {
1228  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
1229  /*effectOnFullRegion=*/true,
1231  }
1232  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
1233  /*effectOnFullRegion=*/true,
1235  }
1236 }
1237 
1238 void GenericOp::getEffects(
1240  &effects) {
1241  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1242 }
1243 
1245 getGenericSpeculatabilityImpl(LinalgOp linalgOp) {
1246  // Operands with value semantics are speculatable, while operands with memory
1247  // semantics are not.
1248  if (!linalgOp.hasPureTensorSemantics())
1250  // The body of the op can still have speculation in its region.
1252 }
1253 
1254 Speculation::Speculatability GenericOp::getSpeculatability() {
1255  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1256 }
1257 
1258 LogicalResult GenericOp::verify() { return success(); }
1259 
1260 namespace {
1261 
1262 /// Remove any linalg operation (on tensors) that are just copying
1263 /// the values from inputs to the results. Requirements are
1264 /// 1) All iterator types are parallel
1265 /// 2) The body contains just a yield operation with the yielded values being
1266 /// the arguments corresponding to the operands.
1267 template <typename OpTy>
1268 struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1270 
1271  LogicalResult matchAndRewrite(OpTy linalgOp,
1272  PatternRewriter &rewriter) const override {
1273  // All indexing maps must be equal. It follows that they are permutations.
1274  if (!llvm::all_equal(linalgOp.getIndexingMapsArray()))
1275  return failure();
1276 
1277  // Check that the body of the linalg operation is just a linalg.yield
1278  // operation.
1279  Block &body = linalgOp->getRegion(0).front();
1280  if (!llvm::hasSingleElement(body))
1281  return failure();
1282  auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1283  if (!yieldOp)
1284  return failure();
1285 
1286  // In the buffer case, we need to check exact buffer equality.
1287  if (linalgOp.hasPureBufferSemantics()) {
1288  if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1289  linalgOp.getDpsInputOperand(0)->get() ==
1290  linalgOp.getDpsInitOperand(0)->get()) {
1291  rewriter.eraseOp(linalgOp);
1292  return success();
1293  }
1294  return failure();
1295  }
1296 
1297  // Mixed semantics is not supported yet.
1298  if (!linalgOp.hasPureTensorSemantics())
1299  return failure();
1300 
1301  // Get the argument number of the returned values. That is the operand
1302  // number to use for replacing uses of this operation.
1303  SmallVector<Value> returnedArgs;
1304  for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1305  auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1306  if (!yieldArg || yieldArg.getOwner() != &body)
1307  return failure();
1308  unsigned argumentNumber = yieldArg.getArgNumber();
1309  Value returnedArg = linalgOp->getOperand(argumentNumber);
1310  Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1311  // The input can have a different type than the result, e.g. a dynamic
1312  // input dimension can be turned into a static output dimension.
1313  Type returnType = returnedArg.getType();
1314  if (returnType != resultType) {
1315  // Distinguish between sparse conversion or dense tensor casting.
1316  // TODO: unify the two ops?
1317  if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1319  returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1320  linalgOp.getLoc(), resultType, returnedArg);
1321  else {
1322  if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1323  resultType))
1324  return failure();
1325  returnedArg = rewriter.create<tensor::CastOp>(
1326  linalgOp.getLoc(), resultType, returnedArg);
1327  }
1328  }
1329  returnedArgs.push_back(returnedArg);
1330  }
1331 
1332  if (returnedArgs.size() != linalgOp->getNumResults())
1333  return failure();
1334  rewriter.replaceOp(linalgOp, returnedArgs);
1335  return success();
1336  }
1337 };
1338 
1339 } // namespace
1340 
1341 void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1342  MLIRContext *context) {
1343  results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1344 }
1345 
1346 LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1347  return memref::foldMemRefCast(*this);
1348 }
1349 
1350 //===----------------------------------------------------------------------===//
1351 // MapOp
1352 //===----------------------------------------------------------------------===//
1353 
1354 static ParseResult parseDstStyleOp(
1355  OpAsmParser &parser, OperationState &result,
1356  function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1357  nullptr) {
1358  // Parse `ins` and `outs`.
1359  SmallVector<Type, 4> inputTypes, outputTypes;
1360  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1361  /*addOperandSegmentSizes=*/false))
1362  return failure();
1363 
1364  // Add result types.
1365  for (Type outputType : outputTypes) {
1366  if (llvm::isa<RankedTensorType>(outputType))
1367  result.addTypes(outputType);
1368  }
1369 
1370  // Parse required attributes.
1371  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
1372  return failure();
1373 
1374  // Parse optional attributes.
1375  if (parser.parseOptionalAttrDict(result.attributes))
1376  return failure();
1377  return success();
1378 }
1379 
1380 void MapOp::getAsmBlockArgumentNames(Region &region,
1381  OpAsmSetValueNameFn setNameFn) {
1382  for (Value v : getRegionInputArgs())
1383  setNameFn(v, "in");
1384 }
1385 
1386 void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1387  if (!getResults().empty())
1388  setNameFn(getResults().front(), "mapped");
1389 }
1390 
1391 void MapOp::build(
1392  OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1393  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1394  ArrayRef<NamedAttribute> attributes) {
1395  build(builder, result, TypeRange{}, inputs, init);
1396  result.addAttributes(attributes);
1397 
1398  // Add output types for `RankedTensorType` output arguments.
1399  Type initType = init.getType();
1400  if (llvm::isa<RankedTensorType>(initType))
1401  result.addTypes(initType);
1402 
1403  if (bodyBuild)
1404  buildGenericRegion(builder, result.location, *result.regions.front(),
1405  inputs, /*outputs=*/{}, bodyBuild);
1406 }
1407 
1408 static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1409  const OperationName &payloadOpName,
1410  const NamedAttrList &payloadOpAttrs,
1411  ArrayRef<Value> operands,
1412  bool initFirst = false) {
1413  OpBuilder b(parser.getContext());
1414  Region *body = result.addRegion();
1415  Block &block = body->emplaceBlock();
1416  b.setInsertionPointToStart(&block);
1417  SmallVector<Value> bbArgs;
1418  for (auto &operand : operands) {
1419  block.addArgument(
1420  llvm::cast<ShapedType>(operand.getType()).getElementType(),
1421  b.getUnknownLoc());
1422  }
1423  SmallVector<Value> payloadOpOperands;
1424  // If initFirst flag is enabled, we consider init as the first position of
1425  // payload operands.
1426  if (initFirst) {
1427  payloadOpOperands.push_back(block.getArguments().back());
1428  for (const auto &arg : block.getArguments().drop_back())
1429  payloadOpOperands.push_back(arg);
1430  } else {
1431  payloadOpOperands = {block.getArguments().begin(),
1432  block.getArguments().end()};
1433  }
1434 
1435  Operation *payloadOp = b.create(
1436  result.location, b.getStringAttr(payloadOpName.getStringRef()),
1437  payloadOpOperands,
1438  TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1439  .getElementType()},
1440  payloadOpAttrs);
1441  b.create<YieldOp>(result.location, payloadOp->getResults());
1442 }
1443 
1444 ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1445  std::optional<OperationName> payloadOpName;
1446  NamedAttrList payloadOpAttrs;
1447  if (succeeded(parser.parseOptionalLBrace())) {
1448  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1449  if (failed(operationName))
1450  return failure();
1451  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1452  return failure();
1453  payloadOpName = operationName.value();
1454  if (parser.parseRBrace())
1455  return failure();
1456  }
1457 
1458  if (parseDstStyleOp(parser, result))
1459  return failure();
1460 
1461  if (payloadOpName.has_value()) {
1462  if (!result.operands.empty())
1463  addBodyWithPayloadOp(parser, result, payloadOpName.value(),
1464  payloadOpAttrs,
1465  ArrayRef(result.operands).drop_back());
1466  else
1467  result.addRegion();
1468  } else {
1470  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1471  /*allowType=*/true, /*allowAttrs=*/true)) {
1472  return failure();
1473  }
1474  Region *body = result.addRegion();
1475  if (parser.parseRegion(*body, regionArgs))
1476  return failure();
1477  }
1478  return success();
1479 }
1480 
1481 // Retrieve the operation from the body, if it is the only one (except
1482 // yield) and if it gets the same amount of arguments as the body does.
1483 // If initFirst flag is enabled, we check that init takes the first position in
1484 // operands of payload.
1485 static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1486  if (body->getOperations().size() != 2)
1487  return nullptr;
1488  Operation &payload = body->getOperations().front();
1489  assert(isa<YieldOp>(body->getOperations().back()));
1490 
1491  if (payload.getNumOperands() == 0 ||
1492  payload.getNumOperands() != body->getNumArguments())
1493  return nullptr;
1494  if (initFirst) {
1495  // check init
1496  if (payload.getOperands().back() != body->getArgument(0))
1497  return nullptr;
1498  // check rest
1499  for (const auto &[operand, bbArg] :
1500  llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
1501  if (bbArg != operand)
1502  return nullptr;
1503  }
1504  } else {
1505  for (const auto &[operand, bbArg] :
1506  llvm::zip(payload.getOperands(), body->getArguments())) {
1507  if (bbArg != operand)
1508  return nullptr;
1509  }
1510  }
1511  return &payload;
1512 }
1513 
1514 void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1515  SmallVector<StringRef> elidedAttrs;
1516  std::string attrToElide;
1517  p << " { " << payloadOp->getName().getStringRef();
1518  for (const auto &attr : payloadOp->getAttrs()) {
1519  auto fastAttr =
1520  llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1521  if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1522  attrToElide = attr.getName().str();
1523  elidedAttrs.push_back(attrToElide);
1524  break;
1525  }
1526  }
1527  p.printOptionalAttrDict(payloadOp->getAttrs(), elidedAttrs);
1528  p << " }";
1529 }
1530 
1531 void MapOp::print(OpAsmPrinter &p) {
1532  Block *mapper = getBody();
1533  Operation *payloadOp = findPayloadOp(mapper);
1534  if (payloadOp) {
1535  printShortForm(p, payloadOp);
1536  }
1537 
1538  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1539  p.printOptionalAttrDict((*this)->getAttrs());
1540 
1541  if (!payloadOp) {
1542  // Print region if the payload op was not detected.
1543  p.increaseIndent();
1544  p.printNewline();
1545  p << "(";
1546  llvm::interleaveComma(mapper->getArguments(), p,
1547  [&](auto arg) { p.printRegionArgument(arg); });
1548  p << ") ";
1549 
1550  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1551  p.decreaseIndent();
1552  }
1553 }
1554 
1555 LogicalResult MapOp::verify() {
1556  auto *bodyBlock = getBody();
1557  auto blockArgs = bodyBlock->getArguments();
1558 
1559  // Checks if the number of `inputs` match the arity of the `mapper` region.
1560  if (getInputs().size() != blockArgs.size())
1561  return emitOpError() << "expects number of operands to match the arity of "
1562  "mapper, but got: "
1563  << getInputs().size() << " and " << blockArgs.size();
1564 
1565  // The parameters of mapper should all match the element type of inputs.
1566  for (const auto &[bbArgType, inputArg] :
1567  llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1568  auto inputElemType =
1569  llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1570  if (bbArgType != inputElemType) {
1571  return emitOpError() << "expected element type of input " << inputElemType
1572  << " to match bbArg type " << bbArgType;
1573  }
1574  }
1575 
1576  // The shape of each input must match the shape of the output.
1577  auto outputShape = getInit().getType().getShape();
1578  for (Type inputArgType : TypeRange{getInputs()}) {
1579  auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1580  if (inputElemShape != outputShape) {
1581  return emitOpError() << "expected shape of input (" << inputElemShape
1582  << ") to match shape of output (" << outputShape
1583  << ")";
1584  }
1585  }
1586 
1587  return success();
1588 }
1589 
1590 SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1591  int64_t rank = getInit().getType().getRank();
1592  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1593 }
1594 
1595 ArrayAttr MapOp::getIndexingMaps() {
1596  Builder builder(getContext());
1597  int64_t rank = getInit().getType().getRank();
1598  int64_t numIndexingMaps = getOperands().size();
1600  numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1601 }
1602 
1603 void MapOp::getEffects(
1605  &effects) {
1606  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1607 }
1608 
1609 Speculation::Speculatability MapOp::getSpeculatability() {
1610  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1611 }
1612 
1613 //===----------------------------------------------------------------------===//
1614 // ReduceOp
1615 //===----------------------------------------------------------------------===//
1616 
1617 void ReduceOp::getAsmBlockArgumentNames(Region &region,
1618  OpAsmSetValueNameFn setNameFn) {
1619  for (Value v : getRegionInputArgs())
1620  setNameFn(v, "in");
1621  for (Value v : getRegionOutputArgs())
1622  setNameFn(v, "init");
1623 }
1624 
1625 void ReduceOp::getAsmResultNames(
1626  function_ref<void(Value, StringRef)> setNameFn) {
1627  if (!getResults().empty())
1628  setNameFn(getResults().front(), "reduced");
1629 }
1630 
1631 void ReduceOp::build(
1632  OpBuilder &builder, OperationState &result, ValueRange inputs,
1633  ValueRange inits, ArrayRef<int64_t> dimensions,
1634  function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1635  ArrayRef<NamedAttribute> attributes) {
1636  build(builder, result, TypeRange{}, inputs, inits, dimensions);
1637  result.addAttributes(attributes);
1638 
1639  // Add output types for `RankedTensorType` output arguments.
1640  for (Value init : inits) {
1641  Type initType = init.getType();
1642  if (llvm::isa<RankedTensorType>(initType))
1643  result.addTypes(initType);
1644  }
1645 
1646  if (bodyBuild)
1647  buildGenericRegion(builder, result.location, *result.regions.front(),
1648  inputs, inits, bodyBuild);
1649 }
1650 
1651 SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1652  int64_t inputRank =
1653  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1654  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1655  utils::IteratorType::parallel);
1656  for (int64_t reductionDim : getDimensions())
1657  iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1658  return iteratorTypes;
1659 }
1660 
1661 ArrayAttr ReduceOp::getIndexingMaps() {
1662  int64_t inputRank =
1663  llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1664  SmallVector<AffineMap> affineMaps(
1665  getNumDpsInputs(),
1667  AffineMap resultMap =
1669  .dropResults(getDimensions());
1670  for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1671  affineMaps.push_back(resultMap);
1672  return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1673 }
1674 
1675 void ReduceOp::getEffects(
1677  &effects) {
1678  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1679 }
1680 
1681 Speculation::Speculatability ReduceOp::getSpeculatability() {
1682  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1683 }
1684 
1685 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1686  NamedAttrList &attributes,
1687  StringRef attributeName) {
1688  if (parser.parseKeyword(attributeName) || parser.parseEqual())
1689  return failure();
1690 
1691  attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1692  return success();
1693 }
1694 
1695 ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1696  std::optional<OperationName> payloadOpName;
1697  NamedAttrList payloadOpAttrs;
1698  if (succeeded(parser.parseOptionalLBrace())) {
1699  FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1700  if (failed(operationName))
1701  return failure();
1702  if (parser.parseOptionalAttrDict(payloadOpAttrs))
1703  return failure();
1704  payloadOpName = operationName.value();
1705  if (parser.parseRBrace())
1706  return failure();
1707  }
1708 
1709  if (parseDstStyleOp(
1710  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1711  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1712  }))
1713  return failure();
1714 
1715  if (payloadOpName.has_value()) {
1716  addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1717  ArrayRef(result.operands), /*initFirst=*/true);
1718  } else {
1720  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1721  /*allowType=*/true, /*allowAttrs=*/true)) {
1722  return failure();
1723  }
1724 
1725  Region *body = result.addRegion();
1726  if (parser.parseRegion(*body, regionArgs))
1727  return failure();
1728  }
1729 
1730  return success();
1731 }
1732 
1733 static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1734  ArrayRef<int64_t> attributeValue) {
1735  p << ' ' << attributeName << " = [" << attributeValue << "] ";
1736 }
1737 
1738 void ReduceOp::print(OpAsmPrinter &p) {
1739  Block *mapper = getBody();
1740  Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1741  if (payloadOp) {
1742  printShortForm(p, payloadOp);
1743  }
1744 
1745  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1746  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1747  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1748  if (!payloadOp) {
1749  // Print region if the payload op was not detected.
1750  p.increaseIndent();
1751  p.printNewline();
1752  p << "(";
1753  llvm::interleaveComma(mapper->getArguments(), p,
1754  [&](auto arg) { p.printRegionArgument(arg); });
1755  p << ") ";
1756 
1757  p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1758  p.decreaseIndent();
1759  }
1760 }
1761 
1762 LogicalResult ReduceOp::verify() {
1763  ArrayRef<int64_t> dimensionsRef = getDimensions();
1764 
1765  for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1766  if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1767  llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1768  return emitOpError() << "expects all inputs to have the same shapes. "
1769  "Shape at input-index "
1770  << i
1771  << " is not equal to the shape at input-index 0.";
1772  }
1773  }
1774  for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1775  if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1776  llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1777  return emitOpError() << "expects all outputs to have the same shapes. "
1778  "Shape at output-index "
1779  << i
1780  << " is not equal to the shape at output-index 0.";
1781  }
1782  }
1783  auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1784  auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1785 
1786  DenseSet<int64_t> dimensionsToReduce;
1787  for (int64_t dimension : dimensionsRef) {
1788  if (dimension < 0 || dimension >= inputType.getRank()) {
1789  return emitOpError()
1790  << "dimensions for reduction should be in the range [0, "
1791  << inputType.getRank() - 1 << "].";
1792  }
1793  dimensionsToReduce.insert(dimension);
1794  }
1795 
1796  auto inputDims = inputType.getShape();
1797  auto initDims = initType.getShape();
1798 
1799  // Input dimensions that will be left after the reduction.
1800  SmallVector<int64_t> reducedInputDims;
1801  for (const auto &en : llvm::enumerate(inputDims)) {
1802  if (!dimensionsToReduce.count(en.index()))
1803  reducedInputDims.push_back(en.value());
1804  }
1805 
1806  if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1807  return emitOpError() << "number of dimensions after reduction "
1808  << reducedInputDims.size()
1809  << " doesn't match the init rank "
1810  << initType.getRank();
1811  }
1812 
1813  if (reducedInputDims != initDims)
1814  return emitOpError() << "init dimensions [" << initDims
1815  << "] doesn't match input dimensions after reduction ["
1816  << reducedInputDims << "]";
1817 
1818  Block *block = getBody();
1819  if (block->getNumArguments() != this->getNumOperands())
1820  return emitOpError()
1821  << "mismatching number of operands and block arguments";
1822 
1823  // Check that the first block arguments match the element type of the inputs.
1824  for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1825  Type inputElementType =
1826  llvm::cast<ShapedType>(input.getType()).getElementType();
1827  if (inputElementType != bbArg.getType())
1828  return emitOpError()
1829  << "input element type " << inputElementType
1830  << " does not match corresponding block argument type "
1831  << bbArg.getType();
1832  }
1833 
1834  // Check that the last block arguments match the element type of the outputs.
1835  for (auto [output, bbArg] : llvm::zip(
1836  getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1837  auto outputElementType =
1838  llvm::cast<ShapedType>(output.getType()).getElementType();
1839  if (outputElementType != bbArg.getType())
1840  return emitOpError()
1841  << "output element type " << outputElementType
1842  << " does not match corresponding block argument type "
1843  << bbArg.getType();
1844  }
1845  return success();
1846 }
1847 
1848 //===----------------------------------------------------------------------===//
1849 // TransposeOp
1850 //===----------------------------------------------------------------------===//
1851 
1852 static void buildIdentityRegion(OpBuilder &builder, Location loc,
1853  Region &region, ValueRange inputs,
1854  ValueRange outputs) {
1855  buildGenericRegion(builder, loc, region, inputs, outputs,
1856  [](OpBuilder &b, Location loc, ValueRange args) {
1857  if (!args.empty())
1858  b.create<linalg::YieldOp>(loc, args[0]);
1859  });
1860 }
1861 
1862 void TransposeOp::build(::mlir::OpBuilder &builder,
1863  ::mlir::OperationState &result, Value input, Value init,
1864  DenseI64ArrayAttr permutation,
1865  ArrayRef<NamedAttribute> attributes) {
1866  result.addOperands(input);
1867  result.addOperands(init);
1868  result.addAttribute(getPermutationAttrName(result.name), permutation);
1869  result.addAttributes(attributes);
1870 
1871  // Add output types for `RankedTensorType` output arguments.
1872  Type initType = init.getType();
1873  if (llvm::isa<RankedTensorType>(initType))
1874  result.addTypes(initType);
1875 
1876  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1877  init);
1878 }
1879 
1880 void TransposeOp::build(::mlir::OpBuilder &builder,
1881  ::mlir::OperationState &result, Value input, Value init,
1882  ArrayRef<int64_t> permutation,
1883  ArrayRef<NamedAttribute> attributes) {
1884  build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1885  attributes);
1886 }
1887 
1888 ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1889  if (failed(parseDstStyleOp(
1890  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1891  return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1892  })))
1893  return failure();
1894 
1895  OpBuilder builder(parser.getContext());
1896  buildIdentityRegion(builder, result.location, *result.addRegion(),
1897  /*inputs=*/result.operands,
1898  /*outputs=*/{});
1899  return success();
1900 }
1901 
1902 void TransposeOp::getAsmResultNames(
1903  function_ref<void(Value, StringRef)> setNameFn) {
1904  if (!getResults().empty())
1905  setNameFn(getResults().front(), "transposed");
1906 }
1907 
1909  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1910  printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1911  p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1912 }
1913 
1914 LogicalResult TransposeOp::verify() {
1915  ArrayRef<int64_t> permutationRef = getPermutation();
1916 
1917  if (!isPermutationVector(permutationRef))
1918  return emitOpError("permutation is not valid");
1919 
1920  auto inputType = getInput().getType();
1921  auto initType = getInit().getType();
1922 
1923  int64_t rank = inputType.getRank();
1924 
1925  if (rank != initType.getRank())
1926  return emitOpError() << "input rank " << rank
1927  << " does not match init rank " << initType.getRank();
1928 
1929  if (rank != static_cast<int64_t>(permutationRef.size()))
1930  return emitOpError() << "size of permutation " << permutationRef.size()
1931  << " does not match the argument rank " << rank;
1932 
1933  auto inputDims = inputType.getShape();
1934  auto initDims = initType.getShape();
1935 
1936  for (int64_t i = 0; i < rank; ++i) {
1937  int64_t inputDim = inputDims[permutationRef[i]];
1938  int64_t initDim = initDims[i];
1939 
1940  if (inputDim != initDim) {
1941  return emitOpError() << "dim(result, " << i << ") = " << initDim
1942  << " doesn't match dim(input, permutation[" << i
1943  << "]) = " << inputDim;
1944  }
1945  }
1946 
1947  return success();
1948 }
1949 
1950 SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1951  int64_t rank = getInit().getType().getRank();
1952  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1953 }
1954 
1955 ArrayAttr TransposeOp::getIndexingMaps() {
1956  Builder builder(getContext());
1957  int64_t rank = getInit().getType().getRank();
1958  return builder.getAffineMapArrayAttr(
1960  llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1961  builder.getMultiDimIdentityMap(rank)});
1962 }
1963 
1964 void TransposeOp::getEffects(
1966  &effects) {
1967  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
1968 }
1969 
1970 Speculation::Speculatability TransposeOp::getSpeculatability() {
1971  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
1972 }
1973 
1974 LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1976  // Only the tensor type is supported.
1977  if (!isa<TensorType>(getInput().getType()))
1978  return failure();
1979 
1980  // Single dimension transpose.
1981  if (getPermutation().size() == 0) {
1982  result.push_back(getInput());
1983  return success();
1984  }
1985  // Identity permutation.
1986  if (isIdentityPermutation(getPermutation())) {
1987  result.push_back(getInput());
1988  return success();
1989  }
1990 
1991  return failure();
1992 }
1993 
1994 /// Fold transpose with transpose.
1995 struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
1997 
1998  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
1999  PatternRewriter &rewriter) const override {
2000  auto defTransposeOp = transposeOp.getInput().getDefiningOp<TransposeOp>();
2001  if (!defTransposeOp)
2002  return failure();
2003  ArrayRef<int64_t> defPerms = defTransposeOp.getPermutation();
2004  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2005  SmallVector<int64_t> foldedPerms;
2006  foldedPerms.reserve(perms.size());
2007  for (int64_t perm : perms)
2008  foldedPerms.push_back(defPerms[perm]);
2009 
2010  rewriter.replaceOpWithNewOp<TransposeOp>(
2011  transposeOp, defTransposeOp.getInput(), transposeOp.getInit(),
2012  foldedPerms);
2013  return success();
2014  }
2015 };
2016 
2017 /// This pattern canonicalize transpose by swapping the order of
2018 /// broadcast and transpose:
2019 /// transpose(broadcast(input)) -> broadcast(transpose(input))
2020 struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
2022 
2023  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
2024  PatternRewriter &rewriter) const override {
2025  Value input = transposeOp.getInput();
2026  BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
2027  if (!input.hasOneUse() || !broadcastOp)
2028  return failure();
2029 
2030  ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
2031  ArrayRef<int64_t> perms = transposeOp.getPermutation();
2032 
2033  // Get new perms and new dimensions.
2034  SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
2035  SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
2036  SmallVector<int64_t> resultDimensions;
2037  unsigned dimensionSize = dimensions.size();
2038  for (unsigned i = 0; i < dimensionSize; ++i)
2039  resultDimensions.push_back(invertPerm[dimensions[i]]);
2040 
2041  // Create transpose result.
2042  Value broadcastInput = broadcastOp.getInput();
2043  Location loc = transposeOp.getLoc();
2044  MLIRContext *ctx = transposeOp.getContext();
2046  auto broadcastInputTy =
2047  mlir::cast<RankedTensorType>(broadcastInput.getType());
2048  unsigned inputRank = broadcastInputTy.getRank();
2049  for (unsigned i = 0; i < inputRank; ++i) {
2050  if (broadcastInputTy.isDynamicDim(i)) {
2051  dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
2052  ->getResult(0));
2053  } else {
2054  dims.push_back(IntegerAttr::get(IndexType::get(ctx),
2055  broadcastInputTy.getDimSize(i)));
2056  }
2057  }
2058  SmallVector<OpFoldResult> transposeResultShapes =
2059  applyPermutation(dims, resultPerms);
2060  Value transposeInit = rewriter.create<tensor::EmptyOp>(
2061  transposeOp.getLoc(), transposeResultShapes,
2062  broadcastInputTy.getElementType());
2063 
2064  // Create broadcast(transpose(input)).
2065  Value transposeResult =
2066  rewriter
2067  .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
2068  resultPerms)
2069  ->getResult(0);
2070  rewriter.replaceOpWithNewOp<BroadcastOp>(
2071  transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
2072  return success();
2073  }
2074 };
2075 
2076 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2077  MLIRContext *context) {
2079 }
2080 
2081 //===----------------------------------------------------------------------===//
2082 // BroadcastOp
2083 //===----------------------------------------------------------------------===//
2084 
2085 void BroadcastOp::build(::mlir::OpBuilder &builder,
2086  ::mlir::OperationState &result, Value input, Value init,
2087  DenseI64ArrayAttr dimensions,
2088  ArrayRef<NamedAttribute> attributes) {
2089  result.addOperands(input);
2090  result.addOperands(init);
2091  result.addAttribute(getDimensionsAttrName(result.name), dimensions);
2092  result.addAttributes(attributes);
2093 
2094  // Add output types for `RankedTensorType` output arguments.
2095  Type initType = init.getType();
2096  if (llvm::isa<RankedTensorType>(initType))
2097  result.addTypes(initType);
2098 
2099  buildIdentityRegion(builder, result.location, *result.addRegion(), input,
2100  init);
2101 }
2102 
2103 void BroadcastOp::build(::mlir::OpBuilder &builder,
2104  ::mlir::OperationState &result, Value input, Value init,
2105  ArrayRef<int64_t> dimensions,
2106  ArrayRef<NamedAttribute> attributes) {
2107  build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
2108  attributes);
2109 }
2110 
2111 ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
2112  if (failed(parseDstStyleOp(
2113  parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
2114  return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
2115  })))
2116  return failure();
2117 
2118  OpBuilder builder(parser.getContext());
2119  buildIdentityRegion(builder, result.location, *result.addRegion(),
2120  /*inputs=*/result.operands,
2121  /*outputs=*/{});
2122  return success();
2123 }
2124 
2125 void BroadcastOp::getAsmResultNames(
2126  function_ref<void(Value, StringRef)> setNameFn) {
2127  if (!getResults().empty())
2128  setNameFn(getResults().front(), "broadcasted");
2129 }
2130 
2132  printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
2133  printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
2134  p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
2135 }
2136 
2137 LogicalResult BroadcastOp::verify() {
2138  ArrayRef<int64_t> dimensionsRef = getDimensions();
2139 
2140  auto inputType = getInput().getType();
2141  auto initType = getInit().getType();
2142 
2143  int64_t inputRank = inputType.getRank();
2144  int64_t initRank = initType.getRank();
2145 
2146  auto inputShape = inputType.getShape();
2147  auto initShape = initType.getShape();
2148 
2149  if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
2150  return emitOpError() << "input rank plus added dimensions does not "
2151  "match init rank. input rank: "
2152  << inputRank
2153  << ", dimensions size: " << dimensionsRef.size()
2154  << ", init rank: " << initRank;
2155 
2156  for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
2157  if (dim < 0 || dim >= initRank)
2158  return emitOpError() << "dimension " << idx
2159  << " is out of range. expected range: [0, "
2160  << initRank - 1 << "], got: " << dim;
2161  }
2162 
2163  // Mapping from input dims to init dims.
2164  SmallVector<int64_t> dimMap;
2165  for (auto dim : llvm::seq<int64_t>(0, initRank)) {
2166  if (!llvm::is_contained(dimensionsRef, dim))
2167  dimMap.push_back(dim);
2168  }
2169 
2170  for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
2171  // This dimensions is mapped from the input. Init and input dims should
2172  // match.
2173  if (inputShape[inputDimIdx] != initShape[initDimIdx])
2174  return emitOpError() << "input dim " << inputDimIdx
2175  << " should match init dim " << initDimIdx
2176  << ". input: " << inputShape[inputDimIdx]
2177  << ", init: " << initShape[initDimIdx];
2178  }
2179 
2180  return success();
2181 }
2182 
2183 SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
2184  int64_t rank = getInit().getType().getRank();
2185  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
2186 }
2187 
2188 ArrayAttr BroadcastOp::getIndexingMaps() {
2189  Builder builder(getContext());
2190  int64_t rank = getInit().getType().getRank();
2191  return builder.getAffineMapArrayAttr(
2192  {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
2193  builder.getMultiDimIdentityMap(rank)});
2194 }
2195 
2196 void BroadcastOp::getEffects(
2198  &effects) {
2199  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
2200 }
2201 
2202 Speculation::Speculatability BroadcastOp::getSpeculatability() {
2203  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
2204 }
2205 
2206 void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2207  MLIRContext *context) {
2208  results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
2209 }
2210 
2211 //===----------------------------------------------------------------------===//
2212 // YieldOp
2213 //===----------------------------------------------------------------------===//
2214 
2216  if (getNumOperands() > 0)
2217  p << ' ' << getOperands();
2218  p.printOptionalAttrDict((*this)->getAttrs());
2219  if (getNumOperands() > 0)
2220  p << " : " << getOperandTypes();
2221 }
2222 
2223 ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
2225  SmallVector<Type, 2> types;
2226  SMLoc loc = parser.getCurrentLocation();
2227  return failure(parser.parseOperandList(opInfo) ||
2228  parser.parseOptionalAttrDict(result.attributes) ||
2229  (!opInfo.empty() && parser.parseColonTypeList(types)) ||
2230  parser.resolveOperands(opInfo, types, loc, result.operands));
2231 }
2232 
2233 // Check the operand number and types must match the element types of the
2234 // LinalgOp interface's shaped operands.
2235 static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
2236  if (op.getNumOperands() != linalgOp.getNumDpsInits())
2237  return op.emitOpError("expected number of yield values (")
2238  << op.getNumOperands()
2239  << ") to match the number of inits / outs operands of the enclosing "
2240  << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
2241 
2242  for (OpOperand &opOperand : op->getOpOperands()) {
2243  OpOperand *outputOperand =
2244  linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
2245  Type elementType = outputOperand->get().getType();
2246  if (isa<MemRefType, RankedTensorType>(elementType))
2247  elementType = getElementTypeOrSelf(outputOperand->get().getType());
2248  if (opOperand.get().getType() != elementType)
2249  return op.emitOpError("type of yield operand ")
2250  << (opOperand.getOperandNumber() + 1) << " ("
2251  << opOperand.get().getType() << ") doesn't match "
2252  << "the element type of the enclosing linalg.generic op ("
2253  << elementType << ")";
2254  }
2255  return success();
2256 }
2257 
2258 LogicalResult linalg::YieldOp::verify() {
2259  auto *parentOp = (*this)->getParentOp();
2260  if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
2261  return emitOpError("expected single non-empty parent region");
2262 
2263  if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2264  return verifyYield(*this, linalgOp);
2265 
2266  return emitOpError("expected parent op with LinalgOp interface");
2267 }
2268 
2269 //===----------------------------------------------------------------------===//
2270 // IndexOp
2271 //===----------------------------------------------------------------------===//
2272 
2273 LogicalResult IndexOp::verify() {
2274  auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2275  if (!linalgOp)
2276  return emitOpError("expected parent op with LinalgOp interface");
2277  if (linalgOp.getNumLoops() <= getDim())
2278  return emitOpError("expected dim (")
2279  << getDim() << ") to be lower than the number of loops ("
2280  << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2281  return success();
2282 }
2283 
2284 /////// Operations corresponding to library calls defined with Tablegen ////////
2285 
2286 #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2287 
2288 #define GET_OP_CLASSES
2289 #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2290 
2291 #define GET_OP_CLASSES
2292 #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2293 #define GET_OP_CLASSES
2294 #include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.cpp.inc"
2295 
2296 AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2297  unsigned rank,
2298  MLIRContext *context) {
2299  if (maybeMap)
2300  return *maybeMap;
2301  if (rank == 0)
2302  return AffineMap::get(context);
2303  return AffineMap::getMultiDimIdentityMap(rank, context);
2304 }
2305 
2307 mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2308  MLIRContext *context) {
2310  res.reserve(num);
2311  for (unsigned i = 0; i < num; ++i)
2312  res.push_back(getAffineDimExpr(startIdx++, context));
2313  return res;
2314 }
2315 
2318  auto rangeA = llvm::make_range(a.begin(), a.end());
2319  auto rangeB = llvm::make_range(b.begin(), b.end());
2320  auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
2321  return llvm::to_vector<4>(concatRanges);
2322 }
2323 
2324 static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2325  if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2326  ss << "view";
2327  for (auto size : memref.getShape())
2328  if (size < 0)
2329  ss << "sx";
2330  else
2331  ss << size << "x";
2332  if (failed(appendMangledType(ss, memref.getElementType())))
2333  return failure();
2334  if (auto as = memref.getMemorySpace()) {
2335  if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2336  ss << "as" << attr.getInt();
2337  else
2338  return failure();
2339  }
2340  return success();
2341  }
2342  if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2343  ss << "vector";
2344  llvm::interleave(
2345  vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2346  if (failed(appendMangledType(ss, vec.getElementType())))
2347  return failure();
2348  return success();
2349  }
2350  if (t.isSignlessIntOrIndexOrFloat()) {
2351  ss << t;
2352  return success();
2353  }
2354  return failure();
2355 }
2356 
2358  assert(isa<LinalgOp>(op));
2359  std::string name(op->getName().getStringRef().str());
2360  std::string fun = "";
2361  for (NamedAttribute kv : op->getAttrs()) {
2362  if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2363  fun = stringifyEnum(ufa.getValue()).str() + "_";
2364  } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2365  fun = stringifyEnum(bfa.getValue()).str() + "_";
2366  }
2367  }
2368  name.reserve(128);
2369  std::replace(name.begin(), name.end(), '.', '_');
2370  llvm::raw_string_ostream ss(name);
2371  ss << "_" << fun;
2372  for (Type t : op->getOperandTypes()) {
2373  if (failed(appendMangledType(ss, t)))
2374  return std::string();
2375  ss << "_";
2376  }
2377  name.pop_back();
2378  return name;
2379 }
2380 
2381 //===----------------------------------------------------------------------===//
2382 // Canonicalizers and Folders.
2383 //===----------------------------------------------------------------------===//
2384 
2385 namespace {
2386 struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2388 
2389  LogicalResult matchAndRewrite(LinalgOp op,
2390  PatternRewriter &rewriter) const override {
2391  for (OpOperand &opOperand : op->getOpOperands()) {
2392  // Linalg "inputs" may be either tensor or memref type.
2393  // tensor<0xelt_type> is a convention that may not always mean
2394  // "0 iterations". Only erase in cases we see memref<...x0x...>.
2395  auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2396  if (!mt)
2397  continue;
2398  if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2399  rewriter.eraseOp(op);
2400  return success();
2401  }
2402  }
2403  return failure();
2404  }
2405 };
2406 
2407 /// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2408 /// result that is more static than the linalg op.
2409 struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2411 
2412  LogicalResult matchAndRewrite(tensor::CastOp castOp,
2413  PatternRewriter &rewriter) const override {
2414  if (!tensor::canFoldIntoProducerOp(castOp))
2415  return failure();
2416 
2417  auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2418  if (!linalgOp)
2419  return failure();
2420 
2421  // Cast can be in conditionally reachable region, if which case folding will
2422  // generate invalid code. Only conservatively fold ops in same block for
2423  // now.
2424  if (castOp->getBlock() != linalgOp->getBlock())
2425  return failure();
2426 
2427  OpBuilder::InsertionGuard guard(rewriter);
2428  rewriter.setInsertionPoint(linalgOp);
2429 
2430  Location loc = linalgOp.getLoc();
2431  OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2432  unsigned resultNumber = resultValue.getResultNumber();
2433  auto resultType =
2434  llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2435  // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2436  // going from a more dynamic shape to a less dynamic shape. If the producer
2437  // for this cast, i.e. producer of the out operand, is also an operation
2438  // that folds with tensor.cast consumer (like this pattern), the cast will
2439  // continue to propagate as far up the stack as it can go.
2440  OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2441  Value newOperand =
2442  rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2443  SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2444  SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2445  linalgOp.getDpsInits().end());
2446  outputOperands[resultNumber] = newOperand;
2447  newOperands.append(outputOperands.begin(), outputOperands.end());
2448 
2449  SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2450  linalgOp->result_type_end());
2451  resultTypes[resultNumber] = resultType;
2452  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2453 
2454  // Create a tensor.cast operation back to the original type.
2455  Value castBack = rewriter.create<tensor::CastOp>(
2456  loc, resultValue.getType(), newOp->getResult(resultNumber));
2457 
2458  SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2459  results[resultNumber] = castBack;
2460  rewriter.replaceOp(linalgOp, results);
2461  rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
2462  return success();
2463  }
2464 };
2465 
2466 /// For each of the operand in `operands` this function maps the static sizes of
2467 /// dimensions to their affine dim expressions.
2468 static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2469  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2470  for (OpOperand &opOperand : operands) {
2471  if (linalgOp.isScalar(&opOperand))
2472  continue;
2473  Value src = opOperand.get();
2474  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2475  auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2476 
2477  // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2478  // `tensor.cast` operation and source of the cast operation has a static
2479  // shape, then assign it to the `sourceShape`.
2480  auto *parentOp = src.getDefiningOp();
2481  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2482  if (parentOp) {
2483  if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2484  Value castSource = castOp.getSource();
2485  auto castSourceType =
2486  llvm::dyn_cast<RankedTensorType>(castSource.getType());
2487  if (castSourceType && castSourceType.hasStaticShape())
2488  sourceShape = castSourceType.getShape();
2489  }
2490  }
2491 
2492  // If the source shape's dimension has a static shape, map the affine dim
2493  // expression to the known static size.
2494  for (unsigned i = 0; i < sourceShape.size(); i++) {
2495  if (sourceType.isDynamicDim(i))
2496  continue;
2497  if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2498  affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2499  }
2500  }
2501 }
2502 
2503 /// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2504 /// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2505 /// their result types is stored in `resultTypes`. If `opOperand` requires no
2506 /// change then `changeNeeded` is false and same operand is added in the
2507 /// `newOperands` list.
2508 static void createNewOperandWithStaticSizes(
2509  Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2510  llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2511  SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2512  bool &changeNeeded) {
2513  Value src = opOperand->get();
2514  newOperands.push_back(src);
2515  if (linalgOp.isScalar(opOperand))
2516  return;
2517  auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2518  Type resultType = sourceType;
2519  if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2520  resultTypes.push_back(resultType);
2521  return;
2522  }
2523  ArrayRef<int64_t> sourceShape = sourceType.getShape();
2524  AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2525  SmallVector<int64_t> newShape;
2526  // If operand is updated with new shape, `newOperandNeeded` will be
2527  // true.
2528  bool newOperandNeeded = false;
2529  for (unsigned i = 0; i < sourceShape.size(); i++) {
2530  int64_t dimShape = sourceShape[i];
2531  AffineExpr dimExpr = sourceMap.getResult(i);
2532  if (!affineExprToSize.contains(dimExpr) || !sourceType.isDynamicDim(i)) {
2533  newShape.push_back(dimShape);
2534  continue;
2535  }
2536  // Dimension has a dynamic shape and corresponding affine dim
2537  // expression is present in the map. So assign the size for the
2538  // given affine dim expression to the dimension.
2539  newShape.push_back(affineExprToSize[dimExpr]);
2540  newOperandNeeded = true;
2541  }
2542  resultType = RankedTensorType::get(newShape, sourceType.getElementType(),
2543  sourceType.getEncoding());
2544  if (newOperandNeeded) {
2545  changeNeeded = true;
2546  // Get the new operand value given its size and element type by
2547  // casting it.
2548  Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2549  unsigned index = opOperand->getOperandNumber();
2550  newOperands[index] = newOperand;
2551  }
2552  if (linalgOp.isDpsInit(opOperand))
2553  resultTypes.push_back(resultType);
2554 }
2555 
2556 /// Static shapes for the operands can be inferred if any one of the operands
2557 /// have a static shape. This can be done by referring to the affine dim
2558 /// expressions for the operand.
2559 struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2561 
2562  LogicalResult matchAndRewrite(LinalgOp linalgOp,
2563  PatternRewriter &rewriter) const override {
2564  if (!linalgOp.hasPureTensorSemantics())
2565  return failure();
2566 
2567  // Maps must be projected permutations.
2568  if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2569  return !map.isProjectedPermutation();
2570  }))
2571  return failure();
2572 
2573  // Maps affine dim expressions to the static size of that dimension.
2574  llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2575  Location loc = linalgOp.getLoc();
2576 
2577  // For each of the affine dim expression, check if the size is known. If
2578  // known add that in the map.
2579  populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2580 
2581  SmallVector<Value> newOperands;
2582  SmallVector<Type> resultTypes;
2583 
2584  // `changeNeeded` is `false` if the operands of `linalgOp` require no
2585  // change in their types.
2586  bool changeNeeded = false;
2587  newOperands.reserve(linalgOp->getNumOperands());
2588  resultTypes.reserve(linalgOp.getNumDpsInits());
2589 
2590  // Iterate over all the operands and update the static sizes.
2591  for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2592  createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2593  affineExprToSize, linalgOp, newOperands,
2594  resultTypes, changeNeeded);
2595  }
2596 
2597  // If the generic op has all the required static information, no
2598  // canonicalization needed.
2599  if (!changeNeeded)
2600  return failure();
2601 
2602  // Clone op.
2603  Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2604  SmallVector<Value> replacements;
2605  replacements.reserve(newOp->getNumResults());
2606  for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2607  Value newResult = std::get<1>(it);
2608  Value oldResult = std::get<0>(it);
2609  Type newType = newResult.getType();
2610  Type oldType = oldResult.getType();
2611  replacements.push_back(
2612  (newType != oldType)
2613  ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2614  : newResult);
2615  }
2616  rewriter.replaceOp(linalgOp, replacements);
2617  return success();
2618  }
2619 };
2620 
2621 } // namespace
2622 
2623 // All named ops canonicalizers and folders are auto-generated in the
2624 // .cpp.inc.
2625 
2626 //===----------------------------------------------------------------------===//
2627 // SoftmaxOp
2628 //===----------------------------------------------------------------------===//
2629 
2630 LogicalResult SoftmaxOp::verify() {
2631  ShapedType inputType = getInputOperandType();
2632  ShapedType outputType = getOutputOperandType();
2633 
2634  ArrayRef<int64_t> inputShape = inputType.getShape();
2635  ArrayRef<int64_t> outputShape = outputType.getShape();
2636  if (failed(verifyCompatibleShape(inputShape, outputShape)))
2637  return emitOpError("incompatible output shape");
2638 
2639  int64_t inputRank = getInputOperandRank();
2640  int64_t dimension = getDimension();
2641  if ((dimension < 0) || (dimension >= inputRank))
2642  return emitOpError("incorrect dimension specified");
2643 
2644  return success();
2645 }
2646 
2647 SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2648  int64_t operandRank = getInputOperandRank();
2649  SmallVector<Range> loopBounds(operandRank);
2650  Location loc = getLoc();
2651  Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2652  Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2653  Value source = getInput();
2654  for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2655  loopBounds[dim].offset = zero;
2656  loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2657  loopBounds[dim].stride = one;
2658  }
2659  return loopBounds;
2660 }
2661 
2662 SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2663  SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2664  utils::IteratorType::parallel);
2665  iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2666  return iteratorTypes;
2667 }
2668 
2669 FailureOr<TilingResult>
2671  ArrayRef<OpFoldResult> offsets,
2672  ArrayRef<OpFoldResult> sizes) {
2673  int64_t rank = getInputOperandRank();
2674  auto oneAttr = builder.getI64IntegerAttr(1);
2675  SmallVector<OpFoldResult> strides(rank, oneAttr);
2676  SmallVector<Value> tiledOperands;
2677  Operation *inputSlice =
2678  getSlice(builder, getLoc(), getInput(), offsets, sizes, strides);
2679  if (!inputSlice) {
2680  return emitOpError("failed to compute input slice");
2681  }
2682  tiledOperands.emplace_back(inputSlice->getResult(0));
2683  Operation *outputSlice =
2684  getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides);
2685  if (!outputSlice) {
2686  return emitOpError("failed to compute output slice");
2687  }
2688  tiledOperands.emplace_back(outputSlice->getResult(0));
2689 
2690  SmallVector<Type, 4> resultTypes;
2691  if (hasPureTensorSemantics())
2692  resultTypes.push_back(tiledOperands[1].getType());
2693  Operation *tiledOp =
2694  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2695 
2696  return TilingResult{
2697  {tiledOp},
2698  SmallVector<Value>(tiledOp->getResults()),
2699  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
2700 }
2701 
2702 LogicalResult SoftmaxOp::getResultTilePosition(
2703  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2704  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2705  SmallVector<OpFoldResult> &resultSizes) {
2706  if (resultNumber == 0) {
2707  resultOffsets.assign(offsets.begin(), offsets.end());
2708  resultSizes.assign(sizes.begin(), sizes.end());
2709  return success();
2710  }
2711  return failure();
2712 }
2713 
2714 // cast(dynamic) -> static.
2715 LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2716  return memref::foldMemRefCast(*this);
2717 }
2718 
2719 LogicalResult
2721  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2723  Location loc = getOperation()->getLoc();
2724  IRRewriter rewriter(b);
2725  auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2726  auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2727  for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2728  if (!outputShapedType.isDynamicDim(dim)) {
2729  // Static dim: Return IntegerAttr.
2730  shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2731  } else {
2732  // Dynamic dim: Return Value.
2733  OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2734  shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2735  }
2736  }
2737  reifiedReturnShapes.emplace_back(std::move(shapes));
2738  return success();
2739 }
2740 
2741 void SoftmaxOp::getEffects(
2743  &effects) {
2744  for (auto [index, operand] : llvm::enumerate(getDpsInputs())) {
2745  if (!llvm::isa<MemRefType>(operand.getType()))
2746  continue;
2747  effects.emplace_back(MemoryEffects::Read::get(),
2748  &getOperation()->getOpOperand(index), /*stage=*/0,
2749  /*effectOnFullRegion=*/true,
2751  }
2752 
2753  for (OpOperand &operand : getDpsInitsMutable()) {
2754  if (!llvm::isa<MemRefType>(operand.get().getType()))
2755  continue;
2756  effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
2757  /*effectOnFullRegion=*/true,
2759  effects.emplace_back(MemoryEffects::Write::get(), &operand, /*stage=*/0,
2760  /*effectOnFullRegion=*/true,
2762  }
2763 }
2764 
2765 // Helper functions for softmax decomposition.
2766 // @{
2767 
2768 // Helper function to produce the iterator types (reduction or parallel) and
2769 // affine maps for the iterators used in the decomposition of softmax.
2770 // This method creates:
2771 // If allParallel == true:
2772 // - iterator type: {parallel, ..., parallel}
2773 // - affine maps:
2774 // -- identity with inputRank dimensions.
2775 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2776 // where N == inputRank.
2777 //
2778 // If allParallel == false:
2779 // - iterator type at dim(i) == parallel for i != \p dim and
2780 // dim(dim) == reduction.
2781 // - affine map:
2782 // -- identity with inputRank dimensions.
2783 // -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2784 // where N == inputRank.
2785 static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2787  int64_t dim, bool allParallel = false) {
2788  SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2789  utils::IteratorType::parallel);
2790  if (!allParallel)
2791  iteratorTypes[dim] = utils::IteratorType::reduction;
2792  MLIRContext *ctxt = builder.getContext();
2793  auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2794  SmallVector<AffineExpr, 2> affineExprs;
2795  for (int i = 0; i < inputRank; i++) {
2796  if (i != dim)
2797  affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2798  }
2799  auto reductionMap =
2800  AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2801  SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2802  return std::make_tuple(iteratorTypes, indexingMaps);
2803 }
2804 
2805 // Helper function to produce a linalg.generic that computes a reduction on
2806 // dimension \p dim with the operation type \p T.
2807 template <typename T>
2808 static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2809  int64_t dim) {
2810  auto inputType = cast<ShapedType>(input.getType());
2811  ArrayRef<int64_t> inputShape = inputType.getShape();
2812  int64_t inputRank = inputShape.size();
2813  auto [iteratorTypes, indexingMaps] =
2814  computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2815  assert(indexingMaps.size() == 2 &&
2816  "We should have two maps: 1 for the input, 1 for the output");
2817  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2818 
2819  auto genericOp = builder.create<linalg::GenericOp>(
2820  loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2821  [&](OpBuilder &b, Location loc, ValueRange args) {
2822  Value result = b.create<T>(loc, args[0], args[1]);
2823  b.create<linalg::YieldOp>(loc, result);
2824  });
2825  return genericOp.getResult(0);
2826 }
2827 
2828 /// Produce a linalg generic that computes the second step of the softmax
2829 /// decomposition: res = exp(input - max), where \p max is the max of \p input
2830 /// on dimension \p dim.
2831 static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2832  Value max, Value output, int64_t dim) {
2833  auto inputType = cast<ShapedType>(input.getType());
2834  ArrayRef<int64_t> inputShape = inputType.getShape();
2835  int64_t inputRank = inputShape.size();
2836  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2837  builder, inputRank, dim, /*allParallel=*/true);
2838  assert(indexingMaps.size() == 2 && "We should have one map for each input");
2839  assert(indexingMaps[0].isIdentity() && "input map should be identity");
2840  // Add the affine map for the output argument.
2841  indexingMaps.push_back(indexingMaps[0]);
2842  auto genericOp = builder.create<linalg::GenericOp>(
2843  loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2844  iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2845  Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2846  Value result = b.create<math::ExpOp>(loc, diff);
2847  b.create<linalg::YieldOp>(loc, result);
2848  });
2849  return genericOp.getResult(0);
2850 }
2851 
2852 /// Produce a linalg generic that computes the final step of the softmax
2853 /// decomposition.
2854 /// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2855 /// yield n / d
2856 /// }
2857 static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2858  Value denominator, Value output, int64_t dim) {
2859  auto inputType = cast<ShapedType>(numerator.getType());
2860  ArrayRef<int64_t> inputShape = inputType.getShape();
2861  int64_t inputRank = inputShape.size();
2862  auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2863  builder, inputRank, dim, /*allParallel=*/true);
2864  assert(indexingMaps.size() == 2 &&
2865  "We should have one map for each input (2)");
2866  assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2867  // Add the affine map for the output tensor.
2868  indexingMaps.push_back(indexingMaps[0]);
2869  auto genericOp = builder.create<linalg::GenericOp>(
2870  loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2871  indexingMaps, iteratorTypes,
2872  [&](OpBuilder &b, Location loc, ValueRange args) {
2873  Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2874  b.create<linalg::YieldOp>(loc, result);
2875  });
2876  return genericOp.getResult(0);
2877 }
2878 // @} End helper functions for softmax decomposition.
2879 
2880 /// Given an N-dimensional tensor x, this method converts
2881 /// softmax(x) to the following sequence of operations:
2882 ///
2883 /// 1. Compute the max of x along dimension d. This results
2884 /// in a N-1 dimensional tensor m.
2885 /// m = max(x, dim = d)
2886 ///
2887 /// 2. Subtract a broadcasted m from x and exponentiate. This results in
2888 /// a N dimensional tensor z.
2889 /// z = exp(x - m)
2890 ///
2891 /// 3. Compute the sum of z along dimension d. This results in
2892 /// a N-1 dimensional tensor l.
2893 /// l = sum(z, dim = d)
2894 ///
2895 /// 4. Divide z and l. This gives the N-dimensional softmax.
2896 /// softmax = z / l
2897 ///
2898 FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2899  OpBuilder::InsertionGuard guard(b);
2900  b.setInsertionPoint(*this);
2901  Location loc = getLoc();
2902  Value input = getInput();
2903  ShapedType inputType = getInputOperandType();
2904  Type elementType = inputType.getElementType();
2905  int64_t reductionDim = getDimension();
2906  SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2907  Value output = getOutput();
2908  dims.erase(dims.begin() + reductionDim);
2909  // Step 1: Compute max along dim.
2910  Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2911  Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf,
2912  elementType, b, loc,
2913  /*useOnlyFiniteValue=*/true);
2914  Value neutralForMaxFInit =
2915  b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2916  .result();
2917  Value max =
2918  reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim);
2919 
2920  // Step 2: Subtract max from input and exponentiate.
2921  Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2922 
2923  // Step 3: Compute sum along dim.
2924  Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2925  b, loc, /*useOnlyFiniteValue=*/true);
2926  Value zeroInit =
2927  b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2928  Value denominator =
2929  reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2930 
2931  // Step 4: Compute softmax.
2932  Value result =
2933  buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2934  return SmallVector<Value>{result};
2935 }
2936 
2937 //===----------------------------------------------------------------------===//
2938 // WinogradFilterTransformOp
2939 //===----------------------------------------------------------------------===//
2940 
2941 LogicalResult WinogradFilterTransformOp::verify() {
2942  auto filterType = cast<ShapedType>(getFilter().getType());
2943  ArrayRef<int64_t> filterShape = filterType.getShape();
2944  int64_t filterH = filterShape[getFilterHDim()];
2945  int64_t filterW = filterShape[getFilterWDim()];
2946  int64_t r = getR();
2947  int64_t m = getM();
2948 
2949  if (filterH != r && filterH != 1)
2950  return emitOpError("expect filter height either equals to r or 1");
2951  if (filterW != r && filterW != 1)
2952  return emitOpError("expect filter width either equals to r or 1");
2953  if (filterH == 1 && filterW == 1)
2954  return emitOpError("expect either filter height or width equals to r");
2955 
2956  SmallVector<int64_t> expectedOutputShape;
2957  expectedOutputShape.push_back(filterH == r ? m + r - 1 : 1);
2958  expectedOutputShape.push_back(filterW == r ? m + r - 1 : 1);
2959  expectedOutputShape.push_back(filterShape[getFilterCDim()]);
2960  expectedOutputShape.push_back(filterShape[getFilterFDim()]);
2961 
2962  auto outputType = cast<ShapedType>(getOutput().getType());
2963  ArrayRef<int64_t> outputShape = outputType.getShape();
2964  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
2965  return emitOpError("the output shape is not expected");
2966  }
2967  return success();
2968 }
2969 
2971 WinogradFilterTransformOp::getIterationDomain(OpBuilder &builder) {
2972  Location loc = getLoc();
2973  IntegerAttr zeroAttr = builder.getIndexAttr(0);
2974  IntegerAttr oneAttr = builder.getIndexAttr(1);
2975  Value filter = getFilter();
2976  int64_t filterRank = getFilterOperandRank();
2977  SmallVector<Range> loopBounds(filterRank);
2978  for (unsigned dim = 0; dim < filterRank; ++dim) {
2979  loopBounds[dim].offset = zeroAttr;
2980  loopBounds[dim].size = getDimValue(builder, loc, filter, dim);
2981  loopBounds[dim].stride = oneAttr;
2982  }
2983  return loopBounds;
2984 }
2985 
2987 WinogradFilterTransformOp::getLoopIteratorTypes() {
2988  int64_t filterRank = getFilterOperandRank();
2989  SmallVector<utils::IteratorType> iteratorTypes(filterRank,
2990  utils::IteratorType::parallel);
2991  return iteratorTypes;
2992 }
2993 
2995  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2996  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2997  SmallVector<OpFoldResult> &resultSizes) {
2998  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
2999  ShapedType filterType = getFilterOperandType();
3000  ArrayRef<int64_t> filterShape = filterType.getShape();
3001  int64_t filterH = filterShape[getFilterHDim()];
3002  int64_t filterW = filterShape[getFilterWDim()];
3003  int64_t m = getM();
3004  int64_t r = getR();
3005  int64_t alpha = m + r - 1;
3006  int64_t alphaH = filterH != 1 ? alpha : 1;
3007  int64_t alphaW = filterW != 1 ? alpha : 1;
3008  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3009  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3010 
3011  resultOffsets.append(
3012  {zeroAttr, zeroAttr, offsets[getFilterCDim()], offsets[getFilterFDim()]});
3013  resultSizes.append(
3014  {alphaHAttr, alphaWAttr, sizes[getFilterCDim()], sizes[getFilterFDim()]});
3015 
3016  return success();
3017 }
3018 
3019 /// Implement tiling for winograd_filter_transform
3020 /// The input of winograd_filter_transform is (F, KH, KW, C).
3021 /// The output of winograd_filter_transform is (alphaH, alphaW, C, F)
3022 /// Users can specify the tile sizes of F and C.
3023 /// `offsets` are the values for the offsets of F, KH, KW, C for one tile.
3024 /// `sizes` are the values for the sizes of F, KH, KW, C for one tile.
3026  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3027  ArrayRef<OpFoldResult> sizes) {
3028  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3029  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3030  ShapedType filterType = getFilterOperandType();
3031  ArrayRef<int64_t> filterShape = filterType.getShape();
3032  int64_t filterH = filterShape[getFilterHDim()];
3033  int64_t filterW = filterShape[getFilterWDim()];
3034  IntegerAttr filterHAttr = builder.getI64IntegerAttr(filterH);
3035  IntegerAttr filterWAttr = builder.getI64IntegerAttr(filterW);
3036  SmallVector<Value> tiledOperands;
3037  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3038 
3039  sliceOffsets.append(
3040  {offsets[getFilterFDim()], zeroAttr, zeroAttr, offsets[getFilterCDim()]});
3041  sliceSizes.append({sizes[getFilterFDim()], filterHAttr, filterWAttr,
3042  sizes[getFilterCDim()]});
3043  int64_t filterRank = getFilterOperandRank();
3044  SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr);
3045  Location loc = getLoc();
3046  auto filterSlice = builder.create<tensor::ExtractSliceOp>(
3047  loc, getFilter(), sliceOffsets, sliceSizes, filterStrides);
3048  tiledOperands.emplace_back(filterSlice);
3049 
3050  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3051  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3052  resultSizes)))
3053  return failure();
3054 
3055  int64_t outputRank = getOutputOperandRank();
3056  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3057  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3058  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3059  tiledOperands.emplace_back(outputSlice);
3060 
3061  SmallVector<Type> resultTypes;
3062  resultTypes.push_back(tiledOperands[1].getType());
3063  Operation *tiledOp =
3064  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3065 
3066  return TilingResult{
3067  {tiledOp},
3068  SmallVector<Value>(tiledOp->getResults()),
3069  llvm::to_vector(ArrayRef<Operation *>{filterSlice, outputSlice})};
3070 }
3071 
3072 //===----------------------------------------------------------------------===//
3073 // WinogradInputTransformOp
3074 //===----------------------------------------------------------------------===//
3075 
3076 LogicalResult WinogradInputTransformOp::verify() {
3077  auto inputType = cast<ShapedType>(getInput().getType());
3078  ArrayRef<int64_t> inputShape = inputType.getShape();
3079  int64_t inputH = inputShape[getInputHDim()];
3080  int64_t inputW = inputShape[getInputWDim()];
3081  int m = getM();
3082  int r = getR();
3083  int64_t tileSize = m + r - 1;
3084 
3085  auto outputType = cast<ShapedType>(getOutput().getType());
3086  ArrayRef<int64_t> outputShape = outputType.getShape();
3087  bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
3088  bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
3089 
3090  SmallVector<int64_t> expectedOutputShape(6, inputH);
3091  if (ShapedType::isDynamic(inputH)) {
3092  expectedOutputShape[getOutputAlphaHDim()] = tileSize;
3093  expectedOutputShape[getOutputTileHDim()] = ShapedType::kDynamic;
3094  } else {
3095  expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
3096  expectedOutputShape[getOutputTileHDim()] =
3097  leftTransform ? (inputH - (r - 1)) / m : inputH;
3098  }
3099  if (ShapedType::isDynamic(inputW)) {
3100  expectedOutputShape[getOutputAlphaWDim()] = tileSize;
3101  expectedOutputShape[getOutputTileWDim()] = ShapedType::kDynamic;
3102  } else {
3103  expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
3104  expectedOutputShape[getOutputTileWDim()] =
3105  rightTransform ? (inputW - (r - 1)) / m : inputW;
3106  }
3107  expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
3108  expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
3109 
3110  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3111  return emitOpError("the output shape is not expected");
3112  }
3113  return success();
3114 }
3115 
3117 WinogradInputTransformOp::getIterationDomain(OpBuilder &builder) {
3118  Location loc = getLoc();
3119  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3120  IntegerAttr oneAttr = builder.getIndexAttr(1);
3121  Value output = getOutput();
3122  int64_t outputRank = getOutputOperandRank();
3123  SmallVector<Range> loopBounds(outputRank);
3124  for (unsigned dim = 0; dim < outputRank; ++dim) {
3125  loopBounds[dim].offset = zeroAttr;
3126  // alphaH, alphaW, tileH, tileW, N, C
3127  loopBounds[dim].size = getDimValue(builder, loc, output, dim);
3128  loopBounds[dim].stride = oneAttr;
3129  }
3130  return loopBounds;
3131 }
3132 
3134 WinogradInputTransformOp::getLoopIteratorTypes() {
3135  int64_t outputRank = getOutputOperandRank();
3136  SmallVector<utils::IteratorType> iteratorTypes(outputRank,
3137  utils::IteratorType::parallel);
3138  return iteratorTypes;
3139 }
3140 
3142  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3143  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3144  SmallVector<OpFoldResult> &resultSizes) {
3145  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3146  ShapedType outputType = getOutputOperandType();
3147  ArrayRef<int64_t> outputShape = outputType.getShape();
3148  int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
3149  int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
3150 
3151  int64_t m = getM();
3152  int64_t r = getR();
3153  int64_t alpha = m + r - 1;
3154  int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
3155  int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
3156 
3157  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3158  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3159 
3160  resultOffsets.append({zeroAttr, zeroAttr, offsets[getOutputTileHDim()],
3161  offsets[getOutputTileWDim()], offsets[getOutputNDim()],
3162  offsets[getOutputCDim()]});
3163  resultSizes.append({alphaHAttr, alphaWAttr, sizes[getOutputTileHDim()],
3164  sizes[getOutputTileWDim()], sizes[getOutputNDim()],
3165  sizes[getOutputCDim()]});
3166 
3167  return success();
3168 }
3169 
3170 /// Implement tiling for winograd_input_transform
3171 /// The input of winograd_input_transform is (N, H, W, C).
3172 /// The output of winograd_input_transform is (alphaH, alphaW, tileH, tileW, N,
3173 /// C) Users can specify the tile sizes of tileH, tileW, N, and C. `offsets` are
3174 /// the values for the offsets of tileH, tileW, N, C for one tile. `sizes` are
3175 /// the values for the sizes of tileH, tileW, N, C for one tile.
3176 FailureOr<TilingResult>
3178  ArrayRef<OpFoldResult> offsets,
3179  ArrayRef<OpFoldResult> sizes) {
3180  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3181  int64_t m = getM();
3182  int64_t r = getR();
3183 
3184  ShapedType outputType = getOutputOperandType();
3185  ArrayRef<int64_t> outputShape = outputType.getShape();
3186  int64_t alphaH = outputShape[getOutputAlphaHDim()];
3187  int64_t alphaW = outputShape[getOutputAlphaWDim()];
3188 
3189  Location loc = getLoc();
3190  MLIRContext *context = builder.getContext();
3191  auto identityAffineMap =
3192  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3193  auto offsetAffineMap =
3194  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3195  Value mappedOffsetH = affine::makeComposedAffineApply(
3196  builder, loc, (alphaH != 1 ? offsetAffineMap : identityAffineMap),
3197  offsets[getOutputTileHDim()]);
3198  Value mappedOffsetW = affine::makeComposedAffineApply(
3199  builder, loc, (alphaW != 1 ? offsetAffineMap : identityAffineMap),
3200  offsets[getOutputTileWDim()]);
3201  auto sizeAffineMap = AffineMap::get(
3202  1, 0, {builder.getAffineDimExpr(0) * m + (r - 1)}, context);
3203  Value mappedSizeH = affine::makeComposedAffineApply(
3204  builder, loc, sizeAffineMap, sizes[getOutputTileHDim()]);
3205  Value mappedSizeW = affine::makeComposedAffineApply(
3206  builder, loc, sizeAffineMap, sizes[getOutputTileWDim()]);
3207 
3208  SmallVector<Value> tiledOperands;
3209  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3210 
3211  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3212  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3213  sliceOffsets.append(
3214  {offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
3215  OpFoldResult sizeH =
3216  alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3217  OpFoldResult sizeW =
3218  alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3219  sliceSizes.append(
3220  {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
3221  int64_t inputRank = getInputOperandRank();
3222  SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr);
3223  auto inputSlice = builder.create<tensor::ExtractSliceOp>(
3224  loc, getInput(), sliceOffsets, sliceSizes, inputStrides);
3225  tiledOperands.emplace_back(inputSlice);
3226 
3227  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3228  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3229  resultSizes)))
3230  return failure();
3231 
3232  int64_t outputRank = getOutputOperandRank();
3233  SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr);
3234  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3235  loc, getOutput(), resultOffsets, resultSizes, outputStrides);
3236  tiledOperands.emplace_back(outputSlice);
3237 
3238  SmallVector<Type> resultTypes;
3239  resultTypes.push_back(tiledOperands[1].getType());
3240  Operation *tiledOp =
3241  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3242 
3243  return TilingResult{
3244  {tiledOp},
3245  SmallVector<Value>(tiledOp->getResults()),
3246  llvm::to_vector(ArrayRef<Operation *>{inputSlice, outputSlice})};
3247 }
3248 
3249 //===----------------------------------------------------------------------===//
3250 // WinogradOutputTransformOp
3251 //===----------------------------------------------------------------------===//
3252 
3253 LogicalResult WinogradOutputTransformOp::verify() {
3254  auto valueType = cast<ShapedType>(getValue().getType());
3255  ArrayRef<int64_t> valueShape = valueType.getShape();
3256  int64_t valueH = valueShape[getValueAlphaHDim()];
3257  int64_t valueW = valueShape[getValueAlphaWDim()];
3258  int64_t valueTileH = valueShape[getValueTileHDim()];
3259  int64_t valueTileW = valueShape[getValueTileWDim()];
3260  int m = getM();
3261  int r = getR();
3262  bool leftTransform = valueH != 1;
3263  bool rightTransform = valueW != 1;
3264 
3265  int64_t outputRank = getOutputOperandRank();
3266  SmallVector<int64_t> expectedOutputShape(outputRank, valueH);
3267  if (ShapedType::isDynamic(valueH) || ShapedType::isDynamic(valueTileH)) {
3268  expectedOutputShape[getOutputHDim()] = ShapedType::kDynamic;
3269  } else {
3270  if (valueH != (leftTransform ? m + r - 1 : 1))
3271  return emitOpError("expect input height equals to input tile size");
3272  expectedOutputShape[getOutputHDim()] = (leftTransform ? m : 1) * valueTileH;
3273  }
3274  if (ShapedType::isDynamic(valueW) || ShapedType::isDynamic(valueTileW)) {
3275  expectedOutputShape[getOutputWDim()] = ShapedType::kDynamic;
3276  } else {
3277  if (valueW != (rightTransform ? m + r - 1 : 1))
3278  return emitOpError("expect input width equals to input tile size");
3279  expectedOutputShape[getOutputWDim()] =
3280  (rightTransform ? m : 1) * valueTileW;
3281  }
3282  expectedOutputShape[getOutputNDim()] = valueShape[getValueNDim()];
3283  expectedOutputShape[getOutputFDim()] = valueShape[getValueFDim()];
3284 
3285  auto outputType = cast<ShapedType>(getOutput().getType());
3286  ArrayRef<int64_t> outputShape = outputType.getShape();
3287  if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
3288  return emitOpError("the output shape is not expected");
3289  }
3290  return success();
3291 }
3292 
3294 WinogradOutputTransformOp::getIterationDomain(OpBuilder &builder) {
3295  Location loc = getLoc();
3296  IntegerAttr zeroAttr = builder.getIndexAttr(0);
3297  IntegerAttr oneAttr = builder.getIndexAttr(1);
3298  Value value = getValue();
3299  int64_t valueRank = getValueOperandRank();
3300  SmallVector<Range> loopBounds(valueRank);
3301  for (unsigned dim = 0; dim < valueRank; ++dim) {
3302  loopBounds[dim].offset = zeroAttr;
3303  // alphaH, alphaW, tileH, tileW, N, F
3304  loopBounds[dim].size = getDimValue(builder, loc, value, dim);
3305  loopBounds[dim].stride = oneAttr;
3306  }
3307  return loopBounds;
3308 }
3309 
3311 WinogradOutputTransformOp::getLoopIteratorTypes() {
3312  int64_t valueRank = getValueOperandRank();
3313  SmallVector<utils::IteratorType> iteratorTypes(valueRank,
3314  utils::IteratorType::parallel);
3315  return iteratorTypes;
3316 }
3317 
3319  OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
3320  ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3321  SmallVector<OpFoldResult> &resultSizes) {
3322  int64_t m = getM();
3323 
3324  Location loc = getLoc();
3325  MLIRContext *context = builder.getContext();
3326  auto identityAffineMap =
3327  AffineMap::get(1, 0, {builder.getAffineDimExpr(0)}, context);
3328  auto affineMap =
3329  AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
3330 
3331  ShapedType valueType = getValueOperandType();
3332  ArrayRef<int64_t> valueShape = valueType.getShape();
3333  int64_t valueH = valueShape[0];
3334  int64_t valueW = valueShape[1];
3335  Value mappedOffsetH = affine::makeComposedAffineApply(
3336  builder, loc, (valueH != 1 ? affineMap : identityAffineMap),
3337  offsets[getValueTileHDim()]);
3338  Value mappedOffsetW = affine::makeComposedAffineApply(
3339  builder, loc, (valueW != 1 ? affineMap : identityAffineMap),
3340  offsets[getValueTileWDim()]);
3341  Value mappedSizeH = affine::makeComposedAffineApply(
3342  builder, loc, affineMap, sizes[getValueTileHDim()]);
3343  Value mappedSizeW = affine::makeComposedAffineApply(
3344  builder, loc, affineMap, sizes[getValueTileWDim()]);
3345 
3346  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3347  OpFoldResult offsetH = OpFoldResult(mappedOffsetH);
3348  OpFoldResult offsetW = OpFoldResult(mappedOffsetW);
3349  OpFoldResult sizeH =
3350  valueH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
3351  OpFoldResult sizeW =
3352  valueW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
3353 
3354  resultOffsets.append(
3355  {offsets[getValueNDim()], offsetH, offsetW, offsets[getValueFDim()]});
3356  resultSizes.append(
3357  {sizes[getValueNDim()], sizeH, sizeW, sizes[getValueFDim()]});
3358  return success();
3359 }
3360 
3361 /// Implement tiling for winograd_output_transform
3362 /// The input of winograd_output_transform is (alphaH, alphaW, tileH, tileW, N,
3363 /// F). The output of winograd_output_transform is (N, H, W, F) Users can
3364 /// specify the tile sizes of tileH, tileW, N, and F. `offsets` are the values
3365 /// for the offsets of tileH, tileW, N, F for one tile. `sizes` are the values
3366 /// for the sizes of tileH, tileW, N, F for one tile.
3368  OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3369  ArrayRef<OpFoldResult> sizes) {
3370  IntegerAttr oneAttr = builder.getI64IntegerAttr(1);
3371  IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
3372  Location loc = getLoc();
3373  SmallVector<Value> tiledOperands;
3374  SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3375 
3376  ShapedType valueType = getValueOperandType();
3377  ArrayRef<int64_t> valueShape = valueType.getShape();
3378  int64_t alphaH = valueShape[getValueAlphaHDim()];
3379  int64_t alphaW = valueShape[getValueAlphaWDim()];
3380  IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
3381  IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
3382 
3383  sliceOffsets.append({zeroAttr, zeroAttr, offsets[getValueTileHDim()],
3384  offsets[getValueTileWDim()], offsets[getValueNDim()],
3385  offsets[getValueFDim()]});
3386  sliceSizes.append({alphaHAttr, alphaWAttr, sizes[getValueTileHDim()],
3387  sizes[getValueTileWDim()], sizes[getValueNDim()],
3388  sizes[getValueFDim()]});
3389  int64_t valueRank = getValueOperandRank();
3390  SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr);
3391  auto valueSlice = builder.create<tensor::ExtractSliceOp>(
3392  loc, getValue(), sliceOffsets, sliceSizes, sliceStrides);
3393  tiledOperands.emplace_back(valueSlice);
3394 
3395  SmallVector<OpFoldResult> resultOffsets, resultSizes;
3396  if (failed(getResultTilePosition(builder, 1, offsets, sizes, resultOffsets,
3397  resultSizes)))
3398  return failure();
3399 
3400  int64_t outputRank = getOutputOperandRank();
3401  SmallVector<OpFoldResult> strides(outputRank, oneAttr);
3402  auto outputSlice = builder.create<tensor::ExtractSliceOp>(
3403  loc, getOutput(), resultOffsets, resultSizes, strides);
3404  tiledOperands.emplace_back(outputSlice);
3405 
3406  SmallVector<Type> resultTypes;
3407  resultTypes.push_back(tiledOperands[1].getType());
3408  Operation *tiledOp =
3409  mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
3410 
3411  return TilingResult{
3412  {tiledOp},
3413  SmallVector<Value>(tiledOp->getResults()),
3414  llvm::to_vector(ArrayRef<Operation *>{valueSlice, outputSlice})};
3415 }
3416 
3417 //===----------------------------------------------------------------------===//
3418 // LinalgDialect
3419 // TODO: Merge with the LinalgDialect block at the bottom
3420 //===----------------------------------------------------------------------===//
3421 
3422 // Returns true if the result expression of `subMap` are a subset of `fullMap`.
3423 static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap) {
3424  auto explicitRange = subMap.getResults();
3425  auto defaultRange = fullMap.getResults();
3426  DenseSet<AffineExpr> explicitSet(explicitRange.begin(), explicitRange.end());
3427  DenseSet<AffineExpr> defaultSet(defaultRange.begin(), defaultRange.end());
3428  llvm::set_union(explicitSet, defaultSet);
3429  return explicitSet == defaultSet;
3430 }
3431 
3432 /// Check if the user defined map is valid broadcast map. Here broadcast
3433 /// indexing maps are defined in context of corresponding default indexing maps
3434 /// for the given Op. This way the check becomes very simple i.e just check the
3435 /// number of result dims.
3436 /// Returns true if the explictMap is broadcasted with respect to the
3437 /// defaultMap.
3438 static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap) {
3439  return explictMap.getNumResults() < defaultMap.getNumResults();
3440 }
3441 
3442 /// Verifies the broadcast and transpose semantic sepecified by the explicit
3443 /// indexing map for the MatmulOp \p op for each operand specified by \p
3444 /// opIndex.
3445 static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp,
3446  unsigned opIndex) {
3447  SmallVector<AffineMap, 3> opIndexingMaps = matmulOp.getIndexingMapsArray();
3448  SmallVector<AffineMap, 3> defaultIndexingMaps =
3449  matmulOp.getDefaultIndexingMaps(matmulOp->getContext());
3450 
3451  auto opIndexingMap = opIndexingMaps[opIndex];
3452  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3453  // Check general validity of indexing map results.
3454  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3455  return matmulOp->emitOpError()
3456  << "Unexpected dim expression in map result.";
3457 
3458  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3459  if (!matmulOp.isValidLhsRhsBroadcastMap(opIndexingMap)) {
3460  return matmulOp->emitOpError()
3461  << "Invalid broadcast requested, should be (d2).";
3462  }
3463  return success();
3464  }
3465  return success();
3466 }
3467 
3468 // Check general validity of input indexing map.
3469 static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp,
3470  AffineMap opIndexingMap,
3471  AffineMap defaultIndexingMap, bool isLHS) {
3472  // Check the result dims are valid.
3473  if (!areResultExprsSubsetOf(opIndexingMap, defaultIndexingMap))
3474  return batchMatmulOp->emitOpError()
3475  << "Unexpected result dim expression (outside the set of default "
3476  "result dims).";
3477 
3478  // Check for valid number of result dims of input maps.
3479  if (opIndexingMap.getNumResults() > 3)
3480  return batchMatmulOp->emitOpError()
3481  << "no. of result dim expressions exceeds 3.";
3482 
3483  auto hasValidBatchDim = [](AffineMap map) {
3484  AffineExpr batchDim = map.getResult(0);
3485  return batchDim.isFunctionOfDim(0);
3486  };
3487 
3488  // Check if the requested broadcast is valid.
3489  if (isBroadcasted(opIndexingMap, defaultIndexingMap)) {
3490  if (!batchMatmulOp.isValidLhsRhsBroadcastMap(opIndexingMap, isLHS))
3491  return batchMatmulOp->emitOpError() << "Invalid broadcast requested.";
3492  } else if (!hasValidBatchDim(opIndexingMap)) {
3493  return batchMatmulOp->emitOpError()
3494  << "Invalid batch dimension expression.";
3495  }
3496  return success();
3497 }
3498 
3499 /// This function checks if the given AffineMap for the output of a
3500 /// BatchMatmulOp has exactly 3 result dimensions and if the output map result
3501 /// dimensions are valid.
3502 static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp,
3503  AffineMap opIndexingMap) {
3504  if (opIndexingMap.getNumResults() != 3)
3505  return batchMatmulOp->emitOpError()
3506  << "expects 3 dims, but got (" << opIndexingMap.getNumResults()
3507  << ").";
3508 
3509  auto areValidOutputResultDim = [](AffineMap outputMap) {
3510  return outputMap.getResult(0).isFunctionOfDim(0) &&
3511  outputMap.getResult(1).isFunctionOfDim(1) &&
3512  outputMap.getResult(2).isFunctionOfDim(2);
3513  };
3514 
3515  if (!areValidOutputResultDim(opIndexingMap))
3516  return batchMatmulOp->emitOpError()
3517  << "Invalid output map result dimension.";
3518 
3519  return success();
3520 }
3521 
3522 /// Verifies the broadcast and transpose semantic specified by the explicit
3523 /// indexing map for the BatchMatmulOp op for each operand specified by opIndex.
3524 static LogicalResult
3525 verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp,
3526  unsigned opIndex) {
3527  SmallVector<AffineMap, 3> opIndexingMaps =
3528  batchMatmulOp.getIndexingMapsArray();
3529  SmallVector<AffineMap, 3> defaultIndexingMaps =
3530  batchMatmulOp.getDefaultIndexingMaps(batchMatmulOp->getContext());
3531 
3532  if (opIndexingMaps.size() != 3)
3533  return batchMatmulOp->emitOpError()
3534  << "Indexing_map attribute must have 3 affine maps.";
3535 
3536  auto opIndexingMap = opIndexingMaps[opIndex];
3537  auto defaultIndexingMap = defaultIndexingMaps[opIndex];
3538 
3539  if (opIndex == 2 && failed(verifyOutputMap(batchMatmulOp, opIndexingMap)))
3540  return failure();
3541 
3542  if (failed(verifyInputMaps(batchMatmulOp, opIndexingMap, defaultIndexingMap,
3543  opIndex == 0)))
3544  return failure();
3545 
3546  return success();
3547 }
3548 
3549 namespace mlir {
3550 namespace linalg {
3551 
3552 //===----------------------------------------------------------------------===//
3553 // MatMulOp
3554 //===----------------------------------------------------------------------===//
3555 
3556 /// Returns a list of AffineMap with the typical matmul indexing charactristic.
3557 SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3558  AffineExpr d0, d1, d2;
3559  SmallVector<AffineMap> indexingMaps;
3560  bindDims(context, d0, d1, d2);
3561  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d2}, context));
3562  indexingMaps.push_back(AffineMap::get(3, 0, {d2, d1}, context));
3563  indexingMaps.push_back(AffineMap::get(3, 0, {d0, d1}, context));
3564  return indexingMaps;
3565 }
3566 
3567 SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
3568  return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
3569  utils::IteratorType::parallel,
3570  utils::IteratorType::reduction};
3571 }
3572 
3573 unsigned MatmulOp::getNumRegionArgs() { return 3; }
3574 
3575 std::string MatmulOp::getLibraryCallName() {
3576  return generateLibraryCallName(getOperation());
3577 }
3578 
3579 bool MatmulOp::hasDynamicIndexingMaps() { return true; }
3580 
3581 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3582 /// the user defined indexing maps are not equal to default map.
3583 bool MatmulOp::hasUserDefinedMaps() {
3584  SmallVector<AffineMap, 3> defaultMaps =
3585  getDefaultIndexingMaps(this->getContext());
3586  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3587  return defaultMaps != explicitMaps;
3588 }
3589 
3590 /// Implements the block region builder for the MatmulOp. This is called by
3591 /// 'fillStructuredOpRegion'.
3592 void MatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3593  ArrayRef<NamedAttribute> attrs) {
3594  assert(3 > 0 && block.getNumArguments() == 3 &&
3595  "MatmulOp regionBuilder expects 3 (>=0) args");
3596  RegionBuilderHelper helper(b, block);
3597  SmallVector<Value> yields;
3598 
3599  TypeFn castVal = TypeFn::cast_signed;
3600  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3601  return attr.getName() == "cast";
3602  });
3603  if (castIter != attrs.end()) {
3604  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3605  castVal = attr.getValue();
3606  }
3607 
3608  Value value1 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3609  block.getArgument(0));
3610  Value value2 = helper.buildTypeFn(castVal, block.getArgument(2).getType(),
3611  block.getArgument(1));
3612  Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
3613  Value value4 =
3614  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
3615  yields.push_back(value4);
3616  helper.yieldOutputs(yields);
3617 }
3618 
3619 /// Returns true if the given broadcast map \p bcastMap is valid for this op.
3620 bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) {
3621  assert(bcastMap.getNumResults() == 1 && "Expected single result dim expr.");
3622  AffineExpr exp = bcastMap.getResult(0);
3623  // Invalid map if the common dimension of matmul not found.
3624  return exp.isFunctionOfDim(bcastMap.getNumDims() - 1);
3625 }
3626 
3627 FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) {
3628  if (parser.parseOptionalKeyword("indexing_maps"))
3629  return ArrayAttr{
3630  nullptr}; // Success in case indexing_maps was not provided.
3631 
3632  ArrayAttr arrayAttr;
3633  if (parser.parseEqual() || parser.parseAttribute(arrayAttr))
3634  return failure();
3635 
3636  if (llvm::any_of(arrayAttr,
3637  [](auto elt) { return !dyn_cast<AffineMapAttr>(elt); }))
3638  return parser.emitError(parser.getCurrentLocation())
3639  << "element of indexing_maps array is not an affine_map";
3640 
3641  return arrayAttr;
3642 }
3643 
3644 ParseResult MatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3645  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3646  if (failed(indexingMapsAttr))
3647  return failure();
3648 
3649  if (*indexingMapsAttr == nullptr) {
3650  auto indexingMapAttrs = llvm::map_to_vector(
3651  MatmulOp::getDefaultIndexingMaps(parser.getContext()),
3652  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3653  indexingMapsAttr = parser.getBuilder().getArrayAttr(indexingMapAttrs);
3654  }
3655 
3656  result.addAttribute("indexing_maps", *indexingMapsAttr);
3657  return parseNamedStructuredOp(parser, result, MatmulOp::getNumRegionArgs(),
3658  MatmulOp::getRegionBuilder());
3659 }
3660 
3661 void MatmulOp::print(OpAsmPrinter &p) {
3662  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
3663  MatmulOp::getDefaultIndexingMaps(getContext()),
3664  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
3665  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
3666  p << " indexing_maps = [";
3667  llvm::interleaveComma(getIndexingMaps(), p,
3668  [&](Attribute attr) { p.printAttribute(attr); });
3669  p << "]";
3670  }
3671 
3672  SmallVector<StringRef, 3> elidedAttrs = {
3673  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
3674  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
3675  elidedAttrs);
3676 }
3677 
3678 /// Verify the user defined indexing maps.
3679 LogicalResult MatmulOp::verify() {
3680  // Verification of pure matmul is handled by verifyStructuredOpInterface().
3681  if (!hasUserDefinedMaps())
3682  return success();
3683 
3684  for (unsigned opIndex = 0; opIndex < 2; opIndex++) {
3685  if (failed(verifyExtendedMatmulSemantic(*this, opIndex)))
3686  return failure();
3687  }
3688  return success();
3689 }
3690 
3691 LogicalResult MatmulOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3692  return memref::foldMemRefCast(*this);
3693 }
3694 
3695 void MatmulOp::getEffects(
3697  &effects) {
3698  if (hasPureTensorSemantics())
3699  return;
3700  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3701 }
3702 
3703 Speculation::Speculatability MatmulOp::getSpeculatability() {
3704  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3705 }
3706 
3707 //===----------------------------------------------------------------------===//
3708 // ContractOp
3709 //===----------------------------------------------------------------------===//
3710 
3711 SmallVector<utils::IteratorType> ContractOp::getIteratorTypesArray() {
3712  AffineMap outAffineMap = getIndexingMapsArray().pop_back_val();
3713  // On well-formed IR, indexing_maps is non-empty, contained affine_maps'
3714  // domains are all the same, and each implements a projected permutation.
3715  // Each iteration space dim must occur for at least one operand and either
3716  // takes part in a contraction/reduction or else has parallel iteration type.
3717  // We have that a dim is a contraction/reduction dim if and only if the dim
3718  // occurs for the output operand. We use this fact for fast inference:
3719  // NB: In case we allow dims to occur solely for one input, the above still
3720  // holds: per the einsum semantics, these are reduction dims as well.
3721  SmallVector<bool> dimsInOutput(outAffineMap.getNumDims(), false);
3722  for (auto result : outAffineMap.getResults()) {
3723  auto dimExpr = dyn_cast<AffineDimExpr>(result);
3724  assert(dimExpr && "affine_map is a projected permutation");
3725  dimsInOutput[dimExpr.getPosition()] = true;
3726  }
3727 
3728  SmallVector<utils::IteratorType> iteratorTypes;
3729  for (auto dimOccursInOutput : dimsInOutput)
3730  iteratorTypes.push_back(dimOccursInOutput ? utils::IteratorType::parallel
3731  : utils::IteratorType::reduction);
3732 
3733  return iteratorTypes;
3734 }
3735 
3736 unsigned ContractOp::getNumRegionArgs() { return 3; }
3737 
3738 /// Implement block region builder, which is called by 'fillStructuredOpRegion'.
3739 void ContractOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3740  ArrayRef<NamedAttribute> attrs) {
3741  assert(block.getNumArguments() == 3 &&
3742  "ContractOp regionBuilder expects 3 args");
3743  RegionBuilderHelper helper(b, block);
3744 
3745  TypeFn castSignedness = TypeFn::cast_signed;
3746  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3747  return attr.getName() == "cast";
3748  });
3749  if (castIter != attrs.end()) {
3750  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3751  castSignedness = attr.getValue();
3752  }
3753 
3754  // TODO: Support fields with operators besides mult & add.
3755  Type outType = block.getArgument(2).getType();
3756  Value lhsAtOutType =
3757  helper.buildTypeFn(castSignedness, outType, block.getArgument(0));
3758  Value rhsAtOutType =
3759  helper.buildTypeFn(castSignedness, outType, block.getArgument(1));
3760  Value productAtOutType =
3761  helper.buildBinaryFn(BinaryFn::mul, lhsAtOutType, rhsAtOutType);
3762  Value result = helper.buildBinaryFn(BinaryFn::add, block.getArgument(2),
3763  productAtOutType);
3764  helper.yieldOutputs({result});
3765 }
3766 
3767 ParseResult ContractOp::parse(OpAsmParser &parser, OperationState &result) {
3768  FailureOr<ArrayAttr> indexingMapsAttr = parseIndexingMapsAttr(parser);
3769  if (failed(indexingMapsAttr) || *indexingMapsAttr == nullptr)
3770  return parser.emitError(parser.getCurrentLocation(),
3771  "expected 'indexing_maps' attribute");
3772  result.addAttribute("indexing_maps", *indexingMapsAttr);
3773 
3774  return parseNamedStructuredOp(parser, result, getNumRegionArgs(),
3775  regionBuilder);
3776 }
3777 
3779  p << " indexing_maps = [";
3780  llvm::interleaveComma(getIndexingMaps(), p,
3781  [&](Attribute attr) { p.printAttribute(attr); });
3782  p << "]";
3784  p, getOperation(), getInputs(), getOutputs(),
3785  /*elidedAttrs=*/{"indexing_maps", "operandSegmentSizes"});
3786 }
3787 
3788 LogicalResult ContractOp::verify() {
3789  int iterationSpaceDims = -1;
3790  // Map iter space dims to #occurrences in inputs' and output's affine_maps:
3791  // e.g., inOccurrences[0] will hold #times that dim (with index) 0 is used to
3792  // access an input operand (so occurrence count can be at most 2) and
3793  // outOccurrences[1] will indicate whether dim 1 occurred in the output, etc.
3794  SmallVector<size_t> inOccurrences;
3795  SmallVector<size_t> outOccurrences;
3796 
3797  // A helper so that for each operand's affine_map and type we check that ...
3798  auto checkAffineMapAndType = [&](AffineMap affineMap, Type operandType,
3799  bool isInput) -> LogicalResult {
3800  // ... the affine_map is a projected permutation;
3801  if (!affineMap.isProjectedPermutation())
3802  return emitError("provided affine_map is not a projected permutation");
3803 
3804  // ... the rank of the affine_map's results and corresponding type match;
3805  if (auto shapedType = dyn_cast<ShapedType>(operandType)) {
3806  if (affineMap.getNumResults() != shapedType.getRank())
3807  return emitError("ranks of shaped operand and results of corresponding "
3808  "affine_map differ");
3809  } else if (affineMap.getNumResults() != 0) {
3810  return emitError("affine_map specifies shaped access while operand has "
3811  "non-shaped type");
3812  }
3813 
3814  // ... the rank of the affine_map's domain is the same as those seen prior;
3815  if (iterationSpaceDims == -1) {
3816  iterationSpaceDims = affineMap.getNumDims();
3817  inOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3818  outOccurrences = SmallVector<size_t>(iterationSpaceDims, 0);
3819  } else if (iterationSpaceDims != (int)affineMap.getNumDims()) {
3820  return emitError("iteration spaces of provided affine_maps differ");
3821  }
3822 
3823  // ... update counts of dims used to access either an input or the output.
3824  for (AffineExpr affineExpr : affineMap.getResults()) {
3825  auto affineDimExpr = dyn_cast<AffineDimExpr>(affineExpr);
3826  if (!affineDimExpr)
3827  llvm_unreachable("affine_map is a projected permutation");
3828 
3829  if (isInput)
3830  inOccurrences[affineDimExpr.getPosition()] += 1;
3831  else
3832  outOccurrences[affineDimExpr.getPosition()] += 1;
3833  }
3834 
3835  return success();
3836  };
3837 
3838  for (auto &&[affineMap, operandType, isInput] :
3839  llvm::zip(getIndexingMapsArray(), getOperandTypes(),
3840  SmallVector<bool>{true, true, false})) {
3841  if (failed(checkAffineMapAndType(affineMap, operandType, isInput)))
3842  return failure(); // NB: checkAffineMapAndType will emit relevant error.
3843  }
3844 
3845  bool hasContractingDim = false;
3846  for (size_t dimIndex = 0; dimIndex < (size_t)iterationSpaceDims; dimIndex++) {
3847  size_t inOccCount = inOccurrences[dimIndex];
3848  size_t outOccCount = outOccurrences[dimIndex];
3849 
3850  // We have a contracting dim if and only if ...
3851  hasContractingDim |= inOccCount == 2 && outOccCount == 0;
3852 
3853  if (inOccCount == 0 && outOccCount == 0)
3854  return emitError() << "iteration space dim at index " << dimIndex
3855  << " not used to access any operand";
3856 
3857  // NB: We disallow a dim which occurs for only one input operand and not
3858  // for the output. In terms of einsum semantics such dims have a
3859  // sensible meaning - namely an additional reduction per each such dim.
3860  // By contrast, the ContractionOpInterface does not know about this
3861  // iter type - cf. inferContractionDims' supported dim kinds. Similarly,
3862  // while vector.contract's verifier accepts dims of this kind many of
3863  // its lowerings give up on encountering these dims.
3864  // TODO: Remove following once we have comprehensive support for input-only
3865  // reduction dims, at both the linalg- and vector-dialect levels.
3866  if (inOccCount == 1 && outOccCount != 1)
3867  return emitError()
3868  << "iteration space dim at index " << dimIndex
3869  << " is neither a contracting dim nor of parallel iteration type";
3870  }
3871 
3872  if (!hasContractingDim)
3873  return emitError("'indexing_maps' do not specify a contracting dimension");
3874 
3875  return success();
3876 }
3877 
3878 LogicalResult ContractOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
3879  return memref::foldMemRefCast(*this);
3880 }
3881 
3882 void ContractOp::getEffects(
3884  &effects) {
3885  if (hasPureTensorSemantics())
3886  return;
3887  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
3888 }
3889 
3890 Speculation::Speculatability ContractOp::getSpeculatability() {
3891  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
3892 }
3893 
3894 //===----------------------------------------------------------------------===//
3895 // Implementation of BatchMatmulOp
3896 //===----------------------------------------------------------------------===//
3898 BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
3899  AffineExpr d0, d1, d2, d3;
3900  SmallVector<AffineMap> indexingMaps;
3901  bindDims(context, d0, d1, d2, d3);
3902  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d3}, context));
3903  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d3, d2}, context));
3904  indexingMaps.push_back(AffineMap::get(4, 0, {d0, d1, d2}, context));
3905  return indexingMaps;
3906 }
3907 
3908 SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
3910  utils::IteratorType::parallel, utils::IteratorType::parallel,
3911  utils::IteratorType::parallel, utils::IteratorType::reduction};
3912 }
3913 
3914 unsigned BatchMatmulOp::getNumRegionArgs() { return 3; }
3915 
3916 std::string BatchMatmulOp::getLibraryCallName() {
3917  return generateLibraryCallName(getOperation());
3918 }
3919 
3920 /// Check if the op has broadcast and/or transpose semantic. Returns true if
3921 /// the user defined indexing maps are not equal to default map.
3922 bool BatchMatmulOp::hasUserDefinedMaps() {
3923  SmallVector<AffineMap, 3> defaultMaps =
3924  getDefaultIndexingMaps(this->getContext());
3925  SmallVector<AffineMap, 3> explicitMaps = getIndexingMapsArray();
3926  return defaultMaps != explicitMaps;
3927 }
3928 
3929 /// Returns true if the given broadcast map bcastMap is valid for this op.
3930 bool BatchMatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS) {
3931  assert(bcastMap.getNumResults() < 3 &&
3932  "Expected less than 3 result dim expr.");
3933  bool isValid = false;
3934  enum Indices { batchPos, mPos, nPos, kPos };
3935  if (bcastMap.getNumResults() == 1) {
3936  AffineExpr exp = bcastMap.getResult(0);
3937  isValid = exp.isFunctionOfDim(kPos);
3938  } else if (bcastMap.getNumResults() == 2) {
3939  AffineExpr exp0 = bcastMap.getResult(0);
3940  AffineExpr exp1 = bcastMap.getResult(1);
3941  isValid = isLHS
3942  ? (exp0.isFunctionOfDim(mPos) && exp1.isFunctionOfDim(kPos))
3943  : (exp0.isFunctionOfDim(kPos) && exp1.isFunctionOfDim(nPos));
3944  }
3945  return isValid;
3946 }
3947 
3948 void BatchMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
3949  ArrayRef<NamedAttribute> attrs) {
3950  assert(block.getNumArguments() == 3 &&
3951  "BatchMatmulOp regionBuilder expects 3 (>=0) args");
3952  RegionBuilderHelper helper(b, block);
3953  SmallVector<Value> yields;
3954 
3955  TypeFn castVal = TypeFn::cast_signed;
3956  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
3957  return attr.getName() == "cast";
3958  });
3959  if (castIter != attrs.end()) {
3960  if (auto attr = llvm::dyn_cast<TypeFnAttr>(castIter->getValue()))
3961  castVal = attr.getValue();
3962  }
3963 
3964  auto toType = block.getArgument(2).getType();
3965  Value castValA = helper.buildTypeFn(castVal, toType, block.getArgument(0));
3966  Value castValB = helper.buildTypeFn(castVal, toType, block.getArgument(1));
3967  Value mulVal = helper.buildBinaryFn(BinaryFn::mul, castValA, castValB);
3968  Value addVal =
3969  helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), mulVal);
3970  yields.push_back(addVal);
3971  helper.yieldOutputs(yields);
3972 }
3973 
3974 ParseResult BatchMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
3975  SmallVector<Attribute, 3> indexingMapsAttr;
3976  Attribute mapAttr;
3977  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
3978  if (parser.parseEqual())
3979  return failure();
3980 
3981  if (parser.parseLSquare())
3982  return failure();
3983 
3984  do {
3985  if (parser.parseAttribute(mapAttr))
3986  return failure();
3987  if (!isa<AffineMapAttr>(mapAttr)) {
3988  return parser.emitError(parser.getCurrentLocation(),
3989  "expected affine map attribute");
3990  }
3991  indexingMapsAttr.push_back(mapAttr);
3992 
3993  if (parser.parseOptionalComma())
3994  break;
3995  } while (true);
3996 
3997  if (parser.parseRSquare())
3998  return failure();
3999  }
4000  // Initialize indexingMaps, if not supplied explicitly.
4001  if (indexingMapsAttr.empty()) {
4002  indexingMapsAttr = llvm::map_to_vector(
4003  BatchMatmulOp::getDefaultIndexingMaps(parser.getContext()),
4004  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4005  }
4006  result.addAttribute("indexing_maps",
4007  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4008 
4009  return ::parseNamedStructuredOp(parser, result,
4010  BatchMatmulOp::getNumRegionArgs(),
4011  BatchMatmulOp::getRegionBuilder());
4012 }
4013 
4015  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
4016  BatchMatmulOp::getDefaultIndexingMaps(getContext()),
4017  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4018  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4019  p << " indexing_maps = [";
4020  llvm::interleaveComma(getIndexingMaps(), p,
4021  [&](Attribute attr) { p.printAttribute(attr); });
4022  p << "]";
4023  }
4024 
4025  SmallVector<StringRef, 3> elidedAttrs = {
4026  "operandSegmentSizes", "linalg.memoized_indexing_maps", "indexing_maps"};
4027  ::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4028  elidedAttrs);
4029 }
4030 
4031 /// Verify the user defined indexing maps.
4032 LogicalResult BatchMatmulOp::verify() {
4033  // Verification of pure batch_matmul is handled by
4034  // verifyStructuredOpInterface().
4035  if (!hasUserDefinedMaps())
4036  return success();
4037 
4038  for (unsigned opIndex = 0; opIndex < 3; opIndex++) {
4039  if (failed(verifyExtendedBatchMatmulSemantic(*this, opIndex)))
4040  return failure();
4041  }
4042  return success();
4043 }
4044 
4045 LogicalResult BatchMatmulOp::fold(FoldAdaptor,
4047  return memref::foldMemRefCast(*this);
4048 }
4049 
4050 void BatchMatmulOp::getEffects(
4052  &effects) {
4053  if (hasPureTensorSemantics())
4054  return;
4055  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4056 }
4057 
4058 Speculation::Speculatability BatchMatmulOp::getSpeculatability() {
4059  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4060 }
4061 
4062 //===----------------------------------------------------------------------===//
4063 // ElementwiseOp
4064 //===----------------------------------------------------------------------===//
4065 //
4066 namespace {
4067 struct ArityGroupAndKind {
4068  // The enum class {Unary, Binary, Ternary, ..}
4069  ElementwiseArityGroup arityGroup;
4070 
4071  // The kind (e.g. `exp` or `add`) belonging to the arity group.
4072  union Kind {
4073  UnaryFn unaryFn;
4074  BinaryFn binaryFn;
4075  TernaryFn ternaryFn;
4076  } kind;
4077 };
4078 
4079 unsigned getArityGroupAsUInt(ElementwiseArityGroup arityGroup) {
4080  return static_cast<unsigned>(arityGroup);
4081 }
4082 } // namespace
4083 
4084 static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind) {
4085  constexpr int lastUnary = static_cast<int>(ElementwiseCaseLimits::LastUnary);
4086  constexpr int lastBinary =
4087  static_cast<int>(ElementwiseCaseLimits::LastBinary);
4088  constexpr int lastTernary =
4089  static_cast<int>(ElementwiseCaseLimits::LastTernary);
4090 
4091  int val = static_cast<int>(kind);
4092  ArityGroupAndKind result;
4093 
4094  if (val < lastUnary) {
4095  result.arityGroup = ElementwiseArityGroup::Unary;
4096  result.kind.unaryFn = static_cast<UnaryFn>(val);
4097  return result;
4098  }
4099  if (val < lastBinary) {
4100  result.arityGroup = ElementwiseArityGroup::Binary;
4101  result.kind.binaryFn = static_cast<BinaryFn>(val - lastUnary);
4102  return result;
4103  }
4104  if (val >= lastTernary) {
4105  llvm_unreachable("unhandled ElementwiseFn");
4106  }
4107  result.arityGroup = ElementwiseArityGroup::Ternary;
4108  result.kind.ternaryFn = static_cast<TernaryFn>(val - lastBinary);
4109  return result;
4110 }
4111 
4112 SmallVector<utils::IteratorType> ElementwiseOp::getIteratorTypesArray() {
4113  auto rank = getResultRank();
4114  return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
4115 }
4116 
4118 ElementwiseOp::getDefaultIndexingMaps(unsigned numMaps, unsigned numDims,
4119  MLIRContext *context) {
4120  auto map = AffineMap::getMultiDimIdentityMap(numDims, context);
4121  return SmallVector<AffineMap>(numMaps, map);
4122 }
4123 
4124 ParseResult ElementwiseOp::parse(OpAsmParser &parser, OperationState &result) {
4125  // Expect e.g. `kind = #linalg.elemwise_kind<add>`
4126  Attribute attr;
4127  mlir::linalg::ElementwiseKind elemwiseKindVal;
4128  if (parser.parseKeyword("kind") || parser.parseEqual())
4129  return failure();
4130 
4131  if (succeeded(parser.parseAttribute(attr))) {
4132  auto elemwiseKindAttr = dyn_cast<ElementwiseKindAttr>(attr);
4133  if (!elemwiseKindAttr)
4134  return parser.emitError(parser.getCurrentLocation(),
4135  "expected ElementwiseKind attribute");
4136  elemwiseKindVal = elemwiseKindAttr.getValue();
4137  } else {
4138  return parser.emitError(parser.getCurrentLocation(),
4139  "expected operation 'kind' attribute");
4140  }
4141  result.addAttribute(
4142  "kind", ElementwiseKindAttr::get(parser.getContext(), elemwiseKindVal));
4143 
4144  // Parse optional `indexing_maps`
4145  SmallVector<Attribute, 3> indexingMapsAttr;
4146  Attribute mapAttr;
4147  if (succeeded(parser.parseOptionalKeyword("indexing_maps"))) {
4148  if (parser.parseEqual())
4149  return failure();
4150  if (parser.parseLSquare())
4151  return failure();
4152  do {
4153  if (parser.parseAttribute(mapAttr))
4154  return failure();
4155  if (!isa<AffineMapAttr>(mapAttr))
4156  return parser.emitError(parser.getCurrentLocation(),
4157  "expected affine map attribute");
4158  indexingMapsAttr.push_back(mapAttr);
4159  if (parser.parseOptionalComma())
4160  break;
4161  } while (true);
4162  if (parser.parseRSquare())
4163  return failure();
4164  }
4165  // At this stage of parsing the only way to infer number of region
4166  // args is through op kind, as input output tensors are not parsed yet.
4167  auto arityGroupAndKind = getArityGroupAndKind(elemwiseKindVal);
4168  int numRegionArgs =
4169  getArityGroupAsUInt(arityGroupAndKind.arityGroup) + 1 /*output*/;
4170  if (parseNamedStructuredOp(parser, result, numRegionArgs,
4171  ElementwiseOp::getRegionBuilder())) {
4172  return parser.emitError(parser.getCurrentLocation(),
4173  "unable to parse elemwise op");
4174  }
4175 
4176  // Initialize indexingMaps, if not supplied explicitly.
4177  if (indexingMapsAttr.empty()) {
4178  // We need to infer the numDims of the indexing maps from the output
4179  // type which is already parsed by now.
4180  auto resultType = result.operands[result.operands.size() - 1].getType();
4181  auto shapedType = llvm::dyn_cast<ShapedType>(resultType);
4182  if (!shapedType)
4183  return parser.emitError(parser.getCurrentLocation(),
4184  "return type needs to be shaped type");
4185  auto numDims = shapedType.getRank();
4186  indexingMapsAttr = llvm::map_to_vector(
4187  ElementwiseOp::getDefaultIndexingMaps(numRegionArgs, numDims,
4188  parser.getContext()),
4189  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4190  }
4191 
4192  result.addAttribute("indexing_maps",
4193  parser.getBuilder().getArrayAttr(indexingMapsAttr));
4194  return success();
4195 }
4196 
4198  p << " kind=";
4199  p.printAttribute(getKindAttr());
4200  SmallVector<StringRef, 3> elidedAttrs = {"operandSegmentSizes", "kind",
4201  "indexing_maps"};
4202  unsigned arity =
4203  getArityGroupAsUInt(getArityGroupAndKind(getKind()).arityGroup);
4204  unsigned numDims = getResultRank();
4205 
4206  SmallVector<Attribute, 3> indexingMaps = llvm::map_to_vector(
4207  ElementwiseOp::getDefaultIndexingMaps(arity + 1 /*output*/, numDims,
4208  getContext()),
4209  [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
4210 
4211  if (!llvm::equal(getIndexingMaps(), indexingMaps)) {
4212  p << " indexing_maps = [";
4213  llvm::interleaveComma(getIndexingMaps(), p,
4214  [&](Attribute attr) { p.printAttribute(attr); });
4215  p << "]";
4216  }
4217 
4218  printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs(),
4219  elidedAttrs);
4220 }
4221 
4222 LogicalResult ElementwiseOp::verify() {
4223  // All necessary checks are done either by
4224  // - EnumAttr (e.g. unknown operation kind)
4225  // - verifyStructuredOpInterface (incorrect map, sizes).
4226  return success();
4227 }
4228 
4229 /// Implements the block region builder for the ElementwiseOp. This is called by
4230 /// 'fillStructuredOpRegion'.
4231 void ElementwiseOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
4232  ArrayRef<NamedAttribute> attrs) {
4233  ElementwiseKind elemwiseKind;
4234  for (auto attr : attrs) {
4235  if (attr.getName() == b.getStringAttr("kind")) {
4236  auto kindAttr = dyn_cast<ElementwiseKindAttr>(attr.getValue());
4237  assert(kindAttr && "op kind attribute incorrectly set");
4238  elemwiseKind = kindAttr.getValue();
4239  break;
4240  }
4241  }
4242 
4243  ArityGroupAndKind groupAndKind = getArityGroupAndKind(elemwiseKind);
4244  auto arityGroup = groupAndKind.arityGroup;
4245  auto kind = groupAndKind.kind;
4246  assert(block.getNumArguments() ==
4247  getArityGroupAsUInt(arityGroup) + 1 /*output*/
4248  && "Elementwise regionBuilder number of block args mismatch");
4249 
4250  RegionBuilderHelper helper(b, block);
4251  SmallVector<Value> yields;
4252  Value result;
4253 
4254  if (arityGroup == ElementwiseArityGroup::Unary) {
4255  result = helper.buildUnaryFn(kind.unaryFn, block.getArgument(0));
4256 
4257  } else if (arityGroup == ElementwiseArityGroup::Binary) {
4258  result = helper.buildBinaryFn(kind.binaryFn, block.getArgument(0),
4259  block.getArgument(1));
4260 
4261  } else if (arityGroup == ElementwiseArityGroup::Ternary) {
4262  result = helper.buildTernaryFn(kind.ternaryFn, block.getArgument(0),
4263  block.getArgument(1), block.getArgument(2));
4264 
4265  } else
4266  assert(false && "found unhandled category in elemwise");
4267 
4268  yields.push_back(result);
4269  helper.yieldOutputs(yields);
4270 }
4271 
4272 LogicalResult ElementwiseOp::fold(FoldAdaptor,
4274  return memref::foldMemRefCast(*this);
4275 }
4276 
4277 void ElementwiseOp::getEffects(
4279  &effects) {
4280  if (hasPureTensorSemantics())
4281  return;
4282  getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
4283 }
4284 
4285 Speculation::Speculatability ElementwiseOp::getSpeculatability() {
4286  return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
4287 }
4288 
4289 //===----------------------------------------------------------------------===//
4290 // PackOp/UnPackOp Common
4291 //===----------------------------------------------------------------------===//
4292 // Given the (potentially) updated packed type, `newPackedTy`, generates an
4293 // updated mixed-tile-sizes attribute. A tile size is updated only
4294 // when:
4295 // * a dim from newPackedTy is static, and
4296 // * the corresponding size from mixedTiles is still dynamic.
4297 // Otherwise, the original tile size is preserved.
4298 // Note - packed-type-dim and mixed-tile-size should always match!
4301  SmallVector<OpFoldResult> mixedTiles) {
4302  SmallVector<OpFoldResult> newMixedTileSizes;
4303  for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4304  .getShape()
4305  .take_back(mixedTiles.size()),
4306  mixedTiles)) {
4307  int64_t shape = std::get<0>(it);
4308  if (shape == ShapedType::kDynamic) {
4309  newMixedTileSizes.push_back(std::get<1>(it));
4310  continue;
4311  }
4312 
4313  // If the current result dim is static, update the dynamic mixed-size
4314  // (provided the original value is dynamic).
4315  OpFoldResult tile = std::get<1>(it);
4316  if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4317  // Already a constant
4318  newMixedTileSizes.push_back(tile);
4319  } else {
4320  assert(getConstantIntValue(tile).value() == shape &&
4321  "tile size and dim size don't match!");
4322  newMixedTileSizes.push_back(
4323  (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4324  }
4325  }
4326 
4327  return newMixedTileSizes;
4328 }
4329 
4330 template <typename OpTy>
4331 static LogicalResult
4333  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4334  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4335  "applies to only pack or unpack operations");
4336  int64_t destRank = op.getDestRank();
4337  reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(destRank));
4338  reifiedReturnShapes[0] =
4339  tensor::getMixedSizes(builder, op.getLoc(), op.getDest());
4340  return success();
4341 }
4342 
4343 template <typename OpTy>
4345  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4346  "applies to only pack or unpack operations");
4347  DenseMap<int64_t, OpFoldResult> dimAndTileMapping;
4348  ArrayRef<int64_t> dimsToTile = op.getInnerDimsPos();
4349  SmallVector<OpFoldResult> tiles = op.getMixedTiles();
4350  assert(tiles.size() == dimsToTile.size() &&
4351  "tiles must match indices of dimension to block");
4352  // bind the dimension `i` with the tile factor.
4353  for (auto i : llvm::seq<int64_t>(0, dimsToTile.size()))
4354  dimAndTileMapping[dimsToTile[i]] = tiles[i];
4355  return dimAndTileMapping;
4356 }
4357 
4358 template <typename OpTy>
4360  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4361  "applies to only pack or unpack operations");
4362  Builder builder(op);
4363  SmallVector<OpFoldResult> mixedInnerTiles;
4364  unsigned dynamicValIndex = 0;
4365  for (int64_t staticTile : op.getStaticInnerTiles()) {
4366  if (!ShapedType::isDynamic(staticTile))
4367  mixedInnerTiles.push_back(builder.getI64IntegerAttr(staticTile));
4368  else
4369  mixedInnerTiles.push_back(op.getInnerTiles()[dynamicValIndex++]);
4370  }
4371  return mixedInnerTiles;
4372 }
4373 
4374 template <typename OpTy>
4376  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4377  "applies to only pack or unpack operations");
4378  SmallVector<Value> dynamicTiles;
4379  SmallVector<int64_t> staticTiles;
4380  dispatchIndexOpFoldResults(op.getMixedTiles(), dynamicTiles, staticTiles);
4381  return staticTiles;
4382 }
4383 
4384 /// Returns true if `dimsPos` is invalid. It is invalid when:
4385 /// a) It contains duplicate.
4386 /// b) At least one dimension is out of bound (`dimPos` is >= 0 and < rank).
4387 /// c) The number of elements in `dimsPos` is > than `rank`.
4389  size_t rank) {
4390  size_t dimsPosSize = dimsPos.size();
4391  if (dimsPosSize > rank)
4392  return true;
4393  DenseSet<int64_t> uniqued(llvm::from_range, dimsPos);
4394  if (dimsPosSize != uniqued.size())
4395  return true;
4396  return llvm::any_of(dimsPos, [rank](int64_t dimPos) {
4397  return dimPos < 0 || dimPos >= static_cast<int64_t>(rank);
4398  });
4399 }
4400 
4401 /// Returns true if the dimension of `sourceShape` is smaller than the dimension
4402 /// of the `limitShape`.
4403 static bool areAllInBound(ArrayRef<int64_t> sourceShape,
4404  ArrayRef<int64_t> limitShape) {
4405  assert(
4406  sourceShape.size() == limitShape.size() &&
4407  "expected source shape rank, and limit of the shape to have same rank");
4408  return llvm::all_of(
4409  llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
4410  int64_t sourceExtent = std::get<0>(it);
4411  int64_t limit = std::get<1>(it);
4412  return ShapedType::isDynamic(sourceExtent) ||
4413  ShapedType::isDynamic(limit) || sourceExtent <= limit;
4414  });
4415 }
4416 
4417 template <typename OpTy>
4418 static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
4419  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4420  "applies to only pack or unpack operations");
4421  Operation *op = packOrUnPack.getOperation();
4422 
4423  // Return true if we have a zero-value tile.
4424  auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
4425  return llvm::any_of(
4426  tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
4427  };
4428 
4429  // Verify tiles. Do not allow zero tiles.
4430  SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
4431  if (hasZeros(mixedTiles))
4432  return op->emitError("invalid zero tile factor");
4433 
4434  // Verify inner_dims_pos and outer_dims_perm.
4435  RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
4436  ? packOrUnPack.getSourceType()
4437  : packOrUnPack.getDestType();
4438  size_t unpackedRank = unpackedType.getRank();
4439  ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
4440  ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
4442  return op->emitError("invalid inner_dims_pos vector");
4443  if (isInvalidPackingPosSpecification(outerDimPerm, unpackedRank))
4444  return op->emitError("invalid outer_dims_perm vector");
4445  if (!outerDimPerm.empty() && outerDimPerm.size() != unpackedRank)
4446  return op->emitError("outer_dims_perm must be a permutation or empty");
4447 
4448  // Tiling factors must be less than or equal to the input rank for pack (or
4449  // output rank for unpack), and must match the number of `inner_dims_pos`.
4450  if (mixedTiles.size() > unpackedRank) {
4451  return op->emitError("tiling factors must be less than or equal to the "
4452  "input rank for pack or output rank for unpack");
4453  }
4454  if (mixedTiles.size() != innerDimsPos.size()) {
4455  return op->emitError(
4456  "tiling factors must equal the number of dimensions to tile");
4457  }
4458 
4459  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4460  ? packOrUnPack.getDestType()
4461  : packOrUnPack.getSourceType();
4462  size_t packedRank = packedType.getRank();
4463  // Require output rank to match input rank + number of blocking factors.
4464  size_t expectedPackedRank = unpackedRank + mixedTiles.size();
4465  if (expectedPackedRank != packedRank) {
4466  return op->emitError(
4467  "packed rank != (unpacked rank + num tiling factors), got ")
4468  << packedRank << " != " << expectedPackedRank;
4469  }
4470 
4471  // Verify result shape is greater than the minimum expected
4472  // by the pack operation, and that the output shape
4473  // represents full tiles.
4474  RankedTensorType expectedPackedType = PackOp::inferPackedType(
4475  unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
4476  if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
4477  return op->emitError("the shape of output is not large enough to hold the "
4478  "packed data. Expected at least ")
4479  << expectedPackedType << ", got " << packedType;
4480  }
4481  if (!llvm::all_of(
4482  llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
4483  mixedTiles),
4484  [](std::tuple<int64_t, OpFoldResult> it) {
4485  int64_t shape = std::get<0>(it);
4486  if (Attribute attr =
4487  llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4488  IntegerAttr intAttr = dyn_cast_or_null<IntegerAttr>(attr);
4489  int64_t staticTileSize = intAttr.getValue().getSExtValue();
4490  return shape == staticTileSize;
4491  }
4492  return ShapedType::isDynamic(shape);
4493  })) {
4494  return op->emitError("mismatch in inner tile sizes specified and shaped of "
4495  "tiled dimension in the packed type");
4496  }
4497  return success();
4498 }
4499 
4500 namespace {
4501 /// Subset of PackOp/UnPackOp fields used to compute the result of applying
4502 /// various permutations to the op.
4503 // TODO: Add linalg.transpose + pack/unpack folding patterns that just reuse
4504 // these. These may or may not become true foldings / canonicalizations
4505 // depending on how aggressive we want to be in automatically folding
4506 // transposes.
4507 struct PackOrUnPackTransposeResult {
4511 };
4512 } // namespace
4513 
4514 template <typename OpTy>
4515 static PackOrUnPackTransposeResult
4517  ArrayRef<int64_t> innerPermutation,
4518  ArrayRef<int64_t> outerPermutation) {
4519  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4520  "applies to only pack or unpack operations");
4521  assert((!innerPermutation.empty() || !outerPermutation.empty()) &&
4522  "some permutation must be non-empty");
4523  PackOrUnPackTransposeResult metadata;
4524  metadata.innerDimsPos =
4525  SmallVector<int64_t>(packOrUnPackOp.getInnerDimsPos());
4526  metadata.innerTiles =
4527  SmallVector<OpFoldResult>(packOrUnPackOp.getMixedTiles());
4528  int64_t numOuterDims = std::is_same<OpTy, PackOp>::value
4529  ? packOrUnPackOp.getSourceRank()
4530  : packOrUnPackOp.getDestRank();
4531  metadata.outerDimsPerm =
4532  packOrUnPackOp.getOuterDimsPerm().empty()
4533  ? llvm::to_vector(llvm::seq<int64_t>(0, numOuterDims))
4534  : SmallVector<int64_t>(packOrUnPackOp.getOuterDimsPerm());
4535  if (!innerPermutation.empty()) {
4536  assert(innerPermutation.size() == metadata.innerDimsPos.size() &&
4537  isPermutationVector(innerPermutation) &&
4538  "invalid inner permutation");
4539  applyPermutationToVector(metadata.innerDimsPos, innerPermutation);
4540  applyPermutationToVector(metadata.innerTiles, innerPermutation);
4541  }
4542  if (!outerPermutation.empty()) {
4543  assert(outerPermutation.size() == metadata.outerDimsPerm.size() &&
4544  isPermutationVector(outerPermutation) &&
4545  "invalid outer permutation");
4546  applyPermutationToVector(metadata.outerDimsPerm, outerPermutation);
4547  }
4548  return metadata;
4549 }
4550 
4551 //===----------------------------------------------------------------------===//
4552 // PackOp
4553 //===----------------------------------------------------------------------===//
4554 
4555 void PackOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
4556  setNameFn(getResult(), "pack");
4557 }
4558 
4559 void PackOp::build(OpBuilder &builder, OperationState &state, Value source,
4562  std::optional<Value> paddingValue,
4564  assert(innerDimsPos.size() == innerTiles.size() &&
4565  "number of tile sizes specified must match the specified number of "
4566  "original dimensions to be tiled");
4567  SmallVector<int64_t> staticTileSizes;
4568  SmallVector<Value> dynamicTileSizes;
4569  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
4570  build(builder, state, dest.getType(), source, dest,
4571  paddingValue ? *paddingValue : nullptr,
4572  outerDimsPerm.empty() ? nullptr
4574  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
4575  builder.getDenseI64ArrayAttr(staticTileSizes));
4576 }
4577 
4578 LogicalResult
4580  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
4581  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
4582 }
4583 
4584 DenseMap<int64_t, OpFoldResult> PackOp::getDimAndTileMapping() {
4585  return getDimAndTileMappingImpl(*this);
4586 }
4587 
4588 SmallVector<OpFoldResult> PackOp::getMixedTiles() {
4589  return getMixedTilesImpl(*this);
4590 }
4591 
4592 SmallVector<int64_t> PackOp::getStaticTiles() {
4593  return getStaticTilesImpl(*this);
4594 }
4595 
4596 ArrayRef<int64_t> PackOp::getAllOuterDims() {
4597  ShapedType inputType = getSourceType();
4598  int64_t inputRank = inputType.getRank();
4599  return getDestType().getShape().take_front(inputRank);
4600 }
4601 
4602 SmallVector<int64_t> PackOp::getTiledOuterDims() {
4603  auto innerDimsPos = getInnerDimsPos();
4604  auto packedShape = getDestType().getShape();
4606 
4607  for (auto index : innerDimsPos)
4608  res.push_back(packedShape[index]);
4609 
4610  return res;
4611 }
4612 
4613 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
4615  ArrayRef<int64_t> outputShape,
4618  SmallVector<int64_t> outputTileSizes(
4619  outputShape.take_front(inputShape.size()));
4620  if (!outerDimsPerm.empty()) {
4621  assert(outerDimsPerm.size() == outputTileSizes.size() &&
4622  "expected output and outer_dims_perm to have same size");
4623  applyPermutationToVector(outputTileSizes,
4625  }
4626  for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) {
4627  if (ShapedType::isDynamic(inputShape[pos]))
4628  continue;
4629  std::optional<int64_t> constantTile = getConstantIntValue(tileSize);
4630 
4631  if (!constantTile) {
4632  if (!ShapedType::isDynamic(outputTileSizes[pos]) &&
4633  (inputShape[pos] % outputTileSizes[pos] != 0))
4634  return true;
4635  } else if (inputShape[pos] % (*constantTile) != 0) {
4636  return true;
4637  }
4638  }
4639  return false;
4640 }
4641 
4642 LogicalResult PackOp::verify() {
4643  if (failed(commonVerifierPackAndUnPackOp(*this)))
4644  return failure();
4645 
4646  // Verify padding value, and bail out if the tile does not divide the
4647  // dimension fully. In the case of dynamic tile factors or dimensions, having
4648  // a partial tile is undefined behavior.
4649  auto paddingValue = getPaddingValue();
4650  if (paddingValue &&
4651  paddingValue.getType() != getSourceType().getElementType()) {
4652  return emitOpError("expected padding_value has ")
4653  << getSourceType().getElementType()
4654  << " but got: " << paddingValue.getType();
4655  }
4656 
4657  if (!paddingValue &&
4658  requirePaddingValue(getSourceType().getShape(), getInnerDimsPos(),
4659  getDestType().getShape(), getOuterDimsPerm(),
4660  getMixedTiles())) {
4661  return emitOpError(
4662  "invalid tile factor or output size provided. Only full tiles are "
4663  "supported when padding_value is not set");
4664  }
4665  return success();
4666 }
4667 
4668 /// Converts OpFoldResults to int64_t shape entries, unconditionally mapping all
4669 /// Value's to kDynamic, even if they are arith.constant values.
4670 static SmallVector<int64_t>
4672  SmallVector<int64_t> result;
4673  for (auto o : ofrs) {
4674  // Have to do this first, as getConstantIntValue special-cases constants.
4675  if (llvm::dyn_cast_if_present<Value>(o))
4676  result.push_back(ShapedType::kDynamic);
4677  else
4678  result.push_back(getConstantIntValue(o).value_or(ShapedType::kDynamic));
4679  }
4680  return result;
4681 }
4682 
4683 /// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
4684 /// the packed type. Having a shared helper helps implement these two methods in
4685 /// a way that ensures that they agree on which dimensions are dynamic.
4687  ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4689  SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4690  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4691  if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
4692  continue;
4693  if (ShapedType::isDynamic(innerTileSizes[tiledDim.index()])) {
4694  resultShape[tiledDim.value()] = ShapedType::kDynamic;
4695  continue;
4696  }
4697  resultShape[tiledDim.value()] = llvm::divideCeilSigned(
4698  resultShape[tiledDim.value()], innerTileSizes[tiledDim.index()]);
4699  }
4700 
4701  // Swap tile loops if outer_dims_perm is available.
4702  if (!outerDimsPerm.empty())
4704 
4705  // Append the inner tile dimensions.
4706  resultShape.append(innerTileSizes.begin(), innerTileSizes.end());
4707  return resultShape;
4708 }
4709 
4710 SmallVector<OpFoldResult> PackOp::getResultShape(
4711  OpBuilder &builder, Location loc, ArrayRef<OpFoldResult> sourceDims,
4714  SmallVector<OpFoldResult> resultDims = llvm::to_vector(sourceDims);
4715 
4716  AffineExpr s0, s1;
4717  bindSymbols(builder.getContext(), s0, s1);
4718  AffineExpr ceilDivExpr = s0.ceilDiv(s1);
4719  for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
4720  resultDims[tiledDim.value()] = affine::makeComposedFoldedAffineApply(
4721  builder, loc, ceilDivExpr,
4722  {resultDims[tiledDim.value()], innerTileSizes[tiledDim.index()]});
4723  }
4724  if (!outerDimsPerm.empty())
4726  resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
4727 
4728  SmallVector<int64_t> resultTypeShape =
4730  asShapeWithAnyValueAsDynamic(innerTileSizes),
4732 
4733  // Fix-up `resultDims` to ensure that they are Value's if and only if the
4734  // result type shape says it's a dynamic dim. This is needed as callers may
4735  // use dispatchIndexOpFoldResults on the result, and rely on exact number of
4736  // dynamic dims returned by that.
4737  for (unsigned i = 0; i < resultDims.size(); ++i) {
4738  if (!ShapedType::isDynamic(resultTypeShape[i]))
4739  continue;
4740  resultDims[i] =
4741  getValueOrCreateConstantIndexOp(builder, loc, resultDims[i]);
4742  }
4743 
4744  return resultDims;
4745 }
4746 
4747 /// Get the expected packed type based on source type, tile factors, position of
4748 /// the inner tiles and permutation of the outer tiled loop.
4749 RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
4750  ArrayRef<int64_t> innerTileSizes,
4754  sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
4755  return RankedTensorType::get(resultShape, sourceType.getElementType());
4756 }
4757 
4758 Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
4759  ArrayRef<OpFoldResult> innerTileSizes,
4762  AffineExpr dim0, dim1;
4763  bindDims(b.getContext(), dim0, dim1);
4764  auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
4765  return affine::makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1),
4766  {v1, v2});
4767  };
4768 
4769  SmallVector<OpFoldResult> mixedSizes;
4770  for (auto [index, value] : llvm::enumerate(
4771  llvm::cast<RankedTensorType>(source.getType()).getShape())) {
4772  if (ShapedType::isDynamic(value))
4773  mixedSizes.push_back(
4774  b.create<tensor::DimOp>(loc, source, index).getResult());
4775  else
4776  mixedSizes.push_back(b.getIndexAttr(value));
4777  }
4778  for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) {
4779  int64_t dimPos = std::get<0>(it);
4780  OpFoldResult tileSize = std::get<1>(it);
4781  mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize);
4782  }
4783  if (!outerDimsPerm.empty())
4784  applyPermutationToVector<OpFoldResult>(mixedSizes, outerDimsPerm);
4785 
4786  mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end());
4787  auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType();
4788  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4789 }
4790 
4791 PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
4792  ArrayRef<int64_t> innerPermutation,
4793  ArrayRef<int64_t> outerPermutation) {
4794  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
4795  *this, innerPermutation, outerPermutation);
4796  Value transposedDest =
4797  createDestinationTensor(b, loc, getSource(), metadata.innerTiles,
4798  metadata.innerDimsPos, metadata.outerDimsPerm);
4799  return b.create<PackOp>(loc, getSource(), transposedDest,
4800  metadata.innerDimsPos, metadata.innerTiles,
4801  getPaddingValue(), metadata.outerDimsPerm);
4802 }
4803 
4804 /// Returns true if the tiles and the tiled dims are constant.
4805 template <typename OpTy>
4807  static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
4808  "applies to only pack or unpack operations");
4809  ShapedType packedType = (std::is_same<OpTy, PackOp>::value)
4810  ? op.getDestType()
4811  : op.getSourceType();
4812  SmallVector<OpFoldResult> mixedTiles = op.getMixedTiles();
4813  for (auto [dimDest, tile] : llvm::zip(
4814  packedType.getShape().take_back(mixedTiles.size()), mixedTiles)) {
4815  std::optional<int64_t> constTileSize = getConstantIntValue(tile);
4816  if (!constTileSize || ShapedType::isDynamic(dimDest))
4817  return false;
4818  }
4819  return true;
4820 }
4821 
4822 Speculation::Speculatability PackOp::getSpeculatability() {
4823  if (getPaddingValue())
4825 
4826  // The verifier rejects already operations if we can statically prove that the
4827  // sizes of the tiles do not divide perfectly the dimension; thus, check only
4828  // to have constant tiles and tiled inner dimensions.
4829  if (!areTilesAndTiledDimsAllConstant(*this))
4831 
4833 }
4834 
4835 // Return true if `inner_dims_pos` and `outer_dims_perm` target the same
4836 // dimensions for pack and unpack.
4837 static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp) {
4838  if (packOp.getInnerDimsPos() != unPackOp.getInnerDimsPos())
4839  return false;
4840  if (packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm())
4841  return true;
4842  // Outer dims permutation is optional.
4843  // To compare unbalanced pack-unpack pair, treat no permutation as equal to
4844  // identity permutation.
4845  return isIdentityPermutation(packOp.getOuterDimsPerm()) &&
4846  isIdentityPermutation(unPackOp.getOuterDimsPerm());
4847 }
4848 
4849 // Return true if pack and unpack have the same tiles.
4850 // Same SSA values or same integer constants.
4851 static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp) {
4852  auto packTiles = packOp.getMixedTiles();
4853  auto unPackTiles = unPackOp.getMixedTiles();
4854  if (packTiles.size() != unPackTiles.size())
4855  return false;
4856  for (size_t i = 0, e = packTiles.size(); i < e; i++) {
4857  if (!isEqualConstantIntOrValue(packTiles[i], unPackTiles[i]))
4858  return false;
4859  }
4860  return true;
4861 }
4862 
4863 /// Returns true if the pack op does not need a padding value.
4864 static bool paddingIsNotNeeded(PackOp op) {
4865  auto srcType = op.getSourceType();
4866  if (llvm::any_of(op.getInnerDimsPos(),
4867  [&](int64_t pos) { return srcType.isDynamicDim(pos); }))
4868  return false;
4869  if (ShapedType::isDynamicShape(op.getStaticInnerTiles()))
4870  return false;
4871  return !PackOp::requirePaddingValue(
4872  srcType.getShape(), op.getInnerDimsPos(), op.getDestType().getShape(),
4873  op.getOuterDimsPerm(), op.getMixedTiles());
4874 }
4875 
4876 /// Returns true if the `srcShape` or `destShape` is different from the one in
4877 /// `packOp` and populates each with the inferred static shape.
4878 static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
4879  SmallVectorImpl<int64_t> &destShape) {
4880  bool changeNeeded = false;
4881  srcShape.assign(packOp.getSourceType().getShape().begin(),
4882  packOp.getSourceType().getShape().end());
4883  destShape.assign(packOp.getDestType().getShape().begin(),
4884  packOp.getDestType().getShape().end());
4885  llvm::SmallSetVector<int64_t, 4> innerDims;
4886  innerDims.insert_range(packOp.getInnerDimsPos());
4887  SmallVector<int64_t> inverseOuterDimsPerm;
4888  if (!packOp.getOuterDimsPerm().empty())
4889  inverseOuterDimsPerm = invertPermutationVector(packOp.getOuterDimsPerm());
4890  int srcRank = packOp.getSourceRank();
4891  for (auto i : llvm::seq<int64_t>(0, srcRank)) {
4892  if (innerDims.contains(i))
4893  continue;
4894  int64_t srcPos = i;
4895  int64_t destPos = i;
4896  if (!inverseOuterDimsPerm.empty())
4897  destPos = inverseOuterDimsPerm[srcPos];
4898  if (ShapedType::isDynamic(srcShape[srcPos]) ==
4899  ShapedType::isDynamic(destShape[destPos])) {
4900  continue;
4901  }
4902  int64_t size = srcShape[srcPos];
4903  if (ShapedType::isDynamic(size))
4904  size = destShape[destPos];
4905  srcShape[srcPos] = size;
4906  destShape[destPos] = size;
4907  changeNeeded = true;
4908  }
4909  return changeNeeded;
4910 }
4911 
4912 LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
4913  // Fold an pack(unpack(x)) to x.
4914  if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
4915  if (unPackOp.getSourceType() != packOp.getDestType())
4916  return failure();
4917  if (packOp.getPaddingValue() ||
4918  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
4919  !haveSameTiles(packOp, unPackOp))
4920  return failure();
4921  rewriter.replaceOp(packOp, unPackOp.getSource());
4922  return success();
4923  }
4924 
4925  // Fold optional PaddingValue operand away if padding is not needed.
4926  if (packOp.getPaddingValue() && paddingIsNotNeeded(packOp)) {
4927  rewriter.startOpModification(packOp);
4928  packOp.getPaddingValueMutable().clear();
4929  rewriter.finalizeOpModification(packOp);
4930  return success();
4931  }
4932 
4933  // Insert tensor.cast ops if static shape inference is available..
4934  SmallVector<int64_t> srcShape, destShape;
4935  if (inferStaticShape(packOp, srcShape, destShape)) {
4936  Location loc = packOp.getLoc();
4937  Value source = packOp.getSource();
4938  if (srcShape != packOp.getSourceType().getShape()) {
4939  auto newSrcType = packOp.getSourceType().clone(srcShape);
4940  source =
4941  rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
4942  }
4943  Value dest = packOp.getDest();
4944  RankedTensorType originalResultType = packOp.getDestType();
4945  bool needUpdateDestType = (destShape != originalResultType.getShape());
4946  if (needUpdateDestType) {
4947  auto newDestType = packOp.getDestType().clone(destShape);
4948  dest =
4949  rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
4950  }
4951  rewriter.modifyOpInPlace(packOp, [&] {
4952  packOp.getSourceMutable().assign(source);
4953  packOp.getDestMutable().assign(dest);
4954  packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
4955  });
4956  // Insert a cast if needed
4957  if (needUpdateDestType) {
4958  rewriter.setInsertionPointAfter(packOp);
4959  auto castOp =
4960  rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
4961  rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
4962  }
4963  return success();
4964  }
4965 
4966  return failure();
4967 }
4968 
4969 template <typename PackOrUnpackOp>
4970 static bool isLikePadUnPad(PackOrUnpackOp packOp,
4971  RankedTensorType packedTensorType) {
4972  static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
4973  std::is_same<PackOrUnpackOp, UnPackOp>::value,
4974  "Function meant for pack/unpack");
4975  // This is a pad if packing only adds ones and we don't transpose dimensions.
4976 
4977  // Check that we are not transposing any dimensions.
4978  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
4979  int64_t numPackedDims = innerDimsPos.size();
4980  auto orderedDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, numPackedDims));
4981  if (orderedDims != innerDimsPos) {
4982  // Dimensions don't happen in order.
4983  return false;
4984  }
4985 
4986  ArrayRef<int64_t> packedShape = packedTensorType.getShape();
4987  int64_t packedRank = packedTensorType.getRank();
4988  // At this point we know that we are taking numPackedDims outer
4989  // dimensions and pushing them all the way as the inner most dimensions.
4990  // What's left on the outer most dimensions is, in this order:
4991  // - the factor of the packed dimensions, then
4992  // - the untouched dimensions
4993  // This shifting inward of dimensions is a no-op (as opposed to a transpose)
4994  // if all the dimensions that bubble outerward are ones.
4995  // Therefore check that all the dimensions but the numPackedDims inner most
4996  // ones are ones.
4997  return llvm::all_of(
4998  llvm::seq<int64_t>(0, packedRank - numPackedDims),
4999  [&packedShape](int64_t i) { return packedShape[i] == 1; });
5000 }
5001 
5002 bool PackOp::isLikePad() {
5003  auto packedTensorType =
5004  llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
5005  return isLikePadUnPad(*this, packedTensorType);
5006 }
5007 
5008 OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
5009  std::optional<Attribute> paddingValue;
5010  if (auto pad = adaptor.getPaddingValue())
5011  paddingValue = pad;
5012  if (OpFoldResult reshapedSource = reshapeConstantSource(
5013  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5014  getDestType(), paddingValue))
5015  return reshapedSource;
5016  return {};
5017 }
5018 
5019 /// Folds a tensor.cast op into a consuming PackOp op if the
5020 /// `tensor.cast` has source that is more static than the consuming op.
5021 ///
5022 /// Example:
5023 /// ```mlir
5024 /// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
5025 /// %2 = tensor.pack %1 ... : tensor<?x?xf32> ...
5026 /// ```
5027 ///
5028 /// folds into:
5029 ///
5030 /// ```mlir
5031 /// %2 = tensor.pack %0 ... : tensor<8x16xf32> ...
5032 /// ```
5033 struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
5035 
5036  LogicalResult matchAndRewrite(PackOp op,
5037  PatternRewriter &rewriter) const override {
5039  return failure();
5040 
5041  SmallVector<Type> newResultTypes(op->getResultTypes());
5042  SmallVector<Value> newOperands =
5044 
5045  // Get the updated mixed-tile-sizes attribute.
5046  SmallVector<OpFoldResult> newMixedTileSizes =
5047  getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
5048 
5049  // Clone op.
5050  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5051  // this point. However, in practice, we use them for things that we'd like
5052  // to preserve. Implement a better abstraction.
5053  PackOp newOp = rewriter.create<PackOp>(
5054  op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
5055  newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
5056  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5057 
5058  // Replace op.
5059  Value oldResult = op.getResult();
5060  Value newResult = newOp.getResult();
5061  Value replacement = (newResult.getType() != oldResult.getType())
5062  ? rewriter.create<tensor::CastOp>(
5063  op->getLoc(), oldResult.getType(), newResult)
5064  : newResult;
5065 
5066  rewriter.replaceOp(op, {replacement});
5067 
5068  return success();
5069  }
5070 };
5071 
5072 //===----------------------------------------------------------------------===//
5073 // UnPackOp
5074 //===----------------------------------------------------------------------===//
5075 
5076 void UnPackOp::getAsmResultNames(
5077  function_ref<void(Value, StringRef)> setNameFn) {
5078  setNameFn(getResult(), "unpack");
5079 }
5080 
5081 LogicalResult
5083  ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
5084  return reifyResultShapesImpl(*this, builder, reifiedReturnShapes);
5085 }
5086 
5087 DenseMap<int64_t, OpFoldResult> UnPackOp::getDimAndTileMapping() {
5088  return getDimAndTileMappingImpl(*this);
5089 }
5090 
5091 SmallVector<OpFoldResult> UnPackOp::getMixedTiles() {
5092  return getMixedTilesImpl(*this);
5093 }
5094 
5095 SmallVector<int64_t> UnPackOp::getStaticTiles() {
5096  return getStaticTilesImpl(*this);
5097 }
5098 
5099 ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
5100  ShapedType destType = getDestType();
5101  int64_t destRank = destType.getRank();
5102  return getSourceType().getShape().take_front(destRank);
5103 }
5104 
5105 SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
5106  auto innerDimsPos = getInnerDimsPos();
5107  auto packedShape = getSourceType().getShape();
5109 
5110  for (auto index : innerDimsPos)
5111  res.push_back(packedShape[index]);
5112 
5113  return res;
5114 }
5115 
5116 LogicalResult UnPackOp::verify() {
5117  return commonVerifierPackAndUnPackOp(*this);
5118 }
5119 
5120 Speculation::Speculatability UnPackOp::getSpeculatability() {
5121  // See PackOp::getSpeculatability.
5122  if (!areTilesAndTiledDimsAllConstant(*this))
5124 
5126 }
5127 
5128 void UnPackOp::build(OpBuilder &builder, OperationState &state, Value source,
5132  assert(innerDimsPos.size() == innerTiles.size() &&
5133  "number of tile sizes specified must match the specified number of "
5134  "original dimensions to be tiled");
5135  SmallVector<int64_t> staticTileSizes;
5136  SmallVector<Value> dynamicTileSizes;
5137  dispatchIndexOpFoldResults(innerTiles, dynamicTileSizes, staticTileSizes);
5138  build(builder, state, dest.getType(), source, dest,
5139  outerDimsPerm.empty() ? nullptr
5141  builder.getDenseI64ArrayAttr(innerDimsPos), dynamicTileSizes,
5142  builder.getDenseI64ArrayAttr(staticTileSizes));
5143 }
5144 
5145 Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
5146  Value source,
5147  ArrayRef<OpFoldResult> innerTileSizes,
5150  AffineExpr sym0, sym1;
5151  bindSymbols(b.getContext(), sym0, sym1);
5152  auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
5153  return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
5154  };
5155 
5156  SmallVector<OpFoldResult> mixedSizes;
5157  auto srcType = llvm::cast<RankedTensorType>(source.getType());
5158  for (auto i :
5159  llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
5160  if (srcType.isDynamicDim(i))
5161  mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult());
5162  else
5163  mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i)));
5164  }
5165  if (!outerDimsPerm.empty()) {
5166  applyPermutationToVector<OpFoldResult>(
5167  mixedSizes, invertPermutationVector(outerDimsPerm));
5168  }
5169 
5170  for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
5171  mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
5172 
5173  auto elemType = srcType.getElementType();
5174  return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
5175 }
5176 
5177 UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
5178  Value transposedSource,
5179  ArrayRef<int64_t> innerPermutation,
5180  ArrayRef<int64_t> outerPermutation) {
5181  PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp(
5182  *this, innerPermutation, outerPermutation);
5183  return b.create<UnPackOp>(loc, transposedSource, getDest(),
5184  metadata.innerDimsPos, metadata.innerTiles,
5185  metadata.outerDimsPerm);
5186 }
5187 
5188 /// Returns true if the `srcShape` or `destShape` is different from the one in
5189 /// `op` and populates each with the inferred static shape.
5190 static bool inferStaticShape(UnPackOp op, SmallVectorImpl<int64_t> &srcShape,
5191  SmallVectorImpl<int64_t> &destShape) {
5192  bool changeNeeded = false;
5193  srcShape.assign(op.getSourceType().getShape().begin(),
5194  op.getSourceType().getShape().end());
5195  destShape.assign(op.getDestType().getShape().begin(),
5196  op.getDestType().getShape().end());
5197  llvm::SmallSetVector<int64_t, 4> innerDims;
5198  innerDims.insert_range(op.getInnerDimsPos());
5199  SmallVector<int64_t> inverseOuterDimsPerm;
5200  if (!op.getOuterDimsPerm().empty())
5201  inverseOuterDimsPerm = invertPermutationVector(op.getOuterDimsPerm());
5202  int destRank = op.getDestRank();
5203  for (auto i : llvm::seq<int64_t>(0, destRank)) {
5204  if (innerDims.contains(i))
5205  continue;
5206  int64_t srcPos = i;
5207  int64_t destPos = i;
5208  if (!inverseOuterDimsPerm.empty())
5209  srcPos = inverseOuterDimsPerm[destPos];
5210  if (ShapedType::isDynamic(srcShape[srcPos]) ==
5211  ShapedType::isDynamic(destShape[destPos])) {
5212  continue;
5213  }
5214  int64_t size = srcShape[srcPos];
5215  if (ShapedType::isDynamic(size))
5216  size = destShape[destPos];
5217  srcShape[srcPos] = size;
5218  destShape[destPos] = size;
5219  changeNeeded = true;
5220  }
5221  return changeNeeded;
5222 }
5223 
5224 LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
5225  PatternRewriter &rewriter) {
5226  /// unpack(pack(x)) -> x
5227  if (PackOp packOp = unPackOp.getSource().getDefiningOp<PackOp>()) {
5228  if (packOp.getSourceType() != unPackOp.getDestType())
5229  return failure();
5230  if (packOp.getPaddingValue() ||
5231  !hasSameInnerOuterAttribute(packOp, unPackOp) ||
5232  !haveSameTiles(packOp, unPackOp))
5233  return failure();
5234  rewriter.replaceOp(unPackOp, packOp.getSource());
5235  return success();
5236  }
5237  /// unpack(destinationStyleOp(x)) -> unpack(x)
5238  if (auto dstStyleOp =
5239  unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
5240  auto destValue = cast<OpResult>(unPackOp.getDest());
5241  Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
5242  rewriter.modifyOpInPlace(unPackOp,
5243  [&]() { unPackOp.setDpsInitOperand(0, newDest); });
5244  return success();
5245  }
5246 
5247  // Insert tensor.cast ops if static shape inference is available..
5248  SmallVector<int64_t> srcShape, destShape;
5249  if (inferStaticShape(unPackOp, srcShape, destShape)) {
5250  Location loc = unPackOp.getLoc();
5251  Value source = unPackOp.getSource();
5252  if (srcShape != unPackOp.getSourceType().getShape()) {
5253  auto newSrcType = unPackOp.getSourceType().clone(srcShape);
5254  source = rewriter.create<tensor::CastOp>(loc, newSrcType,
5255  unPackOp.getSource());
5256  }
5257  Value dest = unPackOp.getDest();
5258  if (destShape != unPackOp.getDestType().getShape()) {
5259  auto newDestType = unPackOp.getDestType().clone(destShape);
5260  dest =
5261  rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest());
5262  }
5263  Value newOp = rewriter.create<UnPackOp>(
5264  loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(),
5265  unPackOp.getOuterDimsPerm());
5266  rewriter.replaceOpWithNewOp<tensor::CastOp>(
5267  unPackOp, unPackOp.getResult().getType(), newOp);
5268  return success();
5269  }
5270 
5271  return failure();
5272 }
5273 
5274 bool UnPackOp::isLikeUnPad() {
5275  RankedTensorType packedTensorType = getSourceType();
5276  return isLikePadUnPad(*this, packedTensorType);
5277 }
5278 
5279 OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
5280  if (OpFoldResult reshapedSource = reshapeConstantSource(
5281  llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
5282  getResult().getType()))
5283  return reshapedSource;
5284  return {};
5285 }
5286 
5287 /// Folds a tensor.cast op into a consuming UnPackOp op if the
5288 /// `tensor.cast` has source that is more static than the consuming op.
5289 ///
5290 /// Example:
5291 /// ```mlir
5292 /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
5293 /// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
5294 /// ```
5295 ///
5296 /// folds into:
5297 ///
5298 /// ```mlir
5299 /// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
5300 /// ```
5301 struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
5303 
5304  LogicalResult matchAndRewrite(UnPackOp op,
5305  PatternRewriter &rewriter) const override {
5307  return failure();
5308 
5309  SmallVector<Type> newResultTypes(op->getResultTypes());
5310  SmallVector<Value> newOperands =
5312  Value sourceTensor = newOperands[0];
5313 
5314  // Get the updated mixed-tile-sizes attribute.
5315  SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
5316  rewriter, sourceTensor.getType(), op.getMixedTiles());
5317 
5318  // Clone op.
5319  // TODO: Strictly speaking, discardable attributes should be _discarded_ at
5320  // this point. However, in practice, we use them for things that we'd like
5321  // to preserve. Implement a better abstraction.
5322  UnPackOp newOp = rewriter.create<UnPackOp>(
5323  op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
5324  newMixedTileSizes, op.getOuterDimsPerm());
5325  newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
5326 
5327  // Replace op.
5328  Value oldResult = op.getResult();
5329  Value newResult = newOp.getResult();
5330  Value replacement = (newResult.getType() != oldResult.getType())
5331  ? rewriter.create<tensor::CastOp>(
5332  op->getLoc(), oldResult.getType(), newResult)
5333  : newResult;
5334 
5335  rewriter.replaceOp(op, {replacement});
5336 
5337  return success();
5338  }
5339 };
5340 
5341 } // namespace linalg
5342 } // namespace mlir
5343 
5344 //===----------------------------------------------------------------------===//
5345 // LinalgDialect
5346 //===----------------------------------------------------------------------===//
5347 
5348 void LinalgDialect::getCanonicalizationPatterns(
5349  RewritePatternSet &results) const {
5350  results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp, FoldTensorCastPackOp,
5351  FoldTensorCastUnPackOp, InferStaticShapeOfOperands>(getContext());
5352 }
5353 
5355  Attribute value, Type type,
5356  Location loc) {
5357  return arith::ConstantOp::materialize(builder, value, type, loc);
5358 }
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyExtendedMatmulSemantic(MatmulOp matmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic sepecified by the explicit indexing map for the MatmulO...
Definition: LinalgOps.cpp:3445
static void buildIdentityRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:1852
static ParseResult parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region, unsigned numRegionArgs, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:310
static std::tuple< SmallVector< utils::IteratorType >, SmallVector< AffineMap > > computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank, int64_t dim, bool allParallel=false)
Definition: LinalgOps.cpp:2786
static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, Value denominator, Value output, int64_t dim)
Produce a linalg generic that computes the final step of the softmax decomposition.
Definition: LinalgOps.cpp:2857
static bool areResultExprsSubsetOf(AffineMap subMap, AffineMap fullMap)
Definition: LinalgOps.cpp:3423
SmallVector< int64_t > outerDimsPerm
Definition: LinalgOps.cpp:4510
union mlir::linalg::@1183::ArityGroupAndKind::Kind kind
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region, TypeRange inputTypes, TypeRange outputTypes, ArrayRef< NamedAttribute > attrs, RegionBuilderFn regionBuilder)
Fills the region of a structured operation using the provided regionBuilder.
Definition: LinalgOps.cpp:126
static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t)
Definition: LinalgOps.cpp:2324
SmallVector< OpFoldResult > innerTiles
Definition: LinalgOps.cpp:4509
static bool isBroadcasted(AffineMap explictMap, AffineMap defaultMap)
Check if the user defined map is valid broadcast map.
Definition: LinalgOps.cpp:3438
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs, ValueRange outputs)
Definition: LinalgOps.cpp:298
static LogicalResult verifyOutputMap(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap)
This function checks if the given AffineMap for the output of a BatchMatmulOp has exactly 3 result di...
Definition: LinalgOps.cpp:3502
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser, NamedAttrList &attributes, StringRef attributeName)
Definition: LinalgOps.cpp:1685
static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, ArrayRef< int64_t > attributeValue)
Definition: LinalgOps.cpp:1733
static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, Value max, Value output, int64_t dim)
Produce a linalg generic that computes the second step of the softmax decomposition: res = exp(input ...
Definition: LinalgOps.cpp:2831
static void buildMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:188
static LogicalResult verifyExtendedBatchMatmulSemantic(BatchMatmulOp batchMatmulOp, unsigned opIndex)
Verifies the broadcast and transpose semantic specified by the explicit indexing map for the BatchMat...
Definition: LinalgOps.cpp:3525
static Operation * findPayloadOp(Block *body, bool initFirst=false)
Definition: LinalgOps.cpp:1485
static void buildStructuredOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder)
Creates a structured operation given inputs, outputs, and attributes.
Definition: LinalgOps.cpp:160
TernaryFn ternaryFn
Definition: LinalgOps.cpp:4075
static ParseResult parseDstStyleOp(OpAsmParser &parser, OperationState &result, function_ref< ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn=nullptr)
Definition: LinalgOps.cpp:1354
static void buildBatchMatmulOp(OpBuilder &b, OperationState &state, std::optional< TypeRange > resultTensorTypes, ValueRange inputs, ValueRange outputs, ArrayRef< NamedAttribute > attributes, RegionBuilderFn regionBuilder, ArrayRef< AffineMap > indexingMaps)
Definition: LinalgOps.cpp:204
static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, int64_t dim)
Definition: LinalgOps.cpp:2808
ElementwiseArityGroup arityGroup
Definition: LinalgOps.cpp:4069
static Speculation::Speculatability getGenericSpeculatabilityImpl(LinalgOp linalgOp)
Definition: LinalgOps.cpp:1245
static void getGenericEffectsImpl(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect >> &effects, LinalgOp linalgOp)
Definition: LinalgOps.cpp:1212
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
Definition: LinalgOps.cpp:2235
SmallVector< int64_t > innerDimsPos
Definition: LinalgOps.cpp:4508
static ParseResult parseNamedStructuredOp(OpAsmParser &parser, OperationState &result, unsigned numRegionArgs, RegionBuilderFn regionBuilder)
Definition: LinalgOps.cpp:336
static void buildGenericRegion(OpBuilder &builder, Location loc, Region &region, ValueRange inputs, ValueRange outputs, function_ref< void(OpBuilder &, Location, ValueRange)> bodyBuild)
Definition: LinalgOps.cpp:987
static ParseResult parseNamedStructuredOpResults(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes)
Definition: LinalgOps.cpp:329
static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim)
Return a memref.dim or tensor.dim for the shape of v at dim.
Definition: LinalgOps.cpp:59
UnaryFn unaryFn
Definition: LinalgOps.cpp:4073
static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, const OperationName &payloadOpName, const NamedAttrList &payloadOpAttrs, ArrayRef< Value > operands, bool initFirst=false)
Definition: LinalgOps.cpp:1408
void printShortForm(OpAsmPrinter &p, Operation *payloadOp)
Definition: LinalgOps.cpp:1514
static void printNamedStructuredOpResults(OpAsmPrinter &p, TypeRange resultTypes)
Definition: LinalgOps.cpp:366
static LogicalResult verifyInputMaps(BatchMatmulOp batchMatmulOp, AffineMap opIndexingMap, AffineMap defaultIndexingMap, bool isLHS)
Definition: LinalgOps.cpp:3469
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op, ValueRange inputs, ValueRange outputs, ArrayRef< StringRef > elidedAttrs={})
Definition: LinalgOps.cpp:373
BinaryFn binaryFn
Definition: LinalgOps.cpp:4074
static ParseResult parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, SmallVectorImpl< Type > &inputTypes, SmallVectorImpl< Type > &outputTypes, bool addOperandSegmentSizes=true)
Common parsing used for both named structured ops created by ods-gen and by manually defined C++ ops.
Definition: LinalgOps.cpp:224
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static LogicalResult getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult, TilingInterface op, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, SmallVector< OpFoldResult > &resultOffset, SmallVector< OpFoldResult > &resultSize, const scf::SCFTilingOptions &options)
static FailureOr< TilingResult > getTiledImplementation(RewriterBase &rewriter, TilingInterface op, ValueRange regionIterArg, ArrayRef< OpFoldResult > offsets, ArrayRef< OpFoldResult > sizes, const scf::SCFTilingOptions &options)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
Base type for affine expression.
Definition: AffineExpr.h:68
bool isFunctionOfDim(unsigned position) const
Return true if the affine expression involves AffineDimExpr position.
Definition: AffineExpr.cpp:316
AffineExpr ceilDiv(uint64_t v) const
Definition: AffineExpr.cpp:968
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
AffineMap dropResults(ArrayRef< int64_t > positions) const
Definition: AffineMap.h:299
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:618
unsigned getNumDims() const
Definition: AffineMap.cpp:394
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:407
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:411
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:264
@ Paren
Parens surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual ParseResult parseOptionalLBrace()=0
Parse a { token if present.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
virtual void printAttribute(Attribute attr)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
unsigned getNumArguments()
Definition: Block.h:128
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
OpListType & getOperations()
Definition: Block.h:137
BlockArgListType getArguments()
Definition: Block.h:87
Operation & front()
Definition: Block.h:153
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:159
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:383
IntegerAttr getI64IntegerAttr(int64_t value)
Definition: Builders.cpp:108
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:360
MLIRContext * getContext() const
Definition: Builders.h:56
Location getUnknownLoc()
Definition: Builders.cpp:27
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:262
IndexType getIndexType()
Definition: Builders.cpp:51
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:314
IRValueT get() const
Return the current value being used by this operand.
Definition: UseDefLists.h:160
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:784
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void assign(const_iterator inStart, const_iterator inEnd)
Replaces the attributes with new list of attributes.
ArrayRef< NamedAttribute > getAttrs() const
Return all of the attributes on this operation.
DictionaryAttr getDictionary(MLIRContext *context) const
Return a dictionary attribute for the underlying dictionary.
Attribute get(StringAttr name) const
Return the specified attribute if present, null otherwise.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
Attribute set(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
NamedAttribute represents a combination of a name and an Attribute value.
Definition: Attributes.h:207
StringAttr getName() const
Return the name of the attribute.
Definition: Attributes.cpp:55
Attribute getValue() const
Return the value of the attribute.
Definition: Attributes.h:222
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual FailureOr< OperationName > parseCustomOperationName()=0
Parse the name of an operation, in the custom form.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void increaseIndent()=0
Increase indentation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void decreaseIndent()=0
Decrease indentation.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:410
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
This class represents an operand of an operation.
Definition: Value.h:267
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition: Value.cpp:216
This is a value defined by a result of an operation.
Definition: Value.h:457
unsigned getResultNumber() const
Returns the number of this result.
Definition: Value.h:469
StringRef getStringRef() const
Return the name of this operation. This always succeeds.
std::optional< RegisteredOperationName > getRegisteredInfo() const
If this operation is registered, returns the registered information, std::nullopt otherwise.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
result_iterator result_begin()
Definition: Operation.h:413
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
unsigned getNumOperands()
Definition: Operation.h:346
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_iterator result_end()
Definition: Operation.h:414
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
result_range getResults()
Definition: Operation.h:415
void setDiscardableAttrs(DictionaryAttr newAttrs)
Set the discardable attribute dictionary on this operation.
Definition: Operation.h:523
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:803
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
iterator end()
Definition: Region.h:56
Block & emplaceBlock()
Definition: Region.h:46
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:865
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:412
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:736
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void finalizeOpModification(Operation *op)
This method is used to signal the end of an in-place modification of the given operation.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
Definition: PatternMatch.h:720
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:648
virtual void startOpModification(Operation *op)
This method is used to notify the rewriter that an in-place operation modification is about to happen...
Definition: PatternMatch.h:632
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:554
This class represents a specific instance of an effect.
static DerivedEffect * get()
Returns a unique instance for the derived effect class.
static DefaultResource * get()
Returns a unique instance for the given effect class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isSignlessIntOrIndexOrFloat() const
Return true if this is a signless integer, index, or float type.
Definition: Types.cpp:104
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static Attribute parse(AsmParser &parser, Type type)
Parse the short form [42, 100, -1] without any type prefix.
constexpr auto RecursivelySpeculatable
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1158
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1208
Value getIdentityValue(AtomicRMWKind op, Type resultType, OpBuilder &builder, Location loc, bool useOnlyFiniteValue=false)
Returns the identity value associated with an AtomicRMWKind op.
Definition: ArithOps.cpp:2651
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
static LogicalResult reifyResultShapesImpl(OpTy op, OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Definition: LinalgOps.cpp:4332
static bool inferStaticShape(PackOp packOp, SmallVectorImpl< int64_t > &srcShape, SmallVectorImpl< int64_t > &destShape)
Returns true if the srcShape or destShape is different from the one in packOp and populates each with...
Definition: LinalgOps.cpp:4878
static bool isLikePadUnPad(PackOrUnpackOp packOp, RankedTensorType packedTensorType)
Definition: LinalgOps.cpp:4970
static SmallVector< int64_t > asShapeWithAnyValueAsDynamic(ArrayRef< OpFoldResult > ofrs)
Converts OpFoldResults to int64_t shape entries, unconditionally mapping all Value's to kDynamic,...
Definition: LinalgOps.cpp:4671
static bool isInvalidPackingPosSpecification(ArrayRef< int64_t > dimsPos, size_t rank)
Returns true if dimsPos is invalid.
Definition: LinalgOps.cpp:4388
static bool areAllInBound(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > limitShape)
Returns true if the dimension of sourceShape is smaller than the dimension of the limitShape.
Definition: LinalgOps.cpp:4403
static SmallVector< int64_t > getStaticTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4375
static SmallVector< int64_t > getPackOpResultTypeShape(ArrayRef< int64_t > sourceShape, ArrayRef< int64_t > innerTileSizes, ArrayRef< int64_t > innerDimsPos, ArrayRef< int64_t > outerDimsPerm)
Helper for PackOp::{getResultShape,inferPackedType}.
Definition: LinalgOps.cpp:4686
SmallVector< AffineExpr, 4 > concat(ArrayRef< AffineExpr > a, ArrayRef< AffineExpr > b)
Return the vector that is the concatenation of a and b.
Definition: LinalgOps.cpp:2316
static ArityGroupAndKind getArityGroupAndKind(ElementwiseKind kind)
Definition: LinalgOps.cpp:4084
static PackOrUnPackTransposeResult commonPermutationOfPackAndUnPackOp(OpTy packOrUnPackOp, ArrayRef< int64_t > innerPermutation, ArrayRef< int64_t > outerPermutation)
Definition: LinalgOps.cpp:4516
OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:106
std::string generateLibraryCallName(Operation *op)
Returns the name mangled library call name to disambiguate between different overloads at the C level...
Definition: LinalgOps.cpp:2357
static bool paddingIsNotNeeded(PackOp op)
Returns true if the pack op does not need a padding value.
Definition: LinalgOps.cpp:4864
AffineMap extractOrIdentityMap(std::optional< AffineMap > maybeMap, unsigned rank, MLIRContext *context)
Returns maybeMap.get() if maybeMap is set, otherwise returns the symbol-less identity map of rank.
Definition: LinalgOps.cpp:2296
SmallVector< AffineExpr, 4 > makeAffineDimExprs(unsigned num, unsigned &startIdx, MLIRContext *context)
Returns num AffineDimExpr dimensions at positions [startIdx, startIdx + num) and increments startIdx ...
Definition: LinalgOps.cpp:2307
static DenseMap< int64_t, OpFoldResult > getDimAndTileMappingImpl(OpTy op)
Definition: LinalgOps.cpp:4344
static SmallVector< OpFoldResult > getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, SmallVector< OpFoldResult > mixedTiles)
Definition: LinalgOps.cpp:4300
static bool haveSameTiles(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4851
Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim)
Create one memref::DimOp or tensor::DimOp depending on the type of val.
Definition: LinalgOps.cpp:97
static bool hasSameInnerOuterAttribute(PackOp packOp, UnPackOp unPackOp)
Definition: LinalgOps.cpp:4837
bool areTilesAndTiledDimsAllConstant(OpTy op)
Returns true if the tiles and the tiled dims are constant.
Definition: LinalgOps.cpp:4806
static SmallVector< OpFoldResult > getMixedTilesImpl(OpTy op)
Definition: LinalgOps.cpp:4359
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack)
Definition: LinalgOps.cpp:4418
FailureOr< ArrayAttr > parseIndexingMapsAttr(OpAsmParser &parser)
Definition: LinalgOps.cpp:3627
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition: MemRefOps.cpp:45
Kind
An enumeration of the kinds of predicates.
Definition: Predicate.h:44
DynamicAPInt floor(const Fraction &f)
Definition: Fraction.h:77
DynamicAPInt ceil(const Fraction &f)
Definition: Fraction.h:79
DynamicAPInt round(const Fraction &f)
Definition: Fraction.h:136
Fraction abs(const Fraction &f)
Definition: Fraction.h:107
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
SparseTensorEncodingAttr getSparseTensorEncoding(Type type)
Convenience method to get a sparse encoding attribute from a type.
uint64_t getM(LevelType lt)
Definition: Enums.h:443
bool hasFoldableTensorCastOperand(Operation *op)
Return true if any of the operands of op is a CastOp that can be folded into its consumer,...
Definition: TensorOps.cpp:359
bool canFoldIntoProducerOp(CastOp castOp)
Determines whether the tensor::CastOp casts to a more static version of the source tensor.
Definition: TensorOps.cpp:352
SmallVector< Value > getUpdatedOperandsAfterCastOpFolding(DestinationStyleOpInterface op, SmallVector< Type > &newResTy)
Assuming that op contains at least one operand that is a foldable CastOp (i.e.
Definition: TensorOps.cpp:368
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition: TensorOps.cpp:70
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:239
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:348
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
Definition: AffineMap.cpp:791
Attribute parseAttribute(llvm::StringRef attrStr, MLIRContext *context, Type type={}, size_t *numRead=nullptr, bool isKnownNullTerminated=false)
This parses a single MLIR attribute to an MLIR context if it was valid.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:362
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:112
Operation * clone(OpBuilder &b, Operation *op, TypeRange newResultTypes, ValueRange newOperands)
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition: Utils.cpp:1297
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
LogicalResult verifyCompatibleShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2)
Returns success if the given two shapes are compatible.
SetVector< Operation * > getSlice(Operation *op, const BackwardSliceOptions &backwardSliceOptions={}, const ForwardSliceOptions &forwardSliceOptions={})
Iteratively computes backward slices and forward slices until a fixed point is reached.
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:621
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:419
SmallVector< int64_t > dropDims(ArrayRef< int64_t > inputPerm, ArrayRef< int64_t > dropPositions)
Returns a permutation vector that drop the input dims in dropPositions from inputPerm.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
Fold transpose with transpose.
Definition: LinalgOps.cpp:1995
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:1998
This pattern canonicalize transpose by swapping the order of broadcast and transpose: transpose(broad...
Definition: LinalgOps.cpp:2020
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:2023
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:379
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:368
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Attribute propertiesAttr
This Attribute is used to opaquely construct the properties of the operation.
Region * addRegion()
Create a region that should be attached to the operation.
Container for result values of tiling.
Folds a tensor.cast op into a consuming PackOp op if the tensor.cast has source that is more static t...
Definition: LinalgOps.cpp:5033
LogicalResult matchAndRewrite(PackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5036
Folds a tensor.cast op into a consuming UnPackOp op if the tensor.cast has source that is more static...
Definition: LinalgOps.cpp:5301
LogicalResult matchAndRewrite(UnPackOp op, PatternRewriter &rewriter) const override
Definition: LinalgOps.cpp:5304