Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/Lean/Meta/Sym/Arith/Ring/DenoteExpr.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Arith.Ring.Functions
public section
namespace Lean.Meta.Sym.Arith.Ring
/-!
Helper functions for converting reified terms back into their denotations.
-/

variable [Monad M] [MonadError M] [MonadLiftT MetaM M] [MonadCanon M] [MonadRing M]

def denoteNum (k : Int) : M Expr := do
let ring ← getRing
let n := mkRawNatLit k.natAbs
let ofNatInst ← if let some inst ← MonadCanon.synthInstance? (mkApp2 (mkConst ``OfNat [ring.u]) ring.type n) then
pure inst
else
pure <| mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
let n := mkApp3 (mkConst ``OfNat.ofNat [ring.u]) ring.type n ofNatInst
if k < 0 then
return mkApp (← getNegFn) n
else
return n

def denotePower (pw : Power) : M Expr := do
let x := (← getRing).vars[pw.x]!
if pw.k == 1 then
return x
else
return mkApp2 (← getPowFn) x (toExpr pw.k)

def denoteMon (m : Mon) : M Expr := do
match m with
| .unit => denoteNum 1
| .mult pw m => go m (← denotePower pw)
where
go (m : Mon) (acc : Expr) : M Expr := do
match m with
| .unit => return acc
| .mult pw m => go m (mkApp2 (← getMulFn) acc (← denotePower pw))

def denotePoly (p : Poly) : M Expr := do
match p with
| .num k => denoteNum k
| .add k m p => go p (← denoteTerm k m)
where
denoteTerm (k : Int) (m : Mon) : M Expr := do
if k == 1 then
denoteMon m
else
return mkApp2 (← getMulFn) (← denoteNum k) (← denoteMon m)

go (p : Poly) (acc : Expr) : M Expr := do
match p with
| .num 0 => return acc
| .num k => return mkApp2 (← getAddFn) acc (← denoteNum k)
| .add k m p => go p (mkApp2 (← getAddFn) acc (← denoteTerm k m))

@[specialize]
private def denoteExprCore (getVar : Nat → Expr) (e : RingExpr) : M Expr := do
go e
where
go : RingExpr → M Expr
| .num k => denoteNum k
| .natCast k => return mkApp (← getNatCastFn) (mkNatLit k)
| .intCast k => return mkApp (← getIntCastFn) (mkIntLit k)
| .var x => return getVar x
| .add a b => return mkApp2 (← getAddFn) (← go a) (← go b)
| .sub a b => return mkApp2 (← getSubFn) (← go a) (← go b)
| .mul a b => return mkApp2 (← getMulFn) (← go a) (← go b)
| .pow a k => return mkApp2 (← getPowFn) (← go a) (toExpr k)
| .neg a => return mkApp (← getNegFn) (← go a)

def denoteRingExpr (e : RingExpr) : M Expr := do
let ring ← getRing
denoteExprCore (fun x => ring.vars[x]!) e

def denoteRingExpr' (vars : Array Expr) (e : RingExpr) : M Expr := do
denoteExprCore (fun x => vars[x]!) e

end Lean.Meta.Sym.Arith.Ring
99 changes: 99 additions & 0 deletions src/Lean/Meta/Sym/Arith/Ring/Detect.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Arith.Ring.SymExt
import Lean.Meta.SynthInstance
import Lean.Meta.DecLevel
public section
namespace Lean.Meta.Sym.Arith.Ring

/--
Wrapper around `Meta.synthInstance?` that catches `isDefEqStuck` exceptions
(which can occur when instance arguments contain metavariables or are not in normal form).
-/
private def synthInstance? (type : Expr) : SymM (Option Expr) :=
catchInternalId isDefEqStuckExceptionId
(Meta.synthInstance? type)
(fun _ => pure none)

/--
Evaluate a `Nat`-typed expression to a concrete `Nat` value.
Handles `OfNat.ofNat`, `HPow.hPow`, `HAdd.hAdd`, `HMul.hMul`, `HSub.hSub`,
`Nat.zero`, `Nat.succ`, and raw `Nat` literals.
This is simpler than `Grind.Arith.evalNat?` (no threshold checks, no `Int` support)
but sufficient for evaluating `IsCharP` characteristic values like `2^8`.
-/
private partial def evalNat? (e : Expr) : Option Nat :=
match_expr e with
| OfNat.ofNat _ n _ =>
match n with
| .lit (.natVal k) => some k
| _ => evalNat? n
| Nat.zero => some 0
| Nat.succ a => (· + 1) <$> evalNat? a
| HAdd.hAdd _ _ _ _ a b => (· + ·) <$> evalNat? a <*> evalNat? b
| HMul.hMul _ _ _ _ a b => (· * ·) <$> evalNat? a <*> evalNat? b
| HSub.hSub _ _ _ _ a b => (· - ·) <$> evalNat? a <*> evalNat? b
| HPow.hPow _ _ _ _ a b => (· ^ ·) <$> evalNat? a <*> evalNat? b
| _ =>
match e with
| .lit (.natVal k) => some k
| _ => none

/--
Detect whether `type` has a `Grind.CommRing` instance.
Returns the shared ring id if found. The `CommRing` object is stored in
`arithRingExt` and is shared between `Sym.simp` and `grind`.
Results are cached in `arithRingExt.typeIdOf`.
-/
def detectCommRing? (type : Expr) : SymM (Option Nat) := do
let s ← arithRingExt.getState
if let some id? := s.typeIdOf.find? { expr := type } then
return id?
let some ring ← go? | do
arithRingExt.modifyState fun st => { st with typeIdOf := st.typeIdOf.insert { expr := type } none }
return none
let id := s.rings.size
let ring := { ring with toRing.id := id }
arithRingExt.modifyState fun st => { st with
rings := st.rings.push ring
typeIdOf := st.typeIdOf.insert { expr := type } (some id)
}
return some id
where
go? : SymM (Option CommRing) := do
let u ← getDecLevel type
let commRing := mkApp (mkConst ``Grind.CommRing [u]) type
let some commRingInst ← synthInstance? commRing | return none
let ringInst := mkApp2 (mkConst ``Grind.CommRing.toRing [u]) type commRingInst
let semiringInst := mkApp2 (mkConst ``Grind.Ring.toSemiring [u]) type ringInst
let commSemiringInst := mkApp2 (mkConst ``Grind.CommRing.toCommSemiring [u]) type semiringInst
let charInst? ← getIsCharInst? u type semiringInst
let noZeroDivInst? ← getNoZeroDivInst? u type
let fieldInst? ← synthInstance? <| mkApp (mkConst ``Grind.Field [u]) type
return some {
id := 0, type, u, semiringInst, ringInst, commSemiringInst,
commRingInst, charInst?, noZeroDivInst?, fieldInst?,
semiringId? := none,
}

getIsCharInst? (u : Level) (type : Expr) (semiringInst : Expr) : SymM (Option (Expr × Nat)) := do
withNewMCtxDepth do
let n ← mkFreshExprMVar (mkConst ``Nat)
let charType := mkApp3 (mkConst ``Grind.IsCharP [u]) type semiringInst n
let some charInst ← synthInstance? charType | return none
let n ← instantiateMVars n
let some nVal := evalNat? n | return none
return some (charInst, nVal)

getNoZeroDivInst? (u : Level) (type : Expr) : SymM (Option Expr) := do
let natModuleType := mkApp (mkConst ``Grind.NatModule [u]) type
let some natModuleInst ← synthInstance? natModuleType | return none
let noZeroDivType := mkApp2 (mkConst ``Grind.NoNatZeroDivisors [u]) type natModuleInst
synthInstance? noZeroDivType

end Lean.Meta.Sym.Arith.Ring
122 changes: 122 additions & 0 deletions src/Lean/Meta/Sym/Arith/Ring/Functions.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Arith.Ring.MonadRing
public import Lean.Meta.Basic
public section
namespace Lean.Meta.Sym.Arith.Ring
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m]

section
variable [MonadRing m]

def checkInst (declName : Name) (inst inst' : Expr) : MetaM Unit := do
unless (← isDefEqI inst inst') do
throwError "error while initializing `grind ring` operators:\ninstance for `{declName}` {indentExpr inst}\nis not definitionally equal to the expected one {indentExpr inst'}\nwhen only reducible definitions and instances are reduced"

def mkUnaryFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) (expectedInst : Expr) : m Expr := do
let inst ← MonadCanon.synthInstance <| mkApp (mkConst instDeclName [u]) type
checkInst declName inst expectedInst
canonExpr <| mkApp2 (mkConst declName [u]) type inst

def mkBinHomoFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) (expectedInst : Expr) : m Expr := do
let inst ← MonadCanon.synthInstance <| mkApp3 (mkConst instDeclName [u, u, u]) type type type
checkInst declName inst expectedInst
canonExpr <| mkApp4 (mkConst declName [u, u, u]) type type type inst

def mkPowFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
let inst ← MonadCanon.synthInstance <| mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type
let inst' := mkApp2 (mkConst ``Grind.Semiring.npow [u]) type semiringInst
checkInst ``HPow.hPow inst inst'
canonExpr <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst

def mkNatCastFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
let instType := mkApp (mkConst ``NatCast [u]) type
-- Note that `Semiring.natCast` is not registered as a global instance
-- (to avoid introducing unwanted coercions)
-- so merely having a `Semiring α` instance
-- does not guarantee that an `NatCast α` will be available.
-- When both are present we verify that they are defeq,
-- and otherwise fall back to the field of the `Semiring α` instance that we already have.
let inst ← match (← MonadCanon.synthInstance? instType) with
| none => pure inst'
| some inst => checkInst ``NatCast.natCast inst inst'; pure inst
canonExpr <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst

def getAddFn : m Expr := do
let ring ← getRing
if let some addFn := ring.addFn? then return addFn
let expectedInst := mkApp2 (mkConst ``instHAdd [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Semiring.toAdd [ring.u]) ring.type ring.semiringInst
let addFn ← mkBinHomoFn ring.type ring.u ``HAdd ``HAdd.hAdd expectedInst
modifyRing fun s => { s with addFn? := some addFn }
return addFn

def getSubFn : m Expr := do
let ring ← getRing
if let some subFn := ring.subFn? then return subFn
let expectedInst := mkApp2 (mkConst ``instHSub [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Ring.toSub [ring.u]) ring.type ring.ringInst
let subFn ← mkBinHomoFn ring.type ring.u ``HSub ``HSub.hSub expectedInst
modifyRing fun s => { s with subFn? := some subFn }
return subFn

def getMulFn : m Expr := do
let ring ← getRing
if let some mulFn := ring.mulFn? then return mulFn
let expectedInst := mkApp2 (mkConst ``instHMul [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Semiring.toMul [ring.u]) ring.type ring.semiringInst
let mulFn ← mkBinHomoFn ring.type ring.u ``HMul ``HMul.hMul expectedInst
modifyRing fun s => { s with mulFn? := some mulFn }
return mulFn

def getNegFn : m Expr := do
let ring ← getRing
if let some negFn := ring.negFn? then return negFn
let expectedInst := mkApp2 (mkConst ``Grind.Ring.toNeg [ring.u]) ring.type ring.ringInst
let negFn ← mkUnaryFn ring.type ring.u ``Neg ``Neg.neg expectedInst
modifyRing fun s => { s with negFn? := some negFn }
return negFn

def getPowFn : m Expr := do
let ring ← getRing
if let some powFn := ring.powFn? then return powFn
let powFn ← mkPowFn ring.u ring.type ring.semiringInst
modifyRing fun s => { s with powFn? := some powFn }
return powFn

def getIntCastFn : m Expr := do
let ring ← getRing
if let some intCastFn := ring.intCastFn? then return intCastFn
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [ring.u]) ring.type ring.ringInst
let instType := mkApp (mkConst ``IntCast [ring.u]) ring.type
-- Note that `Ring.intCast` is not registered as a global instance
-- (to avoid introducing unwanted coercions)
-- so merely having a `Ring α` instance
-- does not guarantee that an `IntCast α` will be available.
-- When both are present we verify that they are defeq,
-- and otherwise fall back to the field of the `Ring α` instance that we already have.
let inst ← match (← MonadCanon.synthInstance? instType) with
| none => pure inst'
| some inst => checkInst ``Int.cast inst inst'; pure inst
let intCastFn ← canonExpr <| mkApp2 (mkConst ``IntCast.intCast [ring.u]) ring.type inst
modifyRing fun s => { s with intCastFn? := some intCastFn }
return intCastFn

def getNatCastFn : m Expr := do
let ring ← getRing
if let some natCastFn := ring.natCastFn? then return natCastFn
let natCastFn ← mkNatCastFn ring.u ring.type ring.semiringInst
modifyRing fun s => { s with natCastFn? := some natCastFn }
return natCastFn

def mkOne (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
let n := mkRawNatLit 1
let ofNatInst := mkApp3 (mkConst ``Grind.Semiring.ofNat [u]) type semiringInst n
canonExpr <| mkApp3 (mkConst ``OfNat.ofNat [u]) type n ofNatInst

end

end Lean.Meta.Sym.Arith.Ring
36 changes: 36 additions & 0 deletions src/Lean/Meta/Sym/Arith/Ring/MonadCanon.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Arith.Ring.Types
public import Lean.Exception
public section
namespace Lean.Meta.Sym.Arith.Ring

class MonadCanon (m : Type → Type) where
/--
Canonicalize an expression (e.g., hash-cons via `shareCommon`).
In `SymM`-based monads, this is typically `shareCommon (← canon e)`.
-/
canonExpr : Expr → m Expr
/--
Synthesize a type class instance, returning `none` on failure.
-/
synthInstance? : Expr → m (Option Expr)

export MonadCanon (canonExpr)

@[always_inline]
instance (m n) [MonadLift m n] [MonadCanon m] : MonadCanon n where
canonExpr e := liftM (canonExpr e : m Expr)
synthInstance? e := liftM (MonadCanon.synthInstance? e : m (Option Expr))

def MonadCanon.synthInstance [Monad m] [MonadError m] [MonadCanon m] (type : Expr) : m Expr := do
let some inst ← synthInstance? type
| throwError "failed to find instance{indentExpr type}"
return inst

end Lean.Meta.Sym.Arith.Ring
23 changes: 23 additions & 0 deletions src/Lean/Meta/Sym/Arith/Ring/MonadRing.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Arith.Ring.MonadCanon
public section
namespace Lean.Meta.Sym.Arith.Ring

class MonadRing (m : Type → Type) where
getRing : m Ring
modifyRing : (Ring → Ring) → m Unit

export MonadRing (getRing modifyRing)

@[always_inline]
instance (m n) [MonadLift m n] [MonadRing m] : MonadRing n where
getRing := liftM (getRing : m Ring)
modifyRing f := liftM (modifyRing f : m Unit)

end Lean.Meta.Sym.Arith.Ring
23 changes: 23 additions & 0 deletions src/Lean/Meta/Sym/Arith/Ring/MonadSemiring.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Arith.Ring.MonadCanon
public section
namespace Lean.Meta.Sym.Arith.Ring

class MonadSemiring (m : Type → Type) where
getSemiring : m Semiring
modifySemiring : (Semiring → Semiring) → m Unit

export MonadSemiring (getSemiring modifySemiring)

@[always_inline]
instance (m n) [MonadLift m n] [MonadSemiring m] : MonadSemiring n where
getSemiring := liftM (getSemiring : m Semiring)
modifySemiring f := liftM (modifySemiring f : m Unit)

end Lean.Meta.Sym.Arith.Ring
Loading
Loading